Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-12-27 19:50:53 +0300
committerJohn Langford <jl@hunch.net>2014-12-27 19:50:53 +0300
commit3018c85ccf02127c1cc94c5ba2bc491aa2b45a9d (patch)
treed4384af3a9385edcdd6671e2c21f1fa931694360
parent74baf926ce5e5eb79c9ae50d150638176132bb02 (diff)
parent43226da4724fe3aec7eb560264e3342e442e3fbd (diff)
fix conflicts
-rw-r--r--vowpalwabbit/active.cc20
-rw-r--r--vowpalwabbit/active.h2
-rw-r--r--vowpalwabbit/autolink.cc14
-rw-r--r--vowpalwabbit/autolink.h2
-rw-r--r--vowpalwabbit/bfgs.cc22
-rw-r--r--vowpalwabbit/bfgs.h2
-rw-r--r--vowpalwabbit/binary.cc12
-rw-r--r--vowpalwabbit/binary.h2
-rw-r--r--vowpalwabbit/bs.cc20
-rw-r--r--vowpalwabbit/bs.h2
-rw-r--r--vowpalwabbit/cb_algs.cc36
-rw-r--r--vowpalwabbit/cb_algs.h4
-rw-r--r--vowpalwabbit/cbify.cc88
-rw-r--r--vowpalwabbit/cbify.h2
-rw-r--r--vowpalwabbit/csoaa.cc56
-rw-r--r--vowpalwabbit/csoaa.h4
-rw-r--r--vowpalwabbit/ect.cc24
-rw-r--r--vowpalwabbit/ect.h2
-rw-r--r--vowpalwabbit/ftrl_proximal.cc21
-rw-r--r--vowpalwabbit/ftrl_proximal.h2
-rw-r--r--vowpalwabbit/gd.cc58
-rw-r--r--vowpalwabbit/gd.h2
-rw-r--r--vowpalwabbit/gd_mf.cc18
-rw-r--r--vowpalwabbit/gd_mf.h2
-rw-r--r--vowpalwabbit/global_data.cc2
-rw-r--r--vowpalwabbit/global_data.h10
-rw-r--r--vowpalwabbit/kernel_svm.cc26
-rw-r--r--vowpalwabbit/kernel_svm.h4
-rw-r--r--vowpalwabbit/lda_core.cc44
-rw-r--r--vowpalwabbit/lda_core.h2
-rw-r--r--vowpalwabbit/learner.h334
-rw-r--r--vowpalwabbit/log_multi.cc50
-rw-r--r--vowpalwabbit/log_multi.h2
-rw-r--r--vowpalwabbit/lrq.cc46
-rw-r--r--vowpalwabbit/lrq.h2
-rw-r--r--vowpalwabbit/mf.cc26
-rw-r--r--vowpalwabbit/mf.h2
-rw-r--r--vowpalwabbit/nn.cc26
-rw-r--r--vowpalwabbit/nn.h2
-rw-r--r--vowpalwabbit/noop.cc4
-rw-r--r--vowpalwabbit/noop.h5
-rw-r--r--vowpalwabbit/oaa.cc31
-rw-r--r--vowpalwabbit/oaa.h4
-rw-r--r--vowpalwabbit/parse_args.cc22
-rw-r--r--vowpalwabbit/parse_args.h2
-rw-r--r--vowpalwabbit/parse_regressor.cc10
-rw-r--r--vowpalwabbit/print.cc12
-rw-r--r--vowpalwabbit/print.h5
-rw-r--r--vowpalwabbit/scorer.cc24
-rw-r--r--vowpalwabbit/scorer.h2
-rw-r--r--vowpalwabbit/search.cc30
-rw-r--r--vowpalwabbit/search.h2
-rw-r--r--vowpalwabbit/sender.cc19
-rw-r--r--vowpalwabbit/sender.h5
-rw-r--r--vowpalwabbit/stagewise_poly.cc24
-rw-r--r--vowpalwabbit/stagewise_poly.h2
-rw-r--r--vowpalwabbit/topk.cc14
-rw-r--r--vowpalwabbit/topk.h2
-rw-r--r--vowpalwabbit/vw.h2
59 files changed, 586 insertions, 630 deletions
diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc
index 1e87208e..1c6f1c3c 100644
--- a/vowpalwabbit/active.cc
+++ b/vowpalwabbit/active.cc
@@ -45,7 +45,7 @@ namespace ACTIVE {
}
template <bool is_learn>
- void predict_or_learn_simulation(active& a, learner& base, example& ec) {
+ void predict_or_learn_simulation(active& a, base_learner& base, example& ec) {
base.predict(ec);
if (is_learn)
@@ -67,7 +67,7 @@ namespace ACTIVE {
}
template <bool is_learn>
- void predict_or_learn_active(active& a, learner& base, example& ec) {
+ void predict_or_learn_active(active& a, base_learner& base, example& ec) {
if (is_learn)
base.learn(ec);
else
@@ -151,7 +151,7 @@ namespace ACTIVE {
VW::finish_example(all,&ec);
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{//parse and set arguments
po::options_description opts("Active Learning options");
opts.add_options()
@@ -170,20 +170,20 @@ namespace ACTIVE {
data.active_c0 = vm["mellowness"].as<float>();
//Create new learner
- learner* ret = new learner(&data, all.l);
+ learner<active>& ret = init_learner(&data, all.l);
if (vm.count("simulation"))
{
- ret->set_learn<active, predict_or_learn_simulation<true> >();
- ret->set_predict<active, predict_or_learn_simulation<false> >();
+ ret.set_learn(predict_or_learn_simulation<true>);
+ ret.set_predict(predict_or_learn_simulation<false>);
}
else
{
all.active = true;
- ret->set_learn<active, predict_or_learn_active<true> >();
- ret->set_predict<active, predict_or_learn_active<false> >();
- ret->set_finish_example<active, return_active_example>();
+ ret.set_learn(predict_or_learn_active<true>);
+ ret.set_predict(predict_or_learn_active<false>);
+ ret.set_finish_example(return_active_example);
}
- return ret;
+ return make_base(ret);
}
}
diff --git a/vowpalwabbit/active.h b/vowpalwabbit/active.h
index e8883ae9..d71950ab 100644
--- a/vowpalwabbit/active.h
+++ b/vowpalwabbit/active.h
@@ -1,4 +1,4 @@
#pragma once
namespace ACTIVE {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc
index f34a228f..f2023565 100644
--- a/vowpalwabbit/autolink.cc
+++ b/vowpalwabbit/autolink.cc
@@ -12,7 +12,7 @@ namespace ALINK {
};
template <bool is_learn>
- void predict_or_learn(autolink& b, learner& base, example& ec)
+ void predict_or_learn(autolink& b, base_learner& base, example& ec)
{
base.predict(ec);
float base_pred = ec.pred.scalar;
@@ -41,7 +41,7 @@ namespace ALINK {
ec.total_sum_feat_sq -= sum_sq;
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("Autolink options");
opts.add_options()
@@ -54,11 +54,11 @@ namespace ALINK {
data.d = (uint32_t)vm["autolink"].as<size_t>();
data.stride_shift = all.reg.stride_shift;
- all.file_options << " --autolink " << data.d;
+ *all.file_options << " --autolink " << data.d;
- learner* ret = new learner(&data, all.l);
- ret->set_learn<autolink, predict_or_learn<true> >();
- ret->set_predict<autolink, predict_or_learn<false> >();
- return ret;
+ learner<autolink>& ret = init_learner(&data, all.l);
+ ret.set_learn(predict_or_learn<true>);
+ ret.set_predict(predict_or_learn<false>);
+ return make_base(ret);
}
}
diff --git a/vowpalwabbit/autolink.h b/vowpalwabbit/autolink.h
index 830c0088..3bb70dc1 100644
--- a/vowpalwabbit/autolink.h
+++ b/vowpalwabbit/autolink.h
@@ -1,4 +1,4 @@
#pragma once
namespace ALINK {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc
index 43e5e502..775affb1 100644
--- a/vowpalwabbit/bfgs.cc
+++ b/vowpalwabbit/bfgs.cc
@@ -821,13 +821,13 @@ void end_pass(bfgs& b)
}
// placeholder
-void predict(bfgs& b, learner& base, example& ec)
+void predict(bfgs& b, base_learner& base, example& ec)
{
vw* all = b.all;
ec.pred.scalar = bfgs_predict(*all,ec);
}
-void learn(bfgs& b, learner& base, example& ec)
+void learn(bfgs& b, base_learner& base, example& ec)
{
vw* all = b.all;
assert(ec.in_use);
@@ -968,7 +968,7 @@ void save_load(bfgs& b, io_buf& model_file, bool read, bool text)
b.backstep_on = true;
}
-learner* setup(vw& all, po::variables_map& vm)
+base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("LBFGS options");
opts.add_options()
@@ -1024,14 +1024,14 @@ learner* setup(vw& all, po::variables_map& vm)
all.bfgs = true;
all.reg.stride_shift = 2;
- learner* l = new learner(&b, 1 << all.reg.stride_shift);
- l->set_learn<bfgs, learn>();
- l->set_predict<bfgs, predict>();
- l->set_save_load<bfgs,save_load>();
- l->set_init_driver<bfgs,init_driver>();
- l->set_end_pass<bfgs,end_pass>();
- l->set_finish<bfgs,finish>();
+ learner<bfgs>& l = init_learner(&b, 1 << all.reg.stride_shift);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_init_driver(init_driver);
+ l.set_end_pass(end_pass);
+ l.set_finish(finish);
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/bfgs.h b/vowpalwabbit/bfgs.h
index 2f30dcb9..1960662b 100644
--- a/vowpalwabbit/bfgs.h
+++ b/vowpalwabbit/bfgs.h
@@ -5,5 +5,5 @@ license as described in the file LICENSE.
*/
#pragma once
namespace BFGS {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc
index 727cf003..7eac0ba8 100644
--- a/vowpalwabbit/binary.cc
+++ b/vowpalwabbit/binary.cc
@@ -7,7 +7,7 @@ using namespace LEARNER;
namespace BINARY {
template <bool is_learn>
- void predict_or_learn(float&, learner& base, example& ec) {
+ void predict_or_learn(char&, base_learner& base, example& ec) {
if (is_learn)
base.learn(ec);
else
@@ -24,7 +24,7 @@ namespace BINARY {
ec.loss = ec.l.simple.weight;
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{//parse and set arguments
po::options_description opts("Binary options");
opts.add_options()
@@ -35,9 +35,9 @@ namespace BINARY {
all.sd->binary_label = true;
//Create new learner
- learner* ret = new learner(NULL, all.l);
- ret->set_learn<float, predict_or_learn<true> >();
- ret->set_predict<float, predict_or_learn<false> >();
- return ret;
+ learner<char>& ret = init_learner<char>(NULL, all.l);
+ ret.set_learn(predict_or_learn<true>);
+ ret.set_predict(predict_or_learn<false>);
+ return make_base(ret);
}
}
diff --git a/vowpalwabbit/binary.h b/vowpalwabbit/binary.h
index 2df4f2e0..609de90b 100644
--- a/vowpalwabbit/binary.h
+++ b/vowpalwabbit/binary.h
@@ -1,4 +1,4 @@
#pragma once
namespace BINARY {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc
index c199e1db..f55467a6 100644
--- a/vowpalwabbit/bs.cc
+++ b/vowpalwabbit/bs.cc
@@ -180,7 +180,7 @@ namespace BS {
}
template <bool is_learn>
- void predict_or_learn(bs& d, learner& base, example& ec)
+ void predict_or_learn(bs& d, base_learner& base, example& ec)
{
vw* all = d.all;
bool shouldOutput = all->raw_prediction > 0;
@@ -239,7 +239,7 @@ namespace BS {
d.pred_vec.~vector();
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
bs& data = calloc_or_die<bs>();
data.ub = FLT_MAX;
@@ -254,7 +254,7 @@ namespace BS {
data.B = (uint32_t)vm["bootstrap"].as<size_t>();
//append bs with number of samples to options_from_file so it is saved to regressor later
- all.file_options << " --bootstrap " << data.B;
+ *all.file_options << " --bootstrap " << data.B;
std::string type_string("mean");
@@ -275,17 +275,17 @@ namespace BS {
}
else //by default use mean
data.bs_type = BS_TYPE_MEAN;
- all.file_options << " --bs_type " << type_string;
+ *all.file_options << " --bs_type " << type_string;
data.pred_vec.reserve(data.B);
data.all = &all;
- learner* l = new learner(&data, all.l, data.B);
- l->set_learn<bs, predict_or_learn<true> >();
- l->set_predict<bs, predict_or_learn<false> >();
- l->set_finish_example<bs,finish_example>();
- l->set_finish<bs,finish>();
+ learner<bs>& l = init_learner(&data, all.l, data.B);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish_example(finish_example);
+ l.set_finish(finish);
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/bs.h b/vowpalwabbit/bs.h
index d41593ef..e05bcd05 100644
--- a/vowpalwabbit/bs.h
+++ b/vowpalwabbit/bs.h
@@ -11,7 +11,7 @@ license as described in the file LICENSE.
namespace BS
{
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
void print_result(int f, float res, float weight, v_array<char> tag, float lb, float ub);
void output_example(vw& all, example* ec, float lb, float ub);
diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc
index 395d496b..72400a1d 100644
--- a/vowpalwabbit/cb_algs.cc
+++ b/vowpalwabbit/cb_algs.cc
@@ -284,7 +284,7 @@ namespace CB_ALGS
}
template <bool is_learn>
- void predict_or_learn(cb& c, learner& base, example& ec) {
+ void predict_or_learn(cb& c, base_learner& base, example& ec) {
vw* all = c.all;
CB::label ld = ec.l.cb;
@@ -341,12 +341,12 @@ namespace CB_ALGS
}
}
- void predict_eval(cb& c, learner& base, example& ec) {
+ void predict_eval(cb& c, base_learner& base, example& ec) {
cout << "can not use a test label for evaluation" << endl;
throw exception();
}
- void learn_eval(cb& c, learner& base, example& ec) {
+ void learn_eval(cb& c, base_learner& base, example& ec) {
vw* all = c.all;
CB_EVAL::label ld = ec.l.cb_eval;
@@ -497,7 +497,7 @@ namespace CB_ALGS
VW::finish_example(all, &ec);
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("CB options");
opts.add_options()
@@ -515,7 +515,7 @@ namespace CB_ALGS
uint32_t nb_actions = (uint32_t)vm["cb"].as<size_t>();
- all.file_options << " --cb " << nb_actions;
+ *all.file_options << " --cb " << nb_actions;
all.sd->k = nb_actions;
@@ -529,7 +529,7 @@ namespace CB_ALGS
std::string type_string;
type_string = vm["cb_type"].as<std::string>();
- all.file_options << " --cb_type " << type_string;
+ *all.file_options << " --cb_type " << type_string;
if (type_string.compare("dr") == 0)
c.cb_type = CB_TYPE_DR;
@@ -556,7 +556,7 @@ namespace CB_ALGS
else {
//by default use doubly robust
c.cb_type = CB_TYPE_DR;
- all.file_options << " --cb_type dr";
+ *all.file_options << " --cb_type dr";
}
if (eval)
@@ -564,25 +564,25 @@ namespace CB_ALGS
else
all.p->lp = CB::cb_label;
- learner* l = new learner(&c, all.l, problem_multiplier);
+ learner<cb>& l = init_learner(&c, all.l, problem_multiplier);
if (eval)
{
- l->set_learn<cb, learn_eval>();
- l->set_predict<cb, predict_eval>();
- l->set_finish_example<cb,eval_finish_example>();
+ l.set_learn(learn_eval);
+ l.set_predict(predict_eval);
+ l.set_finish_example(eval_finish_example);
}
else
{
- l->set_learn<cb, predict_or_learn<true> >();
- l->set_predict<cb, predict_or_learn<false> >();
- l->set_finish_example<cb,finish_example>();
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish_example(finish_example);
}
- l->set_init_driver<cb,init_driver>();
- l->set_finish<cb,finish>();
+ l.set_init_driver(init_driver);
+ l.set_finish(finish);
// preserve the increment of the base learner since we are
// _adding_ to the number of problems rather than multiplying.
- l->increment = all.l->increment;
+ l.increment = all.l->increment;
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/cb_algs.h b/vowpalwabbit/cb_algs.h
index 972a45ad..7756fc6f 100644
--- a/vowpalwabbit/cb_algs.h
+++ b/vowpalwabbit/cb_algs.h
@@ -6,8 +6,8 @@ license as described in the file LICENSE.
#pragma once
//TODO: extend to handle CSOAA_LDF and WAP_LDF
namespace CB_ALGS {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
-
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
+
template <bool is_learn>
float get_cost_pred(vw& all, CB::cb_class* known_cost, example& ec, uint32_t index, uint32_t base)
{
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc
index e79c25ec..5b738d76 100644
--- a/vowpalwabbit/cbify.cc
+++ b/vowpalwabbit/cbify.cc
@@ -83,7 +83,7 @@ namespace CBIFY {
COST_SENSITIVE::label cs_label;
COST_SENSITIVE::label second_cs_label;
- learner* cs;
+ base_learner* cs;
vw* all;
unique_ptr<vw_policy> policy;
@@ -106,7 +106,7 @@ namespace CBIFY {
}
struct vw_context {
- learner* l;
+ base_learner* l;
example* e;
cbify* data;
bool recorded;
@@ -158,7 +158,7 @@ namespace CBIFY {
}
template <bool is_learn>
- void predict_or_learn_first(cbify& data, learner& base, example& ec)
+ void predict_or_learn_first(cbify& data, base_learner& base, example& ec)
{//Explore tau times, then act according to optimal.
MULTICLASS::multiclass ld = ec.l.multi;
@@ -186,7 +186,7 @@ namespace CBIFY {
}
template <bool is_learn>
- void predict_or_learn_greedy(cbify& data, learner& base, example& ec)
+ void predict_or_learn_greedy(cbify& data, base_learner& base, example& ec)
{//Explore uniform random an epsilon fraction of the time.
MULTICLASS::multiclass ld = ec.l.multi;
@@ -213,7 +213,7 @@ namespace CBIFY {
}
template <bool is_learn>
- void predict_or_learn_bag(cbify& data, learner& base, example& ec)
+ void predict_or_learn_bag(cbify& data, base_learner& base, example& ec)
{//Randomize over predictions from a base set of predictors
//Use CB to find current predictions.
MULTICLASS::multiclass ld = ec.l.multi;
@@ -284,7 +284,7 @@ namespace CBIFY {
}
template <bool is_learn>
- void predict_or_learn_cover(cbify& data, learner& base, example& ec)
+ void predict_or_learn_cover(cbify& data, base_learner& base, example& ec)
{//Randomize over predictions from a base set of predictors
//Use cost sensitive oracle to cover actions to form distribution.
MULTICLASS::multiclass ld = ec.l.multi;
@@ -377,7 +377,7 @@ namespace CBIFY {
CB::cb_label.delete_label(&data.cb_label);
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{//parse and set arguments
po::options_description opts("CBIFY options");
opts.add_options()
@@ -394,10 +394,10 @@ namespace CBIFY {
data.all = &all;
data.k = (uint32_t)vm["cbify"].as<size_t>();
- all.file_options << " --cbify " << data.k;
+ *all.file_options << " --cbify " << data.k;
all.p->lp = MULTICLASS::mc_label;
- learner* l;
+ learner<cbify>* l;
data.recorder.reset(new vw_recorder());
data.mwt_explorer.reset(new MwtExplorer<vw_context>("vw", *data.recorder.get()));
if (vm.count("cover"))
@@ -406,52 +406,52 @@ namespace CBIFY {
data.cs = all.cost_sensitive;
data.second_cs_label.costs.resize(data.k);
data.second_cs_label.costs.end = data.second_cs_label.costs.begin+data.k;
- float epsilon = 0.05f;
- if (vm.count("epsilon"))
- epsilon = vm["epsilon"].as<float>();
- data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k));
- data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k));
- l = new learner(&data, all.l, cover + 1);
- l->set_learn<cbify, predict_or_learn_cover<true> >();
- l->set_predict<cbify, predict_or_learn_cover<false> >();
+ float epsilon = 0.05f;
+ if (vm.count("epsilon"))
+ epsilon = vm["epsilon"].as<float>();
+ data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k));
+ data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k));
+ l = &init_learner(&data, all.l, cover + 1);
+ l->set_learn(predict_or_learn_cover<true>);
+ l->set_predict(predict_or_learn_cover<false>);
}
else if (vm.count("bag"))
{
size_t bags = (uint32_t)vm["bag"].as<size_t>();
- for (size_t i = 0; i < bags; i++)
- {
- data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i)));
- }
- data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k));
- l = new learner(&data, all.l, bags);
- l->set_learn<cbify, predict_or_learn_bag<true> >();
- l->set_predict<cbify, predict_or_learn_bag<false> >();
+ for (size_t i = 0; i < bags; i++)
+ {
+ data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i)));
+ }
+ data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k));
+ l = &init_learner(&data, all.l, bags);
+ l->set_learn(predict_or_learn_bag<true>);
+ l->set_predict(predict_or_learn_bag<false>);
}
else if (vm.count("first") )
{
- uint32_t tau = (uint32_t)vm["first"].as<size_t>();
- data.policy.reset(new vw_policy());
- data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k));
- l = new learner(&data, all.l, 1);
- l->set_learn<cbify, predict_or_learn_first<true> >();
- l->set_predict<cbify, predict_or_learn_first<false> >();
+ uint32_t tau = (uint32_t)vm["first"].as<size_t>();
+ data.policy.reset(new vw_policy());
+ data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k));
+ l = &init_learner(&data, all.l, 1);
+ l->set_learn(predict_or_learn_first<true>);
+ l->set_predict(predict_or_learn_first<false>);
}
else
{
- float epsilon = 0.05f;
- if (vm.count("epsilon"))
- epsilon = vm["epsilon"].as<float>();
- data.policy.reset(new vw_policy());
- data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k));
- l = new learner(&data, all.l, 1);
- l->set_learn<cbify, predict_or_learn_greedy<true> >();
- l->set_predict<cbify, predict_or_learn_greedy<false> >();
+ float epsilon = 0.05f;
+ if (vm.count("epsilon"))
+ epsilon = vm["epsilon"].as<float>();
+ data.policy.reset(new vw_policy());
+ data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k));
+ l = &init_learner(&data, all.l, 1);
+ l->set_learn(predict_or_learn_greedy<true>);
+ l->set_predict(predict_or_learn_greedy<false>);
}
-
- l->set_finish_example<cbify,finish_example>();
- l->set_finish<cbify,finish>();
- l->set_init_driver<cbify,init_driver>();
- return l;
+ l->set_finish_example(finish_example);
+ l->set_finish(finish);
+ l->set_init_driver(init_driver);
+
+ return make_base(*l);
}
}
diff --git a/vowpalwabbit/cbify.h b/vowpalwabbit/cbify.h
index 97e9e6fd..aed26b75 100644
--- a/vowpalwabbit/cbify.h
+++ b/vowpalwabbit/cbify.h
@@ -5,5 +5,5 @@ license as described in the file LICENSE.
*/
#pragma once
namespace CBIFY {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index 679e31ed..a5429dd0 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -13,9 +13,7 @@ license as described in the file LICENSE.
#include "gd.h" // GD::foreach_feature() needed in subtract_example()
using namespace std;
-
using namespace LEARNER;
-
using namespace COST_SENSITIVE;
namespace CSOAA {
@@ -24,7 +22,7 @@ namespace CSOAA {
};
template <bool is_learn>
- void predict_or_learn(csoaa& c, learner& base, example& ec) {
+ void predict_or_learn(csoaa& c, base_learner& base, example& ec) {
vw* all = c.all;
COST_SENSITIVE::label ld = ec.l.cs;
uint32_t prediction = 1;
@@ -68,7 +66,7 @@ namespace CSOAA {
VW::finish_example(all, &ec);
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("CSOAA options");
opts.add_options()
@@ -83,17 +81,16 @@ namespace CSOAA {
nb_actions = (uint32_t)vm["csoaa"].as<size_t>();
//append csoaa with nb_actions to file_options so it is saved to regressor later
- all.file_options << " --csoaa " << nb_actions;
+ *all.file_options << " --csoaa " << nb_actions;
all.p->lp = cs_label;
all.sd->k = nb_actions;
- learner* l = new learner(&c, all.l, nb_actions);
- l->set_learn<csoaa, predict_or_learn<true> >();
- l->set_predict<csoaa, predict_or_learn<false> >();
- l->set_finish_example<csoaa,finish_example>();
- all.cost_sensitive = all.l;
- return l;
+ learner<csoaa>& l = init_learner(&c, all.l, nb_actions);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish_example(finish_example);
+ return make_base(l);
}
}
@@ -111,7 +108,7 @@ namespace CSOAA_AND_WAP_LDF {
float csoaa_example_t;
vw* all;
- learner* base;
+ base_learner* base;
};
namespace LabelDict {
@@ -300,7 +297,7 @@ namespace LabelDict {
ec->indices.decr();
}
- void make_single_prediction(ldf& data, learner& base, example& ec) {
+ void make_single_prediction(ldf& data, base_learner& base, example& ec) {
COST_SENSITIVE::label ld = ec.l.cs;
label_data simple_label;
simple_label.initial = 0.;
@@ -339,7 +336,7 @@ namespace LabelDict {
return isTest;
}
- void do_actual_learning_wap(vw& all, ldf& data, learner& base, size_t start_K)
+ void do_actual_learning_wap(vw& all, ldf& data, base_learner& base, size_t start_K)
{
size_t K = data.ec_seq.size();
vector<COST_SENSITIVE::wclass*> all_costs;
@@ -394,7 +391,7 @@ namespace LabelDict {
}
}
- void do_actual_learning_oaa(vw& all, ldf& data, learner& base, size_t start_K)
+ void do_actual_learning_oaa(vw& all, ldf& data, base_learner& base, size_t start_K)
{
size_t K = data.ec_seq.size();
float min_cost = FLT_MAX;
@@ -447,7 +444,7 @@ namespace LabelDict {
}
template <bool is_learn>
- void do_actual_learning(vw& all, ldf& data, learner& base)
+ void do_actual_learning(vw& all, ldf& data, base_learner& base)
{
//cdbg << "do_actual_learning size=" << data.ec_seq.size() << endl;
if (data.ec_seq.size() <= 0) return; // nothing to do
@@ -620,7 +617,7 @@ namespace LabelDict {
}
template <bool is_learn>
- void predict_or_learn(ldf& data, learner& base, example &ec) {
+ void predict_or_learn(ldf& data, base_learner& base, example &ec) {
vw* all = data.all;
data.base = &base;
bool is_test_ec = COST_SENSITIVE::example_is_test(ec);
@@ -653,7 +650,7 @@ namespace LabelDict {
}
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("LDF Options");
opts.add_options()
@@ -674,12 +671,12 @@ namespace LabelDict {
if( vm.count("csoaa_ldf") ){
ldf_arg = vm["csoaa_ldf"].as<string>();
- all.file_options << " --csoaa_ldf " << ldf_arg;
+ *all.file_options << " --csoaa_ldf " << ldf_arg;
}
else {
ldf_arg = vm["wap_ldf"].as<string>();
ld.is_wap = true;
- all.file_options << " --wap_ldf " << ldf_arg;
+ *all.file_options << " --wap_ldf " << ldf_arg;
}
if ( vm.count("ldf_override") )
ldf_arg = vm["ldf_override"].as<string>();
@@ -718,18 +715,17 @@ namespace LabelDict {
ld.read_example_this_loop = 0;
ld.need_to_clear = false;
- learner* l = new learner(&ld, all.l);
- l->set_learn<ldf, predict_or_learn<true> >();
- l->set_predict<ldf, predict_or_learn<false> >();
+ learner<ldf>& l = init_learner(&ld, all.l);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
if (ld.is_singleline)
- l->set_finish_example<ldf,finish_singleline_example>();
+ l.set_finish_example(finish_singleline_example);
else
- l->set_finish_example<ldf,finish_multiline_example>();
- l->set_finish<ldf,finish>();
- l->set_end_examples<ldf,end_examples>();
- l->set_end_pass<ldf,end_pass>();
- all.cost_sensitive = all.l;
- return l;
+ l.set_finish_example(finish_multiline_example);
+ l.set_finish(finish);
+ l.set_end_examples(end_examples);
+ l.set_end_pass(end_pass);
+ return make_base(l);
}
void global_print_newline(vw& all)
diff --git a/vowpalwabbit/csoaa.h b/vowpalwabbit/csoaa.h
index bfaca0f2..79f5e4d2 100644
--- a/vowpalwabbit/csoaa.h
+++ b/vowpalwabbit/csoaa.h
@@ -5,11 +5,11 @@ license as described in the file LICENSE.
*/
#pragma once
namespace CSOAA {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
namespace CSOAA_AND_WAP_LDF {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
namespace LabelDict {
bool ec_is_example_header(example& ec); // example headers look like "0:-1" or just "shared"
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc
index e73b2442..1f99eaaa 100644
--- a/vowpalwabbit/ect.cc
+++ b/vowpalwabbit/ect.cc
@@ -184,7 +184,7 @@ namespace ECT
return e.last_pair + (eliminations-1);
}
- uint32_t ect_predict(vw& all, ect& e, learner& base, example& ec)
+ uint32_t ect_predict(vw& all, ect& e, base_learner& base, example& ec)
{
if (e.k == (size_t)1)
return 1;
@@ -228,7 +228,7 @@ namespace ECT
return false;
}
- void ect_train(vw& all, ect& e, learner& base, example& ec)
+ void ect_train(vw& all, ect& e, base_learner& base, example& ec)
{
if (e.k == 1)//nothing to do
return;
@@ -317,7 +317,7 @@ namespace ECT
}
}
- void predict(ect& e, learner& base, example& ec) {
+ void predict(ect& e, base_learner& base, example& ec) {
vw* all = e.all;
MULTICLASS::multiclass mc = ec.l.multi;
@@ -327,7 +327,7 @@ namespace ECT
ec.l.multi = mc;
}
- void learn(ect& e, learner& base, example& ec)
+ void learn(ect& e, base_learner& base, example& ec)
{
vw* all = e.all;
@@ -366,7 +366,7 @@ namespace ECT
VW::finish_example(all, &ec);
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("ECT options");
opts.add_options()
@@ -387,18 +387,18 @@ namespace ECT
} else
data.errors = 0;
//append error flag to options_from_file so it is saved in regressor file later
- all.file_options << " --ect " << data.k << " --error " << data.errors;
+ *all.file_options << " --ect " << data.k << " --error " << data.errors;
all.p->lp = MULTICLASS::mc_label;
size_t wpp = create_circuit(all, data, data.k, data.errors+1);
data.all = &all;
- learner* l = new learner(&data, all.l, wpp);
- l->set_learn<ect, learn>();
- l->set_predict<ect, predict>();
- l->set_finish_example<ect,finish_example>();
- l->set_finish<ect,finish>();
+ learner<ect>& l = init_learner(&data, all.l, wpp);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_finish_example(finish_example);
+ l.set_finish(finish);
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/ect.h b/vowpalwabbit/ect.h
index a7c392c0..81129791 100644
--- a/vowpalwabbit/ect.h
+++ b/vowpalwabbit/ect.h
@@ -6,5 +6,5 @@ license as described in the file LICENSE.
#pragma once
namespace ECT
{
- LEARNER::learner* setup(vw&, po::variables_map&);
+ LEARNER::base_learner* setup(vw&, po::variables_map&);
}
diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc
index 2707cbcf..15ad1a06 100644
--- a/vowpalwabbit/ftrl_proximal.cc
+++ b/vowpalwabbit/ftrl_proximal.cc
@@ -135,7 +135,7 @@ namespace FTRL {
}
//void learn(void* a, void* d, example* ec) {
- void learn(ftrl& a, learner& base, example& ec) {
+ void learn(ftrl& a, base_learner& base, example& ec) {
vw* all = a.all;
assert(ec.in_use);
@@ -170,15 +170,15 @@ namespace FTRL {
}
// placeholder
- void predict(ftrl& b, learner& base, example& ec)
+ void predict(ftrl& b, base_learner& base, example& ec)
{
vw* all = b.all;
//ec.l.simple.prediction = ftrl_predict(*all,ec);
ec.pred.scalar = ftrl_predict(*all,ec);
}
- learner* setup(vw& all, po::variables_map& vm) {
-
+ base_learner* setup(vw& all, po::variables_map& vm)
+ {
ftrl& b = calloc_or_die<ftrl>();
b.all = &all;
b.ftrl_beta = 0.0;
@@ -217,13 +217,10 @@ namespace FTRL {
cerr << "ftrl_beta = " << b.ftrl_beta << endl;
}
- learner* l = new learner(&b, 1 << all.reg.stride_shift);
- l->set_learn<ftrl, learn>();
- l->set_predict<ftrl, predict>();
- l->set_save_load<ftrl,save_load>();
-
- return l;
+ learner<ftrl>& l = init_learner(&b, 1 << all.reg.stride_shift);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ return make_base(l);
}
-
-
} // end namespace
diff --git a/vowpalwabbit/ftrl_proximal.h b/vowpalwabbit/ftrl_proximal.h
index 934d91c8..59bf4653 100644
--- a/vowpalwabbit/ftrl_proximal.h
+++ b/vowpalwabbit/ftrl_proximal.h
@@ -7,6 +7,6 @@ license as described in the file LICENSE.
#define FTRL_PROXIMAL_H
namespace FTRL {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
#endif
diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc
index 66575e38..b407bc2f 100644
--- a/vowpalwabbit/gd.cc
+++ b/vowpalwabbit/gd.cc
@@ -40,7 +40,7 @@ namespace GD
float neg_norm_power;
float neg_power_t;
float update_multiplier;
- void (*predict)(gd&, learner&, example&);
+ void (*predict)(gd&, base_learner&, example&);
vw* all;
};
@@ -346,7 +346,7 @@ float finalize_prediction(shared_data* sd, float ret)
}
template<bool l1, bool audit>
-void predict(gd& g, learner& base, example& ec)
+void predict(gd& g, base_learner& base, example& ec)
{
vw& all = *g.all;
@@ -508,7 +508,7 @@ float compute_update(gd& g, example& ec)
}
template<bool invariant, bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare>
-void update(gd& g, learner& base, example& ec)
+void update(gd& g, base_learner& base, example& ec)
{//invariant: not a test label, importance weight > 0
float update;
if ( (update = compute_update<invariant, sqrt_rate, feature_mask_off, adaptive, normalized, spare> (g, ec)) != 0.)
@@ -519,7 +519,7 @@ void update(gd& g, learner& base, example& ec)
}
template<bool invariant, bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normalized, size_t spare>
-void learn(gd& g, learner& base, example& ec)
+void learn(gd& g, base_learner& base, example& ec)
{//invariant: not a test label, importance weight > 0
assert(ec.in_use);
assert(ec.l.simple.label != FLT_MAX);
@@ -789,25 +789,25 @@ void save_load(gd& g, io_buf& model_file, bool read, bool text)
}
template<bool invariant, bool sqrt_rate, uint32_t adaptive, uint32_t normalized, uint32_t spare, uint32_t next>
-uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off)
+uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off)
{
all.normalized_idx = normalized;
if (feature_mask_off)
{
- ret->set_learn<gd, learn<invariant, sqrt_rate, true, adaptive, normalized, spare> >();
- ret->set_update<gd, update<invariant, sqrt_rate, true, adaptive, normalized, spare> >();
+ ret.set_learn(learn<invariant, sqrt_rate, true, adaptive, normalized, spare>);
+ ret.set_update(update<invariant, sqrt_rate, true, adaptive, normalized, spare>);
return next;
}
else
{
- ret->set_learn<gd, learn<invariant, sqrt_rate, false, adaptive, normalized, spare> >();
- ret->set_update<gd, update<invariant, sqrt_rate, false, adaptive, normalized, spare> >();
+ ret.set_learn(learn<invariant, sqrt_rate, false, adaptive, normalized, spare>);
+ ret.set_update(update<invariant, sqrt_rate, false, adaptive, normalized, spare>);
return next;
}
}
template<bool sqrt_rate, uint32_t adaptive, uint32_t normalized, uint32_t spare, uint32_t next>
-uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off)
+uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off)
{
if (all.invariant_updates)
return set_learn<true, sqrt_rate, adaptive, normalized, spare, next>(all, ret, feature_mask_off);
@@ -816,7 +816,7 @@ uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off)
}
template<bool sqrt_rate, uint32_t adaptive, uint32_t spare>
-uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off)
+uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off)
{
// select the appropriate learn function based on adaptive, normalization, and feature mask
if (all.normalized_updates)
@@ -826,7 +826,7 @@ uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off)
}
template<bool sqrt_rate>
-uint32_t set_learn(vw& all, learner* ret, bool feature_mask_off)
+uint32_t set_learn(vw& all, learner<gd>& ret, bool feature_mask_off)
{
if (all.adaptive)
return set_learn<sqrt_rate, 1, 2>(all, ret, feature_mask_off);
@@ -842,7 +842,7 @@ uint32_t ceil_log_2(uint32_t v)
return 1 + ceil_log_2(v >> 1);
}
-learner* setup(vw& all, po::variables_map& vm)
+base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("Gradient Descent options");
opts.add_options()
@@ -906,29 +906,18 @@ learner* setup(vw& all, po::variables_map& vm)
cerr << "Warning: the learning rate for the last pass is multiplied by: " << pow((double)all.eta_decay_rate, (double)all.numpasses)
<< " adjust --decay_learning_rate larger to avoid this." << endl;
- learner* ret = new learner(&g, 1);
+ learner<gd>& ret = init_learner(&g, 1);
if (all.reg_mode % 2)
if (all.audit || all.hash_inv)
- {
- ret->set_predict<gd, predict<true, true> >();
- g.predict = predict<true, true>;
- }
+ g.predict = predict<true, true>;
else
- {
- ret->set_predict<gd, predict<true, false> >();
- g.predict = predict<true, false>;
- }
+ g.predict = predict<true, false>;
else if (all.audit || all.hash_inv)
- {
- ret->set_predict<gd, predict<false, true> >();
- g.predict = predict<false, true>;
- }
+ g.predict = predict<false, true>;
else
- {
- ret->set_predict<gd, predict<false, false> >();
- g.predict = predict<false, false>;
- }
+ g.predict = predict<false, false>;
+ ret.set_predict(g.predict);
uint32_t stride;
if (all.power_t == 0.5)
@@ -937,11 +926,10 @@ learner* setup(vw& all, po::variables_map& vm)
stride = set_learn<false>(all, ret, feature_mask_off);
all.reg.stride_shift = ceil_log_2(stride-1);
- ret->increment = ((uint64_t)1 << all.reg.stride_shift);
+ ret.increment = ((uint64_t)1 << all.reg.stride_shift);
- ret->set_save_load<gd,save_load>();
-
- ret->set_end_pass<gd, end_pass>();
- return ret;
+ ret.set_save_load(save_load);
+ ret.set_end_pass(end_pass);
+ return make_base(ret);
}
}
diff --git a/vowpalwabbit/gd.h b/vowpalwabbit/gd.h
index be52e62f..de3964eb 100644
--- a/vowpalwabbit/gd.h
+++ b/vowpalwabbit/gd.h
@@ -24,7 +24,7 @@ namespace GD{
void compute_update(example* ec);
void offset_train(regressor &reg, example* &ec, float update, size_t offset);
void train_one_example_single_thread(regressor& r, example* ex);
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text);
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text);
void output_and_account_example(example* ec);
diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc
index eec54663..532c44f3 100644
--- a/vowpalwabbit/gd_mf.cc
+++ b/vowpalwabbit/gd_mf.cc
@@ -272,14 +272,14 @@ void mf_train(vw& all, example& ec)
all->current_pass++;
}
- void predict(gdmf& d, learner& base, example& ec)
+ void predict(gdmf& d, base_learner& base, example& ec)
{
vw* all = d.all;
mf_predict(*all,ec);
}
- void learn(gdmf& d, learner& base, example& ec)
+ void learn(gdmf& d, base_learner& base, example& ec)
{
vw* all = d.all;
@@ -288,7 +288,7 @@ void mf_train(vw& all, example& ec)
mf_train(*all, ec);
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("Gdmf options");
opts.add_options()
@@ -339,12 +339,12 @@ void mf_train(vw& all, example& ec)
}
all.eta *= powf((float)(all.sd->t), all.power_t);
- learner* l = new learner(&data, 1 << all.reg.stride_shift);
- l->set_learn<gdmf, learn>();
- l->set_predict<gdmf, predict>();
- l->set_save_load<gdmf,save_load>();
- l->set_end_pass<gdmf,end_pass>();
+ learner<gdmf>& l = init_learner(&data, 1 << all.reg.stride_shift);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_end_pass(end_pass);
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/gd_mf.h b/vowpalwabbit/gd_mf.h
index 0705eff1..db093750 100644
--- a/vowpalwabbit/gd_mf.h
+++ b/vowpalwabbit/gd_mf.h
@@ -5,5 +5,5 @@ license as described in the file LICENSE.
*/
#pragma once
namespace GDMF{
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc
index e27c646d..b4e68aff 100644
--- a/vowpalwabbit/global_data.cc
+++ b/vowpalwabbit/global_data.cc
@@ -250,6 +250,8 @@ vw::vw()
data_filename = "";
+ file_options = new std::stringstream;
+
bfgs = false;
hessian_on = false;
active = false;
diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h
index f651426f..b610d75a 100644
--- a/vowpalwabbit/global_data.h
+++ b/vowpalwabbit/global_data.h
@@ -170,9 +170,9 @@ struct vw {
node_socks socks;
- LEARNER::learner* l;//the top level learner
- LEARNER::learner* scorer;//a scoring function
- LEARNER::learner* cost_sensitive;//a cost sensitive learning algorithm.
+ LEARNER::base_learner* l;//the top level learner
+ LEARNER::base_learner* scorer;//a scoring function
+ LEARNER::base_learner* cost_sensitive;//a cost sensitive learning algorithm.
void learn(example*);
@@ -199,7 +199,7 @@ struct vw {
double normalized_sum_norm_x;
po::options_description opts;
- std::stringstream file_options;
+ std::stringstream* file_options;
vector<std::string> args;
void* /*Search::search*/ searchstr;
@@ -266,7 +266,7 @@ struct vw {
size_t length () { return ((size_t)1) << num_bits; };
uint32_t rank;
- v_array<LEARNER::learner* (*)(vw& all, po::variables_map& vm)> reduction_stack;
+ v_array<LEARNER::base_learner* (*)(vw& all, po::variables_map& vm)> reduction_stack;
//Prediction output
v_array<int> final_prediction_sink; // set to send global predictions to.
diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc
index 1f72ffc6..ff12e2bf 100644
--- a/vowpalwabbit/kernel_svm.cc
+++ b/vowpalwabbit/kernel_svm.cc
@@ -395,7 +395,7 @@ namespace KSVM
}
}
- void predict(svm_params& params, learner &base, example& ec) {
+ void predict(svm_params& params, base_learner &base, example& ec) {
flat_example* fec = flatten_sort_example(*(params.all),&ec);
if(fec) {
svm_example* sec = &calloc_or_die<svm_example>();
@@ -733,7 +733,7 @@ namespace KSVM
//cerr<<params.model->support_vec[0]->example_counter<<endl;
}
- void learn(svm_params& params, learner& base, example& ec) {
+ void learn(svm_params& params, base_learner& base, example& ec) {
flat_example* fec = flatten_sort_example(*(params.all),&ec);
// for(int i = 0;i < fec->feature_map_len;i++)
// cout<<i<<":"<<fec->feature_map[i].x<<" "<<fec->feature_map[i].weight_index<<" ";
@@ -790,7 +790,7 @@ namespace KSVM
cerr<<"Done with finish \n";
}
- LEARNER::learner* setup(vw &all, po::variables_map& vm) {
+ LEARNER::base_learner* setup(vw &all, po::variables_map& vm) {
po::options_description opts("KSVM options");
opts.add_options()
("ksvm", "kernel svm")
@@ -857,7 +857,7 @@ namespace KSVM
params.lambda = all.l2_lambda;
- all.file_options <<" --lambda "<< params.lambda;
+ *all.file_options <<" --lambda "<< params.lambda;
cerr<<"Lambda = "<<params.lambda<<endl;
@@ -868,7 +868,7 @@ namespace KSVM
else
kernel_type = string("linear");
- all.file_options <<" --kernel "<< kernel_type;
+ *all.file_options <<" --kernel "<< kernel_type;
cerr<<"Kernel = "<<kernel_type<<endl;
@@ -877,7 +877,7 @@ namespace KSVM
float bandwidth = 1.;
if(vm.count("bandwidth")) {
bandwidth = vm["bandwidth"].as<float>();
- all.file_options <<" --bandwidth "<<bandwidth;
+ *all.file_options <<" --bandwidth "<<bandwidth;
}
cerr<<"bandwidth = "<<bandwidth<<endl;
params.kernel_params = &calloc_or_die<double>();
@@ -888,7 +888,7 @@ namespace KSVM
int degree = 2;
if(vm.count("degree")) {
degree = vm["degree"].as<int>();
- all.file_options <<" --degree "<<degree;
+ *all.file_options <<" --degree "<<degree;
}
cerr<<"degree = "<<degree<<endl;
params.kernel_params = &calloc_or_die<int>();
@@ -900,11 +900,11 @@ namespace KSVM
params.all->reg.weight_mask = (uint32_t)LONG_MAX;
params.all->reg.stride_shift = 0;
- learner* l = new learner(&params, 1);
- l->set_learn<svm_params, learn>();
- l->set_predict<svm_params, predict>();
- l->set_save_load<svm_params, save_load>();
- l->set_finish<svm_params, finish>();
- return l;
+ learner<svm_params>& l = init_learner(&params, 1);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_finish(finish);
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/kernel_svm.h b/vowpalwabbit/kernel_svm.h
index 30b3ac55..563d70e2 100644
--- a/vowpalwabbit/kernel_svm.h
+++ b/vowpalwabbit/kernel_svm.h
@@ -5,6 +5,4 @@ license as described in the file LICENSE.
*/
#pragma once
namespace KSVM
-{
- LEARNER::learner* setup(vw &all, po::variables_map& vm);
-}
+{ LEARNER::base_learner* setup(vw &all, po::variables_map& vm); }
diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc
index 8c8961af..1d03afbd 100644
--- a/vowpalwabbit/lda_core.cc
+++ b/vowpalwabbit/lda_core.cc
@@ -698,7 +698,7 @@ void save_load(lda& l, io_buf& model_file, bool read, bool text)
l.doc_lengths.erase();
}
- void learn(lda& l, learner& base, example& ec)
+ void learn(lda& l, base_learner& base, example& ec)
{
size_t num_ex = l.examples.size();
l.examples.push_back(&ec);
@@ -716,7 +716,7 @@ void save_load(lda& l, io_buf& model_file, bool read, bool text)
}
// placeholder
- void predict(lda& l, learner& base, example& ec)
+ void predict(lda& l, base_learner& base, example& ec)
{
learn(l, base, ec);
}
@@ -754,8 +754,8 @@ void end_examples(lda& l)
}
- learner* setup(vw&all, po::variables_map& vm)
- {
+base_learner* setup(vw&all, po::variables_map& vm)
+{
po::options_description opts("Lda options");
opts.add_options()
("lda", po::value<uint32_t>(), "Run lda with <int> topics")
@@ -788,7 +788,7 @@ void end_examples(lda& l)
all.random_weights = true;
all.add_constant = false;
- all.file_options << " --lda " << all.lda;
+ *all.file_options << " --lda " << all.lda;
if (all.eta > 1.)
{
@@ -799,23 +799,21 @@ void end_examples(lda& l)
if (vm.count("minibatch")) {
size_t minibatch2 = next_pow2(ld.minibatch);
all.p->ring_size = all.p->ring_size > minibatch2 ? all.p->ring_size : minibatch2;
-
}
- ld.v.resize(all.lda*ld.minibatch);
-
- ld.decay_levels.push_back(0.f);
-
- learner* l = new learner(&ld, 1 << all.reg.stride_shift);
- l->set_learn<lda,learn>();
- l->set_predict<lda,predict>();
- l->set_save_load<lda,save_load>();
- l->set_finish_example<lda,finish_example>();
- l->set_end_examples<lda,end_examples>();
- l->set_end_pass<lda,end_pass>();
- l->set_finish<lda,finish>();
-
- return l;
- }
-
- }
+ ld.v.resize(all.lda*ld.minibatch);
+
+ ld.decay_levels.push_back(0.f);
+
+ learner<lda>& l = init_learner(&ld, 1 << all.reg.stride_shift);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_save_load(save_load);
+ l.set_finish_example(finish_example);
+ l.set_end_examples(end_examples);
+ l.set_end_pass(end_pass);
+ l.set_finish(finish);
+
+ return make_base(l);
+}
+}
diff --git a/vowpalwabbit/lda_core.h b/vowpalwabbit/lda_core.h
index 3a377d9f..2a065783 100644
--- a/vowpalwabbit/lda_core.h
+++ b/vowpalwabbit/lda_core.h
@@ -5,5 +5,5 @@ license as described in the file LICENSE.
*/
#pragma once
namespace LDA{
- LEARNER::learner* setup(vw&, po::variables_map&);
+ LEARNER::base_learner* setup(vw&, po::variables_map&);
}
diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h
index 71571c40..45815bcc 100644
--- a/vowpalwabbit/learner.h
+++ b/vowpalwabbit/learner.h
@@ -6,6 +6,7 @@ license as described in the file LICENSE.
#pragma once
// This is the interface for a learning algorithm
#include<iostream>
+#include"memory.h"
using namespace std;
struct vw;
@@ -13,15 +14,16 @@ void return_simple_example(vw& all, void*, example& ec);
namespace LEARNER
{
- struct learner;
+ template<class T> struct learner;
+ typedef learner<char> base_learner;
struct func_data {
void* data;
- learner* base;
+ base_learner* base;
void (*func)(void* data);
};
- inline func_data tuple_dbf(void* data, learner* base, void (*func)(void* data))
+ inline func_data tuple_dbf(void* data, base_learner* base, void (*func)(void* data))
{
func_data foo;
foo.data = data;
@@ -32,200 +34,184 @@ namespace LEARNER
struct learn_data {
void* data;
- learner* base;
- void (*learn_f)(void* data, learner& base, example&);
- void (*predict_f)(void* data, learner& base, example&);
- void (*update_f)(void* data, learner& base, example&);
+ base_learner* base;
+ void (*learn_f)(void* data, base_learner& base, example&);
+ void (*predict_f)(void* data, base_learner& base, example&);
+ void (*update_f)(void* data, base_learner& base, example&);
};
struct save_load_data{
void* data;
- learner* base;
+ base_learner* base;
void (*save_load_f)(void*, io_buf&, bool read, bool text);
};
struct finish_example_data{
void* data;
- learner* base;
+ base_learner* base;
void (*finish_example_f)(vw&, void* data, example&);
};
void generic_driver(vw& all);
inline void generic_sl(void*, io_buf&, bool, bool) {}
- inline void generic_learner(void* data, learner& base, example&) {}
+ inline void generic_learner(void* data, base_learner& base, example&) {}
inline void generic_func(void* data) {}
const save_load_data generic_save_load_fd = {NULL, NULL, generic_sl};
const learn_data generic_learn_fd = {NULL, NULL, generic_learner, generic_learner, NULL};
const func_data generic_func_fd = {NULL, NULL, generic_func};
- template<class R, void (*T)(R&, learner& base, example& ec)>
- inline void tlearn(void* d, learner& base, example& ec)
- { T(*(R*)d, base, ec); }
-
- template<class R, void (*T)(R&, io_buf& io, bool read, bool text)>
- inline void tsl(void* d, io_buf& io, bool read, bool text)
- { T(*(R*)d, io, read, text); }
-
- template<class R, void (*T)(R&)>
- inline void tfunc(void* d) { T(*(R*)d); }
-
- template<class R, void (*T)(vw& all, R&, example&)>
- inline void tend_example(vw& all, void* d, example& ec)
- { T(all, *(R*)d, ec); }
-
- template <class T, void (*learn)(T* data, learner& base, example&), void (*predict)(T* data, learner& base, example&)>
- struct learn_helper {
- void (*learn_f)(void* data, learner& base, example&);
- void (*predict_f)(void* data, learner& base, example&);
-
- learn_helper()
- { learn_f = tlearn<T,learn>;
- predict_f = tlearn<T,predict>;
+ typedef void (*tlearn)(void* d, base_learner& base, example& ec);
+ typedef void (*tsl)(void* d, io_buf& io, bool read, bool text);
+ typedef void (*tfunc)(void*d);
+ typedef void (*tend_example)(vw& all, void* d, example& ec);
+
+ template<class T> learner<T>& init_learner();
+ template<class T> learner<T>& init_learner(T* dat, size_t params_per_weight);
+ template<class T> learner<T>& init_learner(T* dat, base_learner* base, size_t ws = 1);
+
+ template<class T>
+ struct learner {
+ private:
+ func_data init_fd;
+ learn_data learn_fd;
+ finish_example_data finish_example_fd;
+ save_load_data save_load_fd;
+ func_data end_pass_fd;
+ func_data end_examples_fd;
+ func_data finisher_fd;
+
+ public:
+ size_t weights; //this stores the number of "weight vectors" required by the learner.
+ size_t increment;
+
+ //called once for each example. Must work under reduction.
+ inline void learn(example& ec, size_t i=0)
+ {
+ ec.ft_offset += (uint32_t)(increment*i);
+ learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec);
+ ec.ft_offset -= (uint32_t)(increment*i);
}
+ inline void set_learn(void (*u)(T& data, base_learner& base, example&))
+ {
+ learn_fd.learn_f = (tlearn)u;
+ learn_fd.update_f = (tlearn)u;
+ }
+
+ inline void predict(example& ec, size_t i=0)
+ {
+ ec.ft_offset += (uint32_t)(increment*i);
+ learn_fd.predict_f(learn_fd.data, *learn_fd.base, ec);
+ ec.ft_offset -= (uint32_t)(increment*i);
+ }
+ inline void set_predict(void (*u)(T& data, base_learner& base, example&))
+ { learn_fd.predict_f = (tlearn)u; }
+
+ inline void update(example& ec, size_t i=0)
+ {
+ ec.ft_offset += (uint32_t)(increment*i);
+ learn_fd.update_f(learn_fd.data, *learn_fd.base, ec);
+ ec.ft_offset -= (uint32_t)(increment*i);
+ }
+ inline void set_update(void (*u)(T& data, base_learner& base, example&))
+ { learn_fd.update_f = (tlearn)u; }
+
+ //called anytime saving or loading needs to happen. Autorecursive.
+ inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); }
+ inline void set_save_load(void (*sl)(T&, io_buf&, bool, bool))
+ { save_load_fd.save_load_f = (tsl)sl;
+ save_load_fd.data = learn_fd.data;
+ save_load_fd.base = learn_fd.base;}
+
+ //called to clean up state. Autorecursive.
+ void set_finish(void (*f)(T&)) { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); }
+ inline void finish()
+ {
+ if (finisher_fd.data)
+ {finisher_fd.func(finisher_fd.data); free(finisher_fd.data); }
+ if (finisher_fd.base) {
+ finisher_fd.base->finish();
+ free(finisher_fd.base);
+ }
+ }
+
+ void end_pass(){
+ end_pass_fd.func(end_pass_fd.data);
+ if (end_pass_fd.base) end_pass_fd.base->end_pass(); }//autorecursive
+ void set_end_pass(void (*f)(T&))
+ {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, (tfunc)f);}
+
+ //called after parsing of examples is complete. Autorecursive.
+ void end_examples()
+ { end_examples_fd.func(end_examples_fd.data);
+ if (end_examples_fd.base) end_examples_fd.base->end_examples(); }
+ void set_end_examples(void (*f)(T&))
+ {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f);}
+
+ //Called at the beginning by the driver. Explicitly not recursive.
+ void init_driver() { init_fd.func(init_fd.data);}
+ void set_init_driver(void (*f)(T&))
+ { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); }
+
+ //called after learn example for each example. Explicitly not recursive.
+ inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);}
+ void set_finish_example(void (*f)(vw& all, T&, example&))
+ {finish_example_fd.data = learn_fd.data;
+ finish_example_fd.finish_example_f = (tend_example)f;}
+
+ friend learner<T>& init_learner<>();
+ friend learner<T>& init_learner<>(T* dat, size_t params_per_weight);
+ friend learner<T>& init_learner<>(T* dat, base_learner* base, size_t ws);
};
-
-struct learner {
-private:
- func_data init_fd;
- learn_data learn_fd;
- finish_example_data finish_example_fd;
- save_load_data save_load_fd;
- func_data end_pass_fd;
- func_data end_examples_fd;
- func_data finisher_fd;
-public:
- size_t weights; //this stores the number of "weight vectors" required by the learner.
- size_t increment;
-
- //called once for each example. Must work under reduction.
- inline void learn(example& ec, size_t i=0)
- {
- ec.ft_offset += (uint32_t)(increment*i);
- learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec);
- ec.ft_offset -= (uint32_t)(increment*i);
- }
- template <class T, void (*u)(T& data, learner& base, example&)>
- inline void set_learn()
- {
- learn_fd.learn_f = tlearn<T,u>;
- learn_fd.update_f = tlearn<T,u>;
- }
-
- inline void predict(example& ec, size_t i=0)
- {
- ec.ft_offset += (uint32_t)(increment*i);
- learn_fd.predict_f(learn_fd.data, *learn_fd.base, ec);
- ec.ft_offset -= (uint32_t)(increment*i);
- }
- template <class T, void (*u)(T& data, learner& base, example&)>
- inline void set_predict()
- {
- learn_fd.predict_f = tlearn<T,u>;
- }
-
- inline void update(example& ec, size_t i=0)
- {
- ec.ft_offset += (uint32_t)(increment*i);
- learn_fd.update_f(learn_fd.data, *learn_fd.base, ec);
- ec.ft_offset -= (uint32_t)(increment*i);
- }
- template <class T, void (*u)(T& data, learner& base, example&)>
- inline void set_update()
- {
- learn_fd.update_f = tlearn<T,u>;
- }
-
- //called anytime saving or loading needs to happen. Autorecursive.
- inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); }
- template <class T, void (*sl)(T&, io_buf&, bool, bool)>
- inline void set_save_load()
- { save_load_fd.save_load_f = tsl<T,sl>;
- save_load_fd.data = learn_fd.data;
- save_load_fd.base = learn_fd.base;}
-
- //called to clean up state. Autorecursive.
- template <class T, void (*f)(T&)>
- void set_finish() { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc<T, f>); }
- inline void finish()
- {
- if (finisher_fd.data)
- {finisher_fd.func(finisher_fd.data); free(finisher_fd.data); }
- if (finisher_fd.base) {
- finisher_fd.base->finish();
- delete finisher_fd.base;
+ template<class T> learner<T>& init_learner()
+ {
+ learner<T>& ret = calloc_or_die<learner<T> >();
+ ret.weights = 1;
+ ret.increment = 1;
+ ret.learn_fd = LEARNER::generic_learn_fd;
+ ret.finish_example_fd.data = NULL;
+ ret.finish_example_fd.finish_example_f = return_simple_example;
+ ret.end_pass_fd = LEARNER::generic_func_fd;
+ ret.end_examples_fd = LEARNER::generic_func_fd;
+ ret.init_fd = LEARNER::generic_func_fd;
+ ret.finisher_fd = LEARNER::generic_func_fd;
+ ret.save_load_fd = LEARNER::generic_save_load_fd;
+ return ret;
}
- }
-
- void end_pass(){
- end_pass_fd.func(end_pass_fd.data);
- if (end_pass_fd.base) end_pass_fd.base->end_pass(); }//autorecursive
- template <class T, void (*f)(T&)>
- void set_end_pass() {end_pass_fd = tuple_dbf(learn_fd.data, learn_fd.base, tfunc<T,f>);}
-
- //called after parsing of examples is complete. Autorecursive.
- void end_examples()
- { end_examples_fd.func(end_examples_fd.data);
- if (end_examples_fd.base) end_examples_fd.base->end_examples(); }
- template <class T, void (*f)(T&)>
- void set_end_examples() {end_examples_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc<T,f>);}
-
- //Called at the beginning by the driver. Explicitly not recursive.
- void init_driver() { init_fd.func(init_fd.data);}
- template <class T, void (*f)(T&)>
- void set_init_driver() { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, tfunc<T,f>); }
-
- //called after learn example for each example. Explicitly not recursive.
- inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);}
- template<class T, void (*f)(vw& all, T&, example&)>
- void set_finish_example()
- {finish_example_fd.data = learn_fd.data;
- finish_example_fd.finish_example_f = tend_example<T,f>;}
-
- inline learner()
- {
- weights = 1;
- increment = 1;
-
- learn_fd = LEARNER::generic_learn_fd;
- finish_example_fd.data = NULL;
- finish_example_fd.finish_example_f = return_simple_example;
- end_pass_fd = LEARNER::generic_func_fd;
- end_examples_fd = LEARNER::generic_func_fd;
- init_fd = LEARNER::generic_func_fd;
- finisher_fd = LEARNER::generic_func_fd;
- save_load_fd = LEARNER::generic_save_load_fd;
- }
-
- inline learner(void* dat, size_t params_per_weight)
- { // the constructor for all learning algorithms.
- *this = learner();
-
- learn_fd.data = dat;
-
- finisher_fd.data = dat;
- finisher_fd.base = NULL;
- finisher_fd.func = LEARNER::generic_func;
-
- increment = params_per_weight;
- }
-
- inline learner(void *dat, learner* base, size_t ws = 1)
- { //the reduction constructor, with separate learn and predict functions
- *this = *base;
-
- learn_fd.data = dat;
- learn_fd.base = base;
-
- finisher_fd.data = dat;
- finisher_fd.base = base;
- finisher_fd.func = LEARNER::generic_func;
-
- weights = ws;
- increment = base->increment * weights;
- }
-};
-
+
+ template<class T> learner<T>& init_learner(T* dat, size_t params_per_weight)
+ { // the constructor for all learning algorithms.
+ learner<T>& ret = init_learner<T>();
+
+ ret.learn_fd.data = dat;
+
+ ret.finisher_fd.data = dat;
+ ret.finisher_fd.base = NULL;
+ ret.finisher_fd.func = LEARNER::generic_func;
+
+ ret.increment = params_per_weight;
+ return ret;
+ }
+
+ template<class T> learner<T>& init_learner(T* dat, base_learner* base, size_t ws = 1)
+ { //the reduction constructor, with separate learn and predict functions
+ learner<T>& ret = calloc_or_die<learner<T> >();
+ ret = *(learner<T>*)base;
+
+ ret.learn_fd.data = dat;
+ ret.learn_fd.base = base;
+
+ ret.finisher_fd.data = dat;
+ ret.finisher_fd.base = base;
+ ret.finisher_fd.func = LEARNER::generic_func;
+
+ ret.weights = ws;
+ ret.increment = base->increment * ret.weights;
+ return ret;
+ }
+
+ template<class T> base_learner* make_base(learner<T>& base)
+ { return (base_learner*)&base; }
}
diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc
index c1077320..b263d008 100644
--- a/vowpalwabbit/log_multi.cc
+++ b/vowpalwabbit/log_multi.cc
@@ -81,8 +81,8 @@ namespace LOG_MULTI
v_array<node> nodes;
- uint32_t max_predictors;
- uint32_t predictors_used;
+ size_t max_predictors;
+ size_t predictors_used;
bool progress;
uint32_t swap_resist;
@@ -246,7 +246,7 @@ namespace LOG_MULTI
return b.nodes[current].internal;
}
- void train_node(log_multi& b, learner& base, example& ec, uint32_t& current, uint32_t& class_index)
+ void train_node(log_multi& b, base_learner& base, example& ec, uint32_t& current, uint32_t& class_index)
{
if(b.nodes[current].norm_Eh > b.nodes[current].preds[class_index].norm_Ehk)
ec.l.simple.label = -1.f;
@@ -297,7 +297,7 @@ namespace LOG_MULTI
return n.right;
}
- void predict(log_multi& b, learner& base, example& ec)
+ void predict(log_multi& b, base_learner& base, example& ec)
{
MULTICLASS::multiclass mc = ec.l.multi;
@@ -316,7 +316,7 @@ namespace LOG_MULTI
ec.l.multi = mc;
}
- void learn(log_multi& b, learner& base, example& ec)
+ void learn(log_multi& b, base_learner& base, example& ec)
{
// verify_min_dfs(b, b.nodes[0]);
@@ -413,10 +413,10 @@ namespace LOG_MULTI
if (read)
for (uint32_t j = 1; j < temp; j++)
b.nodes.push_back(init_node());
- text_len = sprintf(buff, "max_predictors = %d ",b.max_predictors);
+ text_len = sprintf(buff, "max_predictors = %ld ",b.max_predictors);
bin_text_read_write_fixed(model_file,(char*)&b.max_predictors, sizeof(b.max_predictors), "", read, buff, text_len, text);
- text_len = sprintf(buff, "predictors_used = %d ",b.predictors_used);
+ text_len = sprintf(buff, "predictors_used = %ld ",b.predictors_used);
bin_text_read_write_fixed(model_file,(char*)&b.predictors_used, sizeof(b.predictors_used), "", read, buff, text_len, text);
text_len = sprintf(buff, "progress = %d ",b.progress);
@@ -502,7 +502,7 @@ namespace LOG_MULTI
VW::finish_example(all, &ec);
}
- learner* setup(vw& all, po::variables_map& vm) //learner setup
+ base_learner* setup(vw& all, po::variables_map& vm) //learner setup
{
po::options_description opts("Log Multi options");
opts.add_options()
@@ -513,23 +513,23 @@ namespace LOG_MULTI
if(!vm.count("log_multi"))
return NULL;
- log_multi* data = (log_multi*)calloc(1, sizeof(log_multi));
+ log_multi& data = calloc_or_die<log_multi>();
- data->k = (uint32_t)vm["log_multi"].as<size_t>();
- data->swap_resist = 4;
+ data.k = (uint32_t)vm["log_multi"].as<size_t>();
+ data.swap_resist = 4;
if (vm.count("swap_resistance"))
- data->swap_resist = vm["swap_resistance"].as<uint32_t>();
+ data.swap_resist = vm["swap_resistance"].as<uint32_t>();
//append log_multi with nb_actions to options_from_file so it is saved to regressor later
- all.file_options << " --log_multi " << data->k;
+ *all.file_options << " --log_multi " << data.k;
if (vm.count("no_progress"))
- data->progress = false;
+ data.progress = false;
else
- data->progress = true;
+ data.progress = true;
- data->all = &all;
+ data.all = &all;
(all.p->lp) = MULTICLASS::mc_label;
string loss_function = "quantile";
@@ -537,17 +537,17 @@ namespace LOG_MULTI
delete(all.loss);
all.loss = getLossFunction(&all, loss_function, loss_parameter);
- data->max_predictors = data->k - 1;
+ data.max_predictors = data.k - 1;
- learner* l = new learner(data, all.l, data->max_predictors);
- l->set_save_load<log_multi,save_load_tree>();
- l->set_learn<log_multi,learn>();
- l->set_predict<log_multi,predict>();
- l->set_finish_example<log_multi,finish_example>();
- l->set_finish<log_multi,finish>();
+ learner<log_multi>& l = init_learner(&data, all.l, data.max_predictors);
+ l.set_save_load(save_load_tree);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_finish_example(finish_example);
+ l.set_finish(finish);
- init_tree(*data);
+ init_tree(data);
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/log_multi.h b/vowpalwabbit/log_multi.h
index 26ecb435..5e1ee3bf 100644
--- a/vowpalwabbit/log_multi.h
+++ b/vowpalwabbit/log_multi.h
@@ -6,5 +6,5 @@ license as described in the file LICENSE.
#pragma once
namespace LOG_MULTI
{
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc
index ee826b72..c44b8692 100644
--- a/vowpalwabbit/lrq.cc
+++ b/vowpalwabbit/lrq.cc
@@ -62,7 +62,7 @@ namespace {
namespace LRQ {
template <bool is_learn>
- void predict_or_learn(LRQstate& lrq, learner& base, example& ec)
+ void predict_or_learn(LRQstate& lrq, base_learner& base, example& ec)
{
vw& all = *lrq.all;
@@ -187,7 +187,7 @@ namespace LRQ {
}
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{//parse and set arguments
po::options_description opts("Lrq options");
opts.add_options()
@@ -197,37 +197,37 @@ namespace LRQ {
if(!vm.count("lrq"))
return NULL;
- LRQstate* lrq = (LRQstate*)calloc(1, sizeof (LRQstate));
- unsigned int maxk = 0;
- lrq->all = &all;
+ LRQstate& lrq = calloc_or_die<LRQstate>();
+ size_t maxk = 0;
+ lrq.all = &all;
size_t random_seed = 0;
if (vm.count("random_seed")) random_seed = vm["random_seed"].as<size_t> ();
- lrq->initial_seed = lrq->seed = random_seed | 8675309;
+ lrq.initial_seed = lrq.seed = random_seed | 8675309;
if (vm.count("lrqdropout"))
- lrq->dropout = true;
+ lrq.dropout = true;
else
- lrq->dropout = false;
+ lrq.dropout = false;
- all.file_options << " --lrqdropout ";
+ *all.file_options << " --lrqdropout ";
- lrq->lrpairs = vm["lrq"].as<vector<string> > ();
+ lrq.lrpairs = vm["lrq"].as<vector<string> > ();
- for (vector<string>::iterator i = lrq->lrpairs.begin ();
- i != lrq->lrpairs.end ();
+ for (vector<string>::iterator i = lrq.lrpairs.begin ();
+ i != lrq.lrpairs.end ();
++i)
- all.file_options << " --lrq " << *i;
+ *all.file_options << " --lrq " << *i;
if (! all.quiet)
{
cerr << "creating low rank quadratic features for pairs: ";
- if (lrq->dropout)
+ if (lrq.dropout)
cerr << "(using dropout) ";
}
- for (vector<string>::iterator i = lrq->lrpairs.begin ();
- i != lrq->lrpairs.end ();
+ for (vector<string>::iterator i = lrq.lrpairs.begin ();
+ i != lrq.lrpairs.end ();
++i)
{
if(!all.quiet){
@@ -241,8 +241,8 @@ namespace LRQ {
unsigned int k = atoi (i->c_str () + 2);
- lrq->lrindices[(int) (*i)[0]] = 1;
- lrq->lrindices[(int) (*i)[1]] = 1;
+ lrq.lrindices[(int) (*i)[0]] = 1;
+ lrq.lrindices[(int) (*i)[1]] = 1;
maxk = max (maxk, k);
}
@@ -251,12 +251,12 @@ namespace LRQ {
cerr<<endl;
all.wpp = all.wpp * (1 + maxk);
- learner* l = new learner(lrq, all.l, 1 + maxk);
- l->set_learn<LRQstate, predict_or_learn<true> >();
- l->set_predict<LRQstate, predict_or_learn<false> >();
- l->set_end_pass<LRQstate,reset_seed>();
+ learner<LRQstate>& l = init_learner(&lrq, all.l, 1 + maxk);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_end_pass(reset_seed);
// TODO: leaks memory ?
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/lrq.h b/vowpalwabbit/lrq.h
index af0aae77..376bd6e5 100644
--- a/vowpalwabbit/lrq.h
+++ b/vowpalwabbit/lrq.h
@@ -1,4 +1,4 @@
#pragma once
namespace LRQ {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc
index e5d99939..31896efa 100644
--- a/vowpalwabbit/mf.cc
+++ b/vowpalwabbit/mf.cc
@@ -22,7 +22,7 @@ namespace MF {
struct mf {
vector<string> pairs;
- uint32_t rank;
+ size_t rank;
uint32_t increment;
@@ -43,7 +43,7 @@ struct mf {
};
template <bool cache_sub_predictions>
-void predict(mf& data, learner& base, example& ec) {
+void predict(mf& data, base_learner& base, example& ec) {
float prediction = 0;
if (cache_sub_predictions)
data.sub_predictions.resize(2*data.rank+1, true);
@@ -102,7 +102,7 @@ void predict(mf& data, learner& base, example& ec) {
ec.pred.scalar = GD::finalize_prediction(data.all->sd, ec.partial_prediction);
}
-void learn(mf& data, learner& base, example& ec) {
+void learn(mf& data, base_learner& base, example& ec) {
// predict with current weights
predict<true>(data, base, ec);
float predicted = ec.pred.scalar;
@@ -188,7 +188,7 @@ void finish(mf& o) {
o.sub_predictions.delete_v();
}
- learner* setup(vw& all, po::variables_map& vm) {
+ base_learner* setup(vw& all, po::variables_map& vm) {
po::options_description opts("MF options");
opts.add_options()
("new_mf", po::value<size_t>(), "rank for reduction-based matrix factorization");
@@ -196,23 +196,23 @@ void finish(mf& o) {
if(!vm.count("new_mf"))
return NULL;
- mf* data = new mf;
+ mf& data = calloc_or_die<mf>();
// copy global data locally
- data->all = &all;
- data->rank = (uint32_t)vm["new_mf"].as<size_t>();
+ data.all = &all;
+ data.rank = (uint32_t)vm["new_mf"].as<size_t>();
// store global pairs in local data structure and clear global pairs
// for eventual calls to base learner
- data->pairs = all.pairs;
+ data.pairs = all.pairs;
all.pairs.clear();
all.random_positive_weights = true;
- learner* l = new learner(data, all.l, 2*data->rank+1);
- l->set_learn<mf, learn>();
- l->set_predict<mf, predict<false> >();
- l->set_finish<mf,finish>();
- return l;
+ learner<mf>& l = init_learner(&data, all.l, 2*data.rank+1);
+ l.set_learn(learn);
+ l.set_predict(predict<false>);
+ l.set_finish(finish);
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/mf.h b/vowpalwabbit/mf.h
index 73fc7f00..90ddc33a 100644
--- a/vowpalwabbit/mf.h
+++ b/vowpalwabbit/mf.h
@@ -5,5 +5,5 @@ license as described in the file LICENSE.
*/
#pragma once
namespace MF{
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc
index 3a9d8e45..854a0e88 100644
--- a/vowpalwabbit/nn.cc
+++ b/vowpalwabbit/nn.cc
@@ -96,7 +96,7 @@ namespace NN {
}
template <bool is_learn>
- void predict_or_learn(nn& n, learner& base, example& ec)
+ void predict_or_learn(nn& n, base_learner& base, example& ec)
{
bool shouldOutput = n.all->raw_prediction > 0;
@@ -308,7 +308,7 @@ CONVERSE: // That's right, I'm using goto. So sue me.
free (n.output_layer.atomics[nn_output_namespace].begin);
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("NN options");
opts.add_options()
@@ -324,11 +324,11 @@ CONVERSE: // That's right, I'm using goto. So sue me.
n.all = &all;
//first parse for number of hidden units
n.k = (uint32_t)vm["nn"].as<size_t>();
- all.file_options << " --nn " << n.k;
+ *all.file_options << " --nn " << n.k;
if ( vm.count("dropout") ) {
n.dropout = true;
- all.file_options << " --dropout ";
+ *all.file_options << " --dropout ";
}
if ( vm.count("meanfield") ) {
@@ -347,7 +347,7 @@ CONVERSE: // That's right, I'm using goto. So sue me.
if (vm.count ("inpass")) {
n.inpass = true;
- all.file_options << " --inpass";
+ *all.file_options << " --inpass";
}
@@ -366,13 +366,13 @@ CONVERSE: // That's right, I'm using goto. So sue me.
n.save_xsubi = n.xsubi;
n.increment = all.l->increment;//Indexing of output layer is odd.
- learner* l = new learner(&n, all.l, n.k+1);
- l->set_learn<nn, predict_or_learn<true> >();
- l->set_predict<nn, predict_or_learn<false> >();
- l->set_finish<nn, finish>();
- l->set_finish_example<nn, finish_example>();
- l->set_end_pass<nn,end_pass>();
-
- return l;
+ learner<nn>& l = init_learner(&n, all.l, n.k+1);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish(finish);
+ l.set_finish_example(finish_example);
+ l.set_end_pass(end_pass);
+
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/nn.h b/vowpalwabbit/nn.h
index fb090873..820157d7 100644
--- a/vowpalwabbit/nn.h
+++ b/vowpalwabbit/nn.h
@@ -6,5 +6,5 @@ license as described in the file LICENSE.
#pragma once
namespace NN
{
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/noop.cc b/vowpalwabbit/noop.cc
index 57cf7264..8db5a795 100644
--- a/vowpalwabbit/noop.cc
+++ b/vowpalwabbit/noop.cc
@@ -10,7 +10,7 @@ license as described in the file LICENSE.
using namespace LEARNER;
namespace NOOP {
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("Noop options");
opts.add_options()
@@ -19,6 +19,6 @@ namespace NOOP {
if(!vm.count("noop"))
return NULL;
- return new learner();
+ return &init_learner<char>();
}
}
diff --git a/vowpalwabbit/noop.h b/vowpalwabbit/noop.h
index e314498a..ac8842e9 100644
--- a/vowpalwabbit/noop.h
+++ b/vowpalwabbit/noop.h
@@ -4,6 +4,5 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace NOOP {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
-}
+namespace NOOP
+{ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);}
diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc
index 1a132045..869ce94d 100644
--- a/vowpalwabbit/oaa.cc
+++ b/vowpalwabbit/oaa.cc
@@ -3,12 +3,7 @@ Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved. Released under a BSD (revised)
license as described in the file LICENSE.
*/
-#include <float.h>
-#include <limits.h>
-#include <math.h>
-#include <stdio.h>
#include <sstream>
-
#include "multiclass.h"
#include "simple_label.h"
#include "reductions.h"
@@ -16,7 +11,6 @@ license as described in the file LICENSE.
using namespace std;
using namespace LEARNER;
-using namespace MULTICLASS;
namespace OAA {
struct oaa{
@@ -26,8 +20,8 @@ namespace OAA {
};
template <bool is_learn>
- void predict_or_learn(oaa& o, learner& base, example& ec) {
- multiclass mc_label_data = ec.l.multi;
+ void predict_or_learn(oaa& o, base_learner& base, example& ec) {
+ MULTICLASS::multiclass mc_label_data = ec.l.multi;
if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1))
cout << "label " << mc_label_data.label << " is not in {1,"<< o.k << "} This won't work right." << endl;
@@ -75,7 +69,7 @@ namespace OAA {
VW::finish_example(all, &ec);
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("One-against-all options");
opts.add_options()
@@ -85,20 +79,19 @@ namespace OAA {
return NULL;
oaa& data = calloc_or_die<oaa>();
- //first parse for number of actions
- data.k = vm["oaa"].as<size_t>();
- //append oaa with nb_actions to options_from_file so it is saved to regressor later
- all.file_options << " --oaa " << data.k;
+ data.k = vm["oaa"].as<size_t>();
data.shouldOutput = all.raw_prediction > 0;
data.all = &all;
- all.p->lp = mc_label;
- learner* l = new learner(&data, all.l, data.k);
- l->set_learn<oaa, predict_or_learn<true> >();
- l->set_predict<oaa, predict_or_learn<false> >();
- l->set_finish_example<oaa, finish_example>();
+ *all.file_options << " --oaa " << data.k;
+ all.p->lp = MULTICLASS::mc_label;
+
+ learner<oaa>& l = init_learner(&data, all.l, data.k);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish_example(finish_example);
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/oaa.h b/vowpalwabbit/oaa.h
index 6b47127f..de1b08ab 100644
--- a/vowpalwabbit/oaa.h
+++ b/vowpalwabbit/oaa.h
@@ -5,6 +5,4 @@ license as described in the file LICENSE.
*/
#pragma once
namespace OAA
-{
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
-}
+{ LEARNER::base_learner* setup(vw& all, po::variables_map& vm); }
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index 56e3a89b..201389b4 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -354,7 +354,7 @@ void parse_feature_tweaks(vw& all, po::variables_map& vm)
if (vm.count("affix")) {
parse_affix_argument(all, vm["affix"].as<string>());
- all.file_options << " --affix " << vm["affix"].as<string>();
+ *all.file_options << " --affix " << vm["affix"].as<string>();
}
if(vm.count("ngram")){
@@ -746,9 +746,9 @@ void load_input_model(vw& all, po::variables_map& vm, io_buf& io_temp)
}
}
-LEARNER::learner* setup_base(vw& all, po::variables_map& vm)
+LEARNER::base_learner* setup_base(vw& all, po::variables_map& vm)
{
- LEARNER::learner* ret = all.reduction_stack.pop()(all,vm);
+ LEARNER::base_learner* ret = all.reduction_stack.pop()(all,vm);
if (ret == NULL)
return setup_next(all,vm);
else
@@ -760,6 +760,7 @@ void parse_reductions(vw& all, po::variables_map& vm)
//Base algorithms
all.reduction_stack.push_back(GD::setup);
all.reduction_stack.push_back(KSVM::setup);
+ all.reduction_stack.push_back(FTRL::setup);
all.reduction_stack.push_back(SENDER::setup);
all.reduction_stack.push_back(GDMF::setup);
all.reduction_stack.push_back(PRINT::setup);
@@ -896,7 +897,7 @@ vw* parse_args(int argc, char *argv[])
parse_regressor_args(*all, vm, io_temp);
int temp_argc = 0;
- char** temp_argv = VW::get_argv_from_string(all->file_options.str(), temp_argc);
+ char** temp_argv = VW::get_argv_from_string(all->file_options->str(), temp_argc);
add_to_args(*all, temp_argc, temp_argv);
for (int i = 0; i < temp_argc; i++)
free(temp_argv[i]);
@@ -910,7 +911,7 @@ vw* parse_args(int argc, char *argv[])
po::store(pos, vm);
po::notify(vm);
- all->file_options.str("");
+ all->file_options->str("");
parse_feature_tweaks(*all, vm); //feature tweaks
@@ -966,14 +967,14 @@ vw* parse_args(int argc, char *argv[])
}
namespace VW {
- void cmd_string_replace_value( std::stringstream& ss, string flag_to_replace, string new_value )
+ void cmd_string_replace_value( std::stringstream*& ss, string flag_to_replace, string new_value )
{
flag_to_replace.append(" "); //add a space to make sure we obtain the right flag in case 2 flags start with the same set of characters
- string cmd = ss.str();
+ string cmd = ss->str();
size_t pos = cmd.find(flag_to_replace);
if( pos == string::npos )
//flag currently not present in command string, so just append it to command string
- ss << " " << flag_to_replace << new_value;
+ *ss << " " << flag_to_replace << new_value;
else {
//flag is present, need to replace old value with new value
@@ -989,7 +990,7 @@ namespace VW {
else
//replace characters between pos and pos_after_value by new_value
cmd.replace(pos,pos_after_value-pos,new_value);
- ss.str(cmd);
+ ss->str(cmd);
}
}
@@ -1044,7 +1045,7 @@ namespace VW {
{
finalize_regressor(all, all.final_regressor_name);
all.l->finish();
- delete all.l;
+ free_it(all.l);
if (all.reg.weight_vector != NULL)
free(all.reg.weight_vector);
free_parser(all);
@@ -1053,6 +1054,7 @@ namespace VW {
all.p->parse_name.delete_v();
free(all.p);
free(all.sd);
+ delete all.file_options;
for (size_t i = 0; i < all.final_prediction_sink.size(); i++)
if (all.final_prediction_sink[i] != 1)
io_buf::close_file_or_socket(all.final_prediction_sink[i]);
diff --git a/vowpalwabbit/parse_args.h b/vowpalwabbit/parse_args.h
index 63e47ee3..23531050 100644
--- a/vowpalwabbit/parse_args.h
+++ b/vowpalwabbit/parse_args.h
@@ -7,4 +7,4 @@ license as described in the file LICENSE.
#include "global_data.h"
vw* parse_args(int argc, char *argv[]);
-LEARNER::learner* setup_next(vw& all, po::variables_map& vm);
+LEARNER::base_learner* setup_next(vw& all, po::variables_map& vm);
diff --git a/vowpalwabbit/parse_regressor.cc b/vowpalwabbit/parse_regressor.cc
index daee30da..def3a8de 100644
--- a/vowpalwabbit/parse_regressor.cc
+++ b/vowpalwabbit/parse_regressor.cc
@@ -229,16 +229,16 @@ void save_load_header(vw& all, io_buf& model_file, bool read, bool text)
"", read,
"\n",1, text);
- text_len = sprintf(buff, "options:%s\n", all.file_options.str().c_str());
- uint32_t len = (uint32_t)all.file_options.str().length()+1;
- memcpy(buff2, all.file_options.str().c_str(),len);
+ text_len = sprintf(buff, "options:%s\n", all.file_options->str().c_str());
+ uint32_t len = (uint32_t)all.file_options->str().length()+1;
+ memcpy(buff2, all.file_options->str().c_str(),len);
if (read)
len = buf_size;
bin_text_read_write(model_file,buff2, len,
"", read,
buff, text_len, text);
if (read)
- all.file_options.str(buff2);
+ all.file_options->str(buff2);
}
}
@@ -348,7 +348,7 @@ void parse_mask_regressor_args(vw& all, po::variables_map& vm){
}
} else {
// If no initial regressor, just clear out the options loaded from the header.
- all.file_options.str("");
+ all.file_options->str("");
}
}
}
diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc
index 503b99e5..9f4a93f2 100644
--- a/vowpalwabbit/print.cc
+++ b/vowpalwabbit/print.cc
@@ -22,7 +22,7 @@ namespace PRINT
cout << " ";
}
- void learn(print& p, learner& base, example& ec)
+ void learn(print& p, base_learner& base, example& ec)
{
label_data& ld = ec.l.simple;
if (ld.label != FLT_MAX)
@@ -45,7 +45,7 @@ namespace PRINT
cout << endl;
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("Print options");
opts.add_options()
@@ -61,9 +61,9 @@ namespace PRINT
all.reg.weight_mask = (length << all.reg.stride_shift) - 1;
all.reg.stride_shift = 0;
- learner* ret = new learner(&p, 1);
- ret->set_learn<print,learn>();
- ret->set_predict<print,learn>();
- return ret;
+ learner<print>& ret = init_learner(&p, 1);
+ ret.set_learn(learn);
+ ret.set_predict(learn);
+ return make_base(ret);
}
}
diff --git a/vowpalwabbit/print.h b/vowpalwabbit/print.h
index d3628eff..2c855eaa 100644
--- a/vowpalwabbit/print.h
+++ b/vowpalwabbit/print.h
@@ -4,6 +4,5 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace PRINT {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
-}
+namespace PRINT
+{ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);}
diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc
index a889a2ed..51be45f2 100644
--- a/vowpalwabbit/scorer.cc
+++ b/vowpalwabbit/scorer.cc
@@ -10,7 +10,7 @@ namespace Scorer {
};
template <bool is_learn, float (*link)(float in)>
- void predict_or_learn(scorer& s, learner& base, example& ec)
+ void predict_or_learn(scorer& s, base_learner& base, example& ec)
{
s.all->set_minmax(s.all->sd, ec.l.simple.label);
@@ -45,7 +45,7 @@ namespace Scorer {
return in;
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("Link options");
opts.add_options()
@@ -56,23 +56,23 @@ namespace Scorer {
scorer& s = calloc_or_die<scorer>();
s.all = &all;
- learner* l = new learner(&s, all.l);
+ learner<scorer>& l = init_learner(&s, all.l);
if (!vm.count("link") || link.compare("identity") == 0)
{
- l->set_learn<scorer, predict_or_learn<true, noop> >();
- l->set_predict<scorer, predict_or_learn<false, noop> >();
+ l.set_learn(predict_or_learn<true, noop> );
+ l.set_predict(predict_or_learn<false, noop> );
}
else if (link.compare("logistic") == 0)
{
- all.file_options << " --link=logistic ";
- l->set_learn<scorer, predict_or_learn<true, logistic> >();
- l->set_predict<scorer, predict_or_learn<false, logistic> >();
+ *all.file_options << " --link=logistic ";
+ l.set_learn(predict_or_learn<true, logistic> );
+ l.set_predict(predict_or_learn<false, logistic>);
}
else if (link.compare("glf1") == 0)
{
- all.file_options << " --link=glf1 ";
- l->set_learn<scorer, predict_or_learn<true, glf1> >();
- l->set_predict<scorer, predict_or_learn<false, glf1> >();
+ *all.file_options << " --link=glf1 ";
+ l.set_learn(predict_or_learn<true, glf1>);
+ l.set_predict(predict_or_learn<false, glf1>);
}
else
{
@@ -80,6 +80,6 @@ namespace Scorer {
throw exception();
}
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/scorer.h b/vowpalwabbit/scorer.h
index 3405b2e9..2d0ec294 100644
--- a/vowpalwabbit/scorer.h
+++ b/vowpalwabbit/scorer.h
@@ -1,4 +1,4 @@
#pragma once
namespace Scorer {
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc
index 2907f5c7..e93b4ad6 100644
--- a/vowpalwabbit/search.cc
+++ b/vowpalwabbit/search.cc
@@ -175,7 +175,7 @@ namespace Search {
v_array<size_t> timesteps;
v_array<float> learn_losses;
- LEARNER::learner* base_learner;
+ LEARNER::base_learner* base_learner;
clock_t start_clock_time;
example*empty_example;
@@ -1417,7 +1417,7 @@ namespace Search {
}
template <bool is_learn>
- void search_predict_or_learn(search& sch, learner& base, example& ec) {
+ void search_predict_or_learn(search& sch, base_learner& base, example& ec) {
search_private& priv = *sch.priv;
vw* all = priv.all;
priv.base_learner = &base;
@@ -1653,7 +1653,7 @@ namespace Search {
template<class T> void check_option(T& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, bool(*equal)(T,T), const char* mismatch_error_string, const char* required_error_string) {
if (vm.count(opt_name)) {
ret = vm[opt_name].as<T>();
- all.file_options << " --" << opt_name << " " << ret;
+ *all.file_options << " --" << opt_name << " " << ret;
} else if (strlen(required_error_string)>0) {
std::cerr << required_error_string << endl;
if (! vm.count("help"))
@@ -1664,7 +1664,7 @@ namespace Search {
void check_option(bool& ret, vw&all, po::variables_map& vm, const char* opt_name, bool default_to_cmdline, const char* mismatch_error_string) {
if (vm.count(opt_name)) {
ret = true;
- all.file_options << " --" << opt_name;
+ *all.file_options << " --" << opt_name;
} else
ret = false;
}
@@ -1764,7 +1764,7 @@ namespace Search {
delete[] cstr;
}
- learner* setup(vw&all, po::variables_map& vm) {
+ base_learner* setup(vw&all, po::variables_map& vm) {
po::options_description opts("Search Options");
opts.add_options()
("search", po::value<size_t>(), "use search-based structured prediction, argument=maximum action id or 0 for LDF")
@@ -1985,17 +1985,17 @@ namespace Search {
if (!vm.count("csoaa") && !vm.count("csoaa_ldf") && !vm.count("wap_ldf") && !vm.count("cb"))
vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["search"]));
- learner* base = setup_base(all,vm);
+ base_learner* base = setup_base(all,vm);
- learner* l = new learner(&sch, all.l, priv.total_number_of_policies);
- l->set_learn<search, search_predict_or_learn<true> >();
- l->set_predict<search, search_predict_or_learn<false> >();
- l->set_finish_example<search,finish_example>();
- l->set_end_examples<search,end_examples>();
- l->set_finish<search,search_finish>();
- l->set_end_pass<search,end_pass>();
-
- return l;
+ learner<search>& l = init_learner(&sch, all.l, priv.total_number_of_policies);
+ l.set_learn(search_predict_or_learn<true>);
+ l.set_predict(search_predict_or_learn<false>);
+ l.set_finish_example(finish_example);
+ l.set_end_examples(end_examples);
+ l.set_finish(search_finish);
+ l.set_end_pass(end_pass);
+
+ return make_base(l);
}
float action_hamming_loss(action a, const action* A, size_t sz) {
diff --git a/vowpalwabbit/search.h b/vowpalwabbit/search.h
index 07c73158..08633c5e 100644
--- a/vowpalwabbit/search.h
+++ b/vowpalwabbit/search.h
@@ -241,5 +241,5 @@ namespace Search {
bool size_equal(size_t a, size_t b);
// our interface within VW
- LEARNER::learner* setup(vw&, po::variables_map&);
+ LEARNER::base_learner* setup(vw&, po::variables_map&);
}
diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc
index 0a0a284a..26fdba51 100644
--- a/vowpalwabbit/sender.cc
+++ b/vowpalwabbit/sender.cc
@@ -69,7 +69,7 @@ void receive_result(sender& s)
return_simple_example(*(s.all), NULL, *ec);
}
- void learn(sender& s, learner& base, example& ec)
+ void learn(sender& s, base_learner& base, example& ec)
{
if (s.received_index + s.all->p->ring_size / 2 - 1 == s.sent_index)
receive_result(s);
@@ -100,7 +100,8 @@ void end_examples(sender& s)
delete s.buf;
}
-learner* setup(vw& all, po::variables_map& vm)
+
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("Sender options");
opts.add_options()
@@ -120,13 +121,13 @@ learner* setup(vw& all, po::variables_map& vm)
s.all = &all;
s.delay_ring = calloc_or_die<example*>(all.p->ring_size);
- learner* l = new learner(&s, 1);
- l->set_learn<sender, learn>();
- l->set_predict<sender, learn>();
- l->set_finish<sender, finish>();
- l->set_finish_example<sender, finish_example>();
- l->set_end_examples<sender, end_examples>();
- return l;
+ learner<sender>& l = init_learner(&s, 1);
+ l.set_learn(learn);
+ l.set_predict(learn);
+ l.set_finish(finish);
+ l.set_finish_example(finish_example);
+ l.set_end_examples(end_examples);
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/sender.h b/vowpalwabbit/sender.h
index 39713cdb..55f10754 100644
--- a/vowpalwabbit/sender.h
+++ b/vowpalwabbit/sender.h
@@ -4,6 +4,5 @@ individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#pragma once
-namespace SENDER{
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
-}
+namespace SENDER
+{ LEARNER::base_learner* setup(vw& all, po::variables_map& vm); }
diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc
index fb3dc425..fe4f6fbd 100644
--- a/vowpalwabbit/stagewise_poly.cc
+++ b/vowpalwabbit/stagewise_poly.cc
@@ -502,7 +502,7 @@ namespace StagewisePoly
}
}
- void predict(stagewise_poly &poly, learner &base, example &ec)
+ void predict(stagewise_poly &poly, base_learner &base, example &ec)
{
poly.original_ec = &ec;
synthetic_create(poly, ec, false);
@@ -511,7 +511,7 @@ namespace StagewisePoly
ec.updated_prediction = poly.synth_ec.updated_prediction;
}
- void learn(stagewise_poly &poly, learner &base, example &ec)
+ void learn(stagewise_poly &poly, base_learner &base, example &ec)
{
bool training = poly.all->training && ec.l.simple.label != FLT_MAX;
poly.original_ec = &ec;
@@ -656,7 +656,7 @@ namespace StagewisePoly
//#endif //DEBUG
}
- learner *setup(vw &all, po::variables_map &vm)
+ base_learner *setup(vw &all, po::variables_map &vm)
{
po::options_description opts("Stagewise poly options");
opts.add_options()
@@ -697,16 +697,16 @@ namespace StagewisePoly
poly.next_batch_sz = poly.batch_sz;
//following is so that saved models know to load us.
- all.file_options << " --stage_poly";
+ *all.file_options << " --stage_poly";
- learner *l = new learner(&poly, all.l);
- l->set_learn<stagewise_poly, learn>();
- l->set_predict<stagewise_poly, predict>();
- l->set_finish<stagewise_poly, finish>();
- l->set_save_load<stagewise_poly, save_load>();
- l->set_finish_example<stagewise_poly,finish_example>();
- l->set_end_pass<stagewise_poly, end_pass>();
+ learner<stagewise_poly>& l = init_learner(&poly, all.l);
+ l.set_learn(learn);
+ l.set_predict(predict);
+ l.set_finish(finish);
+ l.set_save_load(save_load);
+ l.set_finish_example(finish_example);
+ l.set_end_pass(end_pass);
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/stagewise_poly.h b/vowpalwabbit/stagewise_poly.h
index 4f5ac1fa..983b4382 100644
--- a/vowpalwabbit/stagewise_poly.h
+++ b/vowpalwabbit/stagewise_poly.h
@@ -6,5 +6,5 @@ license as described in the file LICENSE.
#pragma once
namespace StagewisePoly
{
- LEARNER::learner *setup(vw &all, po::variables_map &vm);
+ LEARNER::base_learner *setup(vw &all, po::variables_map &vm);
}
diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc
index 1fc86002..a7c80e1b 100644
--- a/vowpalwabbit/topk.cc
+++ b/vowpalwabbit/topk.cc
@@ -85,7 +85,7 @@ namespace TOPK {
}
template <bool is_learn>
- void predict_or_learn(topk& d, learner& base, example& ec)
+ void predict_or_learn(topk& d, base_learner& base, example& ec)
{
if (example_is_newline(ec)) return;//do not predict newline
@@ -111,7 +111,7 @@ namespace TOPK {
VW::finish_example(all, &ec);
}
- learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all, po::variables_map& vm)
{
po::options_description opts("TOP K options");
opts.add_options()
@@ -124,11 +124,11 @@ namespace TOPK {
data.B = (uint32_t)vm["top"].as<size_t>();
data.all = &all;
- learner* l = new learner(&data, all.l);
- l->set_learn<topk, predict_or_learn<true> >();
- l->set_predict<topk, predict_or_learn<false> >();
- l->set_finish_example<topk,finish_example>();
+ learner<topk>& l = init_learner(&data, all.l);
+ l.set_learn(predict_or_learn<true>);
+ l.set_predict(predict_or_learn<false>);
+ l.set_finish_example(finish_example);
- return l;
+ return make_base(l);
}
}
diff --git a/vowpalwabbit/topk.h b/vowpalwabbit/topk.h
index 2de41fe2..964ff618 100644
--- a/vowpalwabbit/topk.h
+++ b/vowpalwabbit/topk.h
@@ -6,5 +6,5 @@ license as described in the file LICENSE.
#pragma once
namespace TOPK
{
- LEARNER::learner* setup(vw& all, po::variables_map& vm);
+ LEARNER::base_learner* setup(vw& all, po::variables_map& vm);
}
diff --git a/vowpalwabbit/vw.h b/vowpalwabbit/vw.h
index 49bfaca3..0fa4ee77 100644
--- a/vowpalwabbit/vw.h
+++ b/vowpalwabbit/vw.h
@@ -18,7 +18,7 @@ namespace VW {
*/
vw* initialize(string s);
- void cmd_string_replace_value( std::stringstream& ss, string flag_to_replace, string new_value );
+ void cmd_string_replace_value( std::stringstream*& ss, string flag_to_replace, string new_value );
char** get_argv_from_string(string s, int& argc);