Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/ect.cc')
-rw-r--r--vowpalwabbit/ect.cc24
1 files changed, 7 insertions, 17 deletions
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc
index 25b34ee3..7a68fa13 100644
--- a/vowpalwabbit/ect.cc
+++ b/vowpalwabbit/ect.cc
@@ -51,8 +51,6 @@ namespace ECT
uint32_t last_pair;
v_array<bool> tournaments_won;
-
- vw* all;
};
bool exists(v_array<size_t> db)
@@ -184,7 +182,7 @@ namespace ECT
return e.last_pair + (eliminations-1);
}
- uint32_t ect_predict(vw& all, ect& e, base_learner& base, example& ec)
+ uint32_t ect_predict(ect& e, base_learner& base, example& ec)
{
if (e.k == (size_t)1)
return 1;
@@ -228,7 +226,7 @@ namespace ECT
return false;
}
- void ect_train(vw& all, ect& e, base_learner& base, example& ec)
+ void ect_train(ect& e, base_learner& base, example& ec)
{
if (e.k == 1)//nothing to do
return;
@@ -318,25 +316,21 @@ namespace ECT
}
void predict(ect& e, base_learner& base, example& ec) {
- vw* all = e.all;
-
MULTICLASS::multiclass mc = ec.l.multi;
if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1))
cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl;
- ec.pred.multiclass = ect_predict(*all, e, base, ec);
+ ec.pred.multiclass = ect_predict(e, base, ec);
ec.l.multi = mc;
}
void learn(ect& e, base_learner& base, example& ec)
{
- vw* all = e.all;
-
MULTICLASS::multiclass mc = ec.l.multi;
predict(e, base, ec);
uint32_t pred = ec.pred.multiclass;
- if (mc.label != (uint32_t)-1 && all->training)
- ect_train(*all, e, base, ec);
+ if (mc.label != (uint32_t)-1)
+ ect_train(e, base, ec);
ec.l.multi = mc;
ec.pred.multiclass = pred;
}
@@ -360,8 +354,6 @@ namespace ECT
e.tournaments_won.delete_v();
}
- void finish_example(vw& all, ect&, example& ec) { MULTICLASS::finish_example(all, ec); }
-
base_learner* setup(vw& all)
{
new_options(all, "Error Correcting Tournament options")
@@ -376,14 +368,12 @@ namespace ECT
//append error flag to options_from_file so it is saved in regressor file later
*all.file_options << " --ect " << data.k << " --error " << data.errors;
- all.p->lp = MULTICLASS::mc_label;
size_t wpp = create_circuit(all, data, data.k, data.errors+1);
- data.all = &all;
learner<ect>& l = init_learner(&data, setup_base(all), learn, predict, wpp);
- l.set_finish_example(finish_example);
+ l.set_finish_example(MULTICLASS::finish_example<ect>);
+ all.p->lp = MULTICLASS::mc_label;
l.set_finish(finish);
-
return make_base(l);
}
}