diff options
Diffstat (limited to 'vowpalwabbit/ect.cc')
-rw-r--r-- | vowpalwabbit/ect.cc | 51 |
1 files changed, 17 insertions, 34 deletions
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index dea87040..1a071eb2 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -17,7 +17,6 @@ license as described in the file LICENSE. #include "reductions.h" #include "multiclass.h" #include "simple_label.h" -#include "vw.h" using namespace std; using namespace LEARNER; @@ -51,8 +50,6 @@ namespace ECT uint32_t last_pair; v_array<bool> tournaments_won; - - vw* all; }; bool exists(v_array<size_t> db) @@ -184,7 +181,7 @@ namespace ECT return e.last_pair + (eliminations-1); } - uint32_t ect_predict(vw& all, ect& e, base_learner& base, example& ec) + uint32_t ect_predict(ect& e, base_learner& base, example& ec) { if (e.k == (size_t)1) return 1; @@ -228,7 +225,7 @@ namespace ECT return false; } - void ect_train(vw& all, ect& e, base_learner& base, example& ec) + void ect_train(ect& e, base_learner& base, example& ec) { if (e.k == 1)//nothing to do return; @@ -318,25 +315,21 @@ namespace ECT } void predict(ect& e, base_learner& base, example& ec) { - vw* all = e.all; - MULTICLASS::multiclass mc = ec.l.multi; if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1)) cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl; - ec.pred.multiclass = ect_predict(*all, e, base, ec); + ec.pred.multiclass = ect_predict(e, base, ec); ec.l.multi = mc; } void learn(ect& e, base_learner& base, example& ec) { - vw* all = e.all; - MULTICLASS::multiclass mc = ec.l.multi; predict(e, base, ec); uint32_t pred = ec.pred.multiclass; - if (mc.label != (uint32_t)-1 && all->training) - ect_train(*all, e, base, ec); + if (mc.label != (uint32_t)-1) + ect_train(e, base, ec); ec.l.multi = mc; ec.pred.multiclass = pred; } @@ -360,36 +353,26 @@ namespace ECT e.tournaments_won.delete_v(); } - void finish_example(vw& all, ect&, example& ec) { MULTICLASS::finish_example(all, ec); } - - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { - ect& data = calloc_or_die<ect>(); - po::options_description ect_opts("ECT options"); - ect_opts.add_options() - ("error", po::value<size_t>(), "error in ECT"); + new_options(all, "Error Correcting Tournament options") + ("ect", po::value<size_t>(), "Use error correcting tournament with <k> labels"); + if (missing_required(all)) return NULL; + new_options(all)("error", po::value<size_t>()->default_value(0), "error in ECT"); + add_options(all); - vm = add_options(all, ect_opts); - - //first parse for number of actions - data.k = (int)vm["ect"].as<size_t>(); - - //append ect with nb_actions to options_from_file so it is saved to regressor later - if (vm.count("error")) { - data.errors = (uint32_t)vm["error"].as<size_t>(); - } else - data.errors = 0; + ect& data = calloc_or_die<ect>(); + data.k = (int)all.vm["ect"].as<size_t>(); + data.errors = (uint32_t)all.vm["error"].as<size_t>(); //append error flag to options_from_file so it is saved in regressor file later *all.file_options << " --ect " << data.k << " --error " << data.errors; - all.p->lp = MULTICLASS::mc_label; size_t wpp = create_circuit(all, data, data.k, data.errors+1); - data.all = &all; - learner<ect>& l = init_learner(&data, all.l, learn, predict, wpp); - l.set_finish_example(finish_example); + learner<ect>& l = init_learner(&data, setup_base(all), learn, predict, wpp); + l.set_finish_example(MULTICLASS::finish_example<ect>); + all.p->lp = MULTICLASS::mc_label; l.set_finish(finish); - return make_base(l); } } |