From 3ad8ebdedaca84f0843c904020f7b4bb4dfa2961 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 28 Dec 2014 19:28:37 -0500 Subject: init consistency enforced --- vowpalwabbit/active.cc | 16 ++++---- vowpalwabbit/autolink.cc | 5 +-- vowpalwabbit/bfgs.cc | 3 +- vowpalwabbit/binary.cc | 4 +- vowpalwabbit/bs.cc | 5 +-- vowpalwabbit/cb_algs.cc | 23 +++++------ vowpalwabbit/cbify.cc | 20 ++++------ vowpalwabbit/csoaa.cc | 9 ++--- vowpalwabbit/ect.cc | 4 +- vowpalwabbit/ftrl_proximal.cc | 3 +- vowpalwabbit/gd.cc | 46 ++++++++++----------- vowpalwabbit/gd_mf.cc | 3 +- vowpalwabbit/kernel_svm.cc | 3 +- vowpalwabbit/lda_core.cc | 3 +- vowpalwabbit/learner.h | 91 +++++++++++++++++++----------------------- vowpalwabbit/log_multi.cc | 22 +++++----- vowpalwabbit/lrq.cc | 5 +-- vowpalwabbit/mf.cc | 4 +- vowpalwabbit/nn.cc | 5 +-- vowpalwabbit/noop.cc | 6 +-- vowpalwabbit/oaa.cc | 5 +-- vowpalwabbit/print.cc | 4 +- vowpalwabbit/scorer.cc | 17 ++++---- vowpalwabbit/search.cc | 6 +-- vowpalwabbit/sender.cc | 4 +- vowpalwabbit/stagewise_poly.cc | 4 +- vowpalwabbit/topk.cc | 5 +-- 27 files changed, 140 insertions(+), 185 deletions(-) diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc index 201b70b4..627928d6 100644 --- a/vowpalwabbit/active.cc +++ b/vowpalwabbit/active.cc @@ -166,20 +166,18 @@ namespace ACTIVE { data.all=&all; //Create new learner - learner& ret = init_learner(&data, all.l); + learner* ret; if (vm.count("simulation")) - { - ret.set_learn(predict_or_learn_simulation); - ret.set_predict(predict_or_learn_simulation); - } + ret = &init_learner(&data, all.l, predict_or_learn_simulation, + predict_or_learn_simulation); else { all.active = true; - ret.set_learn(predict_or_learn_active); - ret.set_predict(predict_or_learn_active); - ret.set_finish_example(return_active_example); + ret = &init_learner(&data, all.l, predict_or_learn_active, + predict_or_learn_active); + ret->set_finish_example(return_active_example); } - return make_base(ret); + return make_base(*ret); } } diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index 88ac351f..6d4f0106 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -49,9 +49,8 @@ namespace ALINK { *all.file_options << " --autolink " << data.d; - learner& ret = init_learner(&data, all.l); - ret.set_learn(predict_or_learn); - ret.set_predict(predict_or_learn); + learner& ret = init_learner(&data, all.l, predict_or_learn, + predict_or_learn); return make_base(ret); } } diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index e337d65d..fc4ee851 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -1018,8 +1018,7 @@ base_learner* setup(vw& all, po::variables_map& vm) all.bfgs = true; all.reg.stride_shift = 2; - learner& l = init_learner(&b, 1 << all.reg.stride_shift); - l.set_learn(learn); + learner& l = init_learner(&b, learn, 1 << all.reg.stride_shift); l.set_predict(predict); l.set_save_load(save_load); l.set_init_driver(init_driver); diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc index dda3f4a3..0916dd6f 100644 --- a/vowpalwabbit/binary.cc +++ b/vowpalwabbit/binary.cc @@ -28,9 +28,7 @@ namespace BINARY { {//parse and set arguments all.sd->binary_label = true; //Create new learner - learner& ret = init_learner(NULL, all.l); - ret.set_learn(predict_or_learn); - ret.set_predict(predict_or_learn); + learner& ret = init_learner(NULL, all.l, predict_or_learn, predict_or_learn); return make_base(ret); } } diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc index f55467a6..78828632 100644 --- a/vowpalwabbit/bs.cc +++ b/vowpalwabbit/bs.cc @@ -280,9 +280,8 @@ namespace BS { data.pred_vec.reserve(data.B); data.all = &all; - learner& l = init_learner(&data, all.l, data.B); - l.set_learn(predict_or_learn); - l.set_predict(predict_or_learn); + learner& l = init_learner(&data, all.l, predict_or_learn, + predict_or_learn, data.B); l.set_finish_example(finish_example); l.set_finish(finish); diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc index 58193a05..c7fc2021 100644 --- a/vowpalwabbit/cb_algs.cc +++ b/vowpalwabbit/cb_algs.cc @@ -564,25 +564,24 @@ namespace CB_ALGS else all.p->lp = CB::cb_label; - learner& l = init_learner(&c, all.l, problem_multiplier); + learner* l; if (eval) { - l.set_learn(learn_eval); - l.set_predict(predict_eval); - l.set_finish_example(eval_finish_example); + l = &init_learner(&c, all.l, learn_eval, predict_eval, problem_multiplier); + l->set_finish_example(eval_finish_example); } else { - l.set_learn(predict_or_learn); - l.set_predict(predict_or_learn); - l.set_finish_example(finish_example); + l = &init_learner(&c, all.l, predict_or_learn, predict_or_learn, + problem_multiplier); + l->set_finish_example(finish_example); } - l.set_init_driver(init_driver); - l.set_finish(finish); // preserve the increment of the base learner since we are // _adding_ to the number of problems rather than multiplying. - l.increment = all.l->increment; - - return make_base(l); + l->increment = all.l->increment; + + l->set_init_driver(init_driver); + l->set_finish(finish); + return make_base(*l); } } diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index feded927..9d8d830f 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -409,9 +409,8 @@ namespace CBIFY { epsilon = vm["epsilon"].as(); data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k)); data.generic_explorer.reset(new GenericExplorer(*data.scorer.get(), (u32)data.k)); - l = &init_learner(&data, all.l, cover + 1); - l->set_learn(predict_or_learn_cover); - l->set_predict(predict_or_learn_cover); + l = &init_learner(&data, all.l, predict_or_learn_cover, + predict_or_learn_cover, cover + 1); } else if (vm.count("bag")) { @@ -421,18 +420,16 @@ namespace CBIFY { data.policies.push_back(unique_ptr>(new vw_policy(i))); } data.bootstrap_explorer.reset(new BootstrapExplorer(data.policies, (u32)data.k)); - l = &init_learner(&data, all.l, bags); - l->set_learn(predict_or_learn_bag); - l->set_predict(predict_or_learn_bag); + l = &init_learner(&data, all.l, predict_or_learn_bag, + predict_or_learn_bag, bags); } else if (vm.count("first") ) { uint32_t tau = (uint32_t)vm["first"].as(); data.policy.reset(new vw_policy()); data.tau_explorer.reset(new TauFirstExplorer(*data.policy.get(), (u32)tau, (u32)data.k)); - l = &init_learner(&data, all.l, 1); - l->set_learn(predict_or_learn_first); - l->set_predict(predict_or_learn_first); + l = &init_learner(&data, all.l, predict_or_learn_first, + predict_or_learn_first, 1); } else { @@ -441,9 +438,8 @@ namespace CBIFY { epsilon = vm["epsilon"].as(); data.policy.reset(new vw_policy()); data.greedy_explorer.reset(new EpsilonGreedyExplorer(*data.policy.get(), epsilon, (u32)data.k)); - l = &init_learner(&data, all.l, 1); - l->set_learn(predict_or_learn_greedy); - l->set_predict(predict_or_learn_greedy); + l = &init_learner(&data, all.l, predict_or_learn_greedy, + predict_or_learn_greedy, 1); } l->set_finish_example(finish_example); diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc index 26956d86..350d40a3 100644 --- a/vowpalwabbit/csoaa.cc +++ b/vowpalwabbit/csoaa.cc @@ -82,9 +82,8 @@ namespace CSOAA { all.p->lp = cs_label; all.sd->k = nb_actions; - learner& l = init_learner(&c, all.l, nb_actions); - l.set_learn(predict_or_learn); - l.set_predict(predict_or_learn); + learner& l = init_learner(&c, all.l, predict_or_learn, + predict_or_learn, nb_actions); l.set_finish_example(finish_example); return make_base(l); } @@ -709,9 +708,7 @@ namespace LabelDict { ld.read_example_this_loop = 0; ld.need_to_clear = false; - learner& l = init_learner(&ld, all.l); - l.set_learn(predict_or_learn); - l.set_predict(predict_or_learn); + learner& l = init_learner(&ld, all.l, predict_or_learn, predict_or_learn); if (ld.is_singleline) l.set_finish_example(finish_singleline_example); else diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index ec1aef5f..bcf15c02 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -390,9 +390,7 @@ namespace ECT size_t wpp = create_circuit(all, data, data.k, data.errors+1); data.all = &all; - learner& l = init_learner(&data, all.l, wpp); - l.set_learn(learn); - l.set_predict(predict); + learner& l = init_learner(&data, all.l, learn, predict, wpp); l.set_finish_example(finish_example); l.set_finish(finish); diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc index 15ad1a06..6216058b 100644 --- a/vowpalwabbit/ftrl_proximal.cc +++ b/vowpalwabbit/ftrl_proximal.cc @@ -217,8 +217,7 @@ namespace FTRL { cerr << "ftrl_beta = " << b.ftrl_beta << endl; } - learner& l = init_learner(&b, 1 << all.reg.stride_shift); - l.set_learn(learn); + learner& l = init_learner(&b, learn, 1 << all.reg.stride_shift); l.set_predict(predict); l.set_save_load(save_load); return make_base(l); diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc index 07a37788..851fbea7 100644 --- a/vowpalwabbit/gd.cc +++ b/vowpalwabbit/gd.cc @@ -41,6 +41,8 @@ namespace GD float neg_power_t; float update_multiplier; void (*predict)(gd&, base_learner&, example&); + void (*learn)(gd&, base_learner&, example&); + void (*update)(gd&, base_learner&, example&); vw* all; }; @@ -789,49 +791,49 @@ void save_load(gd& g, io_buf& model_file, bool read, bool text) } template -uint32_t set_learn(vw& all, learner& ret, bool feature_mask_off) +uint32_t set_learn(vw& all, bool feature_mask_off, gd& g) { all.normalized_idx = normalized; if (feature_mask_off) { - ret.set_learn(learn); - ret.set_update(update); + g.learn = learn; + g.update = update; return next; } else { - ret.set_learn(learn); - ret.set_update(update); + g.learn = learn; + g.update = update; return next; } } template -uint32_t set_learn(vw& all, learner& ret, bool feature_mask_off) +uint32_t set_learn(vw& all, bool feature_mask_off, gd& g) { if (all.invariant_updates) - return set_learn(all, ret, feature_mask_off); + return set_learn(all, feature_mask_off, g); else - return set_learn(all, ret, feature_mask_off); + return set_learn(all, feature_mask_off, g); } template -uint32_t set_learn(vw& all, learner& ret, bool feature_mask_off) +uint32_t set_learn(vw& all, bool feature_mask_off, gd& g) { // select the appropriate learn function based on adaptive, normalization, and feature mask if (all.normalized_updates) - return set_learn(all, ret, feature_mask_off); + return set_learn(all, feature_mask_off, g); else - return set_learn(all, ret, feature_mask_off); + return set_learn(all, feature_mask_off, g); } template -uint32_t set_learn(vw& all, learner& ret, bool feature_mask_off) +uint32_t set_learn(vw& all, bool feature_mask_off, gd& g) { if (all.adaptive) - return set_learn(all, ret, feature_mask_off); + return set_learn(all, feature_mask_off, g); else - return set_learn(all, ret, feature_mask_off); + return set_learn(all, feature_mask_off, g); } uint32_t ceil_log_2(uint32_t v) @@ -852,7 +854,7 @@ base_learner* setup(vw& all, po::variables_map& vm) g.early_stop_thres = 3; g.neg_norm_power = (all.adaptive ? (all.power_t - 1.f) : -1.f); g.neg_power_t = - all.power_t; - + if(all.initial_t > 0)//for the normalized update: if initial_t is bigger than 1 we interpret this as if we had seen (all.initial_t) previous fake datapoints all with norm 1 { g.all->normalized_sum_norm_x = all.initial_t; @@ -898,8 +900,6 @@ base_learner* setup(vw& all, po::variables_map& vm) cerr << "Warning: the learning rate for the last pass is multiplied by: " << pow((double)all.eta_decay_rate, (double)all.numpasses) << " adjust --decay_learning_rate larger to avoid this." << endl; - learner& ret = init_learner(&g, 1); - if (all.reg_mode % 2) if (all.audit || all.hash_inv) g.predict = predict; @@ -909,17 +909,17 @@ base_learner* setup(vw& all, po::variables_map& vm) g.predict = predict; else g.predict = predict; - ret.set_predict(g.predict); - + uint32_t stride; if (all.power_t == 0.5) - stride = set_learn(all, ret, feature_mask_off); + stride = set_learn(all, feature_mask_off, g); else - stride = set_learn(all, ret, feature_mask_off); - + stride = set_learn(all, feature_mask_off, g); all.reg.stride_shift = ceil_log_2(stride-1); - ret.increment = ((uint64_t)1 << all.reg.stride_shift); + learner& ret = init_learner(&g, g.learn, ((uint64_t)1 << all.reg.stride_shift)); + ret.set_predict(g.predict); + ret.set_update(g.update); ret.set_save_load(save_load); ret.set_end_pass(end_pass); return make_base(ret); diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc index c1daaa48..b57328f4 100644 --- a/vowpalwabbit/gd_mf.cc +++ b/vowpalwabbit/gd_mf.cc @@ -330,8 +330,7 @@ void mf_train(vw& all, example& ec) } all.eta *= powf((float)(all.sd->t), all.power_t); - learner& l = init_learner(&data, 1 << all.reg.stride_shift); - l.set_learn(learn); + learner& l = init_learner(&data, learn, 1 << all.reg.stride_shift); l.set_predict(predict); l.set_save_load(save_load); l.set_end_pass(end_pass); diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc index 0584958c..607d90b6 100644 --- a/vowpalwabbit/kernel_svm.cc +++ b/vowpalwabbit/kernel_svm.cc @@ -898,8 +898,7 @@ namespace KSVM params.all->reg.weight_mask = (uint32_t)LONG_MAX; params.all->reg.stride_shift = 0; - learner& l = init_learner(¶ms, 1); - l.set_learn(learn); + learner& l = init_learner(¶ms, learn, 1); l.set_predict(predict); l.set_save_load(save_load); l.set_finish(finish); diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index 69027dc3..8da81a4d 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -786,8 +786,7 @@ base_learner* setup(vw&all, po::variables_map& vm) ld.decay_levels.push_back(0.f); - learner& l = init_learner(&ld, 1 << all.reg.stride_shift); - l.set_learn(learn); + learner& l = init_learner(&ld, learn, 1 << all.reg.stride_shift); l.set_predict(predict); l.set_save_load(save_load); l.set_finish_example(finish_example); diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h index 45815bcc..2649429c 100644 --- a/vowpalwabbit/learner.h +++ b/vowpalwabbit/learner.h @@ -54,22 +54,18 @@ namespace LEARNER void generic_driver(vw& all); - inline void generic_sl(void*, io_buf&, bool, bool) {} - inline void generic_learner(void* data, base_learner& base, example&) {} - inline void generic_func(void* data) {} + inline void noop_sl(void*, io_buf&, bool, bool) {} + inline void noop(void* data) {} - const save_load_data generic_save_load_fd = {NULL, NULL, generic_sl}; - const learn_data generic_learn_fd = {NULL, NULL, generic_learner, generic_learner, NULL}; - const func_data generic_func_fd = {NULL, NULL, generic_func}; - typedef void (*tlearn)(void* d, base_learner& base, example& ec); typedef void (*tsl)(void* d, io_buf& io, bool read, bool text); typedef void (*tfunc)(void*d); typedef void (*tend_example)(vw& all, void* d, example& ec); - template learner& init_learner(); - template learner& init_learner(T* dat, size_t params_per_weight); - template learner& init_learner(T* dat, base_learner* base, size_t ws = 1); + template learner& init_learner(T*, void (*)(T&, base_learner&, example&), size_t); + template + learner& init_learner(T*, base_learner*, void (*learn)(T&, base_learner&, example&), + void (*predict)(T&, base_learner&, example&), size_t ws = 1); template struct learner { @@ -93,12 +89,6 @@ namespace LEARNER learn_fd.learn_f(learn_fd.data, *learn_fd.base, ec); ec.ft_offset -= (uint32_t)(increment*i); } - inline void set_learn(void (*u)(T& data, base_learner& base, example&)) - { - learn_fd.learn_f = (tlearn)u; - learn_fd.update_f = (tlearn)u; - } - inline void predict(example& ec, size_t i=0) { ec.ft_offset += (uint32_t)(increment*i); @@ -118,14 +108,17 @@ namespace LEARNER { learn_fd.update_f = (tlearn)u; } //called anytime saving or loading needs to happen. Autorecursive. - inline void save_load(io_buf& io, bool read, bool text) { save_load_fd.save_load_f(save_load_fd.data, io, read, text); if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); } + inline void save_load(io_buf& io, bool read, bool text) + { save_load_fd.save_load_f(save_load_fd.data, io, read, text); + if (save_load_fd.base) save_load_fd.base->save_load(io, read, text); } inline void set_save_load(void (*sl)(T&, io_buf&, bool, bool)) { save_load_fd.save_load_f = (tsl)sl; save_load_fd.data = learn_fd.data; save_load_fd.base = learn_fd.base;} //called to clean up state. Autorecursive. - void set_finish(void (*f)(T&)) { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } + void set_finish(void (*f)(T&)) + { finisher_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } inline void finish() { if (finisher_fd.data) @@ -155,63 +148,63 @@ namespace LEARNER { init_fd = tuple_dbf(learn_fd.data,learn_fd.base, (tfunc)f); } //called after learn example for each example. Explicitly not recursive. - inline void finish_example(vw& all, example& ec) { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);} + inline void finish_example(vw& all, example& ec) + { finish_example_fd.finish_example_f(all, finish_example_fd.data, ec);} void set_finish_example(void (*f)(vw& all, T&, example&)) {finish_example_fd.data = learn_fd.data; finish_example_fd.finish_example_f = (tend_example)f;} - friend learner& init_learner<>(); - friend learner& init_learner<>(T* dat, size_t params_per_weight); - friend learner& init_learner<>(T* dat, base_learner* base, size_t ws); + friend learner& init_learner<>(T*, void (*learn)(T&, base_learner&, example&), size_t); + friend learner& init_learner<>(T*, base_learner*, void (*l)(T&, base_learner&, example&), + void (*pred)(T&, base_learner&, example&), size_t); }; - template learner& init_learner() - { + template + learner& init_learner(T* dat, void (*learn)(T&, base_learner&, example&), + size_t params_per_weight) + { // the constructor for all learning algorithms. learner& ret = calloc_or_die >(); ret.weights = 1; - ret.increment = 1; - ret.learn_fd = LEARNER::generic_learn_fd; - ret.finish_example_fd.data = NULL; - ret.finish_example_fd.finish_example_f = return_simple_example; - ret.end_pass_fd = LEARNER::generic_func_fd; - ret.end_examples_fd = LEARNER::generic_func_fd; - ret.init_fd = LEARNER::generic_func_fd; - ret.finisher_fd = LEARNER::generic_func_fd; - ret.save_load_fd = LEARNER::generic_save_load_fd; - return ret; - } - - template learner& init_learner(T* dat, size_t params_per_weight) - { // the constructor for all learning algorithms. - learner& ret = init_learner(); - - ret.learn_fd.data = dat; - + ret.increment = params_per_weight; + ret.end_pass_fd.func = noop; + ret.end_examples_fd.func = noop; + ret.init_fd.func = noop; + ret.save_load_fd.save_load_f = noop_sl; ret.finisher_fd.data = dat; - ret.finisher_fd.base = NULL; - ret.finisher_fd.func = LEARNER::generic_func; + ret.finisher_fd.func = noop; - ret.increment = params_per_weight; + ret.learn_fd.data = dat; + ret.learn_fd.learn_f = (tlearn)learn; + ret.learn_fd.update_f = (tlearn)learn; + ret.learn_fd.predict_f = (tlearn)learn; + ret.finish_example_fd.data = dat; + ret.finish_example_fd.finish_example_f = return_simple_example; + return ret; } - template learner& init_learner(T* dat, base_learner* base, size_t ws = 1) + template + learner& init_learner(T* dat, base_learner* base, + void (*learn)(T&, base_learner&, example&), + void (*predict)(T&, base_learner&, example&), size_t ws = 1) { //the reduction constructor, with separate learn and predict functions learner& ret = calloc_or_die >(); ret = *(learner*)base; ret.learn_fd.data = dat; + ret.learn_fd.learn_f = (tlearn)learn; + ret.learn_fd.update_f = (tlearn)learn; + ret.learn_fd.predict_f = (tlearn)predict; ret.learn_fd.base = base; ret.finisher_fd.data = dat; ret.finisher_fd.base = base; - ret.finisher_fd.func = LEARNER::generic_func; + ret.finisher_fd.func = noop; ret.weights = ws; ret.increment = base->increment * ret.weights; return ret; } - template base_learner* make_base(learner& base) - { return (base_learner*)&base; } + template base_learner* make_base(learner& base) { return (base_learner*)&base; } } diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index a0d6fc35..19250a65 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -504,24 +504,24 @@ namespace LOG_MULTI base_learner* setup(vw& all, po::variables_map& vm) //learner setup { - log_multi* data = (log_multi*)calloc(1, sizeof(log_multi)); + log_multi& data = calloc_or_die(); po::options_description opts("TXM Online options"); opts.add_options() ("no_progress", "disable progressive validation") - ("swap_resistance", po::value(&(data->swap_resist))->default_value(4), "higher = more resistance to swap, default=4"); + ("swap_resistance", po::value(&(data.swap_resist))->default_value(4), "higher = more resistance to swap, default=4"); vm = add_options(all, opts); - data->k = (uint32_t)vm["log_multi"].as(); - *all.file_options << " --log_multi " << data->k; + data.k = (uint32_t)vm["log_multi"].as(); + *all.file_options << " --log_multi " << data.k; if (vm.count("no_progress")) - data->progress = false; + data.progress = false; else - data->progress = true; + data.progress = true; - data->all = &all; + data.all = &all; (all.p->lp) = MULTICLASS::mc_label; string loss_function = "quantile"; @@ -529,16 +529,14 @@ namespace LOG_MULTI delete(all.loss); all.loss = getLossFunction(&all, loss_function, loss_parameter); - data->max_predictors = data->k - 1; + data.max_predictors = data.k - 1; - learner& l = init_learner(data, all.l, data->max_predictors); + learner& l = init_learner(&data, all.l, learn, predict, data.max_predictors); l.set_save_load(save_load_tree); - l.set_learn(learn); - l.set_predict(predict); l.set_finish_example(finish_example); l.set_finish(finish); - init_tree(*data); + init_tree(data); return make_base(l); } diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc index a2c2205d..268815d5 100644 --- a/vowpalwabbit/lrq.cc +++ b/vowpalwabbit/lrq.cc @@ -243,9 +243,8 @@ namespace LRQ { cerr<& l = init_learner(&lrq, all.l, 1 + maxk); - l.set_learn(predict_or_learn); - l.set_predict(predict_or_learn); + learner& l = init_learner(&lrq, all.l, predict_or_learn, + predict_or_learn, 1 + maxk); l.set_end_pass(reset_seed); // TODO: leaks memory ? diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc index e392cb10..4e00be8d 100644 --- a/vowpalwabbit/mf.cc +++ b/vowpalwabbit/mf.cc @@ -203,9 +203,7 @@ base_learner* setup(vw& all, po::variables_map& vm) { all.random_positive_weights = true; - learner& l = init_learner(data, all.l, 2*data->rank+1); - l.set_learn(learn); - l.set_predict(predict); + learner& l = init_learner(data, all.l, learn, predict, 2*data->rank+1); l.set_finish(finish); return make_base(l); } diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index 428c9f02..bf9bc020 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -365,9 +365,8 @@ CONVERSE: // That's right, I'm using goto. So sue me. n.save_xsubi = n.xsubi; n.increment = all.l->increment;//Indexing of output layer is odd. - learner& l = init_learner(&n, all.l, n.k+1); - l.set_learn(predict_or_learn); - l.set_predict(predict_or_learn); + learner& l = init_learner(&n, all.l, predict_or_learn, + predict_or_learn, n.k+1); l.set_finish(finish); l.set_finish_example(finish_example); l.set_end_pass(end_pass); diff --git a/vowpalwabbit/noop.cc b/vowpalwabbit/noop.cc index 9baec922..797bfc3e 100644 --- a/vowpalwabbit/noop.cc +++ b/vowpalwabbit/noop.cc @@ -10,8 +10,8 @@ license as described in the file LICENSE. using namespace LEARNER; namespace NOOP { + void learn(char&, base_learner&, example&) {} + base_learner* setup(vw& all) - { - return new base_learner(); - } + { return &init_learner(NULL, learn, 1); } } diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 7ed02ad2..89729f36 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -80,9 +80,8 @@ namespace OAA { *all.file_options << " --oaa " << data.k; all.p->lp = MULTICLASS::mc_label; - learner& l = init_learner(&data, all.l, data.k); - l.set_learn(predict_or_learn); - l.set_predict(predict_or_learn); + learner& l = init_learner(&data, all.l, predict_or_learn, + predict_or_learn, data.k); l.set_finish_example(finish_example); return make_base(l); diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc index 19e0d4b3..8cfe5dd7 100644 --- a/vowpalwabbit/print.cc +++ b/vowpalwabbit/print.cc @@ -54,9 +54,7 @@ namespace PRINT all.reg.weight_mask = (length << all.reg.stride_shift) - 1; all.reg.stride_shift = 0; - learner& ret = init_learner(&p, 1); - ret.set_learn(learn); - ret.set_predict(learn); + learner& ret = init_learner(&p, learn, 1); return make_base(ret); } } diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index 58607b47..9f66b4cf 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -57,25 +57,22 @@ namespace Scorer { vm = add_options(all, link_opts); - learner& l = init_learner(&s, all.l); + learner* l; string link = vm["link"].as(); if (!vm.count("link") || link.compare("identity") == 0) - { - l.set_learn(predict_or_learn ); - l.set_predict(predict_or_learn ); - } + l = &init_learner(&s, all.l, predict_or_learn, predict_or_learn); else if (link.compare("logistic") == 0) { *all.file_options << " --link=logistic "; - l.set_learn(predict_or_learn ); - l.set_predict(predict_or_learn); + l = &init_learner(&s, all.l, predict_or_learn, + predict_or_learn); } else if (link.compare("glf1") == 0) { *all.file_options << " --link=glf1 "; - l.set_learn(predict_or_learn); - l.set_predict(predict_or_learn); + l = &init_learner(&s, all.l, predict_or_learn, + predict_or_learn); } else { @@ -83,6 +80,6 @@ namespace Scorer { throw exception(); } - return make_base(l); + return make_base(*l); } } diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index a702f348..63ba16d8 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -1981,9 +1981,9 @@ namespace Search { priv.start_clock_time = clock(); - learner& l = init_learner(&sch, all.l, priv.total_number_of_policies); - l.set_learn(search_predict_or_learn); - l.set_predict(search_predict_or_learn); + learner& l = init_learner(&sch, all.l, search_predict_or_learn, + search_predict_or_learn, + priv.total_number_of_policies); l.set_finish_example(finish_example); l.set_end_examples(end_examples); l.set_finish(search_finish); diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index 4995b663..3b14131f 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -113,9 +113,7 @@ void end_examples(sender& s) s.all = &all; s.delay_ring = calloc_or_die(all.p->ring_size); - learner& l = init_learner(&s, 1); - l.set_learn(learn); - l.set_predict(learn); + learner& l = init_learner(&s, learn, 1); l.set_finish(finish); l.set_finish_example(finish_example); l.set_end_examples(end_examples); diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc index 91f1ca44..b2e7e150 100644 --- a/vowpalwabbit/stagewise_poly.cc +++ b/vowpalwabbit/stagewise_poly.cc @@ -698,9 +698,7 @@ namespace StagewisePoly //following is so that saved models know to load us. *all.file_options << " --stage_poly"; - learner& l = init_learner(&poly, all.l); - l.set_learn(learn); - l.set_predict(predict); + learner& l = init_learner(&poly, all.l, learn, predict); l.set_finish(finish); l.set_save_load(save_load); l.set_finish_example(finish_example); diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index a614f390..8dd245f7 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -119,9 +119,8 @@ namespace TOPK { data.all = &all; - learner& l = init_learner(&data, all.l); - l.set_learn(predict_or_learn); - l.set_predict(predict_or_learn); + learner& l = init_learner(&data, all.l, predict_or_learn, + predict_or_learn); l.set_finish_example(finish_example); return make_base(l); -- cgit v1.2.3 From cb12c710d80dc55dedfc6d5289607b6ad69dbbd9 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 28 Dec 2014 19:31:59 -0500 Subject: minor tweaks --- vowpalwabbit/oaa.cc | 1 - vowpalwabbit/scorer.cc | 17 ++++------------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 89729f36..4c905e13 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -72,7 +72,6 @@ namespace OAA { base_learner* setup(vw& all, po::variables_map& vm) { oaa& data = calloc_or_die(); - data.k = vm["oaa"].as(); data.shouldOutput = all.raw_prediction > 0; data.all = &all; diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index 9f66b4cf..e93eb84f 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -5,9 +5,7 @@ using namespace LEARNER; namespace Scorer { - struct scorer{ - vw* all; - }; + struct scorer{ vw* all; }; template void predict_or_learn(scorer& s, base_learner& base, example& ec) @@ -27,23 +25,17 @@ namespace Scorer { // y = f(x) -> [0, 1] float logistic(float in) - { - return 1.f / (1.f + exp(- in)); - } + { return 1.f / (1.f + exp(- in)); } // http://en.wikipedia.org/wiki/Generalized_logistic_curve // where the lower & upper asymptotes are -1 & 1 respectively // 'glf1' stands for 'Generalized Logistic Function with [-1,1] range' // y = f(x) -> [-1, 1] float glf1(float in) - { - return 2.f / (1.f + exp(- in)) - 1.f; - } + { return 2.f / (1.f + exp(- in)) - 1.f; } float noop(float in) - { - return in; - } + { return in; } base_learner* setup(vw& all, po::variables_map& vm) { @@ -79,7 +71,6 @@ namespace Scorer { cerr << "Unknown link function: " << link << endl; throw exception(); } - return make_base(*l); } } -- cgit v1.2.3 From 82a148fd5266a3b932528e53f9e053edf72fecfe Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 28 Dec 2014 19:39:10 -0500 Subject: simpler oaa.cc --- vowpalwabbit/oaa.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 4c905e13..02a9eaf2 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -9,9 +9,6 @@ license as described in the file LICENSE. #include "reductions.h" #include "vw.h" -using namespace std; -using namespace LEARNER; - namespace OAA { struct oaa{ size_t k; @@ -20,7 +17,7 @@ namespace OAA { }; template - void predict_or_learn(oaa& o, base_learner& base, example& ec) { + void predict_or_learn(oaa& o, LEARNER::base_learner& base, example& ec) { MULTICLASS::multiclass mc_label_data = ec.l.multi; if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1)) cout << "label " << mc_label_data.label << " is not in {1,"<< o.k << "} This won't work right." << endl; @@ -69,7 +66,7 @@ namespace OAA { VW::finish_example(all, &ec); } - base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all, po::variables_map& vm) { oaa& data = calloc_or_die(); data.k = vm["oaa"].as(); @@ -79,10 +76,9 @@ namespace OAA { *all.file_options << " --oaa " << data.k; all.p->lp = MULTICLASS::mc_label; - learner& l = init_learner(&data, all.l, predict_or_learn, + LEARNER::learner& l = init_learner(&data, all.l, predict_or_learn, predict_or_learn, data.k); l.set_finish_example(finish_example); - return make_base(l); } } -- cgit v1.2.3 From 6ded0936d05668220e5e857fc96e08f2ce1939c4 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 28 Dec 2014 20:33:43 -0500 Subject: various simplifications --- vowpalwabbit/autolink.cc | 11 ++++------- vowpalwabbit/binary.cc | 29 ++++++++++++++++------------- vowpalwabbit/global_data.h | 1 - vowpalwabbit/kernel_svm.cc | 2 +- vowpalwabbit/log_multi.cc | 2 +- vowpalwabbit/loss_functions.cc | 20 ++++++++------------ vowpalwabbit/loss_functions.h | 5 ++--- vowpalwabbit/nn.cc | 2 +- vowpalwabbit/parse_args.cc | 2 +- vowpalwabbit/scorer.cc | 20 +++++++------------- vowpalwabbit/simple_label.cc | 3 --- 11 files changed, 41 insertions(+), 56 deletions(-) diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index 6d4f0106..7b3bebd8 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -1,8 +1,6 @@ #include "reductions.h" #include "simple_label.h" -using namespace LEARNER; - namespace ALINK { const int autoconstant = 524267083; @@ -11,8 +9,8 @@ namespace ALINK { uint32_t stride_shift; }; - template - void predict_or_learn(autolink& b, base_learner& base, example& ec) + template + void predict_or_learn(autolink& b, LEARNER::base_learner& base, example& ec) { base.predict(ec); float base_pred = ec.pred.scalar; @@ -30,7 +28,6 @@ namespace ALINK { } ec.total_sum_feat_sq += sum_sq; - // apply predict or learn if (is_learn) base.learn(ec); else @@ -41,7 +38,7 @@ namespace ALINK { ec.total_sum_feat_sq -= sum_sq; } - base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all, po::variables_map& vm) { autolink& data = calloc_or_die(); data.d = (uint32_t)vm["autolink"].as(); @@ -49,7 +46,7 @@ namespace ALINK { *all.file_options << " --autolink " << data.d; - learner& ret = init_learner(&data, all.l, predict_or_learn, + LEARNER::learner& ret = init_learner(&data, all.l, predict_or_learn, predict_or_learn); return make_base(ret); } diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc index 0916dd6f..04810e26 100644 --- a/vowpalwabbit/binary.cc +++ b/vowpalwabbit/binary.cc @@ -1,13 +1,11 @@ +#include #include "reductions.h" #include "multiclass.h" #include "simple_label.h" -using namespace LEARNER; - namespace BINARY { - template - void predict_or_learn(char&, base_learner& base, example& ec) { + void predict_or_learn(char&, LEARNER::base_learner& base, example& ec) { if (is_learn) base.learn(ec); else @@ -18,17 +16,22 @@ namespace BINARY { else ec.pred.scalar = -1; - if (ec.l.simple.label == ec.pred.scalar) - ec.loss = 0.; - else - ec.loss = ec.l.simple.weight; + if (ec.l.simple.label != FLT_MAX) + { + if (fabs(ec.l.simple.label) != 1.f) + cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl; + else + if (ec.l.simple.label == ec.pred.scalar) + ec.loss = 0.; + else + ec.loss = ec.l.simple.weight; + } } - base_learner* setup(vw& all, po::variables_map& vm) - {//parse and set arguments - all.sd->binary_label = true; - //Create new learner - learner& ret = init_learner(NULL, all.l, predict_or_learn, predict_or_learn); + LEARNER::base_learner* setup(vw& all, po::variables_map& vm) + { + LEARNER::learner& ret = + LEARNER::init_learner(NULL, all.l, predict_or_learn, predict_or_learn); return make_base(ret); } } diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h index fffe8b55..bdf76b46 100644 --- a/vowpalwabbit/global_data.h +++ b/vowpalwabbit/global_data.h @@ -154,7 +154,6 @@ struct shared_data { double holdout_sum_loss_since_last_pass; size_t holdout_best_pass; - bool binary_label; uint32_t k; }; diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc index 607d90b6..9a88f17e 100644 --- a/vowpalwabbit/kernel_svm.cc +++ b/vowpalwabbit/kernel_svm.cc @@ -810,7 +810,7 @@ namespace KSVM string loss_function = "hinge"; float loss_parameter = 0.0; delete all.loss; - all.loss = getLossFunction(&all, loss_function, (float)loss_parameter); + all.loss = getLossFunction(all, loss_function, (float)loss_parameter); svm_params& params = calloc_or_die(); params.model = &calloc_or_die(); diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index 19250a65..26f6b454 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -527,7 +527,7 @@ namespace LOG_MULTI string loss_function = "quantile"; float loss_parameter = 0.5; delete(all.loss); - all.loss = getLossFunction(&all, loss_function, loss_parameter); + all.loss = getLossFunction(all, loss_function, loss_parameter); data.max_predictors = data.k - 1; diff --git a/vowpalwabbit/loss_functions.cc b/vowpalwabbit/loss_functions.cc index 4d67ae5e..6badc633 100644 --- a/vowpalwabbit/loss_functions.cc +++ b/vowpalwabbit/loss_functions.cc @@ -297,21 +297,18 @@ public: float tau; }; -loss_function* getLossFunction(void* a, string funcName, float function_parameter) { - vw* all=(vw*)a; - if(funcName.compare("squared") == 0 || funcName.compare("Huber") == 0) { +loss_function* getLossFunction(vw& all, string funcName, float function_parameter) { + if(funcName.compare("squared") == 0 || funcName.compare("Huber") == 0) return new squaredloss(); - } else if(funcName.compare("classic") == 0){ + else if(funcName.compare("classic") == 0) return new classic_squaredloss(); - } else if(funcName.compare("hinge") == 0) { - all->sd->binary_label = true; + else if(funcName.compare("hinge") == 0) return new hingeloss(); - } else if(funcName.compare("logistic") == 0) { - if (all->set_minmax != noop_mm) + else if(funcName.compare("logistic") == 0) { + if (all.set_minmax != noop_mm) { - all->sd->min_label = -50; - all->sd->max_label = 50; - all->sd->binary_label = true; + all.sd->min_label = -50; + all.sd->max_label = 50; } return new logloss(); } else if(funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0) { @@ -320,5 +317,4 @@ loss_function* getLossFunction(void* a, string funcName, float function_paramete cout << "Invalid loss function name: \'" << funcName << "\' Bailing!" << endl; throw exception(); } - cout << "end getLossFunction" << endl; } diff --git a/vowpalwabbit/loss_functions.h b/vowpalwabbit/loss_functions.h index 421b3bfe..35b6f24b 100644 --- a/vowpalwabbit/loss_functions.h +++ b/vowpalwabbit/loss_functions.h @@ -8,8 +8,7 @@ license as described in the file LICENSE. #include "parse_primitives.h" struct shared_data; - -using namespace std; +struct vw; class loss_function { @@ -34,4 +33,4 @@ public : virtual ~loss_function() {}; }; -loss_function* getLossFunction(void*, string funcName, float function_parameter = 0); +loss_function* getLossFunction(vw&, std::string funcName, float function_parameter = 0); diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index bf9bc020..bfab6009 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -356,7 +356,7 @@ CONVERSE: // That's right, I'm using goto. So sue me. << std::endl; n.finished_setup = false; - n.squared_loss = getLossFunction (0, "squared", 0); + n.squared_loss = getLossFunction (all, "squared", 0); n.xsubi = 0; diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index 9f5d071a..ef65345c 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -600,7 +600,7 @@ void parse_example_tweaks(vw& all, po::variables_map& vm) if(vm.count("quantile_tau")) loss_parameter = vm["quantile_tau"].as(); - all.loss = getLossFunction(&all, loss_function, (float)loss_parameter); + all.loss = getLossFunction(all, loss_function, (float)loss_parameter); if (all.l1_lambda < 0.) { cerr << "l1_lambda should be nonnegative: resetting from " << all.l1_lambda << " to 0" << endl; diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index e93eb84f..50645ed8 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -1,14 +1,11 @@ #include - #include "reductions.h" -using namespace LEARNER; - namespace Scorer { struct scorer{ vw* all; }; template - void predict_or_learn(scorer& s, base_learner& base, example& ec) + void predict_or_learn(scorer& s, LEARNER::base_learner& base, example& ec) { s.all->set_minmax(s.all->sd, ec.l.simple.label); @@ -24,20 +21,17 @@ namespace Scorer { } // y = f(x) -> [0, 1] - float logistic(float in) - { return 1.f / (1.f + exp(- in)); } + float logistic(float in) { return 1.f / (1.f + exp(- in)); } // http://en.wikipedia.org/wiki/Generalized_logistic_curve // where the lower & upper asymptotes are -1 & 1 respectively // 'glf1' stands for 'Generalized Logistic Function with [-1,1] range' // y = f(x) -> [-1, 1] - float glf1(float in) - { return 2.f / (1.f + exp(- in)) - 1.f; } + float glf1(float in) { return 2.f / (1.f + exp(- in)) - 1.f; } - float noop(float in) - { return in; } + float id(float in) { return in; } - base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all, po::variables_map& vm) { scorer& s = calloc_or_die(); s.all = &all; @@ -49,11 +43,11 @@ namespace Scorer { vm = add_options(all, link_opts); - learner* l; + LEARNER::learner* l; string link = vm["link"].as(); if (!vm.count("link") || link.compare("identity") == 0) - l = &init_learner(&s, all.l, predict_or_learn, predict_or_learn); + l = &init_learner(&s, all.l, predict_or_learn, predict_or_learn); else if (link.compare("logistic") == 0) { *all.file_options << " --link=logistic "; diff --git a/vowpalwabbit/simple_label.cc b/vowpalwabbit/simple_label.cc index 3bf748de..150f95e9 100644 --- a/vowpalwabbit/simple_label.cc +++ b/vowpalwabbit/simple_label.cc @@ -96,9 +96,6 @@ void parse_simple_label(parser* p, shared_data* sd, void* v, v_array& cerr << "malformed example!\n"; cerr << "words.size() = " << words.size() << endl; } - if (words.size() > 0 && sd->binary_label && fabs(ld->label) != 1.f) - cout << "You are using a label not -1 or 1 with a loss function expecting that!" << endl; - count_label(ld->label); } -- cgit v1.2.3 From 19b21b995948621aa0df7500f3f397d3bb869bb7 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 28 Dec 2014 21:14:23 -0500 Subject: many simplifications --- vowpalwabbit/accumulate.cc | 23 +++++++------------- vowpalwabbit/active.cc | 2 -- vowpalwabbit/active_interactor.cc | 6 ------ vowpalwabbit/autolink.cc | 1 - vowpalwabbit/best_constant.cc | 2 -- vowpalwabbit/cache.cc | 6 +----- vowpalwabbit/io_buf.cc | 24 +++++++-------------- vowpalwabbit/learner.cc | 3 +-- vowpalwabbit/main.cc | 9 -------- vowpalwabbit/memory.cc | 3 +-- vowpalwabbit/multiclass.cc | 42 ++++++++++++++++++------------------- vowpalwabbit/network.cc | 1 - vowpalwabbit/noop.cc | 8 +++---- vowpalwabbit/parse_args.cc | 8 +++---- vowpalwabbit/parse_example.cc | 2 +- vowpalwabbit/parse_primitives.cc | 12 ++++------- vowpalwabbit/print.cc | 13 ++++-------- vowpalwabbit/rand48.cc | 13 ++---------- vowpalwabbit/search.cc | 4 ++-- vowpalwabbit/search_sequencetask.cc | 2 +- vowpalwabbit/sender.cc | 13 ++++-------- vowpalwabbit/topk.cc | 35 +++++++++++-------------------- vowpalwabbit/unique_sort.cc | 4 +--- 23 files changed, 75 insertions(+), 161 deletions(-) diff --git a/vowpalwabbit/accumulate.cc b/vowpalwabbit/accumulate.cc index d6c5e71f..49d2e969 100644 --- a/vowpalwabbit/accumulate.cc +++ b/vowpalwabbit/accumulate.cc @@ -17,9 +17,7 @@ Alekh Agarwal and John Langford, with help Olivier Chapelle. using namespace std; -void add_float(float& c1, const float& c2) { - c1 += c2; -} +void add_float(float& c1, const float& c2) { c1 += c2; } void accumulate(vw& all, string master_location, regressor& reg, size_t o) { uint32_t length = 1 << all.num_bits; //This is size of gradient @@ -27,15 +25,11 @@ void accumulate(vw& all, string master_location, regressor& reg, size_t o) { float* local_grad = new float[length]; weight* weights = reg.weight_vector; for(uint32_t i = 0;i < length;i++) - { - local_grad[i] = weights[stride*i+o]; - } + local_grad[i] = weights[stride*i+o]; all_reduce(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks); for(uint32_t i = 0;i < length;i++) - { - weights[stride*i+o] = local_grad[i]; - } + weights[stride*i+o] = local_grad[i]; delete[] local_grad; } @@ -53,11 +47,11 @@ void accumulate_avg(vw& all, string master_location, regressor& reg, size_t o) { float numnodes = (float)all.total; for(uint32_t i = 0;i < length;i++) - local_grad[i] = weights[stride*i+o]; + local_grad[i] = weights[stride*i+o]; all_reduce(local_grad, length, master_location, all.unique_id, all.total, all.node, all.socks); for(uint32_t i = 0;i < length;i++) - weights[stride*i+o] = local_grad[i]/numnodes; + weights[stride*i+o] = local_grad[i]/numnodes; delete[] local_grad; } @@ -83,17 +77,14 @@ void accumulate_weighted_avg(vw& all, string master_location, regressor& reg) { uint32_t length = 1 << all.num_bits; //This is the number of parameters size_t stride = 1 << all.reg.stride_shift; weight* weights = reg.weight_vector; - - float* local_weights = new float[length]; for(uint32_t i = 0;i < length;i++) local_weights[i] = weights[stride*i+1]; - //First compute weights for averaging all_reduce(local_weights, length, master_location, all.unique_id, all.total, all.node, all.socks); - + for(uint32_t i = 0;i < length;i++) //Compute weighted versions if(local_weights[i] > 0) { float ratio = weights[stride*i+1]/local_weights[i]; @@ -107,7 +98,7 @@ void accumulate_weighted_avg(vw& all, string master_location, regressor& reg) { local_weights[i] = 0; weights[stride*i] = 0; } - + all_reduce(weights, length*stride, master_location, all.unique_id, all.total, all.node, all.socks); delete[] local_weights; diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc index 627928d6..a1070be3 100644 --- a/vowpalwabbit/active.cc +++ b/vowpalwabbit/active.cc @@ -160,9 +160,7 @@ namespace ACTIVE { ("simulation", "active learning simulation mode") ("mellowness", po::value(&(data.active_c0)), "active learning mellowness parameter c_0. Default 8") ; - vm = add_options(all, active_opts); - data.all=&all; //Create new learner diff --git a/vowpalwabbit/active_interactor.cc b/vowpalwabbit/active_interactor.cc index 1f67fa69..ad4f5869 100644 --- a/vowpalwabbit/active_interactor.cc +++ b/vowpalwabbit/active_interactor.cc @@ -19,12 +19,6 @@ license as described in the file LICENSE. #include #endif -using std::cin; -using std::endl; -using std::cout; -using std::cerr; -using std::string; - using namespace std; int open_socket(const char* host, unsigned short port) diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index 7b3bebd8..7cdaecef 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -14,7 +14,6 @@ namespace ALINK { { base.predict(ec); float base_pred = ec.pred.scalar; - // add features of label ec.indices.push_back(autolink_namespace); float sum_sq = 0; diff --git a/vowpalwabbit/best_constant.cc b/vowpalwabbit/best_constant.cc index 3642a014..c56cb336 100644 --- a/vowpalwabbit/best_constant.cc +++ b/vowpalwabbit/best_constant.cc @@ -76,13 +76,11 @@ bool get_best_constant(vw& all, float& best_constant, float& best_constant_loss) } else return false; - if (!is_more_than_two_labels_observed) best_constant_loss = ( all.loss->getLoss(all.sd, best_constant, label1) * label1_cnt + all.loss->getLoss(all.sd, best_constant, label2) * label2_cnt ) / (label1_cnt + label2_cnt); else best_constant_loss = FLT_MIN; - return true; } diff --git a/vowpalwabbit/cache.cc b/vowpalwabbit/cache.cc index f91eba3c..9881e7f8 100644 --- a/vowpalwabbit/cache.cc +++ b/vowpalwabbit/cache.cc @@ -7,8 +7,6 @@ license as described in the file LICENSE. #include "unique_sort.h" #include "global_data.h" -using namespace std; - const size_t neg_1 = 1; const size_t general = 2; @@ -40,9 +38,7 @@ size_t read_cached_tag(io_buf& cache, example* ae) return tag_size+sizeof(tag_size); } -struct one_float { - float f; -} +struct one_float { float f; } #ifndef _WIN32 __attribute__((packed)) #endif diff --git a/vowpalwabbit/io_buf.cc b/vowpalwabbit/io_buf.cc index 4dbf3cf1..ba220762 100644 --- a/vowpalwabbit/io_buf.cc +++ b/vowpalwabbit/io_buf.cc @@ -3,10 +3,7 @@ Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ -#include - #include "io_buf.h" - #ifdef WIN32 #include #endif @@ -110,20 +107,17 @@ void buf_write(io_buf &o, char* &pointer, size_t n) } bool io_buf::is_socket(int f) -{ - // this appears to work in practice, but could probably be done in a cleaner fashion +{ // this appears to work in practice, but could probably be done in a cleaner fashion const int _nhandle = 32; return f >= _nhandle; } ssize_t io_buf::read_file_or_socket(int f, void* buf, size_t nbytes) { #ifdef _WIN32 - if (is_socket(f)) { + if (is_socket(f)) return recv(f, reinterpret_cast(buf), static_cast(nbytes), 0); - } - else { + else return _read(f, buf, (unsigned int)nbytes); - } #else return read(f, buf, (unsigned int)nbytes); #endif @@ -132,12 +126,10 @@ ssize_t io_buf::read_file_or_socket(int f, void* buf, size_t nbytes) { ssize_t io_buf::write_file_or_socket(int f, const void* buf, size_t nbytes) { #ifdef _WIN32 - if (is_socket(f)) { + if (is_socket(f)) return send(f, reinterpret_cast(buf), static_cast(nbytes), 0); - } - else { + else return _write(f, buf, (unsigned int)nbytes); - } #else return write(f, buf, (unsigned int)nbytes); #endif @@ -146,12 +138,10 @@ ssize_t io_buf::write_file_or_socket(int f, const void* buf, size_t nbytes) void io_buf::close_file_or_socket(int f) { #ifdef _WIN32 - if (io_buf::is_socket(f)) { + if (io_buf::is_socket(f)) closesocket(f); - } - else { + else _close(f); - } #else close(f); #endif diff --git a/vowpalwabbit/learner.cc b/vowpalwabbit/learner.cc index 40d37ba3..0376a00c 100644 --- a/vowpalwabbit/learner.cc +++ b/vowpalwabbit/learner.cc @@ -2,8 +2,7 @@ #include "parser.h" #include "learner.h" #include "vw.h" - -void save_predictor(vw& all, string reg_name, size_t current_pass); +#include "parse_regressor.h" void dispatch_example(vw& all, example& ec) { diff --git a/vowpalwabbit/main.cc b/vowpalwabbit/main.cc index 6773d3be..c7f40326 100644 --- a/vowpalwabbit/main.cc +++ b/vowpalwabbit/main.cc @@ -3,12 +3,6 @@ Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ - -#include -#include -#include -#include -#include #ifdef _WIN32 #include #else @@ -17,7 +11,6 @@ license as described in the file LICENSE. #endif #include #include "global_data.h" -#include "parse_example.h" #include "parse_args.h" #include "accumulate.h" #include "best_constant.h" @@ -44,9 +37,7 @@ int main(int argc, char *argv[]) } VW::start_parser(*all); - LEARNER::generic_driver(*all); - VW::end_parser(*all); ftime(&t_end); diff --git a/vowpalwabbit/memory.cc b/vowpalwabbit/memory.cc index ea23597c..7e40bf71 100644 --- a/vowpalwabbit/memory.cc +++ b/vowpalwabbit/memory.cc @@ -1,7 +1,6 @@ #include -#include -void free_it(void*ptr) +void free_it(void* ptr) { if (ptr != NULL) free(ptr); diff --git a/vowpalwabbit/multiclass.cc b/vowpalwabbit/multiclass.cc index 8736bc6e..7e1f7a8e 100644 --- a/vowpalwabbit/multiclass.cc +++ b/vowpalwabbit/multiclass.cc @@ -56,9 +56,7 @@ namespace MULTICLASS { ld->weight = 1.; } - void delete_label(void* v) - { - } + void delete_label(void* v) {} void parse_label(parser* p, shared_data*, void* v, v_array& words) { @@ -145,32 +143,32 @@ namespace MULTICLASS { void output_example(vw& all, example& ec) { multiclass ld = ec.l.multi; - + size_t loss = 1; if (ld.label == (uint32_t)ec.pred.multiclass) loss = 0; - + if(ec.test_only) - { - all.sd->weighted_holdout_examples += ld.weight;//test weight seen - all.sd->weighted_holdout_examples_since_last_dump += ld.weight; - all.sd->weighted_holdout_examples_since_last_pass += ld.weight; - all.sd->holdout_sum_loss += loss; - all.sd->holdout_sum_loss_since_last_dump += loss; - all.sd->holdout_sum_loss_since_last_pass += loss;//since last pass - } + { + all.sd->weighted_holdout_examples += ld.weight;//test weight seen + all.sd->weighted_holdout_examples_since_last_dump += ld.weight; + all.sd->weighted_holdout_examples_since_last_pass += ld.weight; + all.sd->holdout_sum_loss += loss; + all.sd->holdout_sum_loss_since_last_dump += loss; + all.sd->holdout_sum_loss_since_last_pass += loss;//since last pass + } else - { - all.sd->weighted_examples += ld.weight; - all.sd->total_features += ec.num_features; - all.sd->sum_loss += loss; - all.sd->sum_loss_since_last_dump += loss; - all.sd->example_number++; - } - + { + all.sd->weighted_examples += ld.weight; + all.sd->total_features += ec.num_features; + all.sd->sum_loss += loss; + all.sd->sum_loss_since_last_dump += loss; + all.sd->example_number++; + } + for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag); - + MULTICLASS::print_update(all, ec); } } diff --git a/vowpalwabbit/network.cc b/vowpalwabbit/network.cc index 7e39e879..b2922063 100644 --- a/vowpalwabbit/network.cc +++ b/vowpalwabbit/network.cc @@ -18,7 +18,6 @@ license as described in the file LICENSE. #include #include #endif -#include #include #include diff --git a/vowpalwabbit/noop.cc b/vowpalwabbit/noop.cc index 797bfc3e..0c883a8c 100644 --- a/vowpalwabbit/noop.cc +++ b/vowpalwabbit/noop.cc @@ -7,11 +7,9 @@ license as described in the file LICENSE. #include "reductions.h" -using namespace LEARNER; - namespace NOOP { - void learn(char&, base_learner&, example&) {} + void learn(char&, LEARNER::base_learner&, example&) {} - base_learner* setup(vw& all) - { return &init_learner(NULL, learn, 1); } + LEARNER::base_learner* setup(vw& all) + { return &LEARNER::init_learner(NULL, learn, 1); } } diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index ef65345c..a760261f 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -99,13 +99,13 @@ void parse_dictionary_argument(vw&all, string str) { ifstream infile(s); size_t def = (size_t)' '; for (string line; getline(infile, line);) { - char*c = (char*)line.c_str(); // we're throwing away const, which is dangerous... + char* c = (char*)line.c_str(); // we're throwing away const, which is dangerous... while (*c == ' ' || *c == '\t') ++c; // skip initial whitespace - char*d = c; + char* d = c; while (*d != ' ' && *d != '\t' && *d != '\n' && *d != '\0') ++d; // gobble up initial word if (d == c) continue; // no word if (*d != ' ' && *d != '\t') continue; // reached end of line - char*word = (char*)calloc(d-c, sizeof(char)); + char* word = calloc_or_die(d-c); memcpy(word, c, d-c); substring ss = { word, word + (d - c) }; uint32_t hash = uniform_hash( ss.begin, ss.end-ss.begin, quadratic_constant); @@ -132,7 +132,7 @@ void parse_dictionary_argument(vw&all, string str) { cerr << "dictionary " << s << " contains " << map->size() << " item" << (map->size() == 1 ? "\n" : "s\n"); all.namespace_dictionaries[(size_t)ns].push_back(map); - dictionary_info info = { (char*)calloc(strlen(s)+1, sizeof(char)), map }; + dictionary_info info = { calloc_or_die(strlen(s)+1), map }; strcpy(info.name, s); all.read_dictionaries.push_back(info); } diff --git a/vowpalwabbit/parse_example.cc b/vowpalwabbit/parse_example.cc index 941f0366..63bcb227 100644 --- a/vowpalwabbit/parse_example.cc +++ b/vowpalwabbit/parse_example.cc @@ -182,7 +182,7 @@ public: for (feature*f = feats->begin; f != feats->end; ++f) { uint32_t id = f->weight_index; size_t len = 2 + (feature_name.end-feature_name.begin) + 1 + (size_t)ceil(log10(id)) + 1; - char* str = (char*)calloc(len, sizeof(char)); + char* str = calloc_or_die(len); str[0] = index; str[1] = '_'; char *c = str+2; diff --git a/vowpalwabbit/parse_primitives.cc b/vowpalwabbit/parse_primitives.cc index b08f05fb..4ed67313 100644 --- a/vowpalwabbit/parse_primitives.cc +++ b/vowpalwabbit/parse_primitives.cc @@ -11,8 +11,6 @@ license as described in the file LICENSE. #include "parse_primitives.h" #include "hash.h" -using namespace std; - void tokenize(char delim, substring s, v_array& ret, bool allow_empty) { ret.erase(); @@ -53,17 +51,15 @@ size_t hashstring (substring s, uint32_t h) } size_t hashall (substring s, uint32_t h) -{ - return uniform_hash((unsigned char *)s.begin, s.end - s.begin, h); -} +{ return uniform_hash((unsigned char *)s.begin, s.end - s.begin, h); } -hash_func_t getHasher(const string& s){ +hash_func_t getHasher(const std::string& s){ if (s=="strings") return hashstring; else if(s=="all") return hashall; else{ - cerr << "Unknown hash function: " << s.c_str() << ". Exiting " << endl; - throw exception(); + std::cerr << "Unknown hash function: " << s.c_str() << ". Exiting " << std::endl; + throw std::exception(); } } diff --git a/vowpalwabbit/print.cc b/vowpalwabbit/print.cc index 8cfe5dd7..d0dc2765 100644 --- a/vowpalwabbit/print.cc +++ b/vowpalwabbit/print.cc @@ -4,14 +4,9 @@ #include "float.h" #include "reductions.h" -using namespace LEARNER; - namespace PRINT { - struct print{ - vw* all; - - }; + struct print{ vw* all; }; void print_feature(vw& all, float value, float& weight) { @@ -23,7 +18,7 @@ namespace PRINT cout << " "; } - void learn(print& p, base_learner& base, example& ec) + void learn(print& p, LEARNER::base_learner& base, example& ec) { label_data& ld = ec.l.simple; if (ld.label != FLT_MAX) @@ -46,7 +41,7 @@ namespace PRINT cout << endl; } - base_learner* setup(vw& all) + LEARNER::base_learner* setup(vw& all) { print& p = calloc_or_die(); p.all = &all; @@ -54,7 +49,7 @@ namespace PRINT all.reg.weight_mask = (length << all.reg.stride_shift) - 1; all.reg.stride_shift = 0; - learner& ret = init_learner(&p, learn, 1); + LEARNER::learner& ret = init_learner(&p, learn, 1); return make_base(ret); } } diff --git a/vowpalwabbit/rand48.cc b/vowpalwabbit/rand48.cc index 4ea4e75e..4288e64d 100644 --- a/vowpalwabbit/rand48.cc +++ b/vowpalwabbit/rand48.cc @@ -1,8 +1,5 @@ //A quick implementation similar to drand48 for cross-platform compatibility #include -#include -using namespace std; - // // NB: the 'ULL' suffix is not part of the constant it is there to // prevent truncation of constant to (32-bit long) when compiling @@ -25,15 +22,9 @@ float merand48(uint64_t& initial) uint64_t v = c; -void msrand48(uint64_t initial) -{ - v = initial; -} +void msrand48(uint64_t initial) { v = initial; } -float frand48() -{ - return merand48(v); -} +float frand48() { return merand48(v); } float frand48_noadvance() { diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index 63ba16d8..762ef1e9 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -855,7 +855,7 @@ namespace Search { size_t sz = sizeof(size_t) + sizeof(ptag) + sizeof(int) + sizeof(size_t) + sizeof(size_t) + condition_on_cnt * (sizeof(ptag) + sizeof(action) + sizeof(char)); if (sz % 4 != 0) sz = 4 * (sz / 4 + 1); // make sure sz aligns to 4 so that uniform_hash does the right thing - unsigned char* item = (unsigned char*)calloc(sz, 1); + unsigned char* item = &calloc_or_die(); unsigned char* here = item; *here = (unsigned char)sz; here += sizeof(size_t); *here = mytag; here += sizeof(ptag); @@ -2117,7 +2117,7 @@ namespace Search { void predictor::set_input_length(size_t input_length) { is_ldf = true; if (ec_alloced) ec = (example*)realloc(ec, input_length * sizeof(example)); - else ec = (example*)calloc(input_length, sizeof(example)); + else ec = calloc_or_die(input_length); ec_cnt = input_length; ec_alloced = true; } diff --git a/vowpalwabbit/search_sequencetask.cc b/vowpalwabbit/search_sequencetask.cc index 24f97ad5..d92013b4 100644 --- a/vowpalwabbit/search_sequencetask.cc +++ b/vowpalwabbit/search_sequencetask.cc @@ -264,7 +264,7 @@ namespace SequenceTask_DemoLDF { // this is just to debug/show off how to do LD lab.costs.push_back(default_wclass); } - task_data* data = (task_data*)calloc(1, sizeof(task_data)); + task_data* data = &calloc_or_die(); data->ldf_examples = ldf_examples; data->num_actions = num_actions; diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index 3b14131f..a9ded7e4 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -21,9 +21,6 @@ #include "network.h" #include "reductions.h" -using namespace std; -using namespace LEARNER; - namespace SENDER { struct sender { io_buf* buf; @@ -69,7 +66,7 @@ void receive_result(sender& s) return_simple_example(*(s.all), NULL, *ec); } - void learn(sender& s, base_learner& base, example& ec) + void learn(sender& s, LEARNER::base_learner& base, example& ec) { if (s.received_index + s.all->p->ring_size / 2 - 1 == s.sent_index) receive_result(s); @@ -81,8 +78,7 @@ void receive_result(sender& s) s.delay_ring[s.sent_index++ % s.all->p->ring_size] = &ec; } - void finish_example(vw& all, sender&, example& ec) -{} + void finish_example(vw& all, sender&, example& ec){} void end_examples(sender& s) { @@ -100,7 +96,7 @@ void end_examples(sender& s) delete s.buf; } - base_learner* setup(vw& all, po::variables_map& vm, vector pairs) + LEARNER::base_learner* setup(vw& all, po::variables_map& vm, vector pairs) { sender& s = calloc_or_die(); s.sd = -1; @@ -113,11 +109,10 @@ void end_examples(sender& s) s.all = &all; s.delay_ring = calloc_or_die(all.p->ring_size); - learner& l = init_learner(&s, learn, 1); + LEARNER::learner& l = init_learner(&s, learn, 1); l.set_finish(finish); l.set_finish_example(finish_example); l.set_end_examples(end_examples); return make_base(l); } - } diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index 8dd245f7..445bdb23 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -4,31 +4,23 @@ individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ #include -#include -#include #include -#include -#include #include #include "reductions.h" #include "vw.h" -using namespace std; -using namespace LEARNER; - -typedef pair > scored_example; - -struct compare_scored_examples -{ +namespace TOPK { + typedef pair > scored_example; + + struct compare_scored_examples + { bool operator()(scored_example const& a, scored_example const& b) const { - return a.first > b.first; + return a.first > b.first; } -}; - -namespace TOPK { - + }; + struct topk{ uint32_t B; //rec number priority_queue, compare_scored_examples > pr_queue; @@ -85,7 +77,7 @@ namespace TOPK { } template - void predict_or_learn(topk& d, base_learner& base, example& ec) + void predict_or_learn(topk& d, LEARNER::base_learner& base, example& ec) { if (example_is_newline(ec)) return;//do not predict newline @@ -102,7 +94,6 @@ namespace TOPK { d.pr_queue.pop(); d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); } - } void finish_example(vw& all, topk& d, example& ec) @@ -111,16 +102,14 @@ namespace TOPK { VW::finish_example(all, &ec); } - base_learner* setup(vw& all, po::variables_map& vm) + LEARNER::base_learner* setup(vw& all, po::variables_map& vm) { topk& data = calloc_or_die(); - data.B = (uint32_t)vm["top"].as(); - data.all = &all; - learner& l = init_learner(&data, all.l, predict_or_learn, - predict_or_learn); + LEARNER::learner& l = init_learner(&data, all.l, predict_or_learn, + predict_or_learn); l.set_finish_example(finish_example); return make_base(l); diff --git a/vowpalwabbit/unique_sort.cc b/vowpalwabbit/unique_sort.cc index c682cf63..1a323d2d 100644 --- a/vowpalwabbit/unique_sort.cc +++ b/vowpalwabbit/unique_sort.cc @@ -6,9 +6,7 @@ license as described in the file LICENSE. #include "global_data.h" int order_features(const void* first, const void* second) -{ - return ((feature*)first)->weight_index - ((feature*)second)->weight_index; -} +{ return ((feature*)first)->weight_index - ((feature*)second)->weight_index;} int order_audit_features(const void* first, const void* second) { -- cgit v1.2.3 From 33b4571cff31610dd4220d4a0b09c44a207e8357 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 28 Dec 2014 21:36:03 -0500 Subject: simpler multiclass finish --- vowpalwabbit/cbify.cc | 9 ++------- vowpalwabbit/ect.cc | 5 +---- vowpalwabbit/log_multi.cc | 5 +---- vowpalwabbit/multiclass.cc | 3 ++- vowpalwabbit/multiclass.h | 2 +- vowpalwabbit/oaa.cc | 5 +---- 6 files changed, 8 insertions(+), 21 deletions(-) diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index 9d8d830f..f66f1856 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -367,15 +367,10 @@ namespace CBIFY { void init_driver(cbify&) {} void finish_example(vw& all, cbify&, example& ec) - { - MULTICLASS::output_example(all, ec); - VW::finish_example(all, &ec); - } + { MULTICLASS::finish_multiclass_example(all, ec); } void finish(cbify& data) - { - CB::cb_label.delete_label(&data.cb_label); - } + { CB::cb_label.delete_label(&data.cb_label); } base_learner* setup(vw& all, po::variables_map& vm) {//parse and set arguments diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index bcf15c02..59c67dac 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -361,10 +361,7 @@ namespace ECT } void finish_example(vw& all, ect&, example& ec) - { - MULTICLASS::output_example(all, ec); - VW::finish_example(all, &ec); - } + { MULTICLASS::finish_multiclass_example(all, ec); } base_learner* setup(vw& all, po::variables_map& vm) { diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index 26f6b454..5595da92 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -497,10 +497,7 @@ namespace LOG_MULTI } void finish_example(vw& all, log_multi&, example& ec) - { - MULTICLASS::output_example(all, ec); - VW::finish_example(all, &ec); - } + { MULTICLASS::finish_multiclass_example(all, ec); } base_learner* setup(vw& all, po::variables_map& vm) //learner setup { diff --git a/vowpalwabbit/multiclass.cc b/vowpalwabbit/multiclass.cc index 7e1f7a8e..1e90aca6 100644 --- a/vowpalwabbit/multiclass.cc +++ b/vowpalwabbit/multiclass.cc @@ -140,7 +140,7 @@ namespace MULTICLASS { } } - void output_example(vw& all, example& ec) + void finish_multiclass_example(vw& all, example& ec) { multiclass ld = ec.l.multi; @@ -170,5 +170,6 @@ namespace MULTICLASS { all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag); MULTICLASS::print_update(all, ec); + VW::finish_example(all, &ec); } } diff --git a/vowpalwabbit/multiclass.h b/vowpalwabbit/multiclass.h index aca34075..738de421 100644 --- a/vowpalwabbit/multiclass.h +++ b/vowpalwabbit/multiclass.h @@ -18,7 +18,7 @@ namespace MULTICLASS extern label_parser mc_label; - void output_example(vw& all, example& ec); + void finish_multiclass_example(vw& all, example& ec); inline int label_is_test(multiclass* ld) { return ld->label == (uint32_t)-1; } diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 02a9eaf2..634ef365 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -61,10 +61,7 @@ namespace OAA { } void finish_example(vw& all, oaa&, example& ec) - { - MULTICLASS::output_example(all, ec); - VW::finish_example(all, &ec); - } + { MULTICLASS::finish_multiclass_example(all, ec); } LEARNER::base_learner* setup(vw& all, po::variables_map& vm) { -- cgit v1.2.3 From 5669a17da5d03265ef700b495eec6477e016340d Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 28 Dec 2014 21:40:50 -0500 Subject: simplifying multiclass finish --- vowpalwabbit/cbify.cc | 3 +-- vowpalwabbit/ect.cc | 3 +-- vowpalwabbit/log_multi.cc | 3 +-- vowpalwabbit/multiclass.cc | 2 +- vowpalwabbit/multiclass.h | 2 +- vowpalwabbit/oaa.cc | 3 +-- 6 files changed, 6 insertions(+), 10 deletions(-) diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index f66f1856..d8176228 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -366,8 +366,7 @@ namespace CBIFY { void init_driver(cbify&) {} - void finish_example(vw& all, cbify&, example& ec) - { MULTICLASS::finish_multiclass_example(all, ec); } + void finish_example(vw& all, cbify&, example& ec) { MULTICLASS::finish_example(all, ec); } void finish(cbify& data) { CB::cb_label.delete_label(&data.cb_label); } diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index 59c67dac..dea87040 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -360,8 +360,7 @@ namespace ECT e.tournaments_won.delete_v(); } - void finish_example(vw& all, ect&, example& ec) - { MULTICLASS::finish_multiclass_example(all, ec); } + void finish_example(vw& all, ect&, example& ec) { MULTICLASS::finish_example(all, ec); } base_learner* setup(vw& all, po::variables_map& vm) { diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index 5595da92..226376bd 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -496,8 +496,7 @@ namespace LOG_MULTI } } - void finish_example(vw& all, log_multi&, example& ec) - { MULTICLASS::finish_multiclass_example(all, ec); } + void finish_example(vw& all, log_multi&, example& ec) { MULTICLASS::finish_example(all, ec); } base_learner* setup(vw& all, po::variables_map& vm) //learner setup { diff --git a/vowpalwabbit/multiclass.cc b/vowpalwabbit/multiclass.cc index 1e90aca6..4aad3a73 100644 --- a/vowpalwabbit/multiclass.cc +++ b/vowpalwabbit/multiclass.cc @@ -140,7 +140,7 @@ namespace MULTICLASS { } } - void finish_multiclass_example(vw& all, example& ec) + void finish_example(vw& all, example& ec) { multiclass ld = ec.l.multi; diff --git a/vowpalwabbit/multiclass.h b/vowpalwabbit/multiclass.h index 738de421..f48efbfc 100644 --- a/vowpalwabbit/multiclass.h +++ b/vowpalwabbit/multiclass.h @@ -18,7 +18,7 @@ namespace MULTICLASS extern label_parser mc_label; - void finish_multiclass_example(vw& all, example& ec); + void finish_example(vw& all, example& ec); inline int label_is_test(multiclass* ld) { return ld->label == (uint32_t)-1; } diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 634ef365..2328b00d 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -60,8 +60,7 @@ namespace OAA { o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag); } - void finish_example(vw& all, oaa&, example& ec) - { MULTICLASS::finish_multiclass_example(all, ec); } + void finish_example(vw& all, oaa&, example& ec) { MULTICLASS::finish_example(all, ec); } LEARNER::base_learner* setup(vw& all, po::variables_map& vm) { -- cgit v1.2.3