/* Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ #include #include "multiclass.h" #include "simple_label.h" #include "reductions.h" namespace OAA { struct oaa{ size_t k; bool shouldOutput; vw* all; }; template void predict_or_learn(oaa& o, LEARNER::base_learner& base, example& ec) { MULTICLASS::multiclass mc_label_data = ec.l.multi; if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1)) cout << "label " << mc_label_data.label << " is not in {1,"<< o.k << "} This won't work right." << endl; ec.l.simple = {0.f, mc_label_data.weight, 0.f}; stringstream outputStringStream; uint32_t prediction = 1; float score = INT_MIN; for (uint32_t i = 1; i <= o.k; i++) { if (is_learn) { if (mc_label_data.label == i) ec.l.simple.label = 1; else ec.l.simple.label = -1; base.learn(ec, i-1); } else base.predict(ec, i-1); if (ec.partial_prediction > score) { score = ec.partial_prediction; prediction = i; } if (o.shouldOutput) { if (i > 1) outputStringStream << ' '; outputStringStream << i << ':' << ec.partial_prediction; } } ec.pred.multiclass = prediction; ec.l.multi = mc_label_data; if (o.shouldOutput) o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag); } LEARNER::base_learner* setup(vw& all) { new_options(all, "One-against-all options") ("oaa", po::value(), "Use one-against-all multiclass learning with labels"); if(missing_required(all)) return NULL; oaa& data = calloc_or_die(); data.k = all.vm["oaa"].as(); data.shouldOutput = all.raw_prediction > 0; data.all = &all; *all.file_options << " --oaa " << data.k; LEARNER::learner& l = init_learner(&data, setup_base(all), predict_or_learn, predict_or_learn, data.k); l.set_finish_example(MULTICLASS::finish_example); all.p->lp = MULTICLASS::mc_label; return make_base(l); } }