Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-12-27 07:17:30 +0300
committerJohn Langford <jl@hunch.net>2014-12-27 07:17:30 +0300
commit0f9e6ce12e7f0c9c28a75acadd8e7bc50f932b0c (patch)
tree3f1943142e466da5425fbf46cbca3c905bb52e85 /vowpalwabbit
parent33f25c5af9538b77e4cc43d800c1e4cae4995322 (diff)
remove new for learner
Diffstat (limited to 'vowpalwabbit')
-rw-r--r--vowpalwabbit/active.cc14
-rw-r--r--vowpalwabbit/autolink.cc8
-rw-r--r--vowpalwabbit/bfgs.cc18
-rw-r--r--vowpalwabbit/binary.cc8
-rw-r--r--vowpalwabbit/bs.cc12
-rw-r--r--vowpalwabbit/cb_algs.cc22
-rw-r--r--vowpalwabbit/cbify.cc8
-rw-r--r--vowpalwabbit/csoaa.cc28
-rw-r--r--vowpalwabbit/ect.cc12
-rw-r--r--vowpalwabbit/ftrl_proximal.cc10
-rw-r--r--vowpalwabbit/gd.cc28
-rw-r--r--vowpalwabbit/gd_mf.cc12
-rw-r--r--vowpalwabbit/global_data.h2
-rw-r--r--vowpalwabbit/kernel_svm.cc12
-rw-r--r--vowpalwabbit/lda_core.cc18
-rw-r--r--vowpalwabbit/learner.h283
-rw-r--r--vowpalwabbit/log_multi.cc14
-rw-r--r--vowpalwabbit/lrq.cc10
-rw-r--r--vowpalwabbit/mf.cc10
-rw-r--r--vowpalwabbit/nn.cc16
-rw-r--r--vowpalwabbit/oaa.cc10
-rw-r--r--vowpalwabbit/print.cc8
-rw-r--r--vowpalwabbit/scorer.cc16
-rw-r--r--vowpalwabbit/search.cc18
-rw-r--r--vowpalwabbit/sender.cc14
-rw-r--r--vowpalwabbit/stagewise_poly.cc18
-rw-r--r--vowpalwabbit/topk.cc10
27 files changed, 324 insertions, 315 deletions
diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc
index a6337641..c1f7f985 100644
--- a/vowpalwabbit/active.cc
+++ b/vowpalwabbit/active.cc
@@ -166,20 +166,20 @@ namespace ACTIVE {
data.all=&all;
//Create new learner
- learner<active>* ret = new learner<active>(&data, all.l);
+ learner<active>& ret = init_learner(&data, all.l);
if (vm.count("simulation"))
{
- ret->set_learn(predict_or_learn_simulation<true>);
- ret->set_predict(predict_or_learn_simulation<false>);
+ ret.set_learn(predict_or_learn_simulation<true>);
+ ret.set_predict(predict_or_learn_simulation<false>);
}
else
{
all.active = true;
- ret->set_learn(predict_or_learn_active<true>);
- ret->set_predict(predict_or_learn_active<false>);
- ret->set_finish_example(return_active_example);
+ ret.set_learn(predict_or_learn_active<true>);
+ ret.set_predict(predict_or_learn_active<false>);
+ ret.set_finish_example(return_active_example);
}
- return make_base(ret);
+ return make_base(&ret);
}
}
diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc
index 57d6a247..d0abeb70 100644
--- a/vowpalwabbit/autolink.cc
+++ b/vowpalwabbit/autolink.cc
@@ -49,9 +49,9 @@ namespace ALINK {
all.file_options << " --autolink " << data.d;
- learner<autolink>* ret = new learner<autolink>(&data, all.l);
- ret->set_learn(predict_or_learn<true>);
- ret->set_predict(predict_or_learn<false>);
- return make_base(ret);
+ learner<autolink>& ret = init_learner(&data, all.l);
+ ret.set_learn(predict_or_learn<true>);
+ ret.set_predict(predict_or_learn<false>);
+ return make_base(&ret);
}
}
diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc
index a5d4110b..3b6143ab 100644
--- a/vowpalwabbit/bfgs.cc
+++ b/vowpalwabbit/bfgs.cc
@@ -1018,14 +1018,14 @@ base_learner* setup(vw& all, po::variables_map& vm)
all.bfgs = true;
all.reg.stride_shift = 2;
- learner<bfgs>* l = new learner<bfgs>(&b, 1 << all.reg.stride_shift);
- l->set_learn(learn);
- l->set_predict(predict);
- l->set_save_load(save_load);
- l->set_init_driver(init_driver);
- l->set_end_pass(end_pass);
- l->set_finish(finish);
-
- return make_base(l);
+ learner<bfgs>& l = init_learner(&b, 1 << all.reg.stride_shift);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_init_driver(init_driver);
+ l.set_end_pass(end_pass);
+ l.set_finish(finish);
+
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc
index 50c8875e..52556af6 100644
--- a/vowpalwabbit/binary.cc
+++ b/vowpalwabbit/binary.cc
@@ -28,9 +28,9 @@ namespace BINARY {
{//parse and set arguments
all.sd->binary_label = true;
//Create new learner
- learner<char>* ret = new learner<char>(NULL, all.l);
- ret->set_learn(predict_or_learn<true>);
- ret->set_predict(predict_or_learn<false>);
- return make_base(ret);
+ learner<char>& ret = init_learner<char>(NULL, all.l);
+ ret.set_learn(predict_or_learn<true>);
+ ret.set_predict(predict_or_learn<false>);
+ return make_base(&ret);
}
}
diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc
index 8382d04d..564ff3aa 100644
--- a/vowpalwabbit/bs.cc
+++ b/vowpalwabbit/bs.cc
@@ -280,12 +280,12 @@ namespace BS {
data.pred_vec.reserve(data.B);
data.all = &all;
- learner<bs>* l = new learner<bs>(&data, all.l, data.B);
- l->set_learn(predict_or_learn<true>);
- l->set_predict(predict_or_learn<false>);
- l->set_finish_example(finish_example);
- l->set_finish(finish);
+ learner<bs>& l = init_learner(&data, all.l, data.B);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish_example(finish_example);
+ l.set_finish(finish);
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc
index c02ae6ed..2cc3cd7b 100644
--- a/vowpalwabbit/cb_algs.cc
+++ b/vowpalwabbit/cb_algs.cc
@@ -564,25 +564,25 @@ namespace CB_ALGS
else
all.p->lp = CB::cb_label;
- learner<cb>* l = new learner<cb>(&c, all.l, problem_multiplier);
+ learner<cb>& l = init_learner(&c, all.l, problem_multiplier);
if (eval)
{
- l->set_learn(learn_eval);
- l->set_predict(predict_eval);
- l->set_finish_example(eval_finish_example);
+ l.set_learn(learn_eval);
+ l.set_predict(predict_eval);
+ l.set_finish_example(eval_finish_example);
}
else
{
- l->set_learn(predict_or_learn<true>);
- l->set_predict(predict_or_learn<false>);
- l->set_finish_example(finish_example);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish_example(finish_example);
}
- l->set_init_driver(init_driver);
- l->set_finish(finish);
+ l.set_init_driver(init_driver);
+ l.set_finish(finish);
// preserve the increment of the base learner since we are
// _adding_ to the number of problems rather than multiplying.
- l->increment = all.l->increment;
+ l.increment = all.l->increment;
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc
index ab8323ea..9339dabf 100644
--- a/vowpalwabbit/cbify.cc
+++ b/vowpalwabbit/cbify.cc
@@ -409,7 +409,7 @@ namespace CBIFY {
epsilon = vm["epsilon"].as<float>();
data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k));
data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k));
- l = new learner<cbify>(&data, all.l, cover + 1);
+ l = &init_learner(&data, all.l, cover + 1);
l->set_learn(predict_or_learn_cover<true>);
l->set_predict(predict_or_learn_cover<false>);
}
@@ -421,7 +421,7 @@ namespace CBIFY {
data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i)));
}
data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k));
- l = new learner<cbify>(&data, all.l, bags);
+ l = &init_learner(&data, all.l, bags);
l->set_learn(predict_or_learn_bag<true>);
l->set_predict(predict_or_learn_bag<false>);
}
@@ -430,7 +430,7 @@ namespace CBIFY {
uint32_t tau = (uint32_t)vm["first"].as<size_t>();
data.policy.reset(new vw_policy());
data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k));
- l = new learner<cbify>(&data, all.l, 1);
+ l = &init_learner(&data, all.l, 1);
l->set_learn(predict_or_learn_first<true>);
l->set_predict(predict_or_learn_first<false>);
}
@@ -441,7 +441,7 @@ namespace CBIFY {
epsilon = vm["epsilon"].as<float>();
data.policy.reset(new vw_policy());
data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k));
- l = new learner<cbify>(&data, all.l, 1);
+ l = &init_learner(&data, all.l, 1);
l->set_learn(predict_or_learn_greedy<true>);
l->set_predict(predict_or_learn_greedy<false>);
}
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index e14cf8f5..06150b2a 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -82,11 +82,11 @@ namespace CSOAA {
all.p->lp = cs_label;
all.sd->k = nb_actions;
- learner<csoaa>* l = new learner<csoaa>(&c, all.l, nb_actions);
- l->set_learn(predict_or_learn<true>);
- l->set_predict(predict_or_learn<false>);
- l->set_finish_example(finish_example);
- return make_base(l);
+ learner<csoaa>& l = init_learner(&c, all.l, nb_actions);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish_example(finish_example);
+ return make_base(&l);
}
}
@@ -709,17 +709,17 @@ namespace LabelDict {
ld.read_example_this_loop = 0;
ld.need_to_clear = false;
- learner<ldf>* l = new learner<ldf>(&ld, all.l);
- l->set_learn(predict_or_learn<true>);
- l->set_predict(predict_or_learn<false>);
+ learner<ldf>& l = init_learner(&ld, all.l);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
if (ld.is_singleline)
- l->set_finish_example(finish_singleline_example);
+ l.set_finish_example(finish_singleline_example);
else
- l->set_finish_example(finish_multiline_example);
- l->set_finish(finish);
- l->set_end_examples(end_examples);
- l->set_end_pass(end_pass);
- return make_base(l);
+ l.set_finish_example(finish_multiline_example);
+ l.set_finish(finish);
+ l.set_end_examples(end_examples);
+ l.set_end_pass(end_pass);
+ return make_base(&l);
}
void global_print_newline(vw& all)
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc
index f847af76..b9f7b536 100644
--- a/vowpalwabbit/ect.cc
+++ b/vowpalwabbit/ect.cc
@@ -390,12 +390,12 @@ namespace ECT
size_t wpp = create_circuit(all, data, data.k, data.errors+1);
data.all = &all;
- learner<ect>* l = new learner<ect>(&data, all.l, wpp);
- l->set_learn(learn);
- l->set_predict(predict);
- l->set_finish_example(finish_example);
- l->set_finish(finish);
+ learner<ect>& l = init_learner(&data, all.l, wpp);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_finish_example(finish_example);
+ l.set_finish(finish);
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc
index a0f437b0..0d3c58ff 100644
--- a/vowpalwabbit/ftrl_proximal.cc
+++ b/vowpalwabbit/ftrl_proximal.cc
@@ -217,12 +217,12 @@ namespace FTRL {
cerr << "ftrl_beta = " << b.ftrl_beta << endl;
}
- learner<ftrl>* l = new learner<ftrl>(&b, 1 << all.reg.stride_shift);
- l->set_learn(learn);
- l->set_predict(predict);
- l->set_save_load(save_load);
+ learner<ftrl>& l = init_learner(&b, 1 << all.reg.stride_shift);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
- return make_base(l);
+ return make_base(&l);
}
diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc
index 8e4b3b2b..8ccc0373 100644
--- a/vowpalwabbit/gd.cc
+++ b/vowpalwabbit/gd.cc
@@ -789,25 +789,25 @@ void save_load(gd& g, io_buf& model_file, bool read, bool text)
}
template<bool invariant, bool sqrt_rate, uint32_t adaptive, uint32_t normalized, uint32_t spare, uint32_t next>
-uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off)
+uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off)
{
all.normalized_idx = normalized;
if (feature_mask_off)
{
- ret->set_learn(learn<invariant, sqrt_rate, true, adaptive, normalized, spare>);
- ret->set_update(update<invariant, sqrt_rate, true, adaptive, normalized, spare>);
+ ret.set_learn(learn<invariant, sqrt_rate, true, adaptive, normalized, spare>);
+ ret.set_update(update<invariant, sqrt_rate, true, adaptive, normalized, spare>);
return next;
}
else
{
- ret->set_learn(learn<invariant, sqrt_rate, false, adaptive, normalized, spare>);
- ret->set_update(update<invariant, sqrt_rate, false, adaptive, normalized, spare>);
+ ret.set_learn(learn<invariant, sqrt_rate, false, adaptive, normalized, spare>);
+ ret.set_update(update<invariant, sqrt_rate, false, adaptive, normalized, spare>);
return next;
}
}
template<bool sqrt_rate, uint32_t adaptive, uint32_t normalized, uint32_t spare, uint32_t next>
-uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off)
+uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off)
{
if (all.invariant_updates)
return set_learn<true, sqrt_rate, adaptive, normalized, spare, next>(all, ret, feature_mask_off);
@@ -816,7 +816,7 @@ uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off)
}
template<bool sqrt_rate, uint32_t adaptive, uint32_t spare>
-uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off)
+uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off)
{
// select the appropriate learn function based on adaptive, normalization, and feature mask
if (all.normalized_updates)
@@ -826,7 +826,7 @@ uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off)
}
template<bool sqrt_rate>
-uint32_t set_learn(vw& all, learner<gd>* ret, bool feature_mask_off)
+uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off)
{
if (all.adaptive)
return set_learn<sqrt_rate, 1, 2>(all, ret, feature_mask_off);
@@ -898,7 +898,7 @@ base_learner* setup(vw& all, po::variables_map& vm)
cerr << "Warning: the learning rate for the last pass is multiplied by: " << pow((double)all.eta_decay_rate, (double)all.numpasses)
<< " adjust --decay_learning_rate larger to avoid this." << endl;
- learner<gd>* ret = new learner<gd>(&g, 1);
+ learner<gd>& ret = init_learner(&g, 1);
if (all.reg_mode % 2)
if (all.audit || all.hash_inv)
@@ -909,7 +909,7 @@ base_learner* setup(vw& all, po::variables_map& vm)
g.predict = predict<false, true>;
else
g.predict = predict<false, false>;
- ret->set_predict(g.predict);
+ ret.set_predict(g.predict);
uint32_t stride;
if (all.power_t == 0.5)
@@ -918,10 +918,10 @@ base_learner* setup(vw& all, po::variables_map& vm)
stride = set_learn<false>(all, ret, feature_mask_off);
all.reg.stride_shift = ceil_log_2(stride-1);
- ret->increment = ((uint64_t)1 << all.reg.stride_shift);
+ ret.increment = ((uint64_t)1 << all.reg.stride_shift);
- ret->set_save_load(save_load);
- ret->set_end_pass(end_pass);
- return make_base(ret);
+ ret.set_save_load(save_load);
+ ret.set_end_pass(end_pass);
+ return make_base(&ret);
}
}
diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc
index 1ae18985..ed5e7df9 100644
--- a/vowpalwabbit/gd_mf.cc
+++ b/vowpalwabbit/gd_mf.cc
@@ -330,12 +330,12 @@ void mf_train(vw& all, example& ec)
}
all.eta *= powf((float)(all.sd->t), all.power_t);
- learner<gdmf>* l = new learner<gdmf>(&data, 1 << all.reg.stride_shift);
- l->set_learn(learn);
- l->set_predict(predict);
- l->set_save_load(save_load);
- l->set_end_pass(end_pass);
+ learner<gdmf>& l = init_learner(&data, 1 << all.reg.stride_shift);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_end_pass(end_pass);
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h
index 86bd2d15..f0d4a6f1 100644
--- a/vowpalwabbit/global_data.h
+++ b/vowpalwabbit/global_data.h
@@ -170,7 +170,7 @@ struct vw {
node_socks socks;
- LEARNER::base_learner* l;//the top level learner
+ LEARNER::base_learner& l;//the top level learner
LEARNER::base_learner* scorer;//a scoring function
LEARNER::base_learner* cost_sensitive;//a cost sensitive learning algorithm.
diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc
index bfe1f927..cf9d82d2 100644
--- a/vowpalwabbit/kernel_svm.cc
+++ b/vowpalwabbit/kernel_svm.cc
@@ -898,11 +898,11 @@ namespace KSVM
params.all->reg.weight_mask = (uint32_t)LONG_MAX;
params.all->reg.stride_shift = 0;
- learner<svm_params>* l = new learner<svm_params>(&params, 1);
- l->set_learn(learn);
- l->set_predict(predict);
- l->set_save_load(save_load);
- l->set_finish(finish);
- return make_base(l);
+ learner<svm_params>& l = init_learner(&params, 1);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_finish(finish);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc
index a8241557..8f27ea78 100644
--- a/vowpalwabbit/lda_core.cc
+++ b/vowpalwabbit/lda_core.cc
@@ -786,15 +786,15 @@ base_learner* setup(vw&all, po::variables_map& vm)
ld.decay_levels.push_back(0.f);
- learner<lda>* l = new learner<lda>(&ld, 1 << all.reg.stride_shift);
- l->set_learn(learn);
- l->set_predict(predict);
- l->set_save_load(save_load);
- l->set_finish_example(finish_example);
- l->set_end_examples(end_examples);
- l->set_end_pass(end_pass);
- l->set_finish(finish);
+ learner<lda>& l = init_learner(&ld, 1 << all.reg.stride_shift);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_finish_example(finish_example);
+ l.set_end_examples(end_examples);
+ l.set_end_pass(end_pass);
+ l.set_finish(finish);
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h
index 0edeecb4..ef76f14d 100644
--- a/vowpalwabbit/learner.h
+++ b/vowpalwabbit/learner.h
@@ -6,6 +6,7 @@ license as described in the file LICENSE.
#pragma once
// This is the interface for a learning algorithm
#include<iostream>
+#include"memory.h"
using namespace std;
struct vw;
@@ -62,147 +63,155 @@ namespace LEARNER
const func_data generic_func_fd = {NULL, NULL, generic_func};
typedef void (*tlearn)(void* d, base_learner& base, example& ec);
-
typedef void (*tsl)(void* d, io_buf& io, bool read, bool text);
-
typedef void (*tfunc)(void*d);
-
typedef void (*tend_example)(vw& all, void* d, example& ec);
-template<class T>
-struct learner {
-private:
- func_data init_fd;
- learn_data learn_fd;
- finish_example_data finish_example_fd;
- save_load_data save_load_fd;
- func_data end_pass_fd;
- func_data end_examples_fd;
- func_data finisher_fd;
+ template<class T> learner<T>& init_learner();
+ template<class T> learner<T>& init_learner(T* dat, size_t params_per_weight);
+ template<class T> learner<T>& init_learner(T* dat, base_learner* base, size_t ws = 1);
-public:
- size_t weights; //this stores the number of "weight vectors" required by the learner.
- size_t increment;
-
- //called once for each example. Must work under reduction.
- inline void learn(example& ec, size_t i=0)
- {
- ec.ft_offset += (uint32_t)(increment*i);
- learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec);
- ec.ft_offset -= (uint32_t)(increment*i);
- }
- inline void set_learn(void (*u)(T& data, base_learner& base, example&))
- {
- learn_fd.learn_f = (tlearn)u;
- learn_fd.update_f = (tlearn)u;
- }
-
- inline void predict(example& ec, size_t i=0)
- {
- ec.ft_offset += (uint32_t)(increment*i);
- learn_fd.predict_f(learn_fd.data, *learn_fd.base, ec);
- ec.ft_offset -= (uint32_t)(increment*i);
- }
- inline void set_predict(void (*u)(T& data, base_learner& base, example&))
- {
- learn_fd.predict_f = (tlearn)u;
- }
-
- inline void update(example& ec, size_t i=0)
- {
- ec.ft_offset += (uint32_t)(increment*i);
- learn_fd.update_f(learn_fd.data, *learn_fd.base, ec);
- ec.ft_offset -= (uint32_t)(increment*i);
- }
- inline void set_update(void (*u)(T& data, base_learner& base, example&))
- {
- learn_fd.update_f = (tlearn)u;
- }
-
- //called anytime saving or loading needs to happen. Autorecursive.
- inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); }
- inline void set_save_load(void (*sl)(T&, io_buf&, bool, bool))
- { save_load_fd.save_load_f = (tsl)sl;
- save_load_fd.data = learn_fd.data;
- save_load_fd.base = learn_fd.base;}
-
- //called to clean up state. Autorecursive.
- void set_finish(void (*f)(T&)) { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); }
- inline void finish()
- {
- if (finisher_fd.data)
- {finisher_fd.func(finisher_fd.data); free(finisher_fd.data); }
- if (finisher_fd.base) {
- finisher_fd.base->finish();
- delete finisher_fd.base;
+ template<class T>
+ struct learner {
+ private:
+ func_data init_fd;
+ learn_data learn_fd;
+ finish_example_data finish_example_fd;
+ save_load_data save_load_fd;
+ func_data end_pass_fd;
+ func_data end_examples_fd;
+ func_data finisher_fd;
+
+ public:
+ size_t weights; //this stores the number of "weight vectors" required by the learner.
+ size_t increment;
+
+ //called once for each example. Must work under reduction.
+ inline void learn(example& ec, size_t i=0)
+ {
+ ec.ft_offset += (uint32_t)(increment*i);
+ learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec);
+ ec.ft_offset -= (uint32_t)(increment*i);
+ }
+ inline void set_learn(void (*u)(T& data, base_learner& base, example&))
+ {
+ learn_fd.learn_f = (tlearn)u;
+ learn_fd.update_f = (tlearn)u;
+ }
+
+ inline void predict(example& ec, size_t i=0)
+ {
+ ec.ft_offset += (uint32_t)(increment*i);
+ learn_fd.predict_f(learn_fd.data, *learn_fd.base, ec);
+ ec.ft_offset -= (uint32_t)(increment*i);
+ }
+ inline void set_predict(void (*u)(T& data, base_learner& base, example&))
+ { learn_fd.predict_f = (tlearn)u; }
+
+ inline void update(example& ec, size_t i=0)
+ {
+ ec.ft_offset += (uint32_t)(increment*i);
+ learn_fd.update_f(learn_fd.data, *learn_fd.base, ec);
+ ec.ft_offset -= (uint32_t)(increment*i);
+ }
+ inline void set_update(void (*u)(T& data, base_learner& base, example&))
+ { learn_fd.update_f = (tlearn)u; }
+
+ //called anytime saving or loading needs to happen. Autorecursive.
+ inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); }
+ inline void set_save_load(void (*sl)(T&, io_buf&, bool, bool))
+ { save_load_fd.save_load_f = (tsl)sl;
+ save_load_fd.data = learn_fd.data;
+ save_load_fd.base = learn_fd.base;}
+
+ //called to clean up state. Autorecursive.
+ void set_finish(void (*f)(T&)) { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); }
+ inline void finish()
+ {
+ if (finisher_fd.data)
+ {finisher_fd.func(finisher_fd.data); free(finisher_fd.data); }
+ if (finisher_fd.base) {
+ finisher_fd.base->finish();
+ delete finisher_fd.base;
+ }
+ }
+
+ void end_pass(){
+ end_pass_fd.func(end_pass_fd.data);
+ if (end_pass_fd.base) end_pass_fd.base->end_pass(); }//autorecursive
+ void set_end_pass(void (*f)(T&))
+ {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, (tfunc)f);}
+
+ //called after parsing of examples is complete. Autorecursive.
+ void end_examples()
+ { end_examples_fd.func(end_examples_fd.data);
+ if (end_examples_fd.base) end_examples_fd.base->end_examples(); }
+ void set_end_examples(void (*f)(T&))
+ {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f);}
+
+ //Called at the beginning by the driver. Explicitly not recursive.
+ void init_driver() { init_fd.func(init_fd.data);}
+ void set_init_driver(void (*f)(T&))
+ { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); }
+
+ //called after learn example for each example. Explicitly not recursive.
+ inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);}
+ void set_finish_example(void (*f)(vw& all, T&, example&))
+ {finish_example_fd.data = learn_fd.data;
+ finish_example_fd.finish_example_f = (tend_example)f;}
+
+ friend learner<T>& init_learner<>();
+ friend learner<T>& init_learner<>(T* dat, size_t params_per_weight);
+ friend learner<T>& init_learner<>(T* dat, base_learner* base, size_t ws);
+ };
+
+ template<class T> learner<T>& init_learner()
+ {
+ learner<T>& ret = calloc_or_die<learner<T> >();
+ ret.weights = 1;
+ ret.increment = 1;
+ ret.learn_fd = LEARNER::generic_learn_fd;
+ ret.finish_example_fd.data = NULL;
+ ret.finish_example_fd.finish_example_f = return_simple_example;
+ ret.end_pass_fd = LEARNER::generic_func_fd;
+ ret.end_examples_fd = LEARNER::generic_func_fd;
+ ret.init_fd = LEARNER::generic_func_fd;
+ ret.finisher_fd = LEARNER::generic_func_fd;
+ ret.save_load_fd = LEARNER::generic_save_load_fd;
+ return ret;
}
- }
-
- void end_pass(){
- end_pass_fd.func(end_pass_fd.data);
- if (end_pass_fd.base) end_pass_fd.base->end_pass(); }//autorecursive
- void set_end_pass(void (*f)(T&)) {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, (tfunc)f);}
-
- //called after parsing of examples is complete. Autorecursive.
- void end_examples()
- { end_examples_fd.func(end_examples_fd.data);
- if (end_examples_fd.base) end_examples_fd.base->end_examples(); }
- void set_end_examples(void (*f)(T&)) {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f);}
-
- //Called at the beginning by the driver. Explicitly not recursive.
- void init_driver() { init_fd.func(init_fd.data);}
- void set_init_driver(void (*f)(T&)) { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); }
-
- //called after learn example for each example. Explicitly not recursive.
- inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);}
- void set_finish_example(void (*f)(vw& all, T&, example&))
- {finish_example_fd.data = learn_fd.data;
- finish_example_fd.finish_example_f = (tend_example)f;}
-
- inline learner()
- {
- weights = 1;
- increment = 1;
-
- learn_fd = LEARNER::generic_learn_fd;
- finish_example_fd.data = NULL;
- finish_example_fd.finish_example_f = return_simple_example;
- end_pass_fd = LEARNER::generic_func_fd;
- end_examples_fd = LEARNER::generic_func_fd;
- init_fd = LEARNER::generic_func_fd;
- finisher_fd = LEARNER::generic_func_fd;
- save_load_fd = LEARNER::generic_save_load_fd;
- }
-
- inline learner(T* dat, size_t params_per_weight)
- { // the constructor for all learning algorithms.
- *this = learner();
-
- learn_fd.data = dat;
-
- finisher_fd.data = dat;
- finisher_fd.base = NULL;
- finisher_fd.func = LEARNER::generic_func;
-
- increment = params_per_weight;
- }
-
- inline learner(T* dat, base_learner* base, size_t ws = 1)
- { //the reduction constructor, with separate learn and predict functions
- *this = *(learner<T>*)base;
-
- learn_fd.data = dat;
- learn_fd.base = base;
-
- finisher_fd.data = dat;
- finisher_fd.base = base;
- finisher_fd.func = LEARNER::generic_func;
-
- weights = ws;
- increment = base->increment * weights;
- }
-};
-
- template<class T> base_learner* make_base(learner<T>* base)
- { return (base_learner*)base; }
+
+ template<class T> learner<T>& init_learner(T* dat, size_t params_per_weight)
+ { // the constructor for all learning algorithms.
+ learner<T>& ret = init_learner<T>();
+
+ ret.learn_fd.data = dat;
+
+ ret.finisher_fd.data = dat;
+ ret.finisher_fd.base = NULL;
+ ret.finisher_fd.func = LEARNER::generic_func;
+
+ ret.increment = params_per_weight;
+ return ret;
+ }
+
+ template<class T> learner<T>& init_learner(T* dat, base_learner* base, size_t ws = 1)
+ { //the reduction constructor, with separate learn and predict functions
+ learner<T>& ret = calloc_or_die<learner<T> >();
+ ret = *(learner<T>*)base;
+
+ ret.learn_fd.data = dat;
+ ret.learn_fd.base = base;
+
+ ret.finisher_fd.data = dat;
+ ret.finisher_fd.base = base;
+ ret.finisher_fd.func = LEARNER::generic_func;
+
+ ret.weights = ws;
+ ret.increment = base->increment * ret.weights;
+ return ret;
+ }
+
+ template<class T> base_learner* make_base(learner<T>* base)
+ { return (base_learner*)base; }
}
diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc
index 2378106f..c50a73a0 100644
--- a/vowpalwabbit/log_multi.cc
+++ b/vowpalwabbit/log_multi.cc
@@ -531,15 +531,15 @@ namespace LOG_MULTI
data->max_predictors = data->k - 1;
- learner<log_multi>* l = new learner<log_multi>(data, all.l, data->max_predictors);
- l->set_save_load(save_load_tree);
- l->set_learn(learn);
- l->set_predict(predict);
- l->set_finish_example(finish_example);
- l->set_finish(finish);
+ learner<log_multi>& l = init_learner(data, all.l, data->max_predictors);
+ l.set_save_load(save_load_tree);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_finish_example(finish_example);
+ l.set_finish(finish);
init_tree(*data);
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc
index 557411e5..24207ce1 100644
--- a/vowpalwabbit/lrq.cc
+++ b/vowpalwabbit/lrq.cc
@@ -243,12 +243,12 @@ namespace LRQ {
cerr<<endl;
all.wpp = all.wpp * (1 + maxk);
- learner<LRQstate>* l = new learner<LRQstate>(&lrq, all.l, 1 + maxk);
- l->set_learn(predict_or_learn<true>);
- l->set_predict(predict_or_learn<false>);
- l->set_end_pass(reset_seed);
+ learner<LRQstate>& l = init_learner(&lrq, all.l, 1 + maxk);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_end_pass(reset_seed);
// TODO: leaks memory ?
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc
index 65706731..3e1ac244 100644
--- a/vowpalwabbit/mf.cc
+++ b/vowpalwabbit/mf.cc
@@ -203,10 +203,10 @@ base_learner* setup(vw& all, po::variables_map& vm) {
all.random_positive_weights = true;
- learner<mf>* l = new learner<mf>(data, all.l, 2*data->rank+1);
- l->set_learn(learn);
- l->set_predict(predict<false>);
- l->set_finish(finish);
- return make_base(l);
+ learner<mf>& l = init_learner(data, all.l, 2*data->rank+1);
+ l.set_learn(learn);
+ l.set_predict(predict<false>);
+ l.set_finish(finish);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc
index 4f623381..bcc3ab73 100644
--- a/vowpalwabbit/nn.cc
+++ b/vowpalwabbit/nn.cc
@@ -365,13 +365,13 @@ CONVERSE: // That's right, I'm using goto. So sue me.
n.save_xsubi = n.xsubi;
n.increment = all.l->increment;//Indexing of output layer is odd.
- learner<nn>* l = new learner<nn>(&n, all.l, n.k+1);
- l->set_learn(predict_or_learn<true>);
- l->set_predict(predict_or_learn<false>);
- l->set_finish(finish);
- l->set_finish_example(finish_example);
- l->set_end_pass(end_pass);
-
- return make_base(l);
+ learner<nn>& l = init_learner(&n, all.l, n.k+1);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish(finish);
+ l.set_finish_example(finish_example);
+ l.set_end_pass(end_pass);
+
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc
index 47ee84b0..7b3ddd1a 100644
--- a/vowpalwabbit/oaa.cc
+++ b/vowpalwabbit/oaa.cc
@@ -87,11 +87,11 @@ namespace OAA {
data.all = &all;
all.p->lp = mc_label;
- learner<oaa>* l = new learner<oaa>(&data, all.l, data.k);
- l->set_learn(predict_or_learn<true>);
- l->set_predict(predict_or_learn<false>);
- l->set_finish_example(finish_example);
+ learner<oaa>& l = init_learner(&data, all.l, data.k);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish_example(finish_example);
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc
index 30f78cff..b8dcddca 100644
--- a/vowpalwabbit/print.cc
+++ b/vowpalwabbit/print.cc
@@ -54,9 +54,9 @@ namespace PRINT
all.reg.weight_mask = (length << all.reg.stride_shift) - 1;
all.reg.stride_shift = 0;
- learner<print>* ret = new learner<print>(&p, 1);
- ret->set_learn(learn);
- ret->set_predict(learn);
- return make_base(ret);
+ learner<print>& ret = init_learner(&p, 1);
+ ret.set_learn(learn);
+ ret.set_predict(learn);
+ return make_base(&ret);
}
}
diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc
index a4b73855..c82fc9df 100644
--- a/vowpalwabbit/scorer.cc
+++ b/vowpalwabbit/scorer.cc
@@ -57,25 +57,25 @@ namespace Scorer {
vm = add_options(all, link_opts);
- learner<scorer>* l = new learner<scorer>(&s, all.l);
+ learner<scorer>& l = init_learner(&s, all.l);
string link = vm["link"].as<string>();
if (!vm.count("link") || link.compare("identity") == 0)
{
- l->set_learn(predict_or_learn<true, noop> );
- l->set_predict(predict_or_learn<false, noop> );
+ l.set_learn(predict_or_learn<true, noop> );
+ l.set_predict(predict_or_learn<false, noop> );
}
else if (link.compare("logistic") == 0)
{
all.file_options << " --link=logistic ";
- l->set_learn(predict_or_learn<true, logistic> );
- l->set_predict(predict_or_learn<false, logistic>);
+ l.set_learn(predict_or_learn<true, logistic> );
+ l.set_predict(predict_or_learn<false, logistic>);
}
else if (link.compare("glf1") == 0)
{
all.file_options << " --link=glf1 ";
- l->set_learn(predict_or_learn<true, glf1>);
- l->set_predict(predict_or_learn<false, glf1>);
+ l.set_learn(predict_or_learn<true, glf1>);
+ l.set_predict(predict_or_learn<false, glf1>);
}
else
{
@@ -83,6 +83,6 @@ namespace Scorer {
throw exception();
}
- return make_base(l);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc
index ea218da2..8c894f51 100644
--- a/vowpalwabbit/search.cc
+++ b/vowpalwabbit/search.cc
@@ -1981,15 +1981,15 @@ namespace Search {
priv.start_clock_time = clock();
- learner<search>* l = new learner<search>(&sch, all.l, priv.total_number_of_policies);
- l->set_learn(search_predict_or_learn<true>);
- l->set_predict(search_predict_or_learn<false>);
- l->set_finish_example(finish_example);
- l->set_end_examples(end_examples);
- l->set_finish(search_finish);
- l->set_end_pass(end_pass);
-
- return make_base(l);
+ learner<search>& l = init_learner(&sch, all.l, priv.total_number_of_policies);
+ l.set_learn(search_predict_or_learn<true>);
+ l.set_predict(search_predict_or_learn<false>);
+ l.set_finish_example(finish_example);
+ l.set_end_examples(end_examples);
+ l.set_finish(search_finish);
+ l.set_end_pass(end_pass);
+
+ return make_base(&l);
}
float action_hamming_loss(action a, const action* A, size_t sz) {
diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc
index 6222b690..9033ee31 100644
--- a/vowpalwabbit/sender.cc
+++ b/vowpalwabbit/sender.cc
@@ -113,13 +113,13 @@ void end_examples(sender& s)
s.all = &all;
s.delay_ring = calloc_or_die<example*>(all.p->ring_size);
- learner<sender>* l = new learner<sender>(&s, 1);
- l->set_learn(learn);
- l->set_predict(learn);
- l->set_finish(finish);
- l->set_finish_example(finish_example);
- l->set_end_examples(end_examples);
- return make_base(l);
+ learner<sender>& l = init_learner(&s, 1);
+ l.set_learn(learn);
+ l.set_predict(learn);
+ l.set_finish(finish);
+ l.set_finish_example(finish_example);
+ l.set_end_examples(end_examples);
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc
index 5639366d..6550c16e 100644
--- a/vowpalwabbit/stagewise_poly.cc
+++ b/vowpalwabbit/stagewise_poly.cc
@@ -698,14 +698,14 @@ namespace StagewisePoly
//following is so that saved models know to load us.
all.file_options << " --stage_poly";
- learner<stagewise_poly> *l = new learner<stagewise_poly>(&poly, all.l);
- l->set_learn(learn);
- l->set_predict(predict);
- l->set_finish(finish);
- l->set_save_load(save_load);
- l->set_finish_example(finish_example);
- l->set_end_pass(end_pass);
-
- return make_base(l);
+ learner<stagewise_poly>& l = init_learner(&poly, all.l);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_finish(finish);
+ l.set_save_load(save_load);
+ l.set_finish_example(finish_example);
+ l.set_end_pass(end_pass);
+
+ return make_base(&l);
}
}
diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc
index 70b5c9cd..c4e15bd5 100644
--- a/vowpalwabbit/topk.cc
+++ b/vowpalwabbit/topk.cc
@@ -119,11 +119,11 @@ namespace TOPK {
data.all = &all;
- learner<topk>* l = new learner<topk>(&data, all.l);
- l->set_learn(predict_or_learn<true>);
- l->set_predict(predict_or_learn<false>);
- l->set_finish_example(finish_example);
+ learner<topk>& l = init_learner(&data, all.l);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish_example(finish_example);
- return make_base(l);
+ return make_base(&l);
}
}