Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/ect.cc')
-rw-r--r--vowpalwabbit/ect.cc72
1 files changed, 36 insertions, 36 deletions
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc
index 94576f41..3fd7c385 100644
--- a/vowpalwabbit/ect.cc
+++ b/vowpalwabbit/ect.cc
@@ -186,7 +186,7 @@ namespace ECT
return e.last_pair + (eliminations-1);
}
- float ect_predict(vw& all, ect& e, learner& base, example* ec)
+ float ect_predict(vw& all, ect& e, learner& base, example& ec)
{
if (e.k == (size_t)1)
return 1;
@@ -195,7 +195,7 @@ namespace ECT
//Binary final elimination tournament first
label_data simple_temp = {FLT_MAX, 0., 0.};
- ec->ld = & simple_temp;
+ ec.ld = & simple_temp;
for (size_t i = e.tree_height-1; i != (size_t)0 -1; i--)
{
@@ -205,7 +205,7 @@ namespace ECT
base.learn(ec, problem_number);
- float pred = ec->final_prediction;
+ float pred = ec.final_prediction;
if (pred > 0.)
finals_winner = finals_winner | (((size_t)1) << i);
}
@@ -216,7 +216,7 @@ namespace ECT
{
base.learn(ec, id - e.k);
- if (ec->final_prediction > 0.)
+ if (ec.final_prediction > 0.)
id = e.directions[id].right;
else
id = e.directions[id].left;
@@ -232,11 +232,11 @@ namespace ECT
return false;
}
- void ect_train(vw& all, ect& e, learner& base, example* ec)
+ void ect_train(vw& all, ect& e, learner& base, example& ec)
{
if (e.k == 1)//nothing to do
return;
- OAA::mc_label * mc = (OAA::mc_label*)ec->ld;
+ OAA::mc_label * mc = (OAA::mc_label*)ec.ld;
label_data simple_temp = {1.,mc->weight,0.};
@@ -252,12 +252,12 @@ namespace ECT
simple_temp.label = 1;
simple_temp.weight = mc->weight;
- ec->ld = &simple_temp;
+ ec.ld = &simple_temp;
base.learn(ec, id-e.k);
simple_temp.weight = 0.;
base.learn(ec, id-e.k);//inefficient, we should extract final prediction exactly.
- float pred = ec->final_prediction;
+ float pred = ec.final_prediction;
bool won = pred*simple_temp.label > 0;
@@ -305,13 +305,13 @@ namespace ECT
label = 1;
simple_temp.label = label;
simple_temp.weight = (float)(1 << (e.tree_height -i -1));
- ec->ld = & simple_temp;
+ ec.ld = & simple_temp;
uint32_t problem_number = e.last_pair + j*(1 << (i+1)) + (1 << i) -1;
base.learn(ec, problem_number);
- float pred = ec->final_prediction;
+ float pred = ec.final_prediction;
if (pred > 0.)
e.tournaments_won[j] = right;
else
@@ -324,53 +324,53 @@ namespace ECT
}
}
- void predict(ect* e, learner& base, example* ec) {
- vw* all = e->all;
+ void predict(ect& e, learner& base, example& ec) {
+ vw* all = e.all;
- OAA::mc_label* mc = (OAA::mc_label*)ec->ld;
- if (mc->label == 0 || (mc->label > e->k && mc->label != (uint32_t)-1))
- cout << "label " << mc->label << " is not in {1,"<< e->k << "} This won't work right." << endl;
- ec->final_prediction = ect_predict(*all, *e, base, ec);
- ec->ld = mc;
+ OAA::mc_label* mc = (OAA::mc_label*)ec.ld;
+ if (mc->label == 0 || (mc->label > e.k && mc->label != (uint32_t)-1))
+ cout << "label " << mc->label << " is not in {1,"<< e.k << "} This won't work right." << endl;
+ ec.final_prediction = ect_predict(*all, e, base, ec);
+ ec.ld = mc;
}
- void learn(ect* e, learner& base, example* ec)
+ void learn(ect& e, learner& base, example& ec)
{
- vw* all = e->all;
+ vw* all = e.all;
- OAA::mc_label* mc = (OAA::mc_label*)ec->ld;
+ OAA::mc_label* mc = (OAA::mc_label*)ec.ld;
predict(e, base, ec);
- float new_label = ec->final_prediction;
+ float new_label = ec.final_prediction;
if (mc->label != (uint32_t)-1 && all->training)
- ect_train(*all, *e, base, ec);
- ec->ld = mc;
- ec->final_prediction = new_label;
+ ect_train(*all, e, base, ec);
+ ec.ld = mc;
+ ec.final_prediction = new_label;
}
- void finish(ect* e)
+ void finish(ect& e)
{
- for (size_t l = 0; l < e->all_levels.size(); l++)
+ for (size_t l = 0; l < e.all_levels.size(); l++)
{
- for (size_t t = 0; t < e->all_levels[l].size(); t++)
- e->all_levels[l][t].delete_v();
- e->all_levels[l].delete_v();
+ for (size_t t = 0; t < e.all_levels[l].size(); t++)
+ e.all_levels[l][t].delete_v();
+ e.all_levels[l].delete_v();
}
- e->final_nodes.delete_v();
+ e.final_nodes.delete_v();
- e->up_directions.delete_v();
+ e.up_directions.delete_v();
- e->directions.delete_v();
+ e.directions.delete_v();
- e->down_directions.delete_v();
+ e.down_directions.delete_v();
- e->tournaments_won.delete_v();
+ e.tournaments_won.delete_v();
}
- void finish_example(vw& all, ect*, example* ec)
+ void finish_example(vw& all, ect&, example& ec)
{
OAA::output_example(all, ec);
- VW::finish_example(all, ec);
+ VW::finish_example(all, &ec);
}
learner* setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)