diff options
Diffstat (limited to 'vowpalwabbit/oaa.cc')
-rw-r--r-- | vowpalwabbit/oaa.cc | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 2328b00d..185264f5 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -7,7 +7,6 @@ license as described in the file LICENSE. #include "multiclass.h" #include "simple_label.h" #include "reductions.h" -#include "vw.h" namespace OAA { struct oaa{ @@ -60,21 +59,23 @@ namespace OAA { o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag); } - void finish_example(vw& all, oaa&, example& ec) { MULTICLASS::finish_example(all, ec); } - - LEARNER::base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all) { + new_options(all, "One-against-all options") + ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels"); + if(missing_required(all)) return NULL; + oaa& data = calloc_or_die<oaa>(); - data.k = vm["oaa"].as<size_t>(); + data.k = all.vm["oaa"].as<size_t>(); data.shouldOutput = all.raw_prediction > 0; data.all = &all; *all.file_options << " --oaa " << data.k; - all.p->lp = MULTICLASS::mc_label; - LEARNER::learner<oaa>& l = init_learner(&data, all.l, predict_or_learn<true>, - predict_or_learn<false>, data.k); - l.set_finish_example(finish_example); + LEARNER::learner<oaa>& l = init_learner(&data, setup_base(all), predict_or_learn<true>, + predict_or_learn<false>, data.k); + l.set_finish_example(MULTICLASS::finish_example<oaa>); + all.p->lp = MULTICLASS::mc_label; return make_base(l); } } |