diff options
author | John Langford <jl@hunch.net> | 2015-01-03 03:41:00 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2015-01-03 03:41:00 +0300 |
commit | 6d51f4a163d2da2cc1dca9b0717c41b0bfc5237f (patch) | |
tree | e95ce1de96c5f3ec14d662d84a6ecff4847a3abf | |
parent | 00ce5188f90cc426493872e8d630b6c61f99df79 (diff) |
shorter multiclass setup
-rw-r--r-- | vowpalwabbit/ect.cc | 24 | ||||
-rw-r--r-- | vowpalwabbit/log_multi.cc | 18 | ||||
-rw-r--r-- | vowpalwabbit/multiclass.h | 2 | ||||
-rw-r--r-- | vowpalwabbit/oaa.cc | 9 |
4 files changed, 18 insertions, 35 deletions
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index 25b34ee3..7a68fa13 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -51,8 +51,6 @@ namespace ECT uint32_t last_pair; v_array<bool> tournaments_won; - - vw* all; }; bool exists(v_array<size_t> db) @@ -184,7 +182,7 @@ namespace ECT return e.last_pair + (eliminations-1); } - uint32_t ect_predict(vw& all, ect& e, base_learner& base, example& ec) + uint32_t ect_predict(ect& e, base_learner& base, example& ec) { if (e.k == (size_t)1) return 1; @@ -228,7 +226,7 @@ namespace ECT return false; } - void ect_train(vw& all, ect& e, base_learner& base, example& ec) + void ect_train(ect& e, base_learner& base, example& ec) { if (e.k == 1)//nothing to do return; @@ -318,25 +316,21 @@ namespace ECT } void predict(ect& e, base_learner& base, example& ec) { - vw* all = e.all; - MULTICLASS::multiclass mc = ec.l.multi; if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1)) cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl; - ec.pred.multiclass = ect_predict(*all, e, base, ec); + ec.pred.multiclass = ect_predict(e, base, ec); ec.l.multi = mc; } void learn(ect& e, base_learner& base, example& ec) { - vw* all = e.all; - MULTICLASS::multiclass mc = ec.l.multi; predict(e, base, ec); uint32_t pred = ec.pred.multiclass; - if (mc.label != (uint32_t)-1 && all->training) - ect_train(*all, e, base, ec); + if (mc.label != (uint32_t)-1) + ect_train(e, base, ec); ec.l.multi = mc; ec.pred.multiclass = pred; } @@ -360,8 +354,6 @@ namespace ECT e.tournaments_won.delete_v(); } - void finish_example(vw& all, ect&, example& ec) { MULTICLASS::finish_example(all, ec); } - base_learner* setup(vw& all) { new_options(all, "Error Correcting Tournament options") @@ -376,14 +368,12 @@ namespace ECT //append error flag to options_from_file so it is saved in regressor file later *all.file_options << " --ect " << data.k << " --error " << data.errors; - all.p->lp = MULTICLASS::mc_label; size_t wpp = create_circuit(all, data, data.k, data.errors+1); - data.all = &all; learner<ect>& l = init_learner(&data, setup_base(all), learn, predict, wpp); - l.set_finish_example(finish_example); + l.set_finish_example(MULTICLASS::finish_example<ect>); + all.p->lp = MULTICLASS::mc_label; l.set_finish(finish); - return make_base(l); } } diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index 38ba19be..20757c70 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -77,7 +77,6 @@ namespace LOG_MULTI struct log_multi
{
uint32_t k;
- vw* all;
v_array<node> nodes;
@@ -319,11 +318,10 @@ namespace LOG_MULTI void learn(log_multi& b, base_learner& base, example& ec)
{
// verify_min_dfs(b, b.nodes[0]);
-
- if (ec.l.multi.label == (uint32_t)-1 || !b.all->training || b.progress)
+ if (ec.l.multi.label == (uint32_t)-1 || b.progress)
predict(b,base,ec);
- if(b.all->training && (ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
+ if((ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
{
MULTICLASS::multiclass mc = ec.l.multi;
@@ -496,8 +494,6 @@ namespace LOG_MULTI }
}
- void finish_example(vw& all, log_multi&, example& ec) { MULTICLASS::finish_example(all, ec); }
-
base_learner* setup(vw& all) //learner setup
{
new_options(all, "Logarithmic Time Multiclass options") @@ -508,7 +504,7 @@ namespace LOG_MULTI ("swap_resistance", po::value<uint32_t>(), "higher = more resistance to swap, default=4"); add_options(all);
- po::variables_map& vm = all.vm; + po::variables_map& vm = all.vm; log_multi& data = calloc_or_die<log_multi>(); data.k = (uint32_t)vm["log_multi"].as<size_t>();
@@ -524,9 +520,6 @@ namespace LOG_MULTI else
data.progress = true;
- data.all = &all;
- (all.p->lp) = MULTICLASS::mc_label;
-
string loss_function = "quantile";
float loss_parameter = 0.5;
delete(all.loss);
@@ -534,10 +527,11 @@ namespace LOG_MULTI data.max_predictors = data.k - 1;
- learner<log_multi>& l = init_learner(&data, setup_base(all), learn, predict, data.max_predictors);
+ learner<log_multi>& l = init_learner(&data, setup_base(all), learn, predict, data.max_predictors);
l.set_save_load(save_load_tree);
- l.set_finish_example(finish_example);
l.set_finish(finish);
+ l.set_finish_example(MULTICLASS::finish_example<log_multi>); + all.p->lp = MULTICLASS::mc_label; init_tree(data);
diff --git a/vowpalwabbit/multiclass.h b/vowpalwabbit/multiclass.h index f48efbfc..14f5d443 100644 --- a/vowpalwabbit/multiclass.h +++ b/vowpalwabbit/multiclass.h @@ -20,6 +20,8 @@ namespace MULTICLASS void finish_example(vw& all, example& ec); + template <class T> void finish_example(vw& all, T&, example& ec) { finish_example(all, ec); } + inline int label_is_test(multiclass* ld) { return ld->label == (uint32_t)-1; } } diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 0554dcae..185264f5 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -7,7 +7,6 @@ license as described in the file LICENSE. #include "multiclass.h" #include "simple_label.h" #include "reductions.h" -#include "vw.h" namespace OAA { struct oaa{ @@ -60,8 +59,6 @@ namespace OAA { o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag); } - void finish_example(vw& all, oaa&, example& ec) { MULTICLASS::finish_example(all, ec); } - LEARNER::base_learner* setup(vw& all) { new_options(all, "One-against-all options") @@ -74,11 +71,11 @@ namespace OAA { data.all = &all; *all.file_options << " --oaa " << data.k; - all.p->lp = MULTICLASS::mc_label; LEARNER::learner<oaa>& l = init_learner(&data, setup_base(all), predict_or_learn<true>, - predict_or_learn<false>, data.k); - l.set_finish_example(finish_example); + predict_or_learn<false>, data.k); + l.set_finish_example(MULTICLASS::finish_example<oaa>); + all.p->lp = MULTICLASS::mc_label; return make_base(l); } } |