diff options
Diffstat (limited to 'vowpalwabbit/ect.cc')
-rw-r--r-- | vowpalwabbit/ect.cc | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index 210b2733..3eab22c9 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -203,7 +203,7 @@ namespace ECT base.learn(ec, problem_number); - if (ec.l.simple.prediction > 0.) + if (ec.pred.scalar > 0.) finals_winner = finals_winner | (((size_t)1) << i); } } @@ -213,7 +213,7 @@ namespace ECT { base.learn(ec, id - e.k); - if (ec.l.simple.prediction > 0.) + if (ec.pred.scalar > 0.) id = e.directions[id].right; else id = e.directions[id].left; @@ -256,7 +256,7 @@ namespace ECT ec.l.simple.weight = 0.; base.learn(ec, id-e.k);//inefficient, we should extract final prediction exactly. - bool won = ec.l.simple.prediction * simple_temp.label > 0; + bool won = ec.pred.scalar * simple_temp.label > 0; if (won) { @@ -306,7 +306,7 @@ namespace ECT base.learn(ec, problem_number); - if (ec.l.simple.prediction > 0.) + if (ec.pred.scalar > 0.) e.tournaments_won[j] = right; else e.tournaments_won[j] = left; @@ -324,7 +324,7 @@ namespace ECT MULTICLASS::multiclass mc = ec.l.multi; if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1)) cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl; - mc.prediction = ect_predict(*all, e, base, ec); + ec.pred.multiclass = ect_predict(*all, e, base, ec); ec.l.multi = mc; } @@ -334,11 +334,12 @@ namespace ECT MULTICLASS::multiclass mc = ec.l.multi; predict(e, base, ec); - mc.prediction = ec.l.multi.prediction; + uint32_t pred = ec.pred.multiclass; if (mc.label != (uint32_t)-1 && all->training) ect_train(*all, e, base, ec); ec.l.multi = mc; + ec.pred.multiclass = pred; } void finish(ect& e) |