diff options
Diffstat (limited to 'vowpalwabbit/ect.cc')
-rw-r--r-- | vowpalwabbit/ect.cc | 24 |
1 files changed, 7 insertions, 17 deletions
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index 25b34ee3..7a68fa13 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -51,8 +51,6 @@ namespace ECT uint32_t last_pair; v_array<bool> tournaments_won; - - vw* all; }; bool exists(v_array<size_t> db) @@ -184,7 +182,7 @@ namespace ECT return e.last_pair + (eliminations-1); } - uint32_t ect_predict(vw& all, ect& e, base_learner& base, example& ec) + uint32_t ect_predict(ect& e, base_learner& base, example& ec) { if (e.k == (size_t)1) return 1; @@ -228,7 +226,7 @@ namespace ECT return false; } - void ect_train(vw& all, ect& e, base_learner& base, example& ec) + void ect_train(ect& e, base_learner& base, example& ec) { if (e.k == 1)//nothing to do return; @@ -318,25 +316,21 @@ namespace ECT } void predict(ect& e, base_learner& base, example& ec) { - vw* all = e.all; - MULTICLASS::multiclass mc = ec.l.multi; if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1)) cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl; - ec.pred.multiclass = ect_predict(*all, e, base, ec); + ec.pred.multiclass = ect_predict(e, base, ec); ec.l.multi = mc; } void learn(ect& e, base_learner& base, example& ec) { - vw* all = e.all; - MULTICLASS::multiclass mc = ec.l.multi; predict(e, base, ec); uint32_t pred = ec.pred.multiclass; - if (mc.label != (uint32_t)-1 && all->training) - ect_train(*all, e, base, ec); + if (mc.label != (uint32_t)-1) + ect_train(e, base, ec); ec.l.multi = mc; ec.pred.multiclass = pred; } @@ -360,8 +354,6 @@ namespace ECT e.tournaments_won.delete_v(); } - void finish_example(vw& all, ect&, example& ec) { MULTICLASS::finish_example(all, ec); } - base_learner* setup(vw& all) { new_options(all, "Error Correcting Tournament options") @@ -376,14 +368,12 @@ namespace ECT //append error flag to options_from_file so it is saved in regressor file later *all.file_options << " --ect " << data.k << " --error " << data.errors; - all.p->lp = MULTICLASS::mc_label; size_t wpp = create_circuit(all, data, data.k, data.errors+1); - data.all = &all; learner<ect>& l = init_learner(&data, setup_base(all), learn, predict, wpp); - l.set_finish_example(finish_example); + l.set_finish_example(MULTICLASS::finish_example<ect>); + all.p->lp = MULTICLASS::mc_label; l.set_finish(finish); - return make_base(l); } } |