diff options
author | John Langford <jl@hunch.net> | 2014-12-27 19:50:53 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-12-27 19:50:53 +0300 |
commit | 3018c85ccf02127c1cc94c5ba2bc491aa2b45a9d (patch) | |
tree | d4384af3a9385edcdd6671e2c21f1fa931694360 | |
parent | 74baf926ce5e5eb79c9ae50d150638176132bb02 (diff) | |
parent | 43226da4724fe3aec7eb560264e3342e442e3fbd (diff) |
fix conflicts
59 files changed, 586 insertions, 630 deletions
diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc index 1e87208e..1c6f1c3c 100644 --- a/vowpalwabbit/active.cc +++ b/vowpalwabbit/active.cc @@ -45,7 +45,7 @@ namespace ACTIVE { } template <bool is_learn> - void predict_or_learn_simulation(active& a, learner& base, example& ec) { + void predict_or_learn_simulation(active& a, base_learner& base, example& ec) { base.predict(ec); if (is_learn) @@ -67,7 +67,7 @@ namespace ACTIVE { } template <bool is_learn> - void predict_or_learn_active(active& a, learner& base, example& ec) { + void predict_or_learn_active(active& a, base_learner& base, example& ec) { if (is_learn) base.learn(ec); else @@ -151,7 +151,7 @@ namespace ACTIVE { VW::finish_example(all,&ec); } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) {//parse and set arguments po::options_description opts("Active Learning options"); opts.add_options() @@ -170,20 +170,20 @@ namespace ACTIVE { data.active_c0 = vm["mellowness"].as<float>(); //Create new learner - learner* ret = new learner(&data, all.l); + learner<active>& ret = init_learner(&data, all.l); if (vm.count("simulation")) { - ret->set_learn<active, predict_or_learn_simulation<true> >(); - ret->set_predict<active, predict_or_learn_simulation<false> >(); + ret.set_learn(predict_or_learn_simulation<true>); + ret.set_predict(predict_or_learn_simulation<false>); } else { all.active = true; - ret->set_learn<active, predict_or_learn_active<true> >(); - ret->set_predict<active, predict_or_learn_active<false> >(); - ret->set_finish_example<active, return_active_example>(); + ret.set_learn(predict_or_learn_active<true>); + ret.set_predict(predict_or_learn_active<false>); + ret.set_finish_example(return_active_example); } - return ret; + return make_base(ret); } } diff --git a/vowpalwabbit/active.h b/vowpalwabbit/active.h index e8883ae9..d71950ab 100644 --- a/vowpalwabbit/active.h +++ b/vowpalwabbit/active.h @@ -1,4 +1,4 @@ #pragma once namespace ACTIVE { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index f34a228f..f2023565 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -12,7 +12,7 @@ namespace ALINK { }; template <bool is_learn> - void predict_or_learn(autolink& b, learner& base, example& ec) + void predict_or_learn(autolink& b, base_learner& base, example& ec) { base.predict(ec); float base_pred = ec.pred.scalar; @@ -41,7 +41,7 @@ namespace ALINK { ec.total_sum_feat_sq -= sum_sq; } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("Autolink options"); opts.add_options() @@ -54,11 +54,11 @@ namespace ALINK { data.d = (uint32_t)vm["autolink"].as<size_t>(); data.stride_shift = all.reg.stride_shift; - all.file_options << " --autolink " << data.d; + *all.file_options << " --autolink " << data.d; - learner* ret = new learner(&data, all.l); - ret->set_learn<autolink, predict_or_learn<true> >(); - ret->set_predict<autolink, predict_or_learn<false> >(); - return ret; + learner<autolink>& ret = init_learner(&data, all.l); + ret.set_learn(predict_or_learn<true>); + ret.set_predict(predict_or_learn<false>); + return make_base(ret); } } diff --git a/vowpalwabbit/autolink.h b/vowpalwabbit/autolink.h index 830c0088..3bb70dc1 100644 --- a/vowpalwabbit/autolink.h +++ b/vowpalwabbit/autolink.h @@ -1,4 +1,4 @@ #pragma once namespace ALINK { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index 43e5e502..775affb1 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -821,13 +821,13 @@ void end_pass(bfgs& b) } // placeholder -void predict(bfgs& b, learner& base, example& ec) +void predict(bfgs& b, base_learner& base, example& ec) { vw* all = b.all; ec.pred.scalar = bfgs_predict(*all,ec); } -void learn(bfgs& b, learner& base, example& ec) +void learn(bfgs& b, base_learner& base, example& ec) { vw* all = b.all; assert(ec.in_use); @@ -968,7 +968,7 @@ void save_load(bfgs& b, io_buf& model_file, bool read, bool text) b.backstep_on = true; } -learner* setup(vw& all, po::variables_map& vm) +base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("LBFGS options"); opts.add_options() @@ -1024,14 +1024,14 @@ learner* setup(vw& all, po::variables_map& vm) all.bfgs = true; all.reg.stride_shift = 2; - learner* l = new learner(&b, 1 << all.reg.stride_shift); - l->set_learn<bfgs, learn>(); - l->set_predict<bfgs, predict>(); - l->set_save_load<bfgs,save_load>(); - l->set_init_driver<bfgs,init_driver>(); - l->set_end_pass<bfgs,end_pass>(); - l->set_finish<bfgs,finish>(); + learner<bfgs>& l = init_learner(&b, 1 << all.reg.stride_shift); + l.set_learn(learn); + l.set_predict(predict); + l.set_save_load(save_load); + l.set_init_driver(init_driver); + l.set_end_pass(end_pass); + l.set_finish(finish); - return l; + return make_base(l); } } diff --git a/vowpalwabbit/bfgs.h b/vowpalwabbit/bfgs.h index 2f30dcb9..1960662b 100644 --- a/vowpalwabbit/bfgs.h +++ b/vowpalwabbit/bfgs.h @@ -5,5 +5,5 @@ license as described in the file LICENSE. */ #pragma once namespace BFGS { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc index 727cf003..7eac0ba8 100644 --- a/vowpalwabbit/binary.cc +++ b/vowpalwabbit/binary.cc @@ -7,7 +7,7 @@ using namespace LEARNER; namespace BINARY { template <bool is_learn> - void predict_or_learn(float&, learner& base, example& ec) { + void predict_or_learn(char&, base_learner& base, example& ec) { if (is_learn) base.learn(ec); else @@ -24,7 +24,7 @@ namespace BINARY { ec.loss = ec.l.simple.weight; } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) {//parse and set arguments po::options_description opts("Binary options"); opts.add_options() @@ -35,9 +35,9 @@ namespace BINARY { all.sd->binary_label = true; //Create new learner - learner* ret = new learner(NULL, all.l); - ret->set_learn<float, predict_or_learn<true> >(); - ret->set_predict<float, predict_or_learn<false> >(); - return ret; + learner<char>& ret = init_learner<char>(NULL, all.l); + ret.set_learn(predict_or_learn<true>); + ret.set_predict(predict_or_learn<false>); + return make_base(ret); } } diff --git a/vowpalwabbit/binary.h b/vowpalwabbit/binary.h index 2df4f2e0..609de90b 100644 --- a/vowpalwabbit/binary.h +++ b/vowpalwabbit/binary.h @@ -1,4 +1,4 @@ #pragma once namespace BINARY { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc index c199e1db..f55467a6 100644 --- a/vowpalwabbit/bs.cc +++ b/vowpalwabbit/bs.cc @@ -180,7 +180,7 @@ namespace BS { } template <bool is_learn> - void predict_or_learn(bs& d, learner& base, example& ec) + void predict_or_learn(bs& d, base_learner& base, example& ec) { vw* all = d.all; bool shouldOutput = all->raw_prediction > 0; @@ -239,7 +239,7 @@ namespace BS { d.pred_vec.~vector(); } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { bs& data = calloc_or_die<bs>(); data.ub = FLT_MAX; @@ -254,7 +254,7 @@ namespace BS { data.B = (uint32_t)vm["bootstrap"].as<size_t>(); //append bs with number of samples to options_from_file so it is saved to regressor later - all.file_options << " --bootstrap " << data.B; + *all.file_options << " --bootstrap " << data.B; std::string type_string("mean"); @@ -275,17 +275,17 @@ namespace BS { } else //by default use mean data.bs_type = BS_TYPE_MEAN; - all.file_options << " --bs_type " << type_string; + *all.file_options << " --bs_type " << type_string; data.pred_vec.reserve(data.B); data.all = &all; - learner* l = new learner(&data, all.l, data.B); - l->set_learn<bs, predict_or_learn<true> >(); - l->set_predict<bs, predict_or_learn<false> >(); - l->set_finish_example<bs,finish_example>(); - l->set_finish<bs,finish>(); + learner<bs>& l = init_learner(&data, all.l, data.B); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish_example(finish_example); + l.set_finish(finish); - return l; + return make_base(l); } } diff --git a/vowpalwabbit/bs.h b/vowpalwabbit/bs.h index d41593ef..e05bcd05 100644 --- a/vowpalwabbit/bs.h +++ b/vowpalwabbit/bs.h @@ -11,7 +11,7 @@ license as described in the file LICENSE. namespace BS { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); void print_result(int f, float res, float weight, v_array<char> tag, float lb, float ub); void output_example(vw& all, example* ec, float lb, float ub); diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc index 395d496b..72400a1d 100644 --- a/vowpalwabbit/cb_algs.cc +++ b/vowpalwabbit/cb_algs.cc @@ -284,7 +284,7 @@ namespace CB_ALGS } template <bool is_learn> - void predict_or_learn(cb& c, learner& base, example& ec) { + void predict_or_learn(cb& c, base_learner& base, example& ec) { vw* all = c.all; CB::label ld = ec.l.cb; @@ -341,12 +341,12 @@ namespace CB_ALGS } } - void predict_eval(cb& c, learner& base, example& ec) { + void predict_eval(cb& c, base_learner& base, example& ec) { cout << "can not use a test label for evaluation" << endl; throw exception(); } - void learn_eval(cb& c, learner& base, example& ec) { + void learn_eval(cb& c, base_learner& base, example& ec) { vw* all = c.all; CB_EVAL::label ld = ec.l.cb_eval; @@ -497,7 +497,7 @@ namespace CB_ALGS VW::finish_example(all, &ec); } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("CB options"); opts.add_options() @@ -515,7 +515,7 @@ namespace CB_ALGS uint32_t nb_actions = (uint32_t)vm["cb"].as<size_t>(); - all.file_options << " --cb " << nb_actions; + *all.file_options << " --cb " << nb_actions; all.sd->k = nb_actions; @@ -529,7 +529,7 @@ namespace CB_ALGS std::string type_string; type_string = vm["cb_type"].as<std::string>(); - all.file_options << " --cb_type " << type_string; + *all.file_options << " --cb_type " << type_string; if (type_string.compare("dr") == 0) c.cb_type = CB_TYPE_DR; @@ -556,7 +556,7 @@ namespace CB_ALGS else { //by default use doubly robust c.cb_type = CB_TYPE_DR; - all.file_options << " --cb_type dr"; + *all.file_options << " --cb_type dr"; } if (eval) @@ -564,25 +564,25 @@ namespace CB_ALGS else all.p->lp = CB::cb_label; - learner* l = new learner(&c, all.l, problem_multiplier); + learner<cb>& l = init_learner(&c, all.l, problem_multiplier); if (eval) { - l->set_learn<cb, learn_eval>(); - l->set_predict<cb, predict_eval>(); - l->set_finish_example<cb,eval_finish_example>(); + l.set_learn(learn_eval); + l.set_predict(predict_eval); + l.set_finish_example(eval_finish_example); } else { - l->set_learn<cb, predict_or_learn<true> >(); - l->set_predict<cb, predict_or_learn<false> >(); - l->set_finish_example<cb,finish_example>(); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish_example(finish_example); } - l->set_init_driver<cb,init_driver>(); - l->set_finish<cb,finish>(); + l.set_init_driver(init_driver); + l.set_finish(finish); // preserve the increment of the base learner since we are // _adding_ to the number of problems rather than multiplying. - l->increment = all.l->increment; + l.increment = all.l->increment; - return l; + return make_base(l); } } diff --git a/vowpalwabbit/cb_algs.h b/vowpalwabbit/cb_algs.h index 972a45ad..7756fc6f 100644 --- a/vowpalwabbit/cb_algs.h +++ b/vowpalwabbit/cb_algs.h @@ -6,8 +6,8 @@ license as described in the file LICENSE. #pragma once //TODO: extend to handle CSOAA_LDF and WAP_LDF namespace CB_ALGS { - LEARNER::learner* setup(vw& all, po::variables_map& vm); - + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); + template <bool is_learn> float get_cost_pred(vw& all, CB::cb_class* known_cost, example& ec, uint32_t index, uint32_t base) { diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index e79c25ec..5b738d76 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -83,7 +83,7 @@ namespace CBIFY { COST_SENSITIVE::label cs_label; COST_SENSITIVE::label second_cs_label; - learner* cs; + base_learner* cs; vw* all; unique_ptr<vw_policy> policy; @@ -106,7 +106,7 @@ namespace CBIFY { } struct vw_context { - learner* l; + base_learner* l; example* e; cbify* data; bool recorded; @@ -158,7 +158,7 @@ namespace CBIFY { } template <bool is_learn> - void predict_or_learn_first(cbify& data, learner& base, example& ec) + void predict_or_learn_first(cbify& data, base_learner& base, example& ec) {//Explore tau times, then act according to optimal. MULTICLASS::multiclass ld = ec.l.multi; @@ -186,7 +186,7 @@ namespace CBIFY { } template <bool is_learn> - void predict_or_learn_greedy(cbify& data, learner& base, example& ec) + void predict_or_learn_greedy(cbify& data, base_learner& base, example& ec) {//Explore uniform random an epsilon fraction of the time. MULTICLASS::multiclass ld = ec.l.multi; @@ -213,7 +213,7 @@ namespace CBIFY { } template <bool is_learn> - void predict_or_learn_bag(cbify& data, learner& base, example& ec) + void predict_or_learn_bag(cbify& data, base_learner& base, example& ec) {//Randomize over predictions from a base set of predictors //Use CB to find current predictions. MULTICLASS::multiclass ld = ec.l.multi; @@ -284,7 +284,7 @@ namespace CBIFY { } template <bool is_learn> - void predict_or_learn_cover(cbify& data, learner& base, example& ec) + void predict_or_learn_cover(cbify& data, base_learner& base, example& ec) {//Randomize over predictions from a base set of predictors //Use cost sensitive oracle to cover actions to form distribution. MULTICLASS::multiclass ld = ec.l.multi; @@ -377,7 +377,7 @@ namespace CBIFY { CB::cb_label.delete_label(&data.cb_label); } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) {//parse and set arguments po::options_description opts("CBIFY options"); opts.add_options() @@ -394,10 +394,10 @@ namespace CBIFY { data.all = &all; data.k = (uint32_t)vm["cbify"].as<size_t>(); - all.file_options << " --cbify " << data.k; + *all.file_options << " --cbify " << data.k; all.p->lp = MULTICLASS::mc_label; - learner* l; + learner<cbify>* l; data.recorder.reset(new vw_recorder()); data.mwt_explorer.reset(new MwtExplorer<vw_context>("vw", *data.recorder.get())); if (vm.count("cover")) @@ -406,52 +406,52 @@ namespace CBIFY { data.cs = all.cost_sensitive; data.second_cs_label.costs.resize(data.k); data.second_cs_label.costs.end = data.second_cs_label.costs.begin+data.k; - float epsilon = 0.05f; - if (vm.count("epsilon")) - epsilon = vm["epsilon"].as<float>(); - data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k)); - data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k)); - l = new learner(&data, all.l, cover + 1); - l->set_learn<cbify, predict_or_learn_cover<true> >(); - l->set_predict<cbify, predict_or_learn_cover<false> >(); + float epsilon = 0.05f; + if (vm.count("epsilon")) + epsilon = vm["epsilon"].as<float>(); + data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k)); + data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k)); + l = &init_learner(&data, all.l, cover + 1); + l->set_learn(predict_or_learn_cover<true>); + l->set_predict(predict_or_learn_cover<false>); } else if (vm.count("bag")) { size_t bags = (uint32_t)vm["bag"].as<size_t>(); - for (size_t i = 0; i < bags; i++) - { - data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i))); - } - data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k)); - l = new learner(&data, all.l, bags); - l->set_learn<cbify, predict_or_learn_bag<true> >(); - l->set_predict<cbify, predict_or_learn_bag<false> >(); + for (size_t i = 0; i < bags; i++) + { + data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i))); + } + data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k)); + l = &init_learner(&data, all.l, bags); + l->set_learn(predict_or_learn_bag<true>); + l->set_predict(predict_or_learn_bag<false>); } else if (vm.count("first") ) { - uint32_t tau = (uint32_t)vm["first"].as<size_t>(); - data.policy.reset(new vw_policy()); - data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k)); - l = new learner(&data, all.l, 1); - l->set_learn<cbify, predict_or_learn_first<true> >(); - l->set_predict<cbify, predict_or_learn_first<false> >(); + uint32_t tau = (uint32_t)vm["first"].as<size_t>(); + data.policy.reset(new vw_policy()); + data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k)); + l = &init_learner(&data, all.l, 1); + l->set_learn(predict_or_learn_first<true>); + l->set_predict(predict_or_learn_first<false>); } else { - float epsilon = 0.05f; - if (vm.count("epsilon")) - epsilon = vm["epsilon"].as<float>(); - data.policy.reset(new vw_policy()); - data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k)); - l = new learner(&data, all.l, 1); - l->set_learn<cbify, predict_or_learn_greedy<true> >(); - l->set_predict<cbify, predict_or_learn_greedy<false> >(); + float epsilon = 0.05f; + if (vm.count("epsilon")) + epsilon = vm["epsilon"].as<float>(); + data.policy.reset(new vw_policy()); + data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k)); + l = &init_learner(&data, all.l, 1); + l->set_learn(predict_or_learn_greedy<true>); + l->set_predict(predict_or_learn_greedy<false>); } - - l->set_finish_example<cbify,finish_example>(); - l->set_finish<cbify,finish>(); - l->set_init_driver<cbify,init_driver>(); - return l; + l->set_finish_example(finish_example); + l->set_finish(finish); + l->set_init_driver(init_driver); + + return make_base(*l); } } diff --git a/vowpalwabbit/cbify.h b/vowpalwabbit/cbify.h index 97e9e6fd..aed26b75 100644 --- a/vowpalwabbit/cbify.h +++ b/vowpalwabbit/cbify.h @@ -5,5 +5,5 @@ license as described in the file LICENSE. */ #pragma once namespace CBIFY { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc index 679e31ed..a5429dd0 100644 --- a/vowpalwabbit/csoaa.cc +++ b/vowpalwabbit/csoaa.cc @@ -13,9 +13,7 @@ license as described in the file LICENSE. #include "gd.h" // GD::foreach_feature() needed in subtract_example() using namespace std; - using namespace LEARNER; - using namespace COST_SENSITIVE; namespace CSOAA { @@ -24,7 +22,7 @@ namespace CSOAA { }; template <bool is_learn> - void predict_or_learn(csoaa& c, learner& base, example& ec) { + void predict_or_learn(csoaa& c, base_learner& base, example& ec) { vw* all = c.all; COST_SENSITIVE::label ld = ec.l.cs; uint32_t prediction = 1; @@ -68,7 +66,7 @@ namespace CSOAA { VW::finish_example(all, &ec); } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("CSOAA options"); opts.add_options() @@ -83,17 +81,16 @@ namespace CSOAA { nb_actions = (uint32_t)vm["csoaa"].as<size_t>(); //append csoaa with nb_actions to file_options so it is saved to regressor later - all.file_options << " --csoaa " << nb_actions; + *all.file_options << " --csoaa " << nb_actions; all.p->lp = cs_label; all.sd->k = nb_actions; - learner* l = new learner(&c, all.l, nb_actions); - l->set_learn<csoaa, predict_or_learn<true> >(); - l->set_predict<csoaa, predict_or_learn<false> >(); - l->set_finish_example<csoaa,finish_example>(); - all.cost_sensitive = all.l; - return l; + learner<csoaa>& l = init_learner(&c, all.l, nb_actions); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish_example(finish_example); + return make_base(l); } } @@ -111,7 +108,7 @@ namespace CSOAA_AND_WAP_LDF { float csoaa_example_t; vw* all; - learner* base; + base_learner* base; }; namespace LabelDict { @@ -300,7 +297,7 @@ namespace LabelDict { ec->indices.decr(); } - void make_single_prediction(ldf& data, learner& base, example& ec) { + void make_single_prediction(ldf& data, base_learner& base, example& ec) { COST_SENSITIVE::label ld = ec.l.cs; label_data simple_label; simple_label.initial = 0.; @@ -339,7 +336,7 @@ namespace LabelDict { return isTest; } - void do_actual_learning_wap(vw& all, ldf& data, learner& base, size_t start_K) + void do_actual_learning_wap(vw& all, ldf& data, base_learner& base, size_t start_K) { size_t K = data.ec_seq.size(); vector<COST_SENSITIVE::wclass*> all_costs; @@ -394,7 +391,7 @@ namespace LabelDict { } } - void do_actual_learning_oaa(vw& all, ldf& data, learner& base, size_t start_K) + void do_actual_learning_oaa(vw& all, ldf& data, base_learner& base, size_t start_K) { size_t K = data.ec_seq.size(); float min_cost = FLT_MAX; @@ -447,7 +444,7 @@ namespace LabelDict { } template <bool is_learn> - void do_actual_learning(vw& all, ldf& data, learner& base) + void do_actual_learning(vw& all, ldf& data, base_learner& base) { //cdbg << "do_actual_learning size=" << data.ec_seq.size() << endl; if (data.ec_seq.size() <= 0) return; // nothing to do @@ -620,7 +617,7 @@ namespace LabelDict { } template <bool is_learn> - void predict_or_learn(ldf& data, learner& base, example &ec) { + void predict_or_learn(ldf& data, base_learner& base, example &ec) { vw* all = data.all; data.base = &base; bool is_test_ec = COST_SENSITIVE::example_is_test(ec); @@ -653,7 +650,7 @@ namespace LabelDict { } } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("LDF Options"); opts.add_options() @@ -674,12 +671,12 @@ namespace LabelDict { if( vm.count("csoaa_ldf") ){ ldf_arg = vm["csoaa_ldf"].as<string>(); - all.file_options << " --csoaa_ldf " << ldf_arg; + *all.file_options << " --csoaa_ldf " << ldf_arg; } else { ldf_arg = vm["wap_ldf"].as<string>(); ld.is_wap = true; - all.file_options << " --wap_ldf " << ldf_arg; + *all.file_options << " --wap_ldf " << ldf_arg; } if ( vm.count("ldf_override") ) ldf_arg = vm["ldf_override"].as<string>(); @@ -718,18 +715,17 @@ namespace LabelDict { ld.read_example_this_loop = 0; ld.need_to_clear = false; - learner* l = new learner(&ld, all.l); - l->set_learn<ldf, predict_or_learn<true> >(); - l->set_predict<ldf, predict_or_learn<false> >(); + learner<ldf>& l = init_learner(&ld, all.l); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); if (ld.is_singleline) - l->set_finish_example<ldf,finish_singleline_example>(); + l.set_finish_example(finish_singleline_example); else - l->set_finish_example<ldf,finish_multiline_example>(); - l->set_finish<ldf,finish>(); - l->set_end_examples<ldf,end_examples>(); - l->set_end_pass<ldf,end_pass>(); - all.cost_sensitive = all.l; - return l; + l.set_finish_example(finish_multiline_example); + l.set_finish(finish); + l.set_end_examples(end_examples); + l.set_end_pass(end_pass); + return make_base(l); } void global_print_newline(vw& all) diff --git a/vowpalwabbit/csoaa.h b/vowpalwabbit/csoaa.h index bfaca0f2..79f5e4d2 100644 --- a/vowpalwabbit/csoaa.h +++ b/vowpalwabbit/csoaa.h @@ -5,11 +5,11 @@ license as described in the file LICENSE. */ #pragma once namespace CSOAA { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } namespace CSOAA_AND_WAP_LDF { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); namespace LabelDict { bool ec_is_example_header(example& ec); // example headers look like "0:-1" or just "shared" diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index e73b2442..1f99eaaa 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -184,7 +184,7 @@ namespace ECT return e.last_pair + (eliminations-1); } - uint32_t ect_predict(vw& all, ect& e, learner& base, example& ec) + uint32_t ect_predict(vw& all, ect& e, base_learner& base, example& ec) { if (e.k == (size_t)1) return 1; @@ -228,7 +228,7 @@ namespace ECT return false; } - void ect_train(vw& all, ect& e, learner& base, example& ec) + void ect_train(vw& all, ect& e, base_learner& base, example& ec) { if (e.k == 1)//nothing to do return; @@ -317,7 +317,7 @@ namespace ECT } } - void predict(ect& e, learner& base, example& ec) { + void predict(ect& e, base_learner& base, example& ec) { vw* all = e.all; MULTICLASS::multiclass mc = ec.l.multi; @@ -327,7 +327,7 @@ namespace ECT ec.l.multi = mc; } - void learn(ect& e, learner& base, example& ec) + void learn(ect& e, base_learner& base, example& ec) { vw* all = e.all; @@ -366,7 +366,7 @@ namespace ECT VW::finish_example(all, &ec); } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("ECT options"); opts.add_options() @@ -387,18 +387,18 @@ namespace ECT } else data.errors = 0; //append error flag to options_from_file so it is saved in regressor file later - all.file_options << " --ect " << data.k << " --error " << data.errors; + *all.file_options << " --ect " << data.k << " --error " << data.errors; all.p->lp = MULTICLASS::mc_label; size_t wpp = create_circuit(all, data, data.k, data.errors+1); data.all = &all; - learner* l = new learner(&data, all.l, wpp); - l->set_learn<ect, learn>(); - l->set_predict<ect, predict>(); - l->set_finish_example<ect,finish_example>(); - l->set_finish<ect,finish>(); + learner<ect>& l = init_learner(&data, all.l, wpp); + l.set_learn(learn); + l.set_predict(predict); + l.set_finish_example(finish_example); + l.set_finish(finish); - return l; + return make_base(l); } } diff --git a/vowpalwabbit/ect.h b/vowpalwabbit/ect.h index a7c392c0..81129791 100644 --- a/vowpalwabbit/ect.h +++ b/vowpalwabbit/ect.h @@ -6,5 +6,5 @@ license as described in the file LICENSE. #pragma once namespace ECT { - LEARNER::learner* setup(vw&, po::variables_map&); + LEARNER::base_learner* setup(vw&, po::variables_map&); } diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc index 2707cbcf..15ad1a06 100644 --- a/vowpalwabbit/ftrl_proximal.cc +++ b/vowpalwabbit/ftrl_proximal.cc @@ -135,7 +135,7 @@ namespace FTRL { } //void learn(void* a, void* d, example* ec) { - void learn(ftrl& a, learner& base, example& ec) { + void learn(ftrl& a, base_learner& base, example& ec) { vw* all = a.all; assert(ec.in_use); @@ -170,15 +170,15 @@ namespace FTRL { } // placeholder - void predict(ftrl& b, learner& base, example& ec) + void predict(ftrl& b, base_learner& base, example& ec) { vw* all = b.all; //ec.l.simple.prediction = ftrl_predict(*all,ec); ec.pred.scalar = ftrl_predict(*all,ec); } - learner* setup(vw& all, po::variables_map& vm) { - + base_learner* setup(vw& all, po::variables_map& vm) + { ftrl& b = calloc_or_die<ftrl>(); b.all = &all; b.ftrl_beta = 0.0; @@ -217,13 +217,10 @@ namespace FTRL { cerr << "ftrl_beta = " << b.ftrl_beta << endl; } - learner* l = new learner(&b, 1 << all.reg.stride_shift); - l->set_learn<ftrl, learn>(); - l->set_predict<ftrl, predict>(); - l->set_save_load<ftrl,save_load>(); - - return l; + learner<ftrl>& l = init_learner(&b, 1 << all.reg.stride_shift); + l.set_learn(learn); + l.set_predict(predict); + l.set_save_load(save_load); + return make_base(l); } - - } // end namespace diff --git a/vowpalwabbit/ftrl_proximal.h b/vowpalwabbit/ftrl_proximal.h index 934d91c8..59bf4653 100644 --- a/vowpalwabbit/ftrl_proximal.h +++ b/vowpalwabbit/ftrl_proximal.h @@ -7,6 +7,6 @@ license as described in the file LICENSE. #define FTRL_PROXIMAL_H namespace FTRL { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } #endif diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc index 66575e38..b407bc2f 100644 --- a/vowpalwabbit/gd.cc +++ b/vowpalwabbit/gd.cc @@ -40,7 +40,7 @@ namespace GD float neg_norm_power; float neg_power_t; float update_multiplier; - void (*predict)(gd&, learner&, example&); + void (*predict)(gd&, base_learner&, example&); vw* all; }; @@ -346,7 +346,7 @@ float finalize_prediction(shared_data* sd, float ret) } template<bool l1, bool audit> -void predict(gd& g, learner& base, example& ec) +void predict(gd& g, base_learner& base, example& ec) { vw& all = *g.all; @@ -508,7 +508,7 @@ float compute_update(gd& g, example& ec) } template<bool invariant, bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare> -void update(gd& g, learner& base, example& ec) +void update(gd& g, base_learner& base, example& ec) {//invariant: not a test label, importance weight > 0 float update; if ( (update = compute_update<invariant, sqrt_rate, feature_mask_off, adaptive, normalized, spare> (g, ec)) != 0.) @@ -519,7 +519,7 @@ void update(gd& g, learner& base, example& ec) } template<bool invariant, bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare> -void learn(gd& g, learner& base, example& ec) +void learn(gd& g, base_learner& base, example& ec) {//invariant: not a test label, importance weight > 0 assert(ec.in_use); assert(ec.l.simple.label != FLT_MAX); @@ -789,25 +789,25 @@ void save_load(gd& g, io_buf& model_file, bool read, bool text) } template<bool invariant, bool sqrt_rate, uint32_t adaptive, uint32_t normalized, uint32_t spare, uint32_t next> -uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off) +uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off) { all.normalized_idx = normalized; if (feature_mask_off) { - ret->set_learn<gd, learn<invariant, sqrt_rate, true, adaptive, normalized, spare> >(); - ret->set_update<gd, update<invariant, sqrt_rate, true, adaptive, normalized, spare> >(); + ret.set_learn(learn<invariant, sqrt_rate, true, adaptive, normalized, spare>); + ret.set_update(update<invariant, sqrt_rate, true, adaptive, normalized, spare>); return next; } else { - ret->set_learn<gd, learn<invariant, sqrt_rate, false, adaptive, normalized, spare> >(); - ret->set_update<gd, update<invariant, sqrt_rate, false, adaptive, normalized, spare> >(); + ret.set_learn(learn<invariant, sqrt_rate, false, adaptive, normalized, spare>); + ret.set_update(update<invariant, sqrt_rate, false, adaptive, normalized, spare>); return next; } } template<bool sqrt_rate, uint32_t adaptive, uint32_t normalized, uint32_t spare, uint32_t next> -uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off) +uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off) { if (all.invariant_updates) return set_learn<true, sqrt_rate, adaptive, normalized, spare, next>(all, ret, feature_mask_off); @@ -816,7 +816,7 @@ uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off) } template<bool sqrt_rate, uint32_t adaptive, uint32_t spare> -uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off) +uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off) { // select the appropriate learn function based on adaptive, normalization, and feature mask if (all.normalized_updates) @@ -826,7 +826,7 @@ uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off) } template<bool sqrt_rate> -uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off) +uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off) { if (all.adaptive) return set_learn<sqrt_rate, 1, 2>(all, ret, feature_mask_off); @@ -842,7 +842,7 @@ uint32_t ceil_log_2(uint32_t v) return 1 + ceil_log_2(v >> 1); } -learner* setup(vw& all, po::variables_map& vm) +base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("Gradient Descent options"); opts.add_options() @@ -906,29 +906,18 @@ learner* setup(vw& all, po::variables_map& vm) cerr << "Warning: the learning rate for the last pass is multiplied by: " << pow((double)all.eta_decay_rate, (double)all.numpasses) << " adjust --decay_learning_rate larger to avoid this." << endl; - learner* ret = new learner(&g, 1); + learner<gd>& ret = init_learner(&g, 1); if (all.reg_mode % 2) if (all.audit || all.hash_inv) - { - ret->set_predict<gd, predict<true, true> >(); - g.predict = predict<true, true>; - } + g.predict = predict<true, true>; else - { - ret->set_predict<gd, predict<true, false> >(); - g.predict = predict<true, false>; - } + g.predict = predict<true, false>; else if (all.audit || all.hash_inv) - { - ret->set_predict<gd, predict<false, true> >(); - g.predict = predict<false, true>; - } + g.predict = predict<false, true>; else - { - ret->set_predict<gd, predict<false, false> >(); - g.predict = predict<false, false>; - } + g.predict = predict<false, false>; + ret.set_predict(g.predict); uint32_t stride; if (all.power_t == 0.5) @@ -937,11 +926,10 @@ learner* setup(vw& all, po::variables_map& vm) stride = set_learn<false>(all, ret, feature_mask_off); all.reg.stride_shift = ceil_log_2(stride-1); - ret->increment = ((uint64_t)1 << all.reg.stride_shift); + ret.increment = ((uint64_t)1 << all.reg.stride_shift); - ret->set_save_load<gd,save_load>(); - - ret->set_end_pass<gd, end_pass>(); - return ret; + ret.set_save_load(save_load); + ret.set_end_pass(end_pass); + return make_base(ret); } } diff --git a/vowpalwabbit/gd.h b/vowpalwabbit/gd.h index be52e62f..de3964eb 100644 --- a/vowpalwabbit/gd.h +++ b/vowpalwabbit/gd.h @@ -24,7 +24,7 @@ namespace GD{ void compute_update(example* ec); void offset_train(regressor ®, example* &ec, float update, size_t offset); void train_one_example_single_thread(regressor& r, example* ex); - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text); void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text); void output_and_account_example(example* ec); diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc index eec54663..532c44f3 100644 --- a/vowpalwabbit/gd_mf.cc +++ b/vowpalwabbit/gd_mf.cc @@ -272,14 +272,14 @@ void mf_train(vw& all, example& ec) all->current_pass++; } - void predict(gdmf& d, learner& base, example& ec) + void predict(gdmf& d, base_learner& base, example& ec) { vw* all = d.all; mf_predict(*all,ec); } - void learn(gdmf& d, learner& base, example& ec) + void learn(gdmf& d, base_learner& base, example& ec) { vw* all = d.all; @@ -288,7 +288,7 @@ void mf_train(vw& all, example& ec) mf_train(*all, ec); } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("Gdmf options"); opts.add_options() @@ -339,12 +339,12 @@ void mf_train(vw& all, example& ec) } all.eta *= powf((float)(all.sd->t), all.power_t); - learner* l = new learner(&data, 1 << all.reg.stride_shift); - l->set_learn<gdmf, learn>(); - l->set_predict<gdmf, predict>(); - l->set_save_load<gdmf,save_load>(); - l->set_end_pass<gdmf,end_pass>(); + learner<gdmf>& l = init_learner(&data, 1 << all.reg.stride_shift); + l.set_learn(learn); + l.set_predict(predict); + l.set_save_load(save_load); + l.set_end_pass(end_pass); - return l; + return make_base(l); } } diff --git a/vowpalwabbit/gd_mf.h b/vowpalwabbit/gd_mf.h index 0705eff1..db093750 100644 --- a/vowpalwabbit/gd_mf.h +++ b/vowpalwabbit/gd_mf.h @@ -5,5 +5,5 @@ license as described in the file LICENSE. */ #pragma once namespace GDMF{ - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc index e27c646d..b4e68aff 100644 --- a/vowpalwabbit/global_data.cc +++ b/vowpalwabbit/global_data.cc @@ -250,6 +250,8 @@ vw::vw() data_filename = ""; + file_options = new std::stringstream; + bfgs = false; hessian_on = false; active = false; diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h index f651426f..b610d75a 100644 --- a/vowpalwabbit/global_data.h +++ b/vowpalwabbit/global_data.h @@ -170,9 +170,9 @@ struct vw { node_socks socks; - LEARNER::learner* l;//the top level learner - LEARNER::learner* scorer;//a scoring function - LEARNER::learner* cost_sensitive;//a cost sensitive learning algorithm. + LEARNER::base_learner* l;//the top level learner + LEARNER::base_learner* scorer;//a scoring function + LEARNER::base_learner* cost_sensitive;//a cost sensitive learning algorithm. void learn(example*); @@ -199,7 +199,7 @@ struct vw { double normalized_sum_norm_x; po::options_description opts; - std::stringstream file_options; + std::stringstream* file_options; vector<std::string> args; void* /*Search::search*/ searchstr; @@ -266,7 +266,7 @@ struct vw { size_t length () { return ((size_t)1) << num_bits; }; uint32_t rank; - v_array<LEARNER::learner* (*)(vw& all, po::variables_map& vm)> reduction_stack; + v_array<LEARNER::base_learner* (*)(vw& all, po::variables_map& vm)> reduction_stack; //Prediction output v_array<int> final_prediction_sink; // set to send global predictions to. diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc index 1f72ffc6..ff12e2bf 100644 --- a/vowpalwabbit/kernel_svm.cc +++ b/vowpalwabbit/kernel_svm.cc @@ -395,7 +395,7 @@ namespace KSVM } } - void predict(svm_params& params, learner &base, example& ec) { + void predict(svm_params& params, base_learner &base, example& ec) { flat_example* fec = flatten_sort_example(*(params.all),&ec); if(fec) { svm_example* sec = &calloc_or_die<svm_example>(); @@ -733,7 +733,7 @@ namespace KSVM //cerr<<params.model->support_vec[0]->example_counter<<endl; } - void learn(svm_params& params, learner& base, example& ec) { + void learn(svm_params& params, base_learner& base, example& ec) { flat_example* fec = flatten_sort_example(*(params.all),&ec); // for(int i = 0;i < fec->feature_map_len;i++) // cout<<i<<":"<<fec->feature_map[i].x<<" "<<fec->feature_map[i].weight_index<<" "; @@ -790,7 +790,7 @@ namespace KSVM cerr<<"Done with finish \n"; } - LEARNER::learner* setup(vw &all, po::variables_map& vm) { + LEARNER::base_learner* setup(vw &all, po::variables_map& vm) { po::options_description opts("KSVM options"); opts.add_options() ("ksvm", "kernel svm") @@ -857,7 +857,7 @@ namespace KSVM params.lambda = all.l2_lambda; - all.file_options <<" --lambda "<< params.lambda; + *all.file_options <<" --lambda "<< params.lambda; cerr<<"Lambda = "<<params.lambda<<endl; @@ -868,7 +868,7 @@ namespace KSVM else kernel_type = string("linear"); - all.file_options <<" --kernel "<< kernel_type; + *all.file_options <<" --kernel "<< kernel_type; cerr<<"Kernel = "<<kernel_type<<endl; @@ -877,7 +877,7 @@ namespace KSVM float bandwidth = 1.; if(vm.count("bandwidth")) { bandwidth = vm["bandwidth"].as<float>(); - all.file_options <<" --bandwidth "<<bandwidth; + *all.file_options <<" --bandwidth "<<bandwidth; } cerr<<"bandwidth = "<<bandwidth<<endl; params.kernel_params = &calloc_or_die<double>(); @@ -888,7 +888,7 @@ namespace KSVM int degree = 2; if(vm.count("degree")) { degree = vm["degree"].as<int>(); - all.file_options <<" --degree "<<degree; + *all.file_options <<" --degree "<<degree; } cerr<<"degree = "<<degree<<endl; params.kernel_params = &calloc_or_die<int>(); @@ -900,11 +900,11 @@ namespace KSVM params.all->reg.weight_mask = (uint32_t)LONG_MAX; params.all->reg.stride_shift = 0; - learner* l = new learner(¶ms, 1); - l->set_learn<svm_params, learn>(); - l->set_predict<svm_params, predict>(); - l->set_save_load<svm_params, save_load>(); - l->set_finish<svm_params, finish>(); - return l; + learner<svm_params>& l = init_learner(¶ms, 1); + l.set_learn(learn); + l.set_predict(predict); + l.set_save_load(save_load); + l.set_finish(finish); + return make_base(l); } } diff --git a/vowpalwabbit/kernel_svm.h b/vowpalwabbit/kernel_svm.h index 30b3ac55..563d70e2 100644 --- a/vowpalwabbit/kernel_svm.h +++ b/vowpalwabbit/kernel_svm.h @@ -5,6 +5,4 @@ license as described in the file LICENSE. */ #pragma once namespace KSVM -{ - LEARNER::learner* setup(vw &all, po::variables_map& vm); -} +{ LEARNER::base_learner* setup(vw &all, po::variables_map& vm); } diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index 8c8961af..1d03afbd 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -698,7 +698,7 @@ void save_load(lda& l, io_buf& model_file, bool read, bool text) l.doc_lengths.erase();
}
- void learn(lda& l, learner& base, example& ec)
+ void learn(lda& l, base_learner& base, example& ec)
{
size_t num_ex = l.examples.size();
l.examples.push_back(&ec);
@@ -716,7 +716,7 @@ void save_load(lda& l, io_buf& model_file, bool read, bool text) }
// placeholder
- void predict(lda& l, learner& base, example& ec)
+ void predict(lda& l, base_learner& base, example& ec)
{
learn(l, base, ec);
}
@@ -754,8 +754,8 @@ void end_examples(lda& l) }
- learner* setup(vw&all, po::variables_map& vm)
- {
+base_learner* setup(vw&all, po::variables_map& vm)
+{
po::options_description opts("Lda options");
opts.add_options()
("lda", po::value<uint32_t>(), "Run lda with <int> topics")
@@ -788,7 +788,7 @@ void end_examples(lda& l) all.random_weights = true;
all.add_constant = false;
- all.file_options << " --lda " << all.lda;
+ *all.file_options << " --lda " << all.lda;
if (all.eta > 1.)
{
@@ -799,23 +799,21 @@ void end_examples(lda& l) if (vm.count("minibatch")) {
size_t minibatch2 = next_pow2(ld.minibatch);
all.p->ring_size = all.p->ring_size > minibatch2 ? all.p->ring_size : minibatch2;
- }
- ld.v.resize(all.lda*ld.minibatch);
-
- ld.decay_levels.push_back(0.f);
-
- learner* l = new learner(&ld, 1 << all.reg.stride_shift);
- l->set_learn<lda,learn>();
- l->set_predict<lda,predict>();
- l->set_save_load<lda,save_load>();
- l->set_finish_example<lda,finish_example>();
- l->set_end_examples<lda,end_examples>();
- l->set_end_pass<lda,end_pass>();
- l->set_finish<lda,finish>();
-
- return l;
- }
- - }
+ ld.v.resize(all.lda*ld.minibatch);
+
+ ld.decay_levels.push_back(0.f);
+
+ learner<lda>& l = init_learner(&ld, 1 << all.reg.stride_shift);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_finish_example(finish_example);
+ l.set_end_examples(end_examples);
+ l.set_end_pass(end_pass);
+ l.set_finish(finish);
+
+ return make_base(l);
+}
+}
diff --git a/vowpalwabbit/lda_core.h b/vowpalwabbit/lda_core.h index 3a377d9f..2a065783 100644 --- a/vowpalwabbit/lda_core.h +++ b/vowpalwabbit/lda_core.h @@ -5,5 +5,5 @@ license as described in the file LICENSE. */ #pragma once namespace LDA{ - LEARNER::learner* setup(vw&, po::variables_map&); + LEARNER::base_learner* setup(vw&, po::variables_map&); } diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h index 71571c40..45815bcc 100644 --- a/vowpalwabbit/learner.h +++ b/vowpalwabbit/learner.h @@ -6,6 +6,7 @@ license as described in the file LICENSE. #pragma once // This is the interface for a learning algorithm #include<iostream> +#include"memory.h" using namespace std; struct vw; @@ -13,15 +14,16 @@ void return_simple_example(vw& all, void*, example& ec); namespace LEARNER { - struct learner; + template<class T> struct learner; + typedef learner<char> base_learner; struct func_data { void* data; - learner* base; + base_learner* base; void (*func)(void* data); }; - inline func_data tuple_dbf(void* data, learner* base, void (*func)(void* data)) + inline func_data tuple_dbf(void* data, base_learner* base, void (*func)(void* data)) { func_data foo; foo.data = data; @@ -32,200 +34,184 @@ namespace LEARNER struct learn_data { void* data; - learner* base; - void (*learn_f)(void* data, learner& base, example&); - void (*predict_f)(void* data, learner& base, example&); - void (*update_f)(void* data, learner& base, example&); + base_learner* base; + void (*learn_f)(void* data, base_learner& base, example&); + void (*predict_f)(void* data, base_learner& base, example&); + void (*update_f)(void* data, base_learner& base, example&); }; struct save_load_data{ void* data; - learner* base; + base_learner* base; void (*save_load_f)(void*, io_buf&, bool read, bool text); }; struct finish_example_data{ void* data; - learner* base; + base_learner* base; void (*finish_example_f)(vw&, void* data, example&); }; void generic_driver(vw& all); inline void generic_sl(void*, io_buf&, bool, bool) {} - inline void generic_learner(void* data, learner& base, example&) {} + inline void generic_learner(void* data, base_learner& base, example&) {} inline void generic_func(void* data) {} const save_load_data generic_save_load_fd = {NULL, NULL, generic_sl}; const learn_data generic_learn_fd = {NULL, NULL, generic_learner, generic_learner, NULL}; const func_data generic_func_fd = {NULL, NULL, generic_func}; - template<class R, void (*T)(R&, learner& base, example& ec)> - inline void tlearn(void* d, learner& base, example& ec) - { T(*(R*)d, base, ec); } - - template<class R, void (*T)(R&, io_buf& io, bool read, bool text)> - inline void tsl(void* d, io_buf& io, bool read, bool text) - { T(*(R*)d, io, read, text); } - - template<class R, void (*T)(R&)> - inline void tfunc(void* d) { T(*(R*)d); } - - template<class R, void (*T)(vw& all, R&, example&)> - inline void tend_example(vw& all, void* d, example& ec) - { T(all, *(R*)d, ec); } - - template <class T, void (*learn)(T* data, learner& base, example&), void (*predict)(T* data, learner& base, example&)> - struct learn_helper { - void (*learn_f)(void* data, learner& base, example&); - void (*predict_f)(void* data, learner& base, example&); - - learn_helper() - { learn_f = tlearn<T,learn>; - predict_f = tlearn<T,predict>; + typedef void (*tlearn)(void* d, base_learner& base, example& ec); + typedef void (*tsl)(void* d, io_buf& io, bool read, bool text); + typedef void (*tfunc)(void*d); + typedef void (*tend_example)(vw& all, void* d, example& ec); + + template<class T> learner<T>& init_learner(); + template<class T> learner<T>& init_learner(T* dat, size_t params_per_weight); + template<class T> learner<T>& init_learner(T* dat, base_learner* base, size_t ws = 1); + + template<class T> + struct learner { + private: + func_data init_fd; + learn_data learn_fd; + finish_example_data finish_example_fd; + save_load_data save_load_fd; + func_data end_pass_fd; + func_data end_examples_fd; + func_data finisher_fd; + + public: + size_t weights; //this stores the number of "weight vectors" required by the learner. + size_t increment; + + //called once for each example. Must work under reduction. + inline void learn(example& ec, size_t i=0) + { + ec.ft_offset += (uint32_t)(increment*i); + learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec); + ec.ft_offset -= (uint32_t)(increment*i); } + inline void set_learn(void (*u)(T& data, base_learner& base, example&)) + { + learn_fd.learn_f = (tlearn)u; + learn_fd.update_f = (tlearn)u; + } + + inline void predict(example& ec, size_t i=0) + { + ec.ft_offset += (uint32_t)(increment*i); + learn_fd.predict_f(learn_fd.data, *learn_fd.base, ec); + ec.ft_offset -= (uint32_t)(increment*i); + } + inline void set_predict(void (*u)(T& data, base_learner& base, example&)) + { learn_fd.predict_f = (tlearn)u; } + + inline void update(example& ec, size_t i=0) + { + ec.ft_offset += (uint32_t)(increment*i); + learn_fd.update_f(learn_fd.data, *learn_fd.base, ec); + ec.ft_offset -= (uint32_t)(increment*i); + } + inline void set_update(void (*u)(T& data, base_learner& base, example&)) + { learn_fd.update_f = (tlearn)u; } + + //called anytime saving or loading needs to happen. Autorecursive. + inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); } + inline void set_save_load(void (*sl)(T&, io_buf&, bool, bool)) + { save_load_fd.save_load_f = (tsl)sl; + save_load_fd.data = learn_fd.data; + save_load_fd.base = learn_fd.base;} + + //called to clean up state. Autorecursive. + void set_finish(void (*f)(T&)) { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } + inline void finish() + { + if (finisher_fd.data) + {finisher_fd.func(finisher_fd.data); free(finisher_fd.data); } + if (finisher_fd.base) { + finisher_fd.base->finish(); + free(finisher_fd.base); + } + } + + void end_pass(){ + end_pass_fd.func(end_pass_fd.data); + if (end_pass_fd.base) end_pass_fd.base->end_pass(); }//autorecursive + void set_end_pass(void (*f)(T&)) + {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, (tfunc)f);} + + //called after parsing of examples is complete. Autorecursive. + void end_examples() + { end_examples_fd.func(end_examples_fd.data); + if (end_examples_fd.base) end_examples_fd.base->end_examples(); } + void set_end_examples(void (*f)(T&)) + {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f);} + + //Called at the beginning by the driver. Explicitly not recursive. + void init_driver() { init_fd.func(init_fd.data);} + void set_init_driver(void (*f)(T&)) + { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } + + //called after learn example for each example. Explicitly not recursive. + inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);} + void set_finish_example(void (*f)(vw& all, T&, example&)) + {finish_example_fd.data = learn_fd.data; + finish_example_fd.finish_example_f = (tend_example)f;} + + friend learner<T>& init_learner<>(); + friend learner<T>& init_learner<>(T* dat, size_t params_per_weight); + friend learner<T>& init_learner<>(T* dat, base_learner* base, size_t ws); }; - -struct learner { -private: - func_data init_fd; - learn_data learn_fd; - finish_example_data finish_example_fd; - save_load_data save_load_fd; - func_data end_pass_fd; - func_data end_examples_fd; - func_data finisher_fd; -public: - size_t weights; //this stores the number of "weight vectors" required by the learner. - size_t increment; - - //called once for each example. Must work under reduction. - inline void learn(example& ec, size_t i=0) - { - ec.ft_offset += (uint32_t)(increment*i); - learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec); - ec.ft_offset -= (uint32_t)(increment*i); - } - template <class T, void (*u)(T& data, learner& base, example&)> - inline void set_learn() - { - learn_fd.learn_f = tlearn<T,u>; - learn_fd.update_f = tlearn<T,u>; - } - - inline void predict(example& ec, size_t i=0) - { - ec.ft_offset += (uint32_t)(increment*i); - learn_fd.predict_f(learn_fd.data, *learn_fd.base, ec); - ec.ft_offset -= (uint32_t)(increment*i); - } - template <class T, void (*u)(T& data, learner& base, example&)> - inline void set_predict() - { - learn_fd.predict_f = tlearn<T,u>; - } - - inline void update(example& ec, size_t i=0) - { - ec.ft_offset += (uint32_t)(increment*i); - learn_fd.update_f(learn_fd.data, *learn_fd.base, ec); - ec.ft_offset -= (uint32_t)(increment*i); - } - template <class T, void (*u)(T& data, learner& base, example&)> - inline void set_update() - { - learn_fd.update_f = tlearn<T,u>; - } - - //called anytime saving or loading needs to happen. Autorecursive. - inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); } - template <class T, void (*sl)(T&, io_buf&, bool, bool)> - inline void set_save_load() - { save_load_fd.save_load_f = tsl<T,sl>; - save_load_fd.data = learn_fd.data; - save_load_fd.base = learn_fd.base;} - - //called to clean up state. Autorecursive. - template <class T, void (*f)(T&)> - void set_finish() { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc<T, f>); } - inline void finish() - { - if (finisher_fd.data) - {finisher_fd.func(finisher_fd.data); free(finisher_fd.data); } - if (finisher_fd.base) { - finisher_fd.base->finish(); - delete finisher_fd.base; + template<class T> learner<T>& init_learner() + { + learner<T>& ret = calloc_or_die<learner<T> >(); + ret.weights = 1; + ret.increment = 1; + ret.learn_fd = LEARNER::generic_learn_fd; + ret.finish_example_fd.data = NULL; + ret.finish_example_fd.finish_example_f = return_simple_example; + ret.end_pass_fd = LEARNER::generic_func_fd; + ret.end_examples_fd = LEARNER::generic_func_fd; + ret.init_fd = LEARNER::generic_func_fd; + ret.finisher_fd = LEARNER::generic_func_fd; + ret.save_load_fd = LEARNER::generic_save_load_fd; + return ret; } - } - - void end_pass(){ - end_pass_fd.func(end_pass_fd.data); - if (end_pass_fd.base) end_pass_fd.base->end_pass(); }//autorecursive - template <class T, void (*f)(T&)> - void set_end_pass() {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, tfunc<T,f>);} - - //called after parsing of examples is complete. Autorecursive. - void end_examples() - { end_examples_fd.func(end_examples_fd.data); - if (end_examples_fd.base) end_examples_fd.base->end_examples(); } - template <class T, void (*f)(T&)> - void set_end_examples() {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc<T,f>);} - - //Called at the beginning by the driver. Explicitly not recursive. - void init_driver() { init_fd.func(init_fd.data);} - template <class T, void (*f)(T&)> - void set_init_driver() { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc<T,f>); } - - //called after learn example for each example. Explicitly not recursive. - inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);} - template<class T, void (*f)(vw& all, T&, example&)> - void set_finish_example() - {finish_example_fd.data = learn_fd.data; - finish_example_fd.finish_example_f = tend_example<T,f>;} - - inline learner() - { - weights = 1; - increment = 1; - - learn_fd = LEARNER::generic_learn_fd; - finish_example_fd.data = NULL; - finish_example_fd.finish_example_f = return_simple_example; - end_pass_fd = LEARNER::generic_func_fd; - end_examples_fd = LEARNER::generic_func_fd; - init_fd = LEARNER::generic_func_fd; - finisher_fd = LEARNER::generic_func_fd; - save_load_fd = LEARNER::generic_save_load_fd; - } - - inline learner(void* dat, size_t params_per_weight) - { // the constructor for all learning algorithms. - *this = learner(); - - learn_fd.data = dat; - - finisher_fd.data = dat; - finisher_fd.base = NULL; - finisher_fd.func = LEARNER::generic_func; - - increment = params_per_weight; - } - - inline learner(void *dat, learner* base, size_t ws = 1) - { //the reduction constructor, with separate learn and predict functions - *this = *base; - - learn_fd.data = dat; - learn_fd.base = base; - - finisher_fd.data = dat; - finisher_fd.base = base; - finisher_fd.func = LEARNER::generic_func; - - weights = ws; - increment = base->increment * weights; - } -}; - + + template<class T> learner<T>& init_learner(T* dat, size_t params_per_weight) + { // the constructor for all learning algorithms. + learner<T>& ret = init_learner<T>(); + + ret.learn_fd.data = dat; + + ret.finisher_fd.data = dat; + ret.finisher_fd.base = NULL; + ret.finisher_fd.func = LEARNER::generic_func; + + ret.increment = params_per_weight; + return ret; + } + + template<class T> learner<T>& init_learner(T* dat, base_learner* base, size_t ws = 1) + { //the reduction constructor, with separate learn and predict functions + learner<T>& ret = calloc_or_die<learner<T> >(); + ret = *(learner<T>*)base; + + ret.learn_fd.data = dat; + ret.learn_fd.base = base; + + ret.finisher_fd.data = dat; + ret.finisher_fd.base = base; + ret.finisher_fd.func = LEARNER::generic_func; + + ret.weights = ws; + ret.increment = base->increment * ret.weights; + return ret; + } + + template<class T> base_learner* make_base(learner<T>& base) + { return (base_learner*)&base; } } diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index c1077320..b263d008 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -81,8 +81,8 @@ namespace LOG_MULTI v_array<node> nodes;
- uint32_t max_predictors;
- uint32_t predictors_used;
+ size_t max_predictors;
+ size_t predictors_used;
bool progress;
uint32_t swap_resist;
@@ -246,7 +246,7 @@ namespace LOG_MULTI return b.nodes[current].internal;
}
- void train_node(log_multi& b, learner& base, example& ec, uint32_t& current, uint32_t& class_index)
+ void train_node(log_multi& b, base_learner& base, example& ec, uint32_t& current, uint32_t& class_index)
{
if(b.nodes[current].norm_Eh > b.nodes[current].preds[class_index].norm_Ehk)
ec.l.simple.label = -1.f;
@@ -297,7 +297,7 @@ namespace LOG_MULTI return n.right;
}
- void predict(log_multi& b, learner& base, example& ec)
+ void predict(log_multi& b, base_learner& base, example& ec)
{
MULTICLASS::multiclass mc = ec.l.multi;
@@ -316,7 +316,7 @@ namespace LOG_MULTI ec.l.multi = mc;
}
- void learn(log_multi& b, learner& base, example& ec)
+ void learn(log_multi& b, base_learner& base, example& ec)
{
// verify_min_dfs(b, b.nodes[0]);
@@ -413,10 +413,10 @@ namespace LOG_MULTI if (read)
for (uint32_t j = 1; j < temp; j++)
b.nodes.push_back(init_node());
- text_len = sprintf(buff, "max_predictors = %d ",b.max_predictors);
+ text_len = sprintf(buff, "max_predictors = %ld ",b.max_predictors);
bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.max_predictors), "", read, buff, text_len, text);
- text_len = sprintf(buff, "predictors_used = %d ",b.predictors_used);
+ text_len = sprintf(buff, "predictors_used = %ld ",b.predictors_used);
bin_text_read_write_fixed(model_file,(char*)&b.predictors_used, sizeof(b.predictors_used), "", read, buff, text_len, text);
text_len = sprintf(buff, "progress = %d ",b.progress);
@@ -502,7 +502,7 @@ namespace LOG_MULTI VW::finish_example(all, &ec);
}
- learner* setup(vw& all, po::variables_map& vm) //learner setup
+ base_learner* setup(vw& all, po::variables_map& vm) //learner setup
{
po::options_description opts("Log Multi options");
opts.add_options()
@@ -513,23 +513,23 @@ namespace LOG_MULTI if(!vm.count("log_multi"))
return NULL;
- log_multi* data = (log_multi*)calloc(1, sizeof(log_multi));
+ log_multi& data = calloc_or_die<log_multi>(); - data->k = (uint32_t)vm["log_multi"].as<size_t>();
- data->swap_resist = 4;
+ data.k = (uint32_t)vm["log_multi"].as<size_t>();
+ data.swap_resist = 4;
if (vm.count("swap_resistance"))
- data->swap_resist = vm["swap_resistance"].as<uint32_t>();
+ data.swap_resist = vm["swap_resistance"].as<uint32_t>();
//append log_multi with nb_actions to options_from_file so it is saved to regressor later
- all.file_options << " --log_multi " << data->k;
+ *all.file_options << " --log_multi " << data.k;
if (vm.count("no_progress"))
- data->progress = false;
+ data.progress = false;
else
- data->progress = true;
+ data.progress = true;
- data->all = &all;
+ data.all = &all;
(all.p->lp) = MULTICLASS::mc_label;
string loss_function = "quantile";
@@ -537,17 +537,17 @@ namespace LOG_MULTI delete(all.loss);
all.loss = getLossFunction(&all, loss_function, loss_parameter);
- data->max_predictors = data->k - 1;
+ data.max_predictors = data.k - 1;
- learner* l = new learner(data, all.l, data->max_predictors);
- l->set_save_load<log_multi,save_load_tree>();
- l->set_learn<log_multi,learn>();
- l->set_predict<log_multi,predict>();
- l->set_finish_example<log_multi,finish_example>();
- l->set_finish<log_multi,finish>();
+ learner<log_multi>& l = init_learner(&data, all.l, data.max_predictors);
+ l.set_save_load(save_load_tree);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_finish_example(finish_example);
+ l.set_finish(finish);
- init_tree(*data);
+ init_tree(data);
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/log_multi.h b/vowpalwabbit/log_multi.h index 26ecb435..5e1ee3bf 100644 --- a/vowpalwabbit/log_multi.h +++ b/vowpalwabbit/log_multi.h @@ -6,5 +6,5 @@ license as described in the file LICENSE. #pragma once namespace LOG_MULTI { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc index ee826b72..c44b8692 100644 --- a/vowpalwabbit/lrq.cc +++ b/vowpalwabbit/lrq.cc @@ -62,7 +62,7 @@ namespace { namespace LRQ { template <bool is_learn> - void predict_or_learn(LRQstate& lrq, learner& base, example& ec) + void predict_or_learn(LRQstate& lrq, base_learner& base, example& ec) { vw& all = *lrq.all; @@ -187,7 +187,7 @@ namespace LRQ { } } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) {//parse and set arguments po::options_description opts("Lrq options"); opts.add_options() @@ -197,37 +197,37 @@ namespace LRQ { if(!vm.count("lrq")) return NULL; - LRQstate* lrq = (LRQstate*)calloc(1, sizeof (LRQstate)); - unsigned int maxk = 0; - lrq->all = &all; + LRQstate& lrq = calloc_or_die<LRQstate>(); + size_t maxk = 0; + lrq.all = &all; size_t random_seed = 0; if (vm.count("random_seed")) random_seed = vm["random_seed"].as<size_t> (); - lrq->initial_seed = lrq->seed = random_seed | 8675309; + lrq.initial_seed = lrq.seed = random_seed | 8675309; if (vm.count("lrqdropout")) - lrq->dropout = true; + lrq.dropout = true; else - lrq->dropout = false; + lrq.dropout = false; - all.file_options << " --lrqdropout "; + *all.file_options << " --lrqdropout "; - lrq->lrpairs = vm["lrq"].as<vector<string> > (); + lrq.lrpairs = vm["lrq"].as<vector<string> > (); - for (vector<string>::iterator i = lrq->lrpairs.begin (); - i != lrq->lrpairs.end (); + for (vector<string>::iterator i = lrq.lrpairs.begin (); + i != lrq.lrpairs.end (); ++i) - all.file_options << " --lrq " << *i; + *all.file_options << " --lrq " << *i; if (! all.quiet) { cerr << "creating low rank quadratic features for pairs: "; - if (lrq->dropout) + if (lrq.dropout) cerr << "(using dropout) "; } - for (vector<string>::iterator i = lrq->lrpairs.begin (); - i != lrq->lrpairs.end (); + for (vector<string>::iterator i = lrq.lrpairs.begin (); + i != lrq.lrpairs.end (); ++i) { if(!all.quiet){ @@ -241,8 +241,8 @@ namespace LRQ { unsigned int k = atoi (i->c_str () + 2); - lrq->lrindices[(int) (*i)[0]] = 1; - lrq->lrindices[(int) (*i)[1]] = 1; + lrq.lrindices[(int) (*i)[0]] = 1; + lrq.lrindices[(int) (*i)[1]] = 1; maxk = max (maxk, k); } @@ -251,12 +251,12 @@ namespace LRQ { cerr<<endl; all.wpp = all.wpp * (1 + maxk); - learner* l = new learner(lrq, all.l, 1 + maxk); - l->set_learn<LRQstate, predict_or_learn<true> >(); - l->set_predict<LRQstate, predict_or_learn<false> >(); - l->set_end_pass<LRQstate,reset_seed>(); + learner<LRQstate>& l = init_learner(&lrq, all.l, 1 + maxk); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_end_pass(reset_seed); // TODO: leaks memory ? - return l; + return make_base(l); } } diff --git a/vowpalwabbit/lrq.h b/vowpalwabbit/lrq.h index af0aae77..376bd6e5 100644 --- a/vowpalwabbit/lrq.h +++ b/vowpalwabbit/lrq.h @@ -1,4 +1,4 @@ #pragma once namespace LRQ { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc index e5d99939..31896efa 100644 --- a/vowpalwabbit/mf.cc +++ b/vowpalwabbit/mf.cc @@ -22,7 +22,7 @@ namespace MF { struct mf { vector<string> pairs; - uint32_t rank; + size_t rank; uint32_t increment; @@ -43,7 +43,7 @@ struct mf { }; template <bool cache_sub_predictions> -void predict(mf& data, learner& base, example& ec) { +void predict(mf& data, base_learner& base, example& ec) { float prediction = 0; if (cache_sub_predictions) data.sub_predictions.resize(2*data.rank+1, true); @@ -102,7 +102,7 @@ void predict(mf& data, learner& base, example& ec) { ec.pred.scalar = GD::finalize_prediction(data.all->sd, ec.partial_prediction); } -void learn(mf& data, learner& base, example& ec) { +void learn(mf& data, base_learner& base, example& ec) { // predict with current weights predict<true>(data, base, ec); float predicted = ec.pred.scalar; @@ -188,7 +188,7 @@ void finish(mf& o) { o.sub_predictions.delete_v(); } - learner* setup(vw& all, po::variables_map& vm) { + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("MF options"); opts.add_options() ("new_mf", po::value<size_t>(), "rank for reduction-based matrix factorization"); @@ -196,23 +196,23 @@ void finish(mf& o) { if(!vm.count("new_mf")) return NULL; - mf* data = new mf; + mf& data = calloc_or_die<mf>(); // copy global data locally - data->all = &all; - data->rank = (uint32_t)vm["new_mf"].as<size_t>(); + data.all = &all; + data.rank = (uint32_t)vm["new_mf"].as<size_t>(); // store global pairs in local data structure and clear global pairs // for eventual calls to base learner - data->pairs = all.pairs; + data.pairs = all.pairs; all.pairs.clear(); all.random_positive_weights = true; - learner* l = new learner(data, all.l, 2*data->rank+1); - l->set_learn<mf, learn>(); - l->set_predict<mf, predict<false> >(); - l->set_finish<mf,finish>(); - return l; + learner<mf>& l = init_learner(&data, all.l, 2*data.rank+1); + l.set_learn(learn); + l.set_predict(predict<false>); + l.set_finish(finish); + return make_base(l); } } diff --git a/vowpalwabbit/mf.h b/vowpalwabbit/mf.h index 73fc7f00..90ddc33a 100644 --- a/vowpalwabbit/mf.h +++ b/vowpalwabbit/mf.h @@ -5,5 +5,5 @@ license as described in the file LICENSE. */ #pragma once namespace MF{ - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index 3a9d8e45..854a0e88 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -96,7 +96,7 @@ namespace NN { } template <bool is_learn> - void predict_or_learn(nn& n, learner& base, example& ec) + void predict_or_learn(nn& n, base_learner& base, example& ec) { bool shouldOutput = n.all->raw_prediction > 0; @@ -308,7 +308,7 @@ CONVERSE: // That's right, I'm using goto. So sue me. free (n.output_layer.atomics[nn_output_namespace].begin); } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("NN options"); opts.add_options() @@ -324,11 +324,11 @@ CONVERSE: // That's right, I'm using goto. So sue me. n.all = &all; //first parse for number of hidden units n.k = (uint32_t)vm["nn"].as<size_t>(); - all.file_options << " --nn " << n.k; + *all.file_options << " --nn " << n.k; if ( vm.count("dropout") ) { n.dropout = true; - all.file_options << " --dropout "; + *all.file_options << " --dropout "; } if ( vm.count("meanfield") ) { @@ -347,7 +347,7 @@ CONVERSE: // That's right, I'm using goto. So sue me. if (vm.count ("inpass")) { n.inpass = true; - all.file_options << " --inpass"; + *all.file_options << " --inpass"; } @@ -366,13 +366,13 @@ CONVERSE: // That's right, I'm using goto. So sue me. n.save_xsubi = n.xsubi; n.increment = all.l->increment;//Indexing of output layer is odd. - learner* l = new learner(&n, all.l, n.k+1); - l->set_learn<nn, predict_or_learn<true> >(); - l->set_predict<nn, predict_or_learn<false> >(); - l->set_finish<nn, finish>(); - l->set_finish_example<nn, finish_example>(); - l->set_end_pass<nn,end_pass>(); - - return l; + learner<nn>& l = init_learner(&n, all.l, n.k+1); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish(finish); + l.set_finish_example(finish_example); + l.set_end_pass(end_pass); + + return make_base(l); } } diff --git a/vowpalwabbit/nn.h b/vowpalwabbit/nn.h index fb090873..820157d7 100644 --- a/vowpalwabbit/nn.h +++ b/vowpalwabbit/nn.h @@ -6,5 +6,5 @@ license as described in the file LICENSE. #pragma once namespace NN { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/noop.cc b/vowpalwabbit/noop.cc index 57cf7264..8db5a795 100644 --- a/vowpalwabbit/noop.cc +++ b/vowpalwabbit/noop.cc @@ -10,7 +10,7 @@ license as described in the file LICENSE. using namespace LEARNER; namespace NOOP { - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("Noop options"); opts.add_options() @@ -19,6 +19,6 @@ namespace NOOP { if(!vm.count("noop")) return NULL; - return new learner(); + return &init_learner<char>(); } } diff --git a/vowpalwabbit/noop.h b/vowpalwabbit/noop.h index e314498a..ac8842e9 100644 --- a/vowpalwabbit/noop.h +++ b/vowpalwabbit/noop.h @@ -4,6 +4,5 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace NOOP { - LEARNER::learner* setup(vw& all, po::variables_map& vm); -} +namespace NOOP +{ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);} diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 1a132045..869ce94d 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -3,12 +3,7 @@ Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ -#include <float.h> -#include <limits.h> -#include <math.h> -#include <stdio.h> #include <sstream> - #include "multiclass.h" #include "simple_label.h" #include "reductions.h" @@ -16,7 +11,6 @@ license as described in the file LICENSE. using namespace std; using namespace LEARNER; -using namespace MULTICLASS; namespace OAA { struct oaa{ @@ -26,8 +20,8 @@ namespace OAA { }; template <bool is_learn> - void predict_or_learn(oaa& o, learner& base, example& ec) { - multiclass mc_label_data = ec.l.multi; + void predict_or_learn(oaa& o, base_learner& base, example& ec) { + MULTICLASS::multiclass mc_label_data = ec.l.multi; if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1)) cout << "label " << mc_label_data.label << " is not in {1,"<< o.k << "} This won't work right." << endl; @@ -75,7 +69,7 @@ namespace OAA { VW::finish_example(all, &ec); } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("One-against-all options"); opts.add_options() @@ -85,20 +79,19 @@ namespace OAA { return NULL; oaa& data = calloc_or_die<oaa>(); - //first parse for number of actions - data.k = vm["oaa"].as<size_t>(); - //append oaa with nb_actions to options_from_file so it is saved to regressor later - all.file_options << " --oaa " << data.k; + data.k = vm["oaa"].as<size_t>(); data.shouldOutput = all.raw_prediction > 0; data.all = &all; - all.p->lp = mc_label; - learner* l = new learner(&data, all.l, data.k); - l->set_learn<oaa, predict_or_learn<true> >(); - l->set_predict<oaa, predict_or_learn<false> >(); - l->set_finish_example<oaa, finish_example>(); + *all.file_options << " --oaa " << data.k; + all.p->lp = MULTICLASS::mc_label; + + learner<oaa>& l = init_learner(&data, all.l, data.k); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish_example(finish_example); - return l; + return make_base(l); } } diff --git a/vowpalwabbit/oaa.h b/vowpalwabbit/oaa.h index 6b47127f..de1b08ab 100644 --- a/vowpalwabbit/oaa.h +++ b/vowpalwabbit/oaa.h @@ -5,6 +5,4 @@ license as described in the file LICENSE. */ #pragma once namespace OAA -{ - LEARNER::learner* setup(vw& all, po::variables_map& vm); -} +{ LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index 56e3a89b..201389b4 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -354,7 +354,7 @@ void parse_feature_tweaks(vw& all, po::variables_map& vm) if (vm.count("affix")) { parse_affix_argument(all, vm["affix"].as<string>()); - all.file_options << " --affix " << vm["affix"].as<string>(); + *all.file_options << " --affix " << vm["affix"].as<string>(); } if(vm.count("ngram")){ @@ -746,9 +746,9 @@ void load_input_model(vw& all, po::variables_map& vm, io_buf& io_temp) } } -LEARNER::learner* setup_base(vw& all, po::variables_map& vm) +LEARNER::base_learner* setup_base(vw& all, po::variables_map& vm) { - LEARNER::learner* ret = all.reduction_stack.pop()(all,vm); + LEARNER::base_learner* ret = all.reduction_stack.pop()(all,vm); if (ret == NULL) return setup_next(all,vm); else @@ -760,6 +760,7 @@ void parse_reductions(vw& all, po::variables_map& vm) //Base algorithms all.reduction_stack.push_back(GD::setup); all.reduction_stack.push_back(KSVM::setup); + all.reduction_stack.push_back(FTRL::setup); all.reduction_stack.push_back(SENDER::setup); all.reduction_stack.push_back(GDMF::setup); all.reduction_stack.push_back(PRINT::setup); @@ -896,7 +897,7 @@ vw* parse_args(int argc, char *argv[]) parse_regressor_args(*all, vm, io_temp); int temp_argc = 0; - char** temp_argv = VW::get_argv_from_string(all->file_options.str(), temp_argc); + char** temp_argv = VW::get_argv_from_string(all->file_options->str(), temp_argc); add_to_args(*all, temp_argc, temp_argv); for (int i = 0; i < temp_argc; i++) free(temp_argv[i]); @@ -910,7 +911,7 @@ vw* parse_args(int argc, char *argv[]) po::store(pos, vm); po::notify(vm); - all->file_options.str(""); + all->file_options->str(""); parse_feature_tweaks(*all, vm); //feature tweaks @@ -966,14 +967,14 @@ vw* parse_args(int argc, char *argv[]) } namespace VW { - void cmd_string_replace_value( std::stringstream& ss, string flag_to_replace, string new_value ) + void cmd_string_replace_value( std::stringstream*& ss, string flag_to_replace, string new_value ) { flag_to_replace.append(" "); //add a space to make sure we obtain the right flag in case 2 flags start with the same set of characters - string cmd = ss.str(); + string cmd = ss->str(); size_t pos = cmd.find(flag_to_replace); if( pos == string::npos ) //flag currently not present in command string, so just append it to command string - ss << " " << flag_to_replace << new_value; + *ss << " " << flag_to_replace << new_value; else { //flag is present, need to replace old value with new value @@ -989,7 +990,7 @@ namespace VW { else //replace characters between pos and pos_after_value by new_value cmd.replace(pos,pos_after_value-pos,new_value); - ss.str(cmd); + ss->str(cmd); } } @@ -1044,7 +1045,7 @@ namespace VW { { finalize_regressor(all, all.final_regressor_name); all.l->finish(); - delete all.l; + free_it(all.l); if (all.reg.weight_vector != NULL) free(all.reg.weight_vector); free_parser(all); @@ -1053,6 +1054,7 @@ namespace VW { all.p->parse_name.delete_v(); free(all.p); free(all.sd); + delete all.file_options; for (size_t i = 0; i < all.final_prediction_sink.size(); i++) if (all.final_prediction_sink[i] != 1) io_buf::close_file_or_socket(all.final_prediction_sink[i]); diff --git a/vowpalwabbit/parse_args.h b/vowpalwabbit/parse_args.h index 63e47ee3..23531050 100644 --- a/vowpalwabbit/parse_args.h +++ b/vowpalwabbit/parse_args.h @@ -7,4 +7,4 @@ license as described in the file LICENSE. #include "global_data.h" vw* parse_args(int argc, char *argv[]); -LEARNER::learner* setup_next(vw& all, po::variables_map& vm); +LEARNER::base_learner* setup_next(vw& all, po::variables_map& vm); diff --git a/vowpalwabbit/parse_regressor.cc b/vowpalwabbit/parse_regressor.cc index daee30da..def3a8de 100644 --- a/vowpalwabbit/parse_regressor.cc +++ b/vowpalwabbit/parse_regressor.cc @@ -229,16 +229,16 @@ void save_load_header(vw& all, io_buf& model_file, bool read, bool text) "", read, "\n",1, text); - text_len = sprintf(buff, "options:%s\n", all.file_options.str().c_str()); - uint32_t len = (uint32_t)all.file_options.str().length()+1; - memcpy(buff2, all.file_options.str().c_str(),len); + text_len = sprintf(buff, "options:%s\n", all.file_options->str().c_str()); + uint32_t len = (uint32_t)all.file_options->str().length()+1; + memcpy(buff2, all.file_options->str().c_str(),len); if (read) len = buf_size; bin_text_read_write(model_file,buff2, len, "", read, buff, text_len, text); if (read) - all.file_options.str(buff2); + all.file_options->str(buff2); } } @@ -348,7 +348,7 @@ void parse_mask_regressor_args(vw& all, po::variables_map& vm){ } } else { // If no initial regressor, just clear out the options loaded from the header. - all.file_options.str(""); + all.file_options->str(""); } } } diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc index 503b99e5..9f4a93f2 100644 --- a/vowpalwabbit/print.cc +++ b/vowpalwabbit/print.cc @@ -22,7 +22,7 @@ namespace PRINT cout << " "; } - void learn(print& p, learner& base, example& ec) + void learn(print& p, base_learner& base, example& ec) { label_data& ld = ec.l.simple; if (ld.label != FLT_MAX) @@ -45,7 +45,7 @@ namespace PRINT cout << endl; } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("Print options"); opts.add_options() @@ -61,9 +61,9 @@ namespace PRINT all.reg.weight_mask = (length << all.reg.stride_shift) - 1; all.reg.stride_shift = 0; - learner* ret = new learner(&p, 1); - ret->set_learn<print,learn>(); - ret->set_predict<print,learn>(); - return ret; + learner<print>& ret = init_learner(&p, 1); + ret.set_learn(learn); + ret.set_predict(learn); + return make_base(ret); } } diff --git a/vowpalwabbit/print.h b/vowpalwabbit/print.h index d3628eff..2c855eaa 100644 --- a/vowpalwabbit/print.h +++ b/vowpalwabbit/print.h @@ -4,6 +4,5 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace PRINT { - LEARNER::learner* setup(vw& all, po::variables_map& vm); -} +namespace PRINT +{ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);} diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index a889a2ed..51be45f2 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -10,7 +10,7 @@ namespace Scorer { }; template <bool is_learn, float (*link)(float in)> - void predict_or_learn(scorer& s, learner& base, example& ec) + void predict_or_learn(scorer& s, base_learner& base, example& ec) { s.all->set_minmax(s.all->sd, ec.l.simple.label); @@ -45,7 +45,7 @@ namespace Scorer { return in; } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("Link options"); opts.add_options() @@ -56,23 +56,23 @@ namespace Scorer { scorer& s = calloc_or_die<scorer>(); s.all = &all; - learner* l = new learner(&s, all.l); + learner<scorer>& l = init_learner(&s, all.l); if (!vm.count("link") || link.compare("identity") == 0) { - l->set_learn<scorer, predict_or_learn<true, noop> >(); - l->set_predict<scorer, predict_or_learn<false, noop> >(); + l.set_learn(predict_or_learn<true, noop> ); + l.set_predict(predict_or_learn<false, noop> ); } else if (link.compare("logistic") == 0) { - all.file_options << " --link=logistic "; - l->set_learn<scorer, predict_or_learn<true, logistic> >(); - l->set_predict<scorer, predict_or_learn<false, logistic> >(); + *all.file_options << " --link=logistic "; + l.set_learn(predict_or_learn<true, logistic> ); + l.set_predict(predict_or_learn<false, logistic>); } else if (link.compare("glf1") == 0) { - all.file_options << " --link=glf1 "; - l->set_learn<scorer, predict_or_learn<true, glf1> >(); - l->set_predict<scorer, predict_or_learn<false, glf1> >(); + *all.file_options << " --link=glf1 "; + l.set_learn(predict_or_learn<true, glf1>); + l.set_predict(predict_or_learn<false, glf1>); } else { @@ -80,6 +80,6 @@ namespace Scorer { throw exception(); } - return l; + return make_base(l); } } diff --git a/vowpalwabbit/scorer.h b/vowpalwabbit/scorer.h index 3405b2e9..2d0ec294 100644 --- a/vowpalwabbit/scorer.h +++ b/vowpalwabbit/scorer.h @@ -1,4 +1,4 @@ #pragma once namespace Scorer { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index 2907f5c7..e93b4ad6 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -175,7 +175,7 @@ namespace Search { v_array<size_t> timesteps; v_array<float> learn_losses; - LEARNER::learner* base_learner; + LEARNER::base_learner* base_learner; clock_t start_clock_time; example*empty_example; @@ -1417,7 +1417,7 @@ namespace Search { } template <bool is_learn> - void search_predict_or_learn(search& sch, learner& base, example& ec) { + void search_predict_or_learn(search& sch, base_learner& base, example& ec) { search_private& priv = *sch.priv; vw* all = priv.all; priv.base_learner = &base; @@ -1653,7 +1653,7 @@ namespace Search { template<class T> void check_option(T& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, bool(*equal)(T,T), const char* mismatch_error_string, const char* required_error_string) { if (vm.count(opt_name)) { ret = vm[opt_name].as<T>(); - all.file_options << " --" << opt_name << " " << ret; + *all.file_options << " --" << opt_name << " " << ret; } else if (strlen(required_error_string)>0) { std::cerr << required_error_string << endl; if (! vm.count("help")) @@ -1664,7 +1664,7 @@ namespace Search { void check_option(bool& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, const char* mismatch_error_string) { if (vm.count(opt_name)) { ret = true; - all.file_options << " --" << opt_name; + *all.file_options << " --" << opt_name; } else ret = false; } @@ -1764,7 +1764,7 @@ namespace Search { delete[] cstr; } - learner* setup(vw&all, po::variables_map& vm) { + base_learner* setup(vw&all, po::variables_map& vm) { po::options_description opts("Search Options"); opts.add_options() ("search", po::value<size_t>(), "use search-based structured prediction, argument=maximum action id or 0 for LDF") @@ -1985,17 +1985,17 @@ namespace Search { if (!vm.count("csoaa") && !vm.count("csoaa_ldf") && !vm.count("wap_ldf") && !vm.count("cb")) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["search"])); - learner* base = setup_base(all,vm); + base_learner* base = setup_base(all,vm); - learner* l = new learner(&sch, all.l, priv.total_number_of_policies); - l->set_learn<search, search_predict_or_learn<true> >(); - l->set_predict<search, search_predict_or_learn<false> >(); - l->set_finish_example<search,finish_example>(); - l->set_end_examples<search,end_examples>(); - l->set_finish<search,search_finish>(); - l->set_end_pass<search,end_pass>(); - - return l; + learner<search>& l = init_learner(&sch, all.l, priv.total_number_of_policies); + l.set_learn(search_predict_or_learn<true>); + l.set_predict(search_predict_or_learn<false>); + l.set_finish_example(finish_example); + l.set_end_examples(end_examples); + l.set_finish(search_finish); + l.set_end_pass(end_pass); + + return make_base(l); } float action_hamming_loss(action a, const action* A, size_t sz) { diff --git a/vowpalwabbit/search.h b/vowpalwabbit/search.h index 07c73158..08633c5e 100644 --- a/vowpalwabbit/search.h +++ b/vowpalwabbit/search.h @@ -241,5 +241,5 @@ namespace Search { bool size_equal(size_t a, size_t b); // our interface within VW - LEARNER::learner* setup(vw&, po::variables_map&); + LEARNER::base_learner* setup(vw&, po::variables_map&); } diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index 0a0a284a..26fdba51 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -69,7 +69,7 @@ void receive_result(sender& s) return_simple_example(*(s.all), NULL, *ec); } - void learn(sender& s, learner& base, example& ec) + void learn(sender& s, base_learner& base, example& ec) { if (s.received_index + s.all->p->ring_size / 2 - 1 == s.sent_index) receive_result(s); @@ -100,7 +100,8 @@ void end_examples(sender& s) delete s.buf; } -learner* setup(vw& all, po::variables_map& vm) + + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("Sender options"); opts.add_options() @@ -120,13 +121,13 @@ learner* setup(vw& all, po::variables_map& vm) s.all = &all; s.delay_ring = calloc_or_die<example*>(all.p->ring_size); - learner* l = new learner(&s, 1); - l->set_learn<sender, learn>(); - l->set_predict<sender, learn>(); - l->set_finish<sender, finish>(); - l->set_finish_example<sender, finish_example>(); - l->set_end_examples<sender, end_examples>(); - return l; + learner<sender>& l = init_learner(&s, 1); + l.set_learn(learn); + l.set_predict(learn); + l.set_finish(finish); + l.set_finish_example(finish_example); + l.set_end_examples(end_examples); + return make_base(l); } } diff --git a/vowpalwabbit/sender.h b/vowpalwabbit/sender.h index 39713cdb..55f10754 100644 --- a/vowpalwabbit/sender.h +++ b/vowpalwabbit/sender.h @@ -4,6 +4,5 @@ individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ #pragma once -namespace SENDER{ - LEARNER::learner* setup(vw& all, po::variables_map& vm); -} +namespace SENDER +{ LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc index fb3dc425..fe4f6fbd 100644 --- a/vowpalwabbit/stagewise_poly.cc +++ b/vowpalwabbit/stagewise_poly.cc @@ -502,7 +502,7 @@ namespace StagewisePoly } } - void predict(stagewise_poly &poly, learner &base, example &ec) + void predict(stagewise_poly &poly, base_learner &base, example &ec) { poly.original_ec = &ec; synthetic_create(poly, ec, false); @@ -511,7 +511,7 @@ namespace StagewisePoly ec.updated_prediction = poly.synth_ec.updated_prediction; } - void learn(stagewise_poly &poly, learner &base, example &ec) + void learn(stagewise_poly &poly, base_learner &base, example &ec) { bool training = poly.all->training && ec.l.simple.label != FLT_MAX; poly.original_ec = &ec; @@ -656,7 +656,7 @@ namespace StagewisePoly //#endif //DEBUG } - learner *setup(vw &all, po::variables_map &vm) + base_learner *setup(vw &all, po::variables_map &vm) { po::options_description opts("Stagewise poly options"); opts.add_options() @@ -697,16 +697,16 @@ namespace StagewisePoly poly.next_batch_sz = poly.batch_sz; //following is so that saved models know to load us. - all.file_options << " --stage_poly"; + *all.file_options << " --stage_poly"; - learner *l = new learner(&poly, all.l); - l->set_learn<stagewise_poly, learn>(); - l->set_predict<stagewise_poly, predict>(); - l->set_finish<stagewise_poly, finish>(); - l->set_save_load<stagewise_poly, save_load>(); - l->set_finish_example<stagewise_poly,finish_example>(); - l->set_end_pass<stagewise_poly, end_pass>(); + learner<stagewise_poly>& l = init_learner(&poly, all.l); + l.set_learn(learn); + l.set_predict(predict); + l.set_finish(finish); + l.set_save_load(save_load); + l.set_finish_example(finish_example); + l.set_end_pass(end_pass); - return l; + return make_base(l); } } diff --git a/vowpalwabbit/stagewise_poly.h b/vowpalwabbit/stagewise_poly.h index 4f5ac1fa..983b4382 100644 --- a/vowpalwabbit/stagewise_poly.h +++ b/vowpalwabbit/stagewise_poly.h @@ -6,5 +6,5 @@ license as described in the file LICENSE. #pragma once namespace StagewisePoly { - LEARNER::learner *setup(vw &all, po::variables_map &vm); + LEARNER::base_learner *setup(vw &all, po::variables_map &vm); } diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index 1fc86002..a7c80e1b 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -85,7 +85,7 @@ namespace TOPK { } template <bool is_learn> - void predict_or_learn(topk& d, learner& base, example& ec) + void predict_or_learn(topk& d, base_learner& base, example& ec) { if (example_is_newline(ec)) return;//do not predict newline @@ -111,7 +111,7 @@ namespace TOPK { VW::finish_example(all, &ec); } - learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all, po::variables_map& vm) { po::options_description opts("TOP K options"); opts.add_options() @@ -124,11 +124,11 @@ namespace TOPK { data.B = (uint32_t)vm["top"].as<size_t>(); data.all = &all; - learner* l = new learner(&data, all.l); - l->set_learn<topk, predict_or_learn<true> >(); - l->set_predict<topk, predict_or_learn<false> >(); - l->set_finish_example<topk,finish_example>(); + learner<topk>& l = init_learner(&data, all.l); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish_example(finish_example); - return l; + return make_base(l); } } diff --git a/vowpalwabbit/topk.h b/vowpalwabbit/topk.h index 2de41fe2..964ff618 100644 --- a/vowpalwabbit/topk.h +++ b/vowpalwabbit/topk.h @@ -6,5 +6,5 @@ license as described in the file LICENSE. #pragma once namespace TOPK { - LEARNER::learner* setup(vw& all, po::variables_map& vm); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm); } diff --git a/vowpalwabbit/vw.h b/vowpalwabbit/vw.h index 49bfaca3..0fa4ee77 100644 --- a/vowpalwabbit/vw.h +++ b/vowpalwabbit/vw.h @@ -18,7 +18,7 @@ namespace VW { */ vw* initialize(string s); - void cmd_string_replace_value( std::stringstream& ss, string flag_to_replace, string new_value ); + void cmd_string_replace_value( std::stringstream*& ss, string flag_to_replace, string new_value ); char** get_argv_from_string(string s, int& argc); |