diff options
author | John Langford <jl@hunch.net> | 2014-12-27 07:17:30 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-12-27 07:17:30 +0300 |
commit | 0f9e6ce12e7f0c9c28a75acadd8e7bc50f932b0c (patch) | |
tree | 3f1943142e466da5425fbf46cbca3c905bb52e85 /vowpalwabbit | |
parent | 33f25c5af9538b77e4cc43d800c1e4cae4995322 (diff) |
remove new for learner
Diffstat (limited to 'vowpalwabbit')
-rw-r--r-- | vowpalwabbit/active.cc | 14 | ||||
-rw-r--r-- | vowpalwabbit/autolink.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/bfgs.cc | 18 | ||||
-rw-r--r-- | vowpalwabbit/binary.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/bs.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/cb_algs.cc | 22 | ||||
-rw-r--r-- | vowpalwabbit/cbify.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/csoaa.cc | 28 | ||||
-rw-r--r-- | vowpalwabbit/ect.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/ftrl_proximal.cc | 10 | ||||
-rw-r--r-- | vowpalwabbit/gd.cc | 28 | ||||
-rw-r--r-- | vowpalwabbit/gd_mf.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/global_data.h | 2 | ||||
-rw-r--r-- | vowpalwabbit/kernel_svm.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/lda_core.cc | 18 | ||||
-rw-r--r-- | vowpalwabbit/learner.h | 283 | ||||
-rw-r--r-- | vowpalwabbit/log_multi.cc | 14 | ||||
-rw-r--r-- | vowpalwabbit/lrq.cc | 10 | ||||
-rw-r--r-- | vowpalwabbit/mf.cc | 10 | ||||
-rw-r--r-- | vowpalwabbit/nn.cc | 16 | ||||
-rw-r--r-- | vowpalwabbit/oaa.cc | 10 | ||||
-rw-r--r-- | vowpalwabbit/print.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/scorer.cc | 16 | ||||
-rw-r--r-- | vowpalwabbit/search.cc | 18 | ||||
-rw-r--r-- | vowpalwabbit/sender.cc | 14 | ||||
-rw-r--r-- | vowpalwabbit/stagewise_poly.cc | 18 | ||||
-rw-r--r-- | vowpalwabbit/topk.cc | 10 |
27 files changed, 324 insertions, 315 deletions
diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc index a6337641..c1f7f985 100644 --- a/vowpalwabbit/active.cc +++ b/vowpalwabbit/active.cc @@ -166,20 +166,20 @@ namespace ACTIVE { data.all=&all; //Create new learner - learner<active>* ret = new learner<active>(&data, all.l); + learner<active>& ret = init_learner(&data, all.l); if (vm.count("simulation")) { - ret->set_learn(predict_or_learn_simulation<true>); - ret->set_predict(predict_or_learn_simulation<false>); + ret.set_learn(predict_or_learn_simulation<true>); + ret.set_predict(predict_or_learn_simulation<false>); } else { all.active = true; - ret->set_learn(predict_or_learn_active<true>); - ret->set_predict(predict_or_learn_active<false>); - ret->set_finish_example(return_active_example); + ret.set_learn(predict_or_learn_active<true>); + ret.set_predict(predict_or_learn_active<false>); + ret.set_finish_example(return_active_example); } - return make_base(ret); + return make_base(&ret); } } diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index 57d6a247..d0abeb70 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -49,9 +49,9 @@ namespace ALINK { all.file_options << " --autolink " << data.d; - learner<autolink>* ret = new learner<autolink>(&data, all.l); - ret->set_learn(predict_or_learn<true>); - ret->set_predict(predict_or_learn<false>); - return make_base(ret); + learner<autolink>& ret = init_learner(&data, all.l); + ret.set_learn(predict_or_learn<true>); + ret.set_predict(predict_or_learn<false>); + return make_base(&ret); } } diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index a5d4110b..3b6143ab 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -1018,14 +1018,14 @@ base_learner* setup(vw& all, po::variables_map& vm) all.bfgs = true; all.reg.stride_shift = 2; - learner<bfgs>* l = new learner<bfgs>(&b, 1 << all.reg.stride_shift); - l->set_learn(learn); - l->set_predict(predict); - l->set_save_load(save_load); - l->set_init_driver(init_driver); - l->set_end_pass(end_pass); - l->set_finish(finish); - - return make_base(l); + learner<bfgs>& l = init_learner(&b, 1 << all.reg.stride_shift); + l.set_learn(learn); + l.set_predict(predict); + l.set_save_load(save_load); + l.set_init_driver(init_driver); + l.set_end_pass(end_pass); + l.set_finish(finish); + + return make_base(&l); } } diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc index 50c8875e..52556af6 100644 --- a/vowpalwabbit/binary.cc +++ b/vowpalwabbit/binary.cc @@ -28,9 +28,9 @@ namespace BINARY { {//parse and set arguments all.sd->binary_label = true; //Create new learner - learner<char>* ret = new learner<char>(NULL, all.l); - ret->set_learn(predict_or_learn<true>); - ret->set_predict(predict_or_learn<false>); - return make_base(ret); + learner<char>& ret = init_learner<char>(NULL, all.l); + ret.set_learn(predict_or_learn<true>); + ret.set_predict(predict_or_learn<false>); + return make_base(&ret); } } diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc index 8382d04d..564ff3aa 100644 --- a/vowpalwabbit/bs.cc +++ b/vowpalwabbit/bs.cc @@ -280,12 +280,12 @@ namespace BS { data.pred_vec.reserve(data.B); data.all = &all; - learner<bs>* l = new learner<bs>(&data, all.l, data.B); - l->set_learn(predict_or_learn<true>); - l->set_predict(predict_or_learn<false>); - l->set_finish_example(finish_example); - l->set_finish(finish); + learner<bs>& l = init_learner(&data, all.l, data.B); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish_example(finish_example); + l.set_finish(finish); - return make_base(l); + return make_base(&l); } } diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc index c02ae6ed..2cc3cd7b 100644 --- a/vowpalwabbit/cb_algs.cc +++ b/vowpalwabbit/cb_algs.cc @@ -564,25 +564,25 @@ namespace CB_ALGS else all.p->lp = CB::cb_label; - learner<cb>* l = new learner<cb>(&c, all.l, problem_multiplier); + learner<cb>& l = init_learner(&c, all.l, problem_multiplier); if (eval) { - l->set_learn(learn_eval); - l->set_predict(predict_eval); - l->set_finish_example(eval_finish_example); + l.set_learn(learn_eval); + l.set_predict(predict_eval); + l.set_finish_example(eval_finish_example); } else { - l->set_learn(predict_or_learn<true>); - l->set_predict(predict_or_learn<false>); - l->set_finish_example(finish_example); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish_example(finish_example); } - l->set_init_driver(init_driver); - l->set_finish(finish); + l.set_init_driver(init_driver); + l.set_finish(finish); // preserve the increment of the base learner since we are // _adding_ to the number of problems rather than multiplying. - l->increment = all.l->increment; + l.increment = all.l->increment; - return make_base(l); + return make_base(&l); } } diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index ab8323ea..9339dabf 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -409,7 +409,7 @@ namespace CBIFY { epsilon = vm["epsilon"].as<float>(); data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k)); data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k)); - l = new learner<cbify>(&data, all.l, cover + 1); + l = &init_learner(&data, all.l, cover + 1); l->set_learn(predict_or_learn_cover<true>); l->set_predict(predict_or_learn_cover<false>); } @@ -421,7 +421,7 @@ namespace CBIFY { data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i))); } data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k)); - l = new learner<cbify>(&data, all.l, bags); + l = &init_learner(&data, all.l, bags); l->set_learn(predict_or_learn_bag<true>); l->set_predict(predict_or_learn_bag<false>); } @@ -430,7 +430,7 @@ namespace CBIFY { uint32_t tau = (uint32_t)vm["first"].as<size_t>(); data.policy.reset(new vw_policy()); data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k)); - l = new learner<cbify>(&data, all.l, 1); + l = &init_learner(&data, all.l, 1); l->set_learn(predict_or_learn_first<true>); l->set_predict(predict_or_learn_first<false>); } @@ -441,7 +441,7 @@ namespace CBIFY { epsilon = vm["epsilon"].as<float>(); data.policy.reset(new vw_policy()); data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k)); - l = new learner<cbify>(&data, all.l, 1); + l = &init_learner(&data, all.l, 1); l->set_learn(predict_or_learn_greedy<true>); l->set_predict(predict_or_learn_greedy<false>); } diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc index e14cf8f5..06150b2a 100644 --- a/vowpalwabbit/csoaa.cc +++ b/vowpalwabbit/csoaa.cc @@ -82,11 +82,11 @@ namespace CSOAA { all.p->lp = cs_label; all.sd->k = nb_actions; - learner<csoaa>* l = new learner<csoaa>(&c, all.l, nb_actions); - l->set_learn(predict_or_learn<true>); - l->set_predict(predict_or_learn<false>); - l->set_finish_example(finish_example); - return make_base(l); + learner<csoaa>& l = init_learner(&c, all.l, nb_actions); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish_example(finish_example); + return make_base(&l); } } @@ -709,17 +709,17 @@ namespace LabelDict { ld.read_example_this_loop = 0; ld.need_to_clear = false; - learner<ldf>* l = new learner<ldf>(&ld, all.l); - l->set_learn(predict_or_learn<true>); - l->set_predict(predict_or_learn<false>); + learner<ldf>& l = init_learner(&ld, all.l); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); if (ld.is_singleline) - l->set_finish_example(finish_singleline_example); + l.set_finish_example(finish_singleline_example); else - l->set_finish_example(finish_multiline_example); - l->set_finish(finish); - l->set_end_examples(end_examples); - l->set_end_pass(end_pass); - return make_base(l); + l.set_finish_example(finish_multiline_example); + l.set_finish(finish); + l.set_end_examples(end_examples); + l.set_end_pass(end_pass); + return make_base(&l); } void global_print_newline(vw& all) diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index f847af76..b9f7b536 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -390,12 +390,12 @@ namespace ECT size_t wpp = create_circuit(all, data, data.k, data.errors+1); data.all = &all; - learner<ect>* l = new learner<ect>(&data, all.l, wpp); - l->set_learn(learn); - l->set_predict(predict); - l->set_finish_example(finish_example); - l->set_finish(finish); + learner<ect>& l = init_learner(&data, all.l, wpp); + l.set_learn(learn); + l.set_predict(predict); + l.set_finish_example(finish_example); + l.set_finish(finish); - return make_base(l); + return make_base(&l); } } diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc index a0f437b0..0d3c58ff 100644 --- a/vowpalwabbit/ftrl_proximal.cc +++ b/vowpalwabbit/ftrl_proximal.cc @@ -217,12 +217,12 @@ namespace FTRL { cerr << "ftrl_beta = " << b.ftrl_beta << endl; } - learner<ftrl>* l = new learner<ftrl>(&b, 1 << all.reg.stride_shift); - l->set_learn(learn); - l->set_predict(predict); - l->set_save_load(save_load); + learner<ftrl>& l = init_learner(&b, 1 << all.reg.stride_shift); + l.set_learn(learn); + l.set_predict(predict); + l.set_save_load(save_load); - return make_base(l); + return make_base(&l); } diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc index 8e4b3b2b..8ccc0373 100644 --- a/vowpalwabbit/gd.cc +++ b/vowpalwabbit/gd.cc @@ -789,25 +789,25 @@ void save_load(gd& g, io_buf& model_file, bool read, bool text) } template<bool invariant, bool sqrt_rate, uint32_t adaptive, uint32_t normalized, uint32_t spare, uint32_t next> -uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off) +uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off) { all.normalized_idx = normalized; if (feature_mask_off) { - ret->set_learn(learn<invariant, sqrt_rate, true, adaptive, normalized, spare>); - ret->set_update(update<invariant, sqrt_rate, true, adaptive, normalized, spare>); + ret.set_learn(learn<invariant, sqrt_rate, true, adaptive, normalized, spare>); + ret.set_update(update<invariant, sqrt_rate, true, adaptive, normalized, spare>); return next; } else { - ret->set_learn(learn<invariant, sqrt_rate, false, adaptive, normalized, spare>); - ret->set_update(update<invariant, sqrt_rate, false, adaptive, normalized, spare>); + ret.set_learn(learn<invariant, sqrt_rate, false, adaptive, normalized, spare>); + ret.set_update(update<invariant, sqrt_rate, false, adaptive, normalized, spare>); return next; } } template<bool sqrt_rate, uint32_t adaptive, uint32_t normalized, uint32_t spare, uint32_t next> -uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off) +uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off) { if (all.invariant_updates) return set_learn<true, sqrt_rate, adaptive, normalized, spare, next>(all, ret, feature_mask_off); @@ -816,7 +816,7 @@ uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off) } template<bool sqrt_rate, uint32_t adaptive, uint32_t spare> -uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off) +uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off) { // select the appropriate learn function based on adaptive, normalization, and feature mask if (all.normalized_updates) @@ -826,7 +826,7 @@ uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off) } template<bool sqrt_rate> -uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off) +uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off) { if (all.adaptive) return set_learn<sqrt_rate, 1, 2>(all, ret, feature_mask_off); @@ -898,7 +898,7 @@ base_learner* setup(vw& all, po::variables_map& vm) cerr << "Warning: the learning rate for the last pass is multiplied by: " << pow((double)all.eta_decay_rate, (double)all.numpasses) << " adjust --decay_learning_rate larger to avoid this." << endl; - learner<gd>* ret = new learner<gd>(&g, 1); + learner<gd>& ret = init_learner(&g, 1); if (all.reg_mode % 2) if (all.audit || all.hash_inv) @@ -909,7 +909,7 @@ base_learner* setup(vw& all, po::variables_map& vm) g.predict = predict<false, true>; else g.predict = predict<false, false>; - ret->set_predict(g.predict); + ret.set_predict(g.predict); uint32_t stride; if (all.power_t == 0.5) @@ -918,10 +918,10 @@ base_learner* setup(vw& all, po::variables_map& vm) stride = set_learn<false>(all, ret, feature_mask_off); all.reg.stride_shift = ceil_log_2(stride-1); - ret->increment = ((uint64_t)1 << all.reg.stride_shift); + ret.increment = ((uint64_t)1 << all.reg.stride_shift); - ret->set_save_load(save_load); - ret->set_end_pass(end_pass); - return make_base(ret); + ret.set_save_load(save_load); + ret.set_end_pass(end_pass); + return make_base(&ret); } } diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc index 1ae18985..ed5e7df9 100644 --- a/vowpalwabbit/gd_mf.cc +++ b/vowpalwabbit/gd_mf.cc @@ -330,12 +330,12 @@ void mf_train(vw& all, example& ec) } all.eta *= powf((float)(all.sd->t), all.power_t); - learner<gdmf>* l = new learner<gdmf>(&data, 1 << all.reg.stride_shift); - l->set_learn(learn); - l->set_predict(predict); - l->set_save_load(save_load); - l->set_end_pass(end_pass); + learner<gdmf>& l = init_learner(&data, 1 << all.reg.stride_shift); + l.set_learn(learn); + l.set_predict(predict); + l.set_save_load(save_load); + l.set_end_pass(end_pass); - return make_base(l); + return make_base(&l); } } diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h index 86bd2d15..f0d4a6f1 100644 --- a/vowpalwabbit/global_data.h +++ b/vowpalwabbit/global_data.h @@ -170,7 +170,7 @@ struct vw { node_socks socks; - LEARNER::base_learner* l;//the top level learner + LEARNER::base_learner& l;//the top level learner LEARNER::base_learner* scorer;//a scoring function LEARNER::base_learner* cost_sensitive;//a cost sensitive learning algorithm. diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc index bfe1f927..cf9d82d2 100644 --- a/vowpalwabbit/kernel_svm.cc +++ b/vowpalwabbit/kernel_svm.cc @@ -898,11 +898,11 @@ namespace KSVM params.all->reg.weight_mask = (uint32_t)LONG_MAX; params.all->reg.stride_shift = 0; - learner<svm_params>* l = new learner<svm_params>(¶ms, 1); - l->set_learn(learn); - l->set_predict(predict); - l->set_save_load(save_load); - l->set_finish(finish); - return make_base(l); + learner<svm_params>& l = init_learner(¶ms, 1); + l.set_learn(learn); + l.set_predict(predict); + l.set_save_load(save_load); + l.set_finish(finish); + return make_base(&l); } } diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index a8241557..8f27ea78 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -786,15 +786,15 @@ base_learner* setup(vw&all, po::variables_map& vm) ld.decay_levels.push_back(0.f);
- learner<lda>* l = new learner<lda>(&ld, 1 << all.reg.stride_shift);
- l->set_learn(learn);
- l->set_predict(predict);
- l->set_save_load(save_load);
- l->set_finish_example(finish_example);
- l->set_end_examples(end_examples);
- l->set_end_pass(end_pass);
- l->set_finish(finish);
+ learner<lda>& l = init_learner(&ld, 1 << all.reg.stride_shift);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_finish_example(finish_example);
+ l.set_end_examples(end_examples);
+ l.set_end_pass(end_pass);
+ l.set_finish(finish);
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h index 0edeecb4..ef76f14d 100644 --- a/vowpalwabbit/learner.h +++ b/vowpalwabbit/learner.h @@ -6,6 +6,7 @@ license as described in the file LICENSE. #pragma once // This is the interface for a learning algorithm #include<iostream> +#include"memory.h" using namespace std; struct vw; @@ -62,147 +63,155 @@ namespace LEARNER const func_data generic_func_fd = {NULL, NULL, generic_func}; typedef void (*tlearn)(void* d, base_learner& base, example& ec); - typedef void (*tsl)(void* d, io_buf& io, bool read, bool text); - typedef void (*tfunc)(void*d); - typedef void (*tend_example)(vw& all, void* d, example& ec); -template<class T> -struct learner { -private: - func_data init_fd; - learn_data learn_fd; - finish_example_data finish_example_fd; - save_load_data save_load_fd; - func_data end_pass_fd; - func_data end_examples_fd; - func_data finisher_fd; + template<class T> learner<T>& init_learner(); + template<class T> learner<T>& init_learner(T* dat, size_t params_per_weight); + template<class T> learner<T>& init_learner(T* dat, base_learner* base, size_t ws = 1); -public: - size_t weights; //this stores the number of "weight vectors" required by the learner. - size_t increment; - - //called once for each example. Must work under reduction. - inline void learn(example& ec, size_t i=0) - { - ec.ft_offset += (uint32_t)(increment*i); - learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec); - ec.ft_offset -= (uint32_t)(increment*i); - } - inline void set_learn(void (*u)(T& data, base_learner& base, example&)) - { - learn_fd.learn_f = (tlearn)u; - learn_fd.update_f = (tlearn)u; - } - - inline void predict(example& ec, size_t i=0) - { - ec.ft_offset += (uint32_t)(increment*i); - learn_fd.predict_f(learn_fd.data, *learn_fd.base, ec); - ec.ft_offset -= (uint32_t)(increment*i); - } - inline void set_predict(void (*u)(T& data, base_learner& base, example&)) - { - learn_fd.predict_f = (tlearn)u; - } - - inline void update(example& ec, size_t i=0) - { - ec.ft_offset += (uint32_t)(increment*i); - learn_fd.update_f(learn_fd.data, *learn_fd.base, ec); - ec.ft_offset -= (uint32_t)(increment*i); - } - inline void set_update(void (*u)(T& data, base_learner& base, example&)) - { - learn_fd.update_f = (tlearn)u; - } - - //called anytime saving or loading needs to happen. Autorecursive. - inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); } - inline void set_save_load(void (*sl)(T&, io_buf&, bool, bool)) - { save_load_fd.save_load_f = (tsl)sl; - save_load_fd.data = learn_fd.data; - save_load_fd.base = learn_fd.base;} - - //called to clean up state. Autorecursive. - void set_finish(void (*f)(T&)) { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } - inline void finish() - { - if (finisher_fd.data) - {finisher_fd.func(finisher_fd.data); free(finisher_fd.data); } - if (finisher_fd.base) { - finisher_fd.base->finish(); - delete finisher_fd.base; + template<class T> + struct learner { + private: + func_data init_fd; + learn_data learn_fd; + finish_example_data finish_example_fd; + save_load_data save_load_fd; + func_data end_pass_fd; + func_data end_examples_fd; + func_data finisher_fd; + + public: + size_t weights; //this stores the number of "weight vectors" required by the learner. + size_t increment; + + //called once for each example. Must work under reduction. + inline void learn(example& ec, size_t i=0) + { + ec.ft_offset += (uint32_t)(increment*i); + learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec); + ec.ft_offset -= (uint32_t)(increment*i); + } + inline void set_learn(void (*u)(T& data, base_learner& base, example&)) + { + learn_fd.learn_f = (tlearn)u; + learn_fd.update_f = (tlearn)u; + } + + inline void predict(example& ec, size_t i=0) + { + ec.ft_offset += (uint32_t)(increment*i); + learn_fd.predict_f(learn_fd.data, *learn_fd.base, ec); + ec.ft_offset -= (uint32_t)(increment*i); + } + inline void set_predict(void (*u)(T& data, base_learner& base, example&)) + { learn_fd.predict_f = (tlearn)u; } + + inline void update(example& ec, size_t i=0) + { + ec.ft_offset += (uint32_t)(increment*i); + learn_fd.update_f(learn_fd.data, *learn_fd.base, ec); + ec.ft_offset -= (uint32_t)(increment*i); + } + inline void set_update(void (*u)(T& data, base_learner& base, example&)) + { learn_fd.update_f = (tlearn)u; } + + //called anytime saving or loading needs to happen. Autorecursive. + inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); } + inline void set_save_load(void (*sl)(T&, io_buf&, bool, bool)) + { save_load_fd.save_load_f = (tsl)sl; + save_load_fd.data = learn_fd.data; + save_load_fd.base = learn_fd.base;} + + //called to clean up state. Autorecursive. + void set_finish(void (*f)(T&)) { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } + inline void finish() + { + if (finisher_fd.data) + {finisher_fd.func(finisher_fd.data); free(finisher_fd.data); } + if (finisher_fd.base) { + finisher_fd.base->finish(); + delete finisher_fd.base; + } + } + + void end_pass(){ + end_pass_fd.func(end_pass_fd.data); + if (end_pass_fd.base) end_pass_fd.base->end_pass(); }//autorecursive + void set_end_pass(void (*f)(T&)) + {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, (tfunc)f);} + + //called after parsing of examples is complete. Autorecursive. + void end_examples() + { end_examples_fd.func(end_examples_fd.data); + if (end_examples_fd.base) end_examples_fd.base->end_examples(); } + void set_end_examples(void (*f)(T&)) + {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f);} + + //Called at the beginning by the driver. Explicitly not recursive. + void init_driver() { init_fd.func(init_fd.data);} + void set_init_driver(void (*f)(T&)) + { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } + + //called after learn example for each example. Explicitly not recursive. + inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);} + void set_finish_example(void (*f)(vw& all, T&, example&)) + {finish_example_fd.data = learn_fd.data; + finish_example_fd.finish_example_f = (tend_example)f;} + + friend learner<T>& init_learner<>(); + friend learner<T>& init_learner<>(T* dat, size_t params_per_weight); + friend learner<T>& init_learner<>(T* dat, base_learner* base, size_t ws); + }; + + template<class T> learner<T>& init_learner() + { + learner<T>& ret = calloc_or_die<learner<T> >(); + ret.weights = 1; + ret.increment = 1; + ret.learn_fd = LEARNER::generic_learn_fd; + ret.finish_example_fd.data = NULL; + ret.finish_example_fd.finish_example_f = return_simple_example; + ret.end_pass_fd = LEARNER::generic_func_fd; + ret.end_examples_fd = LEARNER::generic_func_fd; + ret.init_fd = LEARNER::generic_func_fd; + ret.finisher_fd = LEARNER::generic_func_fd; + ret.save_load_fd = LEARNER::generic_save_load_fd; + return ret; } - } - - void end_pass(){ - end_pass_fd.func(end_pass_fd.data); - if (end_pass_fd.base) end_pass_fd.base->end_pass(); }//autorecursive - void set_end_pass(void (*f)(T&)) {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, (tfunc)f);} - - //called after parsing of examples is complete. Autorecursive. - void end_examples() - { end_examples_fd.func(end_examples_fd.data); - if (end_examples_fd.base) end_examples_fd.base->end_examples(); } - void set_end_examples(void (*f)(T&)) {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f);} - - //Called at the beginning by the driver. Explicitly not recursive. - void init_driver() { init_fd.func(init_fd.data);} - void set_init_driver(void (*f)(T&)) { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } - - //called after learn example for each example. Explicitly not recursive. - inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);} - void set_finish_example(void (*f)(vw& all, T&, example&)) - {finish_example_fd.data = learn_fd.data; - finish_example_fd.finish_example_f = (tend_example)f;} - - inline learner() - { - weights = 1; - increment = 1; - - learn_fd = LEARNER::generic_learn_fd; - finish_example_fd.data = NULL; - finish_example_fd.finish_example_f = return_simple_example; - end_pass_fd = LEARNER::generic_func_fd; - end_examples_fd = LEARNER::generic_func_fd; - init_fd = LEARNER::generic_func_fd; - finisher_fd = LEARNER::generic_func_fd; - save_load_fd = LEARNER::generic_save_load_fd; - } - - inline learner(T* dat, size_t params_per_weight) - { // the constructor for all learning algorithms. - *this = learner(); - - learn_fd.data = dat; - - finisher_fd.data = dat; - finisher_fd.base = NULL; - finisher_fd.func = LEARNER::generic_func; - - increment = params_per_weight; - } - - inline learner(T* dat, base_learner* base, size_t ws = 1) - { //the reduction constructor, with separate learn and predict functions - *this = *(learner<T>*)base; - - learn_fd.data = dat; - learn_fd.base = base; - - finisher_fd.data = dat; - finisher_fd.base = base; - finisher_fd.func = LEARNER::generic_func; - - weights = ws; - increment = base->increment * weights; - } -}; - - template<class T> base_learner* make_base(learner<T>* base) - { return (base_learner*)base; } + + template<class T> learner<T>& init_learner(T* dat, size_t params_per_weight) + { // the constructor for all learning algorithms. + learner<T>& ret = init_learner<T>(); + + ret.learn_fd.data = dat; + + ret.finisher_fd.data = dat; + ret.finisher_fd.base = NULL; + ret.finisher_fd.func = LEARNER::generic_func; + + ret.increment = params_per_weight; + return ret; + } + + template<class T> learner<T>& init_learner(T* dat, base_learner* base, size_t ws = 1) + { //the reduction constructor, with separate learn and predict functions + learner<T>& ret = calloc_or_die<learner<T> >(); + ret = *(learner<T>*)base; + + ret.learn_fd.data = dat; + ret.learn_fd.base = base; + + ret.finisher_fd.data = dat; + ret.finisher_fd.base = base; + ret.finisher_fd.func = LEARNER::generic_func; + + ret.weights = ws; + ret.increment = base->increment * ret.weights; + return ret; + } + + template<class T> base_learner* make_base(learner<T>* base) + { return (base_learner*)base; } } diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index 2378106f..c50a73a0 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -531,15 +531,15 @@ namespace LOG_MULTI data->max_predictors = data->k - 1;
- learner<log_multi>* l = new learner<log_multi>(data, all.l, data->max_predictors);
- l->set_save_load(save_load_tree);
- l->set_learn(learn);
- l->set_predict(predict);
- l->set_finish_example(finish_example);
- l->set_finish(finish);
+ learner<log_multi>& l = init_learner(data, all.l, data->max_predictors);
+ l.set_save_load(save_load_tree);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_finish_example(finish_example);
+ l.set_finish(finish);
init_tree(*data);
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc index 557411e5..24207ce1 100644 --- a/vowpalwabbit/lrq.cc +++ b/vowpalwabbit/lrq.cc @@ -243,12 +243,12 @@ namespace LRQ { cerr<<endl; all.wpp = all.wpp * (1 + maxk); - learner<LRQstate>* l = new learner<LRQstate>(&lrq, all.l, 1 + maxk); - l->set_learn(predict_or_learn<true>); - l->set_predict(predict_or_learn<false>); - l->set_end_pass(reset_seed); + learner<LRQstate>& l = init_learner(&lrq, all.l, 1 + maxk); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_end_pass(reset_seed); // TODO: leaks memory ? - return make_base(l); + return make_base(&l); } } diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc index 65706731..3e1ac244 100644 --- a/vowpalwabbit/mf.cc +++ b/vowpalwabbit/mf.cc @@ -203,10 +203,10 @@ base_learner* setup(vw& all, po::variables_map& vm) { all.random_positive_weights = true; - learner<mf>* l = new learner<mf>(data, all.l, 2*data->rank+1); - l->set_learn(learn); - l->set_predict(predict<false>); - l->set_finish(finish); - return make_base(l); + learner<mf>& l = init_learner(data, all.l, 2*data->rank+1); + l.set_learn(learn); + l.set_predict(predict<false>); + l.set_finish(finish); + return make_base(&l); } } diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index 4f623381..bcc3ab73 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -365,13 +365,13 @@ CONVERSE: // That's right, I'm using goto. So sue me. n.save_xsubi = n.xsubi; n.increment = all.l->increment;//Indexing of output layer is odd. - learner<nn>* l = new learner<nn>(&n, all.l, n.k+1); - l->set_learn(predict_or_learn<true>); - l->set_predict(predict_or_learn<false>); - l->set_finish(finish); - l->set_finish_example(finish_example); - l->set_end_pass(end_pass); - - return make_base(l); + learner<nn>& l = init_learner(&n, all.l, n.k+1); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish(finish); + l.set_finish_example(finish_example); + l.set_end_pass(end_pass); + + return make_base(&l); } } diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 47ee84b0..7b3ddd1a 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -87,11 +87,11 @@ namespace OAA { data.all = &all; all.p->lp = mc_label; - learner<oaa>* l = new learner<oaa>(&data, all.l, data.k); - l->set_learn(predict_or_learn<true>); - l->set_predict(predict_or_learn<false>); - l->set_finish_example(finish_example); + learner<oaa>& l = init_learner(&data, all.l, data.k); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish_example(finish_example); - return make_base(l); + return make_base(&l); } } diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc index 30f78cff..b8dcddca 100644 --- a/vowpalwabbit/print.cc +++ b/vowpalwabbit/print.cc @@ -54,9 +54,9 @@ namespace PRINT all.reg.weight_mask = (length << all.reg.stride_shift) - 1; all.reg.stride_shift = 0; - learner<print>* ret = new learner<print>(&p, 1); - ret->set_learn(learn); - ret->set_predict(learn); - return make_base(ret); + learner<print>& ret = init_learner(&p, 1); + ret.set_learn(learn); + ret.set_predict(learn); + return make_base(&ret); } } diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index a4b73855..c82fc9df 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -57,25 +57,25 @@ namespace Scorer { vm = add_options(all, link_opts); - learner<scorer>* l = new learner<scorer>(&s, all.l); + learner<scorer>& l = init_learner(&s, all.l); string link = vm["link"].as<string>(); if (!vm.count("link") || link.compare("identity") == 0) { - l->set_learn(predict_or_learn<true, noop> ); - l->set_predict(predict_or_learn<false, noop> ); + l.set_learn(predict_or_learn<true, noop> ); + l.set_predict(predict_or_learn<false, noop> ); } else if (link.compare("logistic") == 0) { all.file_options << " --link=logistic "; - l->set_learn(predict_or_learn<true, logistic> ); - l->set_predict(predict_or_learn<false, logistic>); + l.set_learn(predict_or_learn<true, logistic> ); + l.set_predict(predict_or_learn<false, logistic>); } else if (link.compare("glf1") == 0) { all.file_options << " --link=glf1 "; - l->set_learn(predict_or_learn<true, glf1>); - l->set_predict(predict_or_learn<false, glf1>); + l.set_learn(predict_or_learn<true, glf1>); + l.set_predict(predict_or_learn<false, glf1>); } else { @@ -83,6 +83,6 @@ namespace Scorer { throw exception(); } - return make_base(l); + return make_base(&l); } } diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index ea218da2..8c894f51 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -1981,15 +1981,15 @@ namespace Search { priv.start_clock_time = clock(); - learner<search>* l = new learner<search>(&sch, all.l, priv.total_number_of_policies); - l->set_learn(search_predict_or_learn<true>); - l->set_predict(search_predict_or_learn<false>); - l->set_finish_example(finish_example); - l->set_end_examples(end_examples); - l->set_finish(search_finish); - l->set_end_pass(end_pass); - - return make_base(l); + learner<search>& l = init_learner(&sch, all.l, priv.total_number_of_policies); + l.set_learn(search_predict_or_learn<true>); + l.set_predict(search_predict_or_learn<false>); + l.set_finish_example(finish_example); + l.set_end_examples(end_examples); + l.set_finish(search_finish); + l.set_end_pass(end_pass); + + return make_base(&l); } float action_hamming_loss(action a, const action* A, size_t sz) { diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index 6222b690..9033ee31 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -113,13 +113,13 @@ void end_examples(sender& s) s.all = &all; s.delay_ring = calloc_or_die<example*>(all.p->ring_size); - learner<sender>* l = new learner<sender>(&s, 1); - l->set_learn(learn); - l->set_predict(learn); - l->set_finish(finish); - l->set_finish_example(finish_example); - l->set_end_examples(end_examples); - return make_base(l); + learner<sender>& l = init_learner(&s, 1); + l.set_learn(learn); + l.set_predict(learn); + l.set_finish(finish); + l.set_finish_example(finish_example); + l.set_end_examples(end_examples); + return make_base(&l); } } diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc index 5639366d..6550c16e 100644 --- a/vowpalwabbit/stagewise_poly.cc +++ b/vowpalwabbit/stagewise_poly.cc @@ -698,14 +698,14 @@ namespace StagewisePoly //following is so that saved models know to load us. all.file_options << " --stage_poly"; - learner<stagewise_poly> *l = new learner<stagewise_poly>(&poly, all.l); - l->set_learn(learn); - l->set_predict(predict); - l->set_finish(finish); - l->set_save_load(save_load); - l->set_finish_example(finish_example); - l->set_end_pass(end_pass); - - return make_base(l); + learner<stagewise_poly>& l = init_learner(&poly, all.l); + l.set_learn(learn); + l.set_predict(predict); + l.set_finish(finish); + l.set_save_load(save_load); + l.set_finish_example(finish_example); + l.set_end_pass(end_pass); + + return make_base(&l); } } diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index 70b5c9cd..c4e15bd5 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -119,11 +119,11 @@ namespace TOPK { data.all = &all; - learner<topk>* l = new learner<topk>(&data, all.l); - l->set_learn(predict_or_learn<true>); - l->set_predict(predict_or_learn<false>); - l->set_finish_example(finish_example); + learner<topk>& l = init_learner(&data, all.l); + l.set_learn(predict_or_learn<true>); + l.set_predict(predict_or_learn<false>); + l.set_finish_example(finish_example); - return make_base(l); + return make_base(&l); } } |