diff options
author | John Langford <jl@hunch.net> | 2014-12-27 05:40:35 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-12-27 05:40:35 +0300 |
commit | 33f25c5af9538b77e4cc43d800c1e4cae4995322 (patch) | |
tree | 2bcee53c27e927d4343f63eae6ebbe279b59ac5c /vowpalwabbit | |
parent | 8d77fce3ccff3a0c46948a703eea1d74730a3378 (diff) |
detemplate learner
Diffstat (limited to 'vowpalwabbit')
-rw-r--r-- | vowpalwabbit/active.cc | 10 | ||||
-rw-r--r-- | vowpalwabbit/autolink.cc | 4 | ||||
-rw-r--r-- | vowpalwabbit/bfgs.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/binary.cc | 4 | ||||
-rw-r--r-- | vowpalwabbit/bs.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/cb_algs.cc | 16 | ||||
-rw-r--r-- | vowpalwabbit/cbify.cc | 68 | ||||
-rw-r--r-- | vowpalwabbit/csoaa.cc | 20 | ||||
-rw-r--r-- | vowpalwabbit/ect.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/ftrl_proximal.cc | 6 | ||||
-rw-r--r-- | vowpalwabbit/gd.cc | 34 | ||||
-rw-r--r-- | vowpalwabbit/gd_mf.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/kernel_svm.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/lda_core.cc | 14 | ||||
-rw-r--r-- | vowpalwabbit/learner.h | 54 | ||||
-rw-r--r-- | vowpalwabbit/log_multi.cc | 10 | ||||
-rw-r--r-- | vowpalwabbit/lrq.cc | 6 | ||||
-rw-r--r-- | vowpalwabbit/mf.cc | 6 | ||||
-rw-r--r-- | vowpalwabbit/nn.cc | 10 | ||||
-rw-r--r-- | vowpalwabbit/oaa.cc | 6 | ||||
-rw-r--r-- | vowpalwabbit/print.cc | 4 | ||||
-rw-r--r-- | vowpalwabbit/scorer.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/search.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/sender.cc | 10 | ||||
-rw-r--r-- | vowpalwabbit/stagewise_poly.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/topk.cc | 6 |
26 files changed, 170 insertions, 198 deletions
diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc index 7f0d1cc7..a6337641 100644 --- a/vowpalwabbit/active.cc +++ b/vowpalwabbit/active.cc @@ -169,15 +169,15 @@ namespace ACTIVE { learner<active>* ret = new learner<active>(&data, all.l); if (vm.count("simulation")) { - ret->set_learn<predict_or_learn_simulation<true> >(); - ret->set_predict<predict_or_learn_simulation<false> >(); + ret->set_learn(predict_or_learn_simulation<true>); + ret->set_predict(predict_or_learn_simulation<false>); } else { all.active = true; - ret->set_learn<predict_or_learn_active<true> >(); - ret->set_predict<predict_or_learn_active<false> >(); - ret->set_finish_example<return_active_example>(); + ret->set_learn(predict_or_learn_active<true>); + ret->set_predict(predict_or_learn_active<false>); + ret->set_finish_example(return_active_example); } return make_base(ret); diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index d4e00f68..57d6a247 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -50,8 +50,8 @@ namespace ALINK { all.file_options << " --autolink " << data.d; learner<autolink>* ret = new learner<autolink>(&data, all.l); - ret->set_learn<predict_or_learn<true> >(); - ret->set_predict<predict_or_learn<false> >(); + ret->set_learn(predict_or_learn<true>); + ret->set_predict(predict_or_learn<false>); return make_base(ret); } } diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index fb08315a..a5d4110b 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -1019,12 +1019,12 @@ base_learner* setup(vw& all, po::variables_map& vm) all.reg.stride_shift = 2; learner<bfgs>* l = new learner<bfgs>(&b, 1 << all.reg.stride_shift); - l->set_learn<learn>(); - l->set_predict<predict>(); - l->set_save_load<save_load>(); - l->set_init_driver<init_driver>(); - l->set_end_pass<end_pass>(); - l->set_finish<finish>(); + l->set_learn(learn); + l->set_predict(predict); + l->set_save_load(save_load); + l->set_init_driver(init_driver); + l->set_end_pass(end_pass); + l->set_finish(finish); return make_base(l); } diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc index 1843a398..50c8875e 100644 --- a/vowpalwabbit/binary.cc +++ b/vowpalwabbit/binary.cc @@ -29,8 +29,8 @@ namespace BINARY { all.sd->binary_label = true; //Create new learner learner<char>* ret = new learner<char>(NULL, all.l); - ret->set_learn<predict_or_learn<true> >(); - ret->set_predict<predict_or_learn<false> >(); + ret->set_learn(predict_or_learn<true>); + ret->set_predict(predict_or_learn<false>); return make_base(ret); } } diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc index 339c464b..8382d04d 100644 --- a/vowpalwabbit/bs.cc +++ b/vowpalwabbit/bs.cc @@ -281,10 +281,10 @@ namespace BS { data.all = &all; learner<bs>* l = new learner<bs>(&data, all.l, data.B); - l->set_learn<predict_or_learn<true> >(); - l->set_predict<predict_or_learn<false> >(); - l->set_finish_example<finish_example>(); - l->set_finish<finish>(); + l->set_learn(predict_or_learn<true>); + l->set_predict(predict_or_learn<false>); + l->set_finish_example(finish_example); + l->set_finish(finish); return make_base(l); } diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc index 84523820..c02ae6ed 100644 --- a/vowpalwabbit/cb_algs.cc +++ b/vowpalwabbit/cb_algs.cc @@ -567,18 +567,18 @@ namespace CB_ALGS learner<cb>* l = new learner<cb>(&c, all.l, problem_multiplier); if (eval) { - l->set_learn<learn_eval>(); - l->set_predict<predict_eval>(); - l->set_finish_example<eval_finish_example>(); + l->set_learn(learn_eval); + l->set_predict(predict_eval); + l->set_finish_example(eval_finish_example); } else { - l->set_learn<predict_or_learn<true> >(); - l->set_predict<predict_or_learn<false> >(); - l->set_finish_example<finish_example>(); + l->set_learn(predict_or_learn<true>); + l->set_predict(predict_or_learn<false>); + l->set_finish_example(finish_example); } - l->set_init_driver<init_driver>(); - l->set_finish<finish>(); + l->set_init_driver(init_driver); + l->set_finish(finish); // preserve the increment of the base learner since we are // _adding_ to the number of problems rather than multiplying. l->increment = all.l->increment; diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index ae0e1a42..ab8323ea 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -404,51 +404,51 @@ namespace CBIFY { data.cs = all.cost_sensitive; data.second_cs_label.costs.resize(data.k); data.second_cs_label.costs.end = data.second_cs_label.costs.begin+data.k; - float epsilon = 0.05f; - if (vm.count("epsilon")) - epsilon = vm["epsilon"].as<float>(); - data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k)); - data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k)); - l = new learner<cbify>(&data, all.l, cover + 1); - l->set_learn<predict_or_learn_cover<true> >(); - l->set_predict<predict_or_learn_cover<false> >(); + float epsilon = 0.05f; + if (vm.count("epsilon")) + epsilon = vm["epsilon"].as<float>(); + data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k)); + data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k)); + l = new learner<cbify>(&data, all.l, cover + 1); + l->set_learn(predict_or_learn_cover<true>); + l->set_predict(predict_or_learn_cover<false>); } else if (vm.count("bag")) { size_t bags = (uint32_t)vm["bag"].as<size_t>(); - for (size_t i = 0; i < bags; i++) - { - data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i))); - } - data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k)); - l = new learner<cbify>(&data, all.l, bags); - l->set_learn<predict_or_learn_bag<true> >(); - l->set_predict<predict_or_learn_bag<false> >(); + for (size_t i = 0; i < bags; i++) + { + data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i))); + } + data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k)); + l = new learner<cbify>(&data, all.l, bags); + l->set_learn(predict_or_learn_bag<true>); + l->set_predict(predict_or_learn_bag<false>); } else if (vm.count("first") ) { - uint32_t tau = (uint32_t)vm["first"].as<size_t>(); - data.policy.reset(new vw_policy()); - data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k)); - l = new learner<cbify>(&data, all.l, 1); - l->set_learn<predict_or_learn_first<true> >(); - l->set_predict<predict_or_learn_first<false> >(); + uint32_t tau = (uint32_t)vm["first"].as<size_t>(); + data.policy.reset(new vw_policy()); + data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k)); + l = new learner<cbify>(&data, all.l, 1); + l->set_learn(predict_or_learn_first<true>); + l->set_predict(predict_or_learn_first<false>); } else { - float epsilon = 0.05f; - if (vm.count("epsilon")) - epsilon = vm["epsilon"].as<float>(); - data.policy.reset(new vw_policy()); - data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k)); - l = new learner<cbify>(&data, all.l, 1); - l->set_learn<predict_or_learn_greedy<true> >(); - l->set_predict<predict_or_learn_greedy<false> >(); + float epsilon = 0.05f; + if (vm.count("epsilon")) + epsilon = vm["epsilon"].as<float>(); + data.policy.reset(new vw_policy()); + data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k)); + l = new learner<cbify>(&data, all.l, 1); + l->set_learn(predict_or_learn_greedy<true>); + l->set_predict(predict_or_learn_greedy<false>); } - - l->set_finish_example<finish_example>(); - l->set_finish<finish>(); - l->set_init_driver<init_driver>(); + + l->set_finish_example(finish_example); + l->set_finish(finish); + l->set_init_driver(init_driver); return make_base(l); } diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc index 5392ad91..e14cf8f5 100644 --- a/vowpalwabbit/csoaa.cc +++ b/vowpalwabbit/csoaa.cc @@ -83,9 +83,9 @@ namespace CSOAA { all.sd->k = nb_actions; learner<csoaa>* l = new learner<csoaa>(&c, all.l, nb_actions); - l->set_learn<predict_or_learn<true> >(); - l->set_predict<predict_or_learn<false> >(); - l->set_finish_example<finish_example>(); + l->set_learn(predict_or_learn<true>); + l->set_predict(predict_or_learn<false>); + l->set_finish_example(finish_example); return make_base(l); } } @@ -710,15 +710,15 @@ namespace LabelDict { ld.read_example_this_loop = 0; ld.need_to_clear = false; learner<ldf>* l = new learner<ldf>(&ld, all.l); - l->set_learn<predict_or_learn<true> >(); - l->set_predict<predict_or_learn<false> >(); + l->set_learn(predict_or_learn<true>); + l->set_predict(predict_or_learn<false>); if (ld.is_singleline) - l->set_finish_example<finish_singleline_example>(); + l->set_finish_example(finish_singleline_example); else - l->set_finish_example<finish_multiline_example>(); - l->set_finish<finish>(); - l->set_end_examples<end_examples>(); - l->set_end_pass<end_pass>(); + l->set_finish_example(finish_multiline_example); + l->set_finish(finish); + l->set_end_examples(end_examples); + l->set_end_pass(end_pass); return make_base(l); } diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index acb83ef4..f847af76 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -391,10 +391,10 @@ namespace ECT data.all = &all; learner<ect>* l = new learner<ect>(&data, all.l, wpp); - l->set_learn<learn>(); - l->set_predict<predict>(); - l->set_finish_example<finish_example>(); - l->set_finish<finish>(); + l->set_learn(learn); + l->set_predict(predict); + l->set_finish_example(finish_example); + l->set_finish(finish); return make_base(l); } diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc index 6d0ca216..a0f437b0 100644 --- a/vowpalwabbit/ftrl_proximal.cc +++ b/vowpalwabbit/ftrl_proximal.cc @@ -218,9 +218,9 @@ namespace FTRL { } learner<ftrl>* l = new learner<ftrl>(&b, 1 << all.reg.stride_shift); - l->set_learn<learn>(); - l->set_predict<predict>(); - l->set_save_load<save_load>(); + l->set_learn(learn); + l->set_predict(predict); + l->set_save_load(save_load); return make_base(l); } diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc index 47978957..8e4b3b2b 100644 --- a/vowpalwabbit/gd.cc +++ b/vowpalwabbit/gd.cc @@ -794,14 +794,14 @@ uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off) all.normalized_idx = normalized; if (feature_mask_off) { - ret->set_learn<learn<invariant, sqrt_rate, true, adaptive, normalized, spare> >(); - ret->set_update<update<invariant, sqrt_rate, true, adaptive, normalized, spare> >(); + ret->set_learn(learn<invariant, sqrt_rate, true, adaptive, normalized, spare>); + ret->set_update(update<invariant, sqrt_rate, true, adaptive, normalized, spare>); return next; } else { - ret->set_learn<learn<invariant, sqrt_rate, false, adaptive, normalized, spare> >(); - ret->set_update<update<invariant, sqrt_rate, false, adaptive, normalized, spare> >(); + ret->set_learn(learn<invariant, sqrt_rate, false, adaptive, normalized, spare>); + ret->set_update(update<invariant, sqrt_rate, false, adaptive, normalized, spare>); return next; } } @@ -902,25 +902,14 @@ base_learner* setup(vw& all, po::variables_map& vm) if (all.reg_mode % 2) if (all.audit || all.hash_inv) - { - ret->set_predict<predict<true, true> >(); - g.predict = predict<true, true>; - } + g.predict = predict<true, true>; else - { - ret->set_predict<predict<true, false> >(); - g.predict = predict<true, false>; - } + g.predict = predict<true, false>; else if (all.audit || all.hash_inv) - { - ret->set_predict<predict<false, true> >(); - g.predict = predict<false, true>; - } + g.predict = predict<false, true>; else - { - ret->set_predict<predict<false, false> >(); - g.predict = predict<false, false>; - } + g.predict = predict<false, false>; + ret->set_predict(g.predict); uint32_t stride; if (all.power_t == 0.5) @@ -931,9 +920,8 @@ base_learner* setup(vw& all, po::variables_map& vm) all.reg.stride_shift = ceil_log_2(stride-1); ret->increment = ((uint64_t)1 << all.reg.stride_shift); - ret->set_save_load<save_load>(); - - ret->set_end_pass<end_pass>(); + ret->set_save_load(save_load); + ret->set_end_pass(end_pass); return make_base(ret); } } diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc index f6f9beaf..1ae18985 100644 --- a/vowpalwabbit/gd_mf.cc +++ b/vowpalwabbit/gd_mf.cc @@ -331,10 +331,10 @@ void mf_train(vw& all, example& ec) all.eta *= powf((float)(all.sd->t), all.power_t); learner<gdmf>* l = new learner<gdmf>(&data, 1 << all.reg.stride_shift); - l->set_learn<learn>(); - l->set_predict<predict>(); - l->set_save_load<save_load>(); - l->set_end_pass<end_pass>(); + l->set_learn(learn); + l->set_predict(predict); + l->set_save_load(save_load); + l->set_end_pass(end_pass); return make_base(l); } diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc index 6c47b89e..bfe1f927 100644 --- a/vowpalwabbit/kernel_svm.cc +++ b/vowpalwabbit/kernel_svm.cc @@ -899,10 +899,10 @@ namespace KSVM params.all->reg.stride_shift = 0; learner<svm_params>* l = new learner<svm_params>(¶ms, 1); - l->set_learn<learn>(); - l->set_predict<predict>(); - l->set_save_load<save_load>(); - l->set_finish<finish>(); + l->set_learn(learn); + l->set_predict(predict); + l->set_save_load(save_load); + l->set_finish(finish); return make_base(l); } } diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index 525408e2..a8241557 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -787,13 +787,13 @@ base_learner* setup(vw&all, po::variables_map& vm) ld.decay_levels.push_back(0.f);
learner<lda>* l = new learner<lda>(&ld, 1 << all.reg.stride_shift);
- l->set_learn<learn>();
- l->set_predict<predict>();
- l->set_save_load<save_load>();
- l->set_finish_example<finish_example>();
- l->set_end_examples<end_examples>();
- l->set_end_pass<end_pass>();
- l->set_finish<finish>();
+ l->set_learn(learn);
+ l->set_predict(predict);
+ l->set_save_load(save_load);
+ l->set_finish_example(finish_example);
+ l->set_end_examples(end_examples);
+ l->set_end_pass(end_pass);
+ l->set_finish(finish);
return make_base(l);
}
diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h index 8d00b6e7..0edeecb4 100644 --- a/vowpalwabbit/learner.h +++ b/vowpalwabbit/learner.h @@ -61,20 +61,13 @@ namespace LEARNER const learn_data generic_learn_fd = {NULL, NULL, generic_learner, generic_learner, NULL}; const func_data generic_func_fd = {NULL, NULL, generic_func}; - template<class R, void (*T)(R&, base_learner& base, example& ec)> - inline void tlearn(void* d, base_learner& base, example& ec) - { T(*(R*)d, base, ec); } + typedef void (*tlearn)(void* d, base_learner& base, example& ec); - template<class R, void (*T)(R&, io_buf& io, bool read, bool text)> - inline void tsl(void* d, io_buf& io, bool read, bool text) - { T(*(R*)d, io, read, text); } + typedef void (*tsl)(void* d, io_buf& io, bool read, bool text); - template<class R, void (*T)(R&)> - inline void tfunc(void* d) { T(*(R*)d); } + typedef void (*tfunc)(void*d); - template<class R, void (*T)(vw& all, R&, example&)> - inline void tend_example(vw& all, void* d, example& ec) - { T(all, *(R*)d, ec); } + typedef void (*tend_example)(vw& all, void* d, example& ec); template<class T> struct learner { @@ -98,11 +91,10 @@ public: learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec); ec.ft_offset -= (uint32_t)(increment*i); } - template <void (*u)(T& data, base_learner& base, example&)> - inline void set_learn() + inline void set_learn(void (*u)(T& data, base_learner& base, example&)) { - learn_fd.learn_f = tlearn<T,u>; - learn_fd.update_f = tlearn<T,u>; + learn_fd.learn_f = (tlearn)u; + learn_fd.update_f = (tlearn)u; } inline void predict(example& ec, size_t i=0) @@ -111,10 +103,9 @@ public: learn_fd.predict_f(learn_fd.data, *learn_fd.base, ec); ec.ft_offset -= (uint32_t)(increment*i); } - template <void (*u)(T& data, base_learner& base, example&)> - inline void set_predict() + inline void set_predict(void (*u)(T& data, base_learner& base, example&)) { - learn_fd.predict_f = tlearn<T,u>; + learn_fd.predict_f = (tlearn)u; } inline void update(example& ec, size_t i=0) @@ -123,23 +114,20 @@ public: learn_fd.update_f(learn_fd.data, *learn_fd.base, ec); ec.ft_offset -= (uint32_t)(increment*i); } - template <void (*u)(T& data, base_learner& base, example&)> - inline void set_update() + inline void set_update(void (*u)(T& data, base_learner& base, example&)) { - learn_fd.update_f = tlearn<T,u>; + learn_fd.update_f = (tlearn)u; } //called anytime saving or loading needs to happen. Autorecursive. inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); } - template <void (*sl)(T&, io_buf&, bool, bool)> - inline void set_save_load() - { save_load_fd.save_load_f = tsl<T,sl>; + inline void set_save_load(void (*sl)(T&, io_buf&, bool, bool)) + { save_load_fd.save_load_f = (tsl)sl; save_load_fd.data = learn_fd.data; save_load_fd.base = learn_fd.base;} //called to clean up state. Autorecursive. - template <void (*f)(T&)> - void set_finish() { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc<T, f>); } + void set_finish(void (*f)(T&)) { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } inline void finish() { if (finisher_fd.data) @@ -153,27 +141,23 @@ public: void end_pass(){ end_pass_fd.func(end_pass_fd.data); if (end_pass_fd.base) end_pass_fd.base->end_pass(); }//autorecursive - template <void (*f)(T&)> - void set_end_pass() {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, tfunc<T,f>);} + void set_end_pass(void (*f)(T&)) {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, (tfunc)f);} //called after parsing of examples is complete. Autorecursive. void end_examples() { end_examples_fd.func(end_examples_fd.data); if (end_examples_fd.base) end_examples_fd.base->end_examples(); } - template <void (*f)(T&)> - void set_end_examples() {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc<T,f>);} + void set_end_examples(void (*f)(T&)) {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f);} //Called at the beginning by the driver. Explicitly not recursive. void init_driver() { init_fd.func(init_fd.data);} - template <void (*f)(T&)> - void set_init_driver() { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc<T,f>); } + void set_init_driver(void (*f)(T&)) { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } //called after learn example for each example. Explicitly not recursive. inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);} - template<void (*f)(vw& all, T&, example&)> - void set_finish_example() + void set_finish_example(void (*f)(vw& all, T&, example&)) {finish_example_fd.data = learn_fd.data; - finish_example_fd.finish_example_f = tend_example<T,f>;} + finish_example_fd.finish_example_f = (tend_example)f;} inline learner() { diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index d68a519f..2378106f 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -532,11 +532,11 @@ namespace LOG_MULTI data->max_predictors = data->k - 1;
learner<log_multi>* l = new learner<log_multi>(data, all.l, data->max_predictors);
- l->set_save_load<save_load_tree>();
- l->set_learn<learn>();
- l->set_predict<predict>();
- l->set_finish_example<finish_example>();
- l->set_finish<finish>();
+ l->set_save_load(save_load_tree);
+ l->set_learn(learn);
+ l->set_predict(predict);
+ l->set_finish_example(finish_example);
+ l->set_finish(finish);
init_tree(*data);
diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc index 013c8e3b..557411e5 100644 --- a/vowpalwabbit/lrq.cc +++ b/vowpalwabbit/lrq.cc @@ -244,9 +244,9 @@ namespace LRQ { all.wpp = all.wpp * (1 + maxk); learner<LRQstate>* l = new learner<LRQstate>(&lrq, all.l, 1 + maxk); - l->set_learn<predict_or_learn<true> >(); - l->set_predict<predict_or_learn<false> >(); - l->set_end_pass<reset_seed>(); + l->set_learn(predict_or_learn<true>); + l->set_predict(predict_or_learn<false>); + l->set_end_pass(reset_seed); // TODO: leaks memory ? return make_base(l); diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc index e933adb8..65706731 100644 --- a/vowpalwabbit/mf.cc +++ b/vowpalwabbit/mf.cc @@ -204,9 +204,9 @@ base_learner* setup(vw& all, po::variables_map& vm) { all.random_positive_weights = true; learner<mf>* l = new learner<mf>(data, all.l, 2*data->rank+1); - l->set_learn<learn>(); - l->set_predict<predict<false> >(); - l->set_finish<finish>(); + l->set_learn(learn); + l->set_predict(predict<false>); + l->set_finish(finish); return make_base(l); } } diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index 5afc989b..4f623381 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -366,11 +366,11 @@ CONVERSE: // That's right, I'm using goto. So sue me. n.save_xsubi = n.xsubi; n.increment = all.l->increment;//Indexing of output layer is odd. learner<nn>* l = new learner<nn>(&n, all.l, n.k+1); - l->set_learn<predict_or_learn<true> >(); - l->set_predict<predict_or_learn<false> >(); - l->set_finish<finish>(); - l->set_finish_example<finish_example>(); - l->set_end_pass<end_pass>(); + l->set_learn(predict_or_learn<true>); + l->set_predict(predict_or_learn<false>); + l->set_finish(finish); + l->set_finish_example(finish_example); + l->set_end_pass(end_pass); return make_base(l); } diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index df1db482..47ee84b0 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -88,9 +88,9 @@ namespace OAA { all.p->lp = mc_label; learner<oaa>* l = new learner<oaa>(&data, all.l, data.k); - l->set_learn<predict_or_learn<true> >(); - l->set_predict<predict_or_learn<false> >(); - l->set_finish_example<finish_example>(); + l->set_learn(predict_or_learn<true>); + l->set_predict(predict_or_learn<false>); + l->set_finish_example(finish_example); return make_base(l); } diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc index 144f7350..30f78cff 100644 --- a/vowpalwabbit/print.cc +++ b/vowpalwabbit/print.cc @@ -55,8 +55,8 @@ namespace PRINT all.reg.stride_shift = 0; learner<print>* ret = new learner<print>(&p, 1); - ret->set_learn<learn>(); - ret->set_predict<learn>(); + ret->set_learn(learn); + ret->set_predict(learn); return make_base(ret); } } diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index 815796fc..a4b73855 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -62,20 +62,20 @@ namespace Scorer { string link = vm["link"].as<string>(); if (!vm.count("link") || link.compare("identity") == 0) { - l->set_learn<predict_or_learn<true, noop> >(); - l->set_predict<predict_or_learn<false, noop> >(); + l->set_learn(predict_or_learn<true, noop> ); + l->set_predict(predict_or_learn<false, noop> ); } else if (link.compare("logistic") == 0) { all.file_options << " --link=logistic "; - l->set_learn<predict_or_learn<true, logistic> >(); - l->set_predict<predict_or_learn<false, logistic> >(); + l->set_learn(predict_or_learn<true, logistic> ); + l->set_predict(predict_or_learn<false, logistic>); } else if (link.compare("glf1") == 0) { all.file_options << " --link=glf1 "; - l->set_learn<predict_or_learn<true, glf1> >(); - l->set_predict<predict_or_learn<false, glf1> >(); + l->set_learn(predict_or_learn<true, glf1>); + l->set_predict(predict_or_learn<false, glf1>); } else { diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index 7236f191..ea218da2 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -1982,12 +1982,12 @@ namespace Search { priv.start_clock_time = clock(); learner<search>* l = new learner<search>(&sch, all.l, priv.total_number_of_policies); - l->set_learn<search_predict_or_learn<true> >(); - l->set_predict<search_predict_or_learn<false> >(); - l->set_finish_example<finish_example>(); - l->set_end_examples<end_examples>(); - l->set_finish<search_finish>(); - l->set_end_pass<end_pass>(); + l->set_learn(search_predict_or_learn<true>); + l->set_predict(search_predict_or_learn<false>); + l->set_finish_example(finish_example); + l->set_end_examples(end_examples); + l->set_finish(search_finish); + l->set_end_pass(end_pass); return make_base(l); } diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index d31e199f..6222b690 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -114,11 +114,11 @@ void end_examples(sender& s) s.delay_ring = calloc_or_die<example*>(all.p->ring_size); learner<sender>* l = new learner<sender>(&s, 1); - l->set_learn<learn>(); - l->set_predict<learn>(); - l->set_finish<finish>(); - l->set_finish_example<finish_example>(); - l->set_end_examples<end_examples>(); + l->set_learn(learn); + l->set_predict(learn); + l->set_finish(finish); + l->set_finish_example(finish_example); + l->set_end_examples(end_examples); return make_base(l); } diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc index daccb3cd..5639366d 100644 --- a/vowpalwabbit/stagewise_poly.cc +++ b/vowpalwabbit/stagewise_poly.cc @@ -699,12 +699,12 @@ namespace StagewisePoly all.file_options << " --stage_poly"; learner<stagewise_poly> *l = new learner<stagewise_poly>(&poly, all.l); - l->set_learn<learn>(); - l->set_predict<predict>(); - l->set_finish<finish>(); - l->set_save_load<save_load>(); - l->set_finish_example<finish_example>(); - l->set_end_pass<end_pass>(); + l->set_learn(learn); + l->set_predict(predict); + l->set_finish(finish); + l->set_save_load(save_load); + l->set_finish_example(finish_example); + l->set_end_pass(end_pass); return make_base(l); } diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index 546f38c0..70b5c9cd 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -120,9 +120,9 @@ namespace TOPK { data.all = &all; learner<topk>* l = new learner<topk>(&data, all.l); - l->set_learn<predict_or_learn<true> >(); - l->set_predict<predict_or_learn<false> >(); - l->set_finish_example<finish_example>(); + l->set_learn(predict_or_learn<true>); + l->set_predict(predict_or_learn<false>); + l->set_finish_example(finish_example); return make_base(l); } |