Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2015-01-05 21:55:31 +0300
committerJohn Langford <jl@hunch.net>2015-01-05 21:55:31 +0300
commit61153d9904cb0f64219c8c52629a7b7b3a21664b (patch)
tree716eed7af763c1009afee81017ceedf65a8115d3
parentaa8afff738f208cb59608b039aa6d25e34b6b961 (diff)
fix cb eval code
-rw-r--r--vowpalwabbit/cb_algs.cc19
1 files changed, 7 insertions, 12 deletions
diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc
index b14328d3..43caa9d0 100644
--- a/vowpalwabbit/cb_algs.cc
+++ b/vowpalwabbit/cb_algs.cc
@@ -299,7 +299,7 @@ namespace CB_ALGS
for (size_t i=0; i<ld.event.costs.size(); i++)
ld.event.costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction;
- ec.l.cb_eval = ld;
+ ec.pred.multiclass = ec.l.cb_eval.action;
}
void init_driver(cb&)
@@ -370,25 +370,22 @@ namespace CB_ALGS
c.known_cost = get_observed_cost(ld);
float chosen_loss = FLT_MAX;
if( know_all_cost_example(ld) ) {
- for (cb_class *cl = ld.costs.begin; cl != ld.costs.end; cl ++) {
+ for (cb_class *cl = ld.costs.begin; cl != ld.costs.end; cl++)
if (cl->action == ec.pred.multiclass)
chosen_loss = cl->cost;
- }
}
- else {
+ else
//we do not know exact cost of each action, so evaluate on generated cost-sensitive example currently stored in cb_cs_ld
- for (COST_SENSITIVE::wclass *cl = c.cb_cs_ld.costs.begin; cl != c.cb_cs_ld.costs.end; cl ++) {
+ for (COST_SENSITIVE::wclass *cl = c.cb_cs_ld.costs.begin; cl != c.cb_cs_ld.costs.end; cl ++)
if (cl->class_index == ec.pred.multiclass)
{
chosen_loss = cl->x;
if (c.known_cost->action == ec.pred.multiclass && c.cb_type == CB_TYPE_DM)
chosen_loss += (c.known_cost->cost - chosen_loss) / c.known_cost->probability;
}
- }
- }
if (chosen_loss == FLT_MAX)
cerr << "warning: cb predicted an invalid class" << endl;
-
+
loss = chosen_loss;
}
@@ -416,13 +413,11 @@ namespace CB_ALGS
all.print(f, (float)ec.pred.multiclass, 0, ec.tag);
}
- print_update(all, c, is_test_label(ec.l.cb), ec);
+ print_update(all, c, is_test_label(ld), ec);
}
void finish(cb& c)
- {
- c.cb_cs_ld.costs.delete_v();
- }
+ { c.cb_cs_ld.costs.delete_v(); }
void finish_example(vw& all, cb& c, example& ec)
{