Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/ect.cc')
-rw-r--r--vowpalwabbit/ect.cc51
1 files changed, 17 insertions, 34 deletions
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc
index dea87040..1a071eb2 100644
--- a/vowpalwabbit/ect.cc
+++ b/vowpalwabbit/ect.cc
@@ -17,7 +17,6 @@ license as described in the file LICENSE.
#include "reductions.h"
#include "multiclass.h"
#include "simple_label.h"
-#include "vw.h"
using namespace std;
using namespace LEARNER;
@@ -51,8 +50,6 @@ namespace ECT
uint32_t last_pair;
v_array<bool> tournaments_won;
-
- vw* all;
};
bool exists(v_array<size_t> db)
@@ -184,7 +181,7 @@ namespace ECT
return e.last_pair + (eliminations-1);
}
- uint32_t ect_predict(vw& all, ect& e, base_learner& base, example& ec)
+ uint32_t ect_predict(ect& e, base_learner& base, example& ec)
{
if (e.k == (size_t)1)
return 1;
@@ -228,7 +225,7 @@ namespace ECT
return false;
}
- void ect_train(vw& all, ect& e, base_learner& base, example& ec)
+ void ect_train(ect& e, base_learner& base, example& ec)
{
if (e.k == 1)//nothing to do
return;
@@ -318,25 +315,21 @@ namespace ECT
}
void predict(ect& e, base_learner& base, example& ec) {
- vw* all = e.all;
-
MULTICLASS::multiclass mc = ec.l.multi;
if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1))
cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl;
- ec.pred.multiclass = ect_predict(*all, e, base, ec);
+ ec.pred.multiclass = ect_predict(e, base, ec);
ec.l.multi = mc;
}
void learn(ect& e, base_learner& base, example& ec)
{
- vw* all = e.all;
-
MULTICLASS::multiclass mc = ec.l.multi;
predict(e, base, ec);
uint32_t pred = ec.pred.multiclass;
- if (mc.label != (uint32_t)-1 && all->training)
- ect_train(*all, e, base, ec);
+ if (mc.label != (uint32_t)-1)
+ ect_train(e, base, ec);
ec.l.multi = mc;
ec.pred.multiclass = pred;
}
@@ -360,36 +353,26 @@ namespace ECT
e.tournaments_won.delete_v();
}
- void finish_example(vw& all, ect&, example& ec) { MULTICLASS::finish_example(all, ec); }
-
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{
- ect& data = calloc_or_die<ect>();
- po::options_description ect_opts("ECT options");
- ect_opts.add_options()
- ("error", po::value<size_t>(), "error in ECT");
+ new_options(all, "Error Correcting Tournament options")
+ ("ect", po::value<size_t>(), "Use error correcting tournament with <k> labels");
+ if (missing_required(all)) return NULL;
+ new_options(all)("error", po::value<size_t>()->default_value(0), "error in ECT");
+ add_options(all);
- vm = add_options(all, ect_opts);
-
- //first parse for number of actions
- data.k = (int)vm["ect"].as<size_t>();
-
- //append ect with nb_actions to options_from_file so it is saved to regressor later
- if (vm.count("error")) {
- data.errors = (uint32_t)vm["error"].as<size_t>();
- } else
- data.errors = 0;
+ ect& data = calloc_or_die<ect>();
+ data.k = (int)all.vm["ect"].as<size_t>();
+ data.errors = (uint32_t)all.vm["error"].as<size_t>();
//append error flag to options_from_file so it is saved in regressor file later
*all.file_options << " --ect " << data.k << " --error " << data.errors;
- all.p->lp = MULTICLASS::mc_label;
size_t wpp = create_circuit(all, data, data.k, data.errors+1);
- data.all = &all;
- learner<ect>& l = init_learner(&data, all.l, learn, predict, wpp);
- l.set_finish_example(finish_example);
+ learner<ect>& l = init_learner(&data, setup_base(all), learn, predict, wpp);
+ l.set_finish_example(MULTICLASS::finish_example<ect>);
+ all.p->lp = MULTICLASS::mc_label;
l.set_finish(finish);
-
return make_base(l);
}
}