diff options
author | John Langford <jl@hunch.net> | 2015-01-05 21:55:31 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2015-01-05 21:55:31 +0300 |
commit | 61153d9904cb0f64219c8c52629a7b7b3a21664b (patch) | |
tree | 716eed7af763c1009afee81017ceedf65a8115d3 | |
parent | aa8afff738f208cb59608b039aa6d25e34b6b961 (diff) |
fix cb eval code
-rw-r--r-- | vowpalwabbit/cb_algs.cc | 19 |
1 files changed, 7 insertions, 12 deletions
diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc index b14328d3..43caa9d0 100644 --- a/vowpalwabbit/cb_algs.cc +++ b/vowpalwabbit/cb_algs.cc @@ -299,7 +299,7 @@ namespace CB_ALGS for (size_t i=0; i<ld.event.costs.size(); i++) ld.event.costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction; - ec.l.cb_eval = ld; + ec.pred.multiclass = ec.l.cb_eval.action; } void init_driver(cb&) @@ -370,25 +370,22 @@ namespace CB_ALGS c.known_cost = get_observed_cost(ld); float chosen_loss = FLT_MAX; if( know_all_cost_example(ld) ) { - for (cb_class *cl = ld.costs.begin; cl != ld.costs.end; cl ++) { + for (cb_class *cl = ld.costs.begin; cl != ld.costs.end; cl++) if (cl->action == ec.pred.multiclass) chosen_loss = cl->cost; - } } - else { + else //we do not know exact cost of each action, so evaluate on generated cost-sensitive example currently stored in cb_cs_ld - for (COST_SENSITIVE::wclass *cl = c.cb_cs_ld.costs.begin; cl != c.cb_cs_ld.costs.end; cl ++) { + for (COST_SENSITIVE::wclass *cl = c.cb_cs_ld.costs.begin; cl != c.cb_cs_ld.costs.end; cl ++) if (cl->class_index == ec.pred.multiclass) { chosen_loss = cl->x; if (c.known_cost->action == ec.pred.multiclass && c.cb_type == CB_TYPE_DM) chosen_loss += (c.known_cost->cost - chosen_loss) / c.known_cost->probability; } - } - } if (chosen_loss == FLT_MAX) cerr << "warning: cb predicted an invalid class" << endl; - + loss = chosen_loss; } @@ -416,13 +413,11 @@ namespace CB_ALGS all.print(f, (float)ec.pred.multiclass, 0, ec.tag); } - print_update(all, c, is_test_label(ec.l.cb), ec); + print_update(all, c, is_test_label(ld), ec); } void finish(cb& c) - { - c.cb_cs_ld.costs.delete_v(); - } + { c.cb_cs_ld.costs.delete_v(); } void finish_example(vw& all, cb& c, example& ec) { |