Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2015-01-03 03:41:00 +0300
committerJohn Langford <jl@hunch.net>2015-01-03 03:41:00 +0300
commit6d51f4a163d2da2cc1dca9b0717c41b0bfc5237f (patch)
treee95ce1de96c5f3ec14d662d84a6ecff4847a3abf
parent00ce5188f90cc426493872e8d630b6c61f99df79 (diff)
shorter multiclass setup
-rw-r--r--vowpalwabbit/ect.cc24
-rw-r--r--vowpalwabbit/log_multi.cc18
-rw-r--r--vowpalwabbit/multiclass.h2
-rw-r--r--vowpalwabbit/oaa.cc9
4 files changed, 18 insertions, 35 deletions
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc
index 25b34ee3..7a68fa13 100644
--- a/vowpalwabbit/ect.cc
+++ b/vowpalwabbit/ect.cc
@@ -51,8 +51,6 @@ namespace ECT
uint32_t last_pair;
v_array<bool> tournaments_won;
-
- vw* all;
};
bool exists(v_array<size_t> db)
@@ -184,7 +182,7 @@ namespace ECT
return e.last_pair + (eliminations-1);
}
- uint32_t ect_predict(vw& all, ect& e, base_learner& base, example& ec)
+ uint32_t ect_predict(ect& e, base_learner& base, example& ec)
{
if (e.k == (size_t)1)
return 1;
@@ -228,7 +226,7 @@ namespace ECT
return false;
}
- void ect_train(vw& all, ect& e, base_learner& base, example& ec)
+ void ect_train(ect& e, base_learner& base, example& ec)
{
if (e.k == 1)//nothing to do
return;
@@ -318,25 +316,21 @@ namespace ECT
}
void predict(ect& e, base_learner& base, example& ec) {
- vw* all = e.all;
-
MULTICLASS::multiclass mc = ec.l.multi;
if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1))
cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl;
- ec.pred.multiclass = ect_predict(*all, e, base, ec);
+ ec.pred.multiclass = ect_predict(e, base, ec);
ec.l.multi = mc;
}
void learn(ect& e, base_learner& base, example& ec)
{
- vw* all = e.all;
-
MULTICLASS::multiclass mc = ec.l.multi;
predict(e, base, ec);
uint32_t pred = ec.pred.multiclass;
- if (mc.label != (uint32_t)-1 && all->training)
- ect_train(*all, e, base, ec);
+ if (mc.label != (uint32_t)-1)
+ ect_train(e, base, ec);
ec.l.multi = mc;
ec.pred.multiclass = pred;
}
@@ -360,8 +354,6 @@ namespace ECT
e.tournaments_won.delete_v();
}
- void finish_example(vw& all, ect&, example& ec) { MULTICLASS::finish_example(all, ec); }
-
base_learner* setup(vw& all)
{
new_options(all, "Error Correcting Tournament options")
@@ -376,14 +368,12 @@ namespace ECT
//append error flag to options_from_file so it is saved in regressor file later
*all.file_options << " --ect " << data.k << " --error " << data.errors;
- all.p->lp = MULTICLASS::mc_label;
size_t wpp = create_circuit(all, data, data.k, data.errors+1);
- data.all = &all;
learner<ect>& l = init_learner(&data, setup_base(all), learn, predict, wpp);
- l.set_finish_example(finish_example);
+ l.set_finish_example(MULTICLASS::finish_example<ect>);
+ all.p->lp = MULTICLASS::mc_label;
l.set_finish(finish);
-
return make_base(l);
}
}
diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc
index 38ba19be..20757c70 100644
--- a/vowpalwabbit/log_multi.cc
+++ b/vowpalwabbit/log_multi.cc
@@ -77,7 +77,6 @@ namespace LOG_MULTI
struct log_multi
{
uint32_t k;
- vw* all;
v_array<node> nodes;
@@ -319,11 +318,10 @@ namespace LOG_MULTI
void learn(log_multi& b, base_learner& base, example& ec)
{
// verify_min_dfs(b, b.nodes[0]);
-
- if (ec.l.multi.label == (uint32_t)-1 || !b.all->training || b.progress)
+ if (ec.l.multi.label == (uint32_t)-1 || b.progress)
predict(b,base,ec);
- if(b.all->training && (ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
+ if((ec.l.multi.label != (uint32_t)-1) && !ec.test_only) //if training the tree
{
MULTICLASS::multiclass mc = ec.l.multi;
@@ -496,8 +494,6 @@ namespace LOG_MULTI
}
}
- void finish_example(vw& all, log_multi&, example& ec) { MULTICLASS::finish_example(all, ec); }
-
base_learner* setup(vw& all) //learner setup
{
new_options(all, "Logarithmic Time Multiclass options")
@@ -508,7 +504,7 @@ namespace LOG_MULTI
("swap_resistance", po::value<uint32_t>(), "higher = more resistance to swap, default=4");
add_options(all);
- po::variables_map& vm = all.vm;
+ po::variables_map& vm = all.vm;
log_multi& data = calloc_or_die<log_multi>();
data.k = (uint32_t)vm["log_multi"].as<size_t>();
@@ -524,9 +520,6 @@ namespace LOG_MULTI
else
data.progress = true;
- data.all = &all;
- (all.p->lp) = MULTICLASS::mc_label;
-
string loss_function = "quantile";
float loss_parameter = 0.5;
delete(all.loss);
@@ -534,10 +527,11 @@ namespace LOG_MULTI
data.max_predictors = data.k - 1;
- learner<log_multi>& l = init_learner(&data, setup_base(all), learn, predict, data.max_predictors);
+ learner<log_multi>& l = init_learner(&data, setup_base(all), learn, predict, data.max_predictors);
l.set_save_load(save_load_tree);
- l.set_finish_example(finish_example);
l.set_finish(finish);
+ l.set_finish_example(MULTICLASS::finish_example<log_multi>);
+ all.p->lp = MULTICLASS::mc_label;
init_tree(data);
diff --git a/vowpalwabbit/multiclass.h b/vowpalwabbit/multiclass.h
index f48efbfc..14f5d443 100644
--- a/vowpalwabbit/multiclass.h
+++ b/vowpalwabbit/multiclass.h
@@ -20,6 +20,8 @@ namespace MULTICLASS
void finish_example(vw& all, example& ec);
+ template <class T> void finish_example(vw& all, T&, example& ec) { finish_example(all, ec); }
+
inline int label_is_test(multiclass* ld)
{ return ld->label == (uint32_t)-1; }
}
diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc
index 0554dcae..185264f5 100644
--- a/vowpalwabbit/oaa.cc
+++ b/vowpalwabbit/oaa.cc
@@ -7,7 +7,6 @@ license as described in the file LICENSE.
#include "multiclass.h"
#include "simple_label.h"
#include "reductions.h"
-#include "vw.h"
namespace OAA {
struct oaa{
@@ -60,8 +59,6 @@ namespace OAA {
o.all->print_text(o.all->raw_prediction, outputStringStream.str(), ec.tag);
}
- void finish_example(vw& all, oaa&, example& ec) { MULTICLASS::finish_example(all, ec); }
-
LEARNER::base_learner* setup(vw& all)
{
new_options(all, "One-against-all options")
@@ -74,11 +71,11 @@ namespace OAA {
data.all = &all;
*all.file_options << " --oaa " << data.k;
- all.p->lp = MULTICLASS::mc_label;
LEARNER::learner<oaa>& l = init_learner(&data, setup_base(all), predict_or_learn<true>,
- predict_or_learn<false>, data.k);
- l.set_finish_example(finish_example);
+ predict_or_learn<false>, data.k);
+ l.set_finish_example(MULTICLASS::finish_example<oaa>);
+ all.p->lp = MULTICLASS::mc_label;
return make_base(l);
}
}