From 33f25c5af9538b77e4cc43d800c1e4cae4995322 Mon Sep 17 00:00:00 2001 From: John Langford Date: Fri, 26 Dec 2014 21:40:35 -0500 Subject: detemplate learner --- vowpalwabbit/active.cc | 10 +++---- vowpalwabbit/autolink.cc | 4 +-- vowpalwabbit/bfgs.cc | 12 ++++---- vowpalwabbit/binary.cc | 4 +-- vowpalwabbit/bs.cc | 8 ++--- vowpalwabbit/cb_algs.cc | 16 +++++----- vowpalwabbit/cbify.cc | 68 +++++++++++++++++++++--------------------- vowpalwabbit/csoaa.cc | 20 ++++++------- vowpalwabbit/ect.cc | 8 ++--- vowpalwabbit/ftrl_proximal.cc | 6 ++-- vowpalwabbit/gd.cc | 34 +++++++-------------- vowpalwabbit/gd_mf.cc | 8 ++--- vowpalwabbit/kernel_svm.cc | 8 ++--- vowpalwabbit/lda_core.cc | 14 ++++----- vowpalwabbit/learner.h | 54 ++++++++++++--------------------- vowpalwabbit/log_multi.cc | 10 +++---- vowpalwabbit/lrq.cc | 6 ++-- vowpalwabbit/mf.cc | 6 ++-- vowpalwabbit/nn.cc | 10 +++---- vowpalwabbit/oaa.cc | 6 ++-- vowpalwabbit/print.cc | 4 +-- vowpalwabbit/scorer.cc | 12 ++++---- vowpalwabbit/search.cc | 12 ++++---- vowpalwabbit/sender.cc | 10 +++---- vowpalwabbit/stagewise_poly.cc | 12 ++++---- vowpalwabbit/topk.cc | 6 ++-- 26 files changed, 170 insertions(+), 198 deletions(-) (limited to 'vowpalwabbit') diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc index 7f0d1cc7..a6337641 100644 --- a/vowpalwabbit/active.cc +++ b/vowpalwabbit/active.cc @@ -169,15 +169,15 @@ namespace ACTIVE { learner* ret = new learner(&data, all.l); if (vm.count("simulation")) { - ret->set_learn >(); - ret->set_predict >(); + ret->set_learn(predict_or_learn_simulation); + ret->set_predict(predict_or_learn_simulation); } else { all.active = true; - ret->set_learn >(); - ret->set_predict >(); - ret->set_finish_example(); + ret->set_learn(predict_or_learn_active); + ret->set_predict(predict_or_learn_active); + ret->set_finish_example(return_active_example); } return make_base(ret); diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index d4e00f68..57d6a247 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -50,8 +50,8 @@ namespace ALINK { all.file_options << " --autolink " << data.d; learner* ret = new learner(&data, all.l); - ret->set_learn >(); - ret->set_predict >(); + ret->set_learn(predict_or_learn); + ret->set_predict(predict_or_learn); return make_base(ret); } } diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index fb08315a..a5d4110b 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -1019,12 +1019,12 @@ base_learner* setup(vw& all, po::variables_map& vm) all.reg.stride_shift = 2; learner* l = new learner(&b, 1 << all.reg.stride_shift); - l->set_learn(); - l->set_predict(); - l->set_save_load(); - l->set_init_driver(); - l->set_end_pass(); - l->set_finish(); + l->set_learn(learn); + l->set_predict(predict); + l->set_save_load(save_load); + l->set_init_driver(init_driver); + l->set_end_pass(end_pass); + l->set_finish(finish); return make_base(l); } diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc index 1843a398..50c8875e 100644 --- a/vowpalwabbit/binary.cc +++ b/vowpalwabbit/binary.cc @@ -29,8 +29,8 @@ namespace BINARY { all.sd->binary_label = true; //Create new learner learner* ret = new learner(NULL, all.l); - ret->set_learn >(); - ret->set_predict >(); + ret->set_learn(predict_or_learn); + ret->set_predict(predict_or_learn); return make_base(ret); } } diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc index 339c464b..8382d04d 100644 --- a/vowpalwabbit/bs.cc +++ b/vowpalwabbit/bs.cc @@ -281,10 +281,10 @@ namespace BS { data.all = &all; learner* l = new learner(&data, all.l, data.B); - l->set_learn >(); - l->set_predict >(); - l->set_finish_example(); - l->set_finish(); + l->set_learn(predict_or_learn); + l->set_predict(predict_or_learn); + l->set_finish_example(finish_example); + l->set_finish(finish); return make_base(l); } diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc index 84523820..c02ae6ed 100644 --- a/vowpalwabbit/cb_algs.cc +++ b/vowpalwabbit/cb_algs.cc @@ -567,18 +567,18 @@ namespace CB_ALGS learner* l = new learner(&c, all.l, problem_multiplier); if (eval) { - l->set_learn(); - l->set_predict(); - l->set_finish_example(); + l->set_learn(learn_eval); + l->set_predict(predict_eval); + l->set_finish_example(eval_finish_example); } else { - l->set_learn >(); - l->set_predict >(); - l->set_finish_example(); + l->set_learn(predict_or_learn); + l->set_predict(predict_or_learn); + l->set_finish_example(finish_example); } - l->set_init_driver(); - l->set_finish(); + l->set_init_driver(init_driver); + l->set_finish(finish); // preserve the increment of the base learner since we are // _adding_ to the number of problems rather than multiplying. l->increment = all.l->increment; diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index ae0e1a42..ab8323ea 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -404,51 +404,51 @@ namespace CBIFY { data.cs = all.cost_sensitive; data.second_cs_label.costs.resize(data.k); data.second_cs_label.costs.end = data.second_cs_label.costs.begin+data.k; - float epsilon = 0.05f; - if (vm.count("epsilon")) - epsilon = vm["epsilon"].as(); - data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k)); - data.generic_explorer.reset(new GenericExplorer(*data.scorer.get(), (u32)data.k)); - l = new learner(&data, all.l, cover + 1); - l->set_learn >(); - l->set_predict >(); + float epsilon = 0.05f; + if (vm.count("epsilon")) + epsilon = vm["epsilon"].as(); + data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k)); + data.generic_explorer.reset(new GenericExplorer(*data.scorer.get(), (u32)data.k)); + l = new learner(&data, all.l, cover + 1); + l->set_learn(predict_or_learn_cover); + l->set_predict(predict_or_learn_cover); } else if (vm.count("bag")) { size_t bags = (uint32_t)vm["bag"].as(); - for (size_t i = 0; i < bags; i++) - { - data.policies.push_back(unique_ptr>(new vw_policy(i))); - } - data.bootstrap_explorer.reset(new BootstrapExplorer(data.policies, (u32)data.k)); - l = new learner(&data, all.l, bags); - l->set_learn >(); - l->set_predict >(); + for (size_t i = 0; i < bags; i++) + { + data.policies.push_back(unique_ptr>(new vw_policy(i))); + } + data.bootstrap_explorer.reset(new BootstrapExplorer(data.policies, (u32)data.k)); + l = new learner(&data, all.l, bags); + l->set_learn(predict_or_learn_bag); + l->set_predict(predict_or_learn_bag); } else if (vm.count("first") ) { - uint32_t tau = (uint32_t)vm["first"].as(); - data.policy.reset(new vw_policy()); - data.tau_explorer.reset(new TauFirstExplorer(*data.policy.get(), (u32)tau, (u32)data.k)); - l = new learner(&data, all.l, 1); - l->set_learn >(); - l->set_predict >(); + uint32_t tau = (uint32_t)vm["first"].as(); + data.policy.reset(new vw_policy()); + data.tau_explorer.reset(new TauFirstExplorer(*data.policy.get(), (u32)tau, (u32)data.k)); + l = new learner(&data, all.l, 1); + l->set_learn(predict_or_learn_first); + l->set_predict(predict_or_learn_first); } else { - float epsilon = 0.05f; - if (vm.count("epsilon")) - epsilon = vm["epsilon"].as(); - data.policy.reset(new vw_policy()); - data.greedy_explorer.reset(new EpsilonGreedyExplorer(*data.policy.get(), epsilon, (u32)data.k)); - l = new learner(&data, all.l, 1); - l->set_learn >(); - l->set_predict >(); + float epsilon = 0.05f; + if (vm.count("epsilon")) + epsilon = vm["epsilon"].as(); + data.policy.reset(new vw_policy()); + data.greedy_explorer.reset(new EpsilonGreedyExplorer(*data.policy.get(), epsilon, (u32)data.k)); + l = new learner(&data, all.l, 1); + l->set_learn(predict_or_learn_greedy); + l->set_predict(predict_or_learn_greedy); } - - l->set_finish_example(); - l->set_finish(); - l->set_init_driver(); + + l->set_finish_example(finish_example); + l->set_finish(finish); + l->set_init_driver(init_driver); return make_base(l); } diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc index 5392ad91..e14cf8f5 100644 --- a/vowpalwabbit/csoaa.cc +++ b/vowpalwabbit/csoaa.cc @@ -83,9 +83,9 @@ namespace CSOAA { all.sd->k = nb_actions; learner* l = new learner(&c, all.l, nb_actions); - l->set_learn >(); - l->set_predict >(); - l->set_finish_example(); + l->set_learn(predict_or_learn); + l->set_predict(predict_or_learn); + l->set_finish_example(finish_example); return make_base(l); } } @@ -710,15 +710,15 @@ namespace LabelDict { ld.read_example_this_loop = 0; ld.need_to_clear = false; learner* l = new learner(&ld, all.l); - l->set_learn >(); - l->set_predict >(); + l->set_learn(predict_or_learn); + l->set_predict(predict_or_learn); if (ld.is_singleline) - l->set_finish_example(); + l->set_finish_example(finish_singleline_example); else - l->set_finish_example(); - l->set_finish(); - l->set_end_examples(); - l->set_end_pass(); + l->set_finish_example(finish_multiline_example); + l->set_finish(finish); + l->set_end_examples(end_examples); + l->set_end_pass(end_pass); return make_base(l); } diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index acb83ef4..f847af76 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -391,10 +391,10 @@ namespace ECT data.all = &all; learner* l = new learner(&data, all.l, wpp); - l->set_learn(); - l->set_predict(); - l->set_finish_example(); - l->set_finish(); + l->set_learn(learn); + l->set_predict(predict); + l->set_finish_example(finish_example); + l->set_finish(finish); return make_base(l); } diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc index 6d0ca216..a0f437b0 100644 --- a/vowpalwabbit/ftrl_proximal.cc +++ b/vowpalwabbit/ftrl_proximal.cc @@ -218,9 +218,9 @@ namespace FTRL { } learner* l = new learner(&b, 1 << all.reg.stride_shift); - l->set_learn(); - l->set_predict(); - l->set_save_load(); + l->set_learn(learn); + l->set_predict(predict); + l->set_save_load(save_load); return make_base(l); } diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc index 47978957..8e4b3b2b 100644 --- a/vowpalwabbit/gd.cc +++ b/vowpalwabbit/gd.cc @@ -794,14 +794,14 @@ uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off) all.normalized_idx = normalized; if (feature_mask_off) { - ret->set_learn >(); - ret->set_update >(); + ret->set_learn(learn); + ret->set_update(update); return next; } else { - ret->set_learn >(); - ret->set_update >(); + ret->set_learn(learn); + ret->set_update(update); return next; } } @@ -902,25 +902,14 @@ base_learner* setup(vw& all, po::variables_map& vm) if (all.reg_mode % 2) if (all.audit || all.hash_inv) - { - ret->set_predict >(); - g.predict = predict; - } + g.predict = predict; else - { - ret->set_predict >(); - g.predict = predict; - } + g.predict = predict; else if (all.audit || all.hash_inv) - { - ret->set_predict >(); - g.predict = predict; - } + g.predict = predict; else - { - ret->set_predict >(); - g.predict = predict; - } + g.predict = predict; + ret->set_predict(g.predict); uint32_t stride; if (all.power_t == 0.5) @@ -931,9 +920,8 @@ base_learner* setup(vw& all, po::variables_map& vm) all.reg.stride_shift = ceil_log_2(stride-1); ret->increment = ((uint64_t)1 << all.reg.stride_shift); - ret->set_save_load(); - - ret->set_end_pass(); + ret->set_save_load(save_load); + ret->set_end_pass(end_pass); return make_base(ret); } } diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc index f6f9beaf..1ae18985 100644 --- a/vowpalwabbit/gd_mf.cc +++ b/vowpalwabbit/gd_mf.cc @@ -331,10 +331,10 @@ void mf_train(vw& all, example& ec) all.eta *= powf((float)(all.sd->t), all.power_t); learner* l = new learner(&data, 1 << all.reg.stride_shift); - l->set_learn(); - l->set_predict(); - l->set_save_load(); - l->set_end_pass(); + l->set_learn(learn); + l->set_predict(predict); + l->set_save_load(save_load); + l->set_end_pass(end_pass); return make_base(l); } diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc index 6c47b89e..bfe1f927 100644 --- a/vowpalwabbit/kernel_svm.cc +++ b/vowpalwabbit/kernel_svm.cc @@ -899,10 +899,10 @@ namespace KSVM params.all->reg.stride_shift = 0; learner* l = new learner(¶ms, 1); - l->set_learn(); - l->set_predict(); - l->set_save_load(); - l->set_finish(); + l->set_learn(learn); + l->set_predict(predict); + l->set_save_load(save_load); + l->set_finish(finish); return make_base(l); } } diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index 525408e2..a8241557 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -787,13 +787,13 @@ base_learner* setup(vw&all, po::variables_map& vm) ld.decay_levels.push_back(0.f); learner* l = new learner(&ld, 1 << all.reg.stride_shift); - l->set_learn(); - l->set_predict(); - l->set_save_load(); - l->set_finish_example(); - l->set_end_examples(); - l->set_end_pass(); - l->set_finish(); + l->set_learn(learn); + l->set_predict(predict); + l->set_save_load(save_load); + l->set_finish_example(finish_example); + l->set_end_examples(end_examples); + l->set_end_pass(end_pass); + l->set_finish(finish); return make_base(l); } diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h index 8d00b6e7..0edeecb4 100644 --- a/vowpalwabbit/learner.h +++ b/vowpalwabbit/learner.h @@ -61,20 +61,13 @@ namespace LEARNER const learn_data generic_learn_fd = {NULL, NULL, generic_learner, generic_learner, NULL}; const func_data generic_func_fd = {NULL, NULL, generic_func}; - template - inline void tlearn(void* d, base_learner& base, example& ec) - { T(*(R*)d, base, ec); } + typedef void (*tlearn)(void* d, base_learner& base, example& ec); - template - inline void tsl(void* d, io_buf& io, bool read, bool text) - { T(*(R*)d, io, read, text); } + typedef void (*tsl)(void* d, io_buf& io, bool read, bool text); - template - inline void tfunc(void* d) { T(*(R*)d); } + typedef void (*tfunc)(void*d); - template - inline void tend_example(vw& all, void* d, example& ec) - { T(all, *(R*)d, ec); } + typedef void (*tend_example)(vw& all, void* d, example& ec); template struct learner { @@ -98,11 +91,10 @@ public: learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec); ec.ft_offset -= (uint32_t)(increment*i); } - template - inline void set_learn() + inline void set_learn(void (*u)(T& data, base_learner& base, example&)) { - learn_fd.learn_f = tlearn; - learn_fd.update_f = tlearn; + learn_fd.learn_f = (tlearn)u; + learn_fd.update_f = (tlearn)u; } inline void predict(example& ec, size_t i=0) @@ -111,10 +103,9 @@ public: learn_fd.predict_f(learn_fd.data, *learn_fd.base, ec); ec.ft_offset -= (uint32_t)(increment*i); } - template - inline void set_predict() + inline void set_predict(void (*u)(T& data, base_learner& base, example&)) { - learn_fd.predict_f = tlearn; + learn_fd.predict_f = (tlearn)u; } inline void update(example& ec, size_t i=0) @@ -123,23 +114,20 @@ public: learn_fd.update_f(learn_fd.data, *learn_fd.base, ec); ec.ft_offset -= (uint32_t)(increment*i); } - template - inline void set_update() + inline void set_update(void (*u)(T& data, base_learner& base, example&)) { - learn_fd.update_f = tlearn; + learn_fd.update_f = (tlearn)u; } //called anytime saving or loading needs to happen. Autorecursive. inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); } - template - inline void set_save_load() - { save_load_fd.save_load_f = tsl; + inline void set_save_load(void (*sl)(T&, io_buf&, bool, bool)) + { save_load_fd.save_load_f = (tsl)sl; save_load_fd.data = learn_fd.data; save_load_fd.base = learn_fd.base;} //called to clean up state. Autorecursive. - template - void set_finish() { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc); } + void set_finish(void (*f)(T&)) { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } inline void finish() { if (finisher_fd.data) @@ -153,27 +141,23 @@ public: void end_pass(){ end_pass_fd.func(end_pass_fd.data); if (end_pass_fd.base) end_pass_fd.base->end_pass(); }//autorecursive - template - void set_end_pass() {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, tfunc);} + void set_end_pass(void (*f)(T&)) {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, (tfunc)f);} //called after parsing of examples is complete. Autorecursive. void end_examples() { end_examples_fd.func(end_examples_fd.data); if (end_examples_fd.base) end_examples_fd.base->end_examples(); } - template - void set_end_examples() {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc);} + void set_end_examples(void (*f)(T&)) {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f);} //Called at the beginning by the driver. Explicitly not recursive. void init_driver() { init_fd.func(init_fd.data);} - template - void set_init_driver() { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc); } + void set_init_driver(void (*f)(T&)) { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } //called after learn example for each example. Explicitly not recursive. inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);} - template - void set_finish_example() + void set_finish_example(void (*f)(vw& all, T&, example&)) {finish_example_fd.data = learn_fd.data; - finish_example_fd.finish_example_f = tend_example;} + finish_example_fd.finish_example_f = (tend_example)f;} inline learner() { diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index d68a519f..2378106f 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -532,11 +532,11 @@ namespace LOG_MULTI data->max_predictors = data->k - 1; learner* l = new learner(data, all.l, data->max_predictors); - l->set_save_load(); - l->set_learn(); - l->set_predict(); - l->set_finish_example(); - l->set_finish(); + l->set_save_load(save_load_tree); + l->set_learn(learn); + l->set_predict(predict); + l->set_finish_example(finish_example); + l->set_finish(finish); init_tree(*data); diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc index 013c8e3b..557411e5 100644 --- a/vowpalwabbit/lrq.cc +++ b/vowpalwabbit/lrq.cc @@ -244,9 +244,9 @@ namespace LRQ { all.wpp = all.wpp * (1 + maxk); learner* l = new learner(&lrq, all.l, 1 + maxk); - l->set_learn >(); - l->set_predict >(); - l->set_end_pass(); + l->set_learn(predict_or_learn); + l->set_predict(predict_or_learn); + l->set_end_pass(reset_seed); // TODO: leaks memory ? return make_base(l); diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc index e933adb8..65706731 100644 --- a/vowpalwabbit/mf.cc +++ b/vowpalwabbit/mf.cc @@ -204,9 +204,9 @@ base_learner* setup(vw& all, po::variables_map& vm) { all.random_positive_weights = true; learner* l = new learner(data, all.l, 2*data->rank+1); - l->set_learn(); - l->set_predict >(); - l->set_finish(); + l->set_learn(learn); + l->set_predict(predict); + l->set_finish(finish); return make_base(l); } } diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index 5afc989b..4f623381 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -366,11 +366,11 @@ CONVERSE: // That's right, I'm using goto. So sue me. n.save_xsubi = n.xsubi; n.increment = all.l->increment;//Indexing of output layer is odd. learner* l = new learner(&n, all.l, n.k+1); - l->set_learn >(); - l->set_predict >(); - l->set_finish(); - l->set_finish_example(); - l->set_end_pass(); + l->set_learn(predict_or_learn); + l->set_predict(predict_or_learn); + l->set_finish(finish); + l->set_finish_example(finish_example); + l->set_end_pass(end_pass); return make_base(l); } diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index df1db482..47ee84b0 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -88,9 +88,9 @@ namespace OAA { all.p->lp = mc_label; learner* l = new learner(&data, all.l, data.k); - l->set_learn >(); - l->set_predict >(); - l->set_finish_example(); + l->set_learn(predict_or_learn); + l->set_predict(predict_or_learn); + l->set_finish_example(finish_example); return make_base(l); } diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc index 144f7350..30f78cff 100644 --- a/vowpalwabbit/print.cc +++ b/vowpalwabbit/print.cc @@ -55,8 +55,8 @@ namespace PRINT all.reg.stride_shift = 0; learner* ret = new learner(&p, 1); - ret->set_learn(); - ret->set_predict(); + ret->set_learn(learn); + ret->set_predict(learn); return make_base(ret); } } diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index 815796fc..a4b73855 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -62,20 +62,20 @@ namespace Scorer { string link = vm["link"].as(); if (!vm.count("link") || link.compare("identity") == 0) { - l->set_learn >(); - l->set_predict >(); + l->set_learn(predict_or_learn ); + l->set_predict(predict_or_learn ); } else if (link.compare("logistic") == 0) { all.file_options << " --link=logistic "; - l->set_learn >(); - l->set_predict >(); + l->set_learn(predict_or_learn ); + l->set_predict(predict_or_learn); } else if (link.compare("glf1") == 0) { all.file_options << " --link=glf1 "; - l->set_learn >(); - l->set_predict >(); + l->set_learn(predict_or_learn); + l->set_predict(predict_or_learn); } else { diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index 7236f191..ea218da2 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -1982,12 +1982,12 @@ namespace Search { priv.start_clock_time = clock(); learner* l = new learner(&sch, all.l, priv.total_number_of_policies); - l->set_learn >(); - l->set_predict >(); - l->set_finish_example(); - l->set_end_examples(); - l->set_finish(); - l->set_end_pass(); + l->set_learn(search_predict_or_learn); + l->set_predict(search_predict_or_learn); + l->set_finish_example(finish_example); + l->set_end_examples(end_examples); + l->set_finish(search_finish); + l->set_end_pass(end_pass); return make_base(l); } diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index d31e199f..6222b690 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -114,11 +114,11 @@ void end_examples(sender& s) s.delay_ring = calloc_or_die(all.p->ring_size); learner* l = new learner(&s, 1); - l->set_learn(); - l->set_predict(); - l->set_finish(); - l->set_finish_example(); - l->set_end_examples(); + l->set_learn(learn); + l->set_predict(learn); + l->set_finish(finish); + l->set_finish_example(finish_example); + l->set_end_examples(end_examples); return make_base(l); } diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc index daccb3cd..5639366d 100644 --- a/vowpalwabbit/stagewise_poly.cc +++ b/vowpalwabbit/stagewise_poly.cc @@ -699,12 +699,12 @@ namespace StagewisePoly all.file_options << " --stage_poly"; learner *l = new learner(&poly, all.l); - l->set_learn(); - l->set_predict(); - l->set_finish(); - l->set_save_load(); - l->set_finish_example(); - l->set_end_pass(); + l->set_learn(learn); + l->set_predict(predict); + l->set_finish(finish); + l->set_save_load(save_load); + l->set_finish_example(finish_example); + l->set_end_pass(end_pass); return make_base(l); } diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index 546f38c0..70b5c9cd 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -120,9 +120,9 @@ namespace TOPK { data.all = &all; learner* l = new learner(&data, all.l); - l->set_learn >(); - l->set_predict >(); - l->set_finish_example(); + l->set_learn(predict_or_learn); + l->set_predict(predict_or_learn); + l->set_finish_example(finish_example); return make_base(l); } -- cgit v1.2.3