diff options
author | John Langford <jl@hunch.net> | 2014-11-28 22:43:58 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-11-28 22:43:58 +0300 |
commit | 4bf53b154796eee7501598f1edf5debb5c8e49df (patch) | |
tree | d124f9e8e992bca47b38aebd30342b1b5a91cc14 | |
parent | 76c9c77631308ef612ec9f9580e49aeebab8172c (diff) |
separated label from prediction
36 files changed, 178 insertions, 167 deletions
diff --git a/library/library_example.cc b/library/library_example.cc index 7a3a9367..8abfc1f7 100644 --- a/library/library_example.cc +++ b/library/library_example.cc @@ -18,7 +18,7 @@ int main(int argc, char *argv[]) example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme");
model->learn(vec2);
- cerr << "p2 = " << vec2->l.simple.prediction << endl;
+ cerr << "p2 = " << vec2->pred.scalar << endl;
VW::finish_example(*model, vec2);
vector< VW::feature_space > ec_info;
@@ -36,7 +36,7 @@ int main(int argc, char *argv[]) example* vec3 = VW::import_example(*model, ec_info);
model->learn(vec3);
- cerr << "p3 = " << vec3->l.simple.prediction << endl;
+ cerr << "p3 = " << vec3->pred.scalar << endl;
VW::finish_example(*model, vec3);
VW::finish(*model);
@@ -44,7 +44,7 @@ int main(int argc, char *argv[]) vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw");
vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme");
model2->learn(vec2);
- cerr << "p4 = " << vec2->l.simple.prediction << endl;
+ cerr << "p4 = " << vec2->pred.scalar << endl;
size_t len=0;
VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len);
diff --git a/library/recommend.cc b/library/recommend.cc index b78d6fa5..3d970d76 100644 --- a/library/recommend.cc +++ b/library/recommend.cc @@ -213,12 +213,12 @@ int main(int argc, char *argv[]) if(pr_queue.size() < (size_t)topk) { - pr_queue.push(make_pair(ex->l.simple.prediction, str)); + pr_queue.push(make_pair(ex->pred.scalar, str)); } - else if(pr_queue.top().first < ex->l.simple.prediction) + else if(pr_queue.top().first < ex->pred.scalar) { pr_queue.pop(); - pr_queue.push(make_pair(ex->l.simple.prediction, str)); + pr_queue.push(make_pair(ex->pred.scalar, str)); } VW::finish_example(*model, ex); diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc index 19ce9778..78e8e0e1 100644 --- a/vowpalwabbit/active.cc +++ b/vowpalwabbit/active.cc @@ -53,7 +53,7 @@ namespace ACTIVE { vw& all = *a.all; float k = ec.example_t - ec.l.simple.weight; - ec.revert_weight = all.loss->getRevertingWeight(all.sd, ec.l.simple.prediction, all.eta/powf(k,all.power_t)); + ec.revert_weight = all.loss->getRevertingWeight(all.sd, ec.pred.scalar, all.eta/powf(k,all.power_t)); float importance = query_decision(a, ec, k); if(importance > 0){ @@ -77,7 +77,7 @@ namespace ACTIVE { vw& all = *a.all; float t = (float)(ec.example_t - all.sd->weighted_holdout_examples); - ec.revert_weight = all.loss->getRevertingWeight(all.sd, ec.l.simple.prediction, + ec.revert_weight = all.loss->getRevertingWeight(all.sd, ec.pred.scalar, all.eta/powf(t,all.power_t)); } } @@ -139,7 +139,7 @@ namespace ACTIVE { for (size_t i = 0; i<all.final_prediction_sink.size(); i++) { int f = (int)all.final_prediction_sink[i]; - active_print_result(f, ld.prediction, ai, ec.tag); + active_print_result(f, ec.pred.scalar, ai, ec.tag); } print_update(all, ec); diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc index 60766443..4eebe6d4 100644 --- a/vowpalwabbit/autolink.cc +++ b/vowpalwabbit/autolink.cc @@ -15,7 +15,7 @@ namespace ALINK { void predict_or_learn(autolink& b, learner& base, example& ec) { base.predict(ec); - float base_pred = ec.l.simple.prediction; + float base_pred = ec.pred.scalar; // add features of label ec.indices.push_back(autolink_namespace); @@ -26,7 +26,7 @@ namespace ALINK { feature f = { base_pred, (uint32_t) (autoconstant + (i << b.stride_shift)) }; ec.atomics[autolink_namespace].push_back(f); sum_sq += base_pred*base_pred; - base_pred *= ec.l.simple.prediction; + base_pred *= ec.pred.scalar; } ec.total_sum_feat_sq += sum_sq; diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index bb22e393..b964c218 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -187,7 +187,7 @@ inline void add_precond(float& d, float f, float& fw) void update_preconditioner(vw& all, example& ec) { label_data& ld = ec.l.simple; - float curvature = all.loss->second_derivative(all.sd, ld.prediction,ld.label) * ld.weight; + float curvature = all.loss->second_derivative(all.sd, ec.pred.scalar, ld.label) * ld.weight; ec.ft_offset += W_COND; GD::foreach_feature<float,add_precond>(all, ec, curvature); @@ -734,10 +734,10 @@ void process_example(vw& all, bfgs& b, example& ec) /********************************************************************/ if (b.gradient_pass) { - ld.prediction = predict_and_gradient(all, ec);//w[0] & w[1] - ec.loss = all.loss->getLoss(all.sd, ld.prediction, ld.label) * ld.weight; + ec.pred.scalar = predict_and_gradient(all, ec);//w[0] & w[1] + ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight; b.loss_sum += ec.loss; - b.predictions.push_back(ld.prediction); + b.predictions.push_back(ec.pred.scalar); } /********************************************************************/ /* II) CURVATURE CALCULATION ****************************************/ @@ -747,13 +747,13 @@ void process_example(vw& all, bfgs& b, example& ec) float d_dot_x = dot_with_direction(all, ec);//w[2] if (b.example_number >= b.predictions.size())//Make things safe in case example source is strange. b.example_number = b.predictions.size()-1; - ld.prediction = b.predictions[b.example_number]; + ec.pred.scalar = b.predictions[b.example_number]; ec.partial_prediction = b.predictions[b.example_number]; - ec.loss = all.loss->getLoss(all.sd, ld.prediction, ld.label) * ld.weight; + ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight; float sd = all.loss->second_derivative(all.sd, b.predictions[b.example_number++],ld.label); b.curvature += d_dot_x*d_dot_x*sd*ld.weight; } - ec.updated_prediction = ld.prediction; + ec.updated_prediction = ec.pred.scalar; if (b.preconditioner_pass) update_preconditioner(all, ec);//w[3] @@ -821,7 +821,7 @@ void end_pass(bfgs& b) void predict(bfgs& b, learner& base, example& ec) { vw* all = b.all; - ec.l.simple.prediction = bfgs_predict(*all,ec); + ec.pred.scalar = bfgs_predict(*all,ec); } void learn(bfgs& b, learner& base, example& ec) diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc index 68b3a341..3f78e789 100644 --- a/vowpalwabbit/binary.cc +++ b/vowpalwabbit/binary.cc @@ -13,12 +13,12 @@ namespace BINARY { else base.predict(ec); - if ( ec.l.simple.prediction > 0) - ec.l.simple.prediction = 1; + if ( ec.pred.scalar > 0) + ec.pred.scalar = 1; else - ec.l.simple.prediction = -1; + ec.pred.scalar = -1; - if (ec.l.simple.label == ec.l.simple.prediction) + if (ec.l.simple.label == ec.pred.scalar) ec.loss = 0.; else ec.loss = ec.l.simple.weight; diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc index dd95b6c8..8c2ca4b9 100644 --- a/vowpalwabbit/bs.cc +++ b/vowpalwabbit/bs.cc @@ -31,8 +31,8 @@ namespace BS { void bs_predict_mean(vw& all, example& ec, vector<double> &pred_vec) { - ec.l.simple.prediction = (float)accumulate(pred_vec.begin(), pred_vec.end(), 0.0)/pred_vec.size(); - ec.loss = all.loss->getLoss(all.sd, ec.l.simple.prediction, ec.l.simple.label) * ec.l.simple.weight; + ec.pred.scalar = (float)accumulate(pred_vec.begin(), pred_vec.end(), 0.0)/pred_vec.size(); + ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ec.l.simple.label) * ec.l.simple.weight; } void bs_predict_vote(vw& all, example& ec, vector<double> &pred_vec) @@ -107,10 +107,10 @@ namespace BS { } // ld.prediction = sum_labels/(float)counter; //replace line below for: "avg on votes" and getLoss() - ec.l.simple.prediction = (float)current_label; + ec.pred.scalar = (float)current_label; // ec.loss = all.loss->getLoss(all.sd, ld.prediction, ld.label) * ld.weight; //replace line below for: "avg on votes" and getLoss() - ec.loss = ((ec.l.simple.prediction == ec.l.simple.label) ? 0.f : 1.f) * ec.l.simple.weight; + ec.loss = ((ec.pred.scalar == ec.l.simple.label) ? 0.f : 1.f) * ec.l.simple.weight; } void print_result(int f, float res, float weight, v_array<char> tag, float lb, float ub) @@ -174,7 +174,7 @@ namespace BS { } for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) - BS::print_result(*sink, ld.prediction, 0, ec.tag, d.lb, d.ub); + BS::print_result(*sink, ec.pred.scalar, 0, ec.tag, d.lb, d.ub); print_update(all, ec); } @@ -200,7 +200,7 @@ namespace BS { else base.predict(ec, i-1); - d.pred_vec.push_back(ec.l.simple.prediction); + d.pred_vec.push_back(ec.pred.scalar); if (shouldOutput) { if (i > 1) outputStringStream << ' '; diff --git a/vowpalwabbit/cb.cc b/vowpalwabbit/cb.cc index eb8d1386..ba65febe 100644 --- a/vowpalwabbit/cb.cc +++ b/vowpalwabbit/cb.cc @@ -172,31 +172,52 @@ namespace CB_EVAL { size_t read_cached_label(shared_data*sd, void* v, io_buf& cache) { - CB::label* ld = (CB::label*) v; + CB_EVAL::label* ld = (CB_EVAL::label*) v; char* c; size_t total = sizeof(uint32_t); if (buf_read(cache, c, total) < total) return 0; - ld->prediction = *(uint32_t*)c; + ld->action = *(uint32_t*)c; c += sizeof(uint32_t); - return total + CB::read_cached_label(sd, ld, cache); + return total + CB::read_cached_label(sd, &(ld->event), cache); } void cache_label(void* v, io_buf& cache) { char *c; - CB::label* ld = (CB::label*) v; + CB_EVAL::label* ld = (CB_EVAL::label*) v; buf_write(cache, c, sizeof(uint32_t)); - *(uint32_t *)c = ld->prediction; + *(uint32_t *)c = ld->action; c+= sizeof(uint32_t); - CB::cache_label(ld, cache); + CB::cache_label(&(ld->event), cache); + } + + void default_label(void* v) + { + CB_EVAL::label* ld = (CB_EVAL::label*) v; + CB::default_label(&(ld->event)); + ld->action = 0; + } + + void delete_label(void* v) + { + CB_EVAL::label* ld = (CB_EVAL::label*)v; + CB::delete_label(&(ld->event)); + } + + void copy_label(void*dst, void*src) + { + CB_EVAL::label* ldD = (CB_EVAL::label*)dst; + CB_EVAL::label* ldS = (CB_EVAL::label*)src; + CB::copy_label(&(ldD->event), &(ldS)->event); + ldD->action = ldS->action; } void parse_label(parser* p, shared_data* sd, void* v, v_array<substring>& words) { - CB::label* ld = (CB::label*)v; + CB_EVAL::label* ld = (CB_EVAL::label*)v; if (words.size() < 2) { @@ -204,18 +225,18 @@ namespace CB_EVAL throw exception(); } - ld->prediction = (uint32_t)hashstring(words[0], 0); + ld->action = (uint32_t)hashstring(words[0], 0); words.begin++; - CB::parse_label(p, sd, ld, words); + CB::parse_label(p, sd, &(ld->event), words); words.begin--; } - label_parser cb_eval = {CB::default_label, parse_label, + label_parser cb_eval = {default_label, parse_label, cache_label, read_cached_label, - CB::delete_label, CB::weight, - CB::copy_label, - sizeof(CB::label)}; + delete_label, CB::weight, + copy_label, + sizeof(CB_EVAL::label)}; } diff --git a/vowpalwabbit/cb.h b/vowpalwabbit/cb.h index cc803d71..1ce2cf9e 100644 --- a/vowpalwabbit/cb.h +++ b/vowpalwabbit/cb.h @@ -19,12 +19,16 @@ namespace CB { struct label { v_array<cb_class> costs; - uint32_t prediction; }; extern label_parser cb_label;//for learning } namespace CB_EVAL { + struct label { + uint32_t action; + CB::label event; + }; + extern label_parser cb_eval;//for evaluation of an arbitrary policy. } diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc index e2dfaee1..4550efaa 100644 --- a/vowpalwabbit/cb_algs.cc +++ b/vowpalwabbit/cb_algs.cc @@ -79,10 +79,8 @@ namespace CB_ALGS return NULL; } - void gen_cs_example_ips(vw& all, cb& c, example& ec, COST_SENSITIVE::label& cs_ld) + void gen_cs_example_ips(vw& all, cb& c, example& ec, CB::label& ld, COST_SENSITIVE::label& cs_ld) {//this implements the inverse propensity score method, where cost are importance weighted by the probability of the chosen action - CB::label ld = ec.l.cb; - //generate cost-sensitive example cs_ld.costs.erase(); if( ld.costs.size() == 1) { //this is a typical example where we can perform all actions @@ -206,7 +204,7 @@ namespace CB_ALGS } } - ec.l.cb.prediction = argmin; + ec.pred.multiclass = argmin; } template <bool is_learn> @@ -233,10 +231,8 @@ namespace CB_ALGS } template <bool is_learn> - void gen_cs_example_dr(vw& all, cb& c, example& ec, COST_SENSITIVE::label& cs_ld) + void gen_cs_example_dr(vw& all, cb& c, example& ec, CB::label& ld, COST_SENSITIVE::label& cs_ld) {//this implements the doubly robust method - CB::label ld = ec.l.cb; - //generate cost sensitive example cs_ld.costs.erase(); if( ld.costs.size() == 1) //this is a typical example where we can perform all actions @@ -300,7 +296,6 @@ namespace CB_ALGS ec.l.cs = c.cb_cs_ld; base.predict(ec); - ld.prediction = c.cb_cs_ld.prediction; for (size_t i=0; i<ld.costs.size(); i++) ld.costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction; @@ -318,13 +313,13 @@ namespace CB_ALGS switch(c.cb_type) { case CB_TYPE_IPS: - gen_cs_example_ips(*all,c,ec,c.cb_cs_ld); + gen_cs_example_ips(*all,c,ec,ld,c.cb_cs_ld); break; case CB_TYPE_DM: gen_cs_example_dm<is_learn>(*all,c,ec,c.cb_cs_ld); break; case CB_TYPE_DR: - gen_cs_example_dr<is_learn>(*all,c,ec,c.cb_cs_ld); + gen_cs_example_dr<is_learn>(*all,c,ec,ld,c.cb_cs_ld); break; default: std::cerr << "Unknown cb_type specified for contextual bandit learning: " << c.cb_type << ". Exiting." << endl; @@ -340,7 +335,6 @@ namespace CB_ALGS else base.predict(ec); - ld.prediction = ec.l.cs.prediction; for (size_t i=0; i<ld.costs.size(); i++) ld.costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction; ec.l.cb = ld; @@ -354,19 +348,19 @@ namespace CB_ALGS void learn_eval(cb& c, learner& base, example& ec) { vw* all = c.all; - CB::label ld = ec.l.cb; + CB_EVAL::label ld = ec.l.cb_eval; - c.known_cost = get_observed_cost(ld); + c.known_cost = get_observed_cost(ld.event); if (c.cb_type == CB_TYPE_DR) - gen_cs_example_dr<true>(*all,c,ec,c.cb_cs_ld); + gen_cs_example_dr<true>(*all, c, ec, ld.event, c.cb_cs_ld); else //c.cb_type == CB_TYPE_IPS - gen_cs_example_ips(*all,c,ec,c.cb_cs_ld); + gen_cs_example_ips(*all, c, ec, ld.event, c.cb_cs_ld); - for (size_t i=0; i<ld.costs.size(); i++) - ld.costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction; + for (size_t i=0; i<ld.event.costs.size(); i++) + ld.event.costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction; - ec.l.cb = ld; + ec.l.cb_eval = ld; } void init_driver(cb&) @@ -384,7 +378,6 @@ namespace CB_ALGS else sprintf(label_buf," known"); - CB::label& ld = ec.l.cb; if(!all.holdout_set_off && all.current_pass >= 1) { if(all.sd->holdout_sum_loss == 0. && all.sd->weighted_holdout_examples == 0.) @@ -401,7 +394,7 @@ namespace CB_ALGS (long int)all.sd->example_number, all.sd->weighted_examples, label_buf, - (long unsigned int)ld.prediction, + (long unsigned int)ec.pred.multiclass, (long unsigned int)ec.num_features, c.avg_loss_regressors, c.last_pred_reg, @@ -417,7 +410,7 @@ namespace CB_ALGS (long int)all.sd->example_number, all.sd->weighted_examples, label_buf, - (long unsigned int)ld.prediction, + (long unsigned int)ec.pred.multiclass, (long unsigned int)ec.num_features, c.avg_loss_regressors, c.last_pred_reg, @@ -430,29 +423,25 @@ namespace CB_ALGS } } - void output_example(vw& all, cb& c, example& ec) + void output_example(vw& all, cb& c, example& ec, CB::label& ld) { - CB::label& ld = ec.l.cb; - float loss = 0.; if (!is_test_label(ld)) {//need to compute exact loss - size_t pred = (size_t)ld.prediction; - float chosen_loss = FLT_MAX; if( know_all_cost_example(ld) ) { for (cb_class *cl = ld.costs.begin; cl != ld.costs.end; cl ++) { - if (cl->action == pred) + if (cl->action == ec.pred.multiclass) chosen_loss = cl->cost; } } else { //we do not know exact cost of each action, so evaluate on generated cost-sensitive example currently stored in cb_cs_ld for (COST_SENSITIVE::wclass *cl = c.cb_cs_ld.costs.begin; cl != c.cb_cs_ld.costs.end; cl ++) { - if (cl->class_index == pred) + if (cl->class_index == ec.pred.multiclass) { chosen_loss = cl->x; - if (c.known_cost->action == pred && c.cb_type == CB_TYPE_DM) + if (c.known_cost->action == ec.pred.multiclass && c.cb_type == CB_TYPE_DM) chosen_loss += (c.known_cost->cost - chosen_loss) / c.known_cost->probability; } } @@ -484,7 +473,7 @@ namespace CB_ALGS for (size_t i = 0; i<all.final_prediction_sink.size(); i++) { int f = all.final_prediction_sink[i]; - all.print(f, (float)ld.prediction, 0, ec.tag); + all.print(f, (float)ec.pred.multiclass, 0, ec.tag); } print_update(all, c, is_test_label(ec.l.cb), ec); @@ -497,7 +486,13 @@ namespace CB_ALGS void finish_example(vw& all, cb& c, example& ec) { - output_example(all, c, ec); + output_example(all, c, ec, ec.l.cb); + VW::finish_example(all, &ec); + } + + void eval_finish_example(vw& all, cb& c, example& ec) + { + output_example(all, c, ec, ec.l.cb_eval.event); VW::finish_example(all, &ec); } @@ -577,13 +572,14 @@ namespace CB_ALGS { l->set_learn<cb, learn_eval>(); l->set_predict<cb, predict_eval>(); + l->set_finish_example<cb,eval_finish_example>(); } else { l->set_learn<cb, predict_or_learn<true> >(); l->set_predict<cb, predict_or_learn<false> >(); + l->set_finish_example<cb,finish_example>(); } - l->set_finish_example<cb,finish_example>(); l->set_init_driver<cb,init_driver>(); l->set_finish<cb,finish>(); // preserve the increment of the base learner since we are diff --git a/vowpalwabbit/cb_algs.h b/vowpalwabbit/cb_algs.h index 85d21591..68da7610 100644 --- a/vowpalwabbit/cb_algs.h +++ b/vowpalwabbit/cb_algs.h @@ -34,10 +34,10 @@ namespace CB_ALGS { else all.scorer->predict(ec, index-1+base); - simple_temp.prediction = ec.l.simple.prediction; + float pred = ec.pred.scalar; ec.l.cb = ld; - return simple_temp.prediction; + return pred; } } diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index 0bd64ecb..c36e345f 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -117,7 +117,7 @@ namespace CBIFY { else ctx.l->predict(*ctx.e, (size_t)m_index); ctx.recorded = false; - return (u32)(ctx.e->l.cb.prediction); + return (u32)(ctx.e->pred.multiclass); } void vw_recorder::Record(vw_context& context, u32 action, float probability, string unique_key) @@ -136,8 +136,9 @@ namespace CBIFY { ctx.data->cs->predict(*ctx.e, i); else ctx.data->cs->predict(*ctx.e, i + 1); - m_scores[ctx.data->cs_label.prediction - 1] += additive_probability; - m_predictions[i] = (uint32_t)ctx.data->cs_label.prediction; + uint32_t pred = ctx.e->pred.multiclass; + m_scores[pred - 1] += additive_probability; + m_predictions[i] = (uint32_t)pred; } float min_prob = m_epsilon * min(1.f / ctx.data->k, 1.f / (float)sqrt(m_counter * ctx.data->k)); @@ -178,7 +179,7 @@ namespace CBIFY { ec.loss = l.cost; } - ld.prediction = action; + ec.pred.multiclass = action; ec.l.multi = ld; } @@ -204,7 +205,7 @@ namespace CBIFY { if (is_learn) base.learn(ec); - ld.prediction = action; + ec.pred.multiclass = action; ec.l.multi = ld; ec.loss = loss(ld.label, action); } @@ -239,7 +240,7 @@ namespace CBIFY { base.learn(ec,i); } } - ld.prediction = action; + ec.pred.multiclass = action; ec.l.multi = ld; } @@ -357,7 +358,7 @@ namespace CBIFY { } } - ld.prediction = action; + ec.pred.multiclass = action; ec.l.multi = ld; } diff --git a/vowpalwabbit/cost_sensitive.cc b/vowpalwabbit/cost_sensitive.cc index e0fe0ca3..11bc7231 100644 --- a/vowpalwabbit/cost_sensitive.cc +++ b/vowpalwabbit/cost_sensitive.cc @@ -206,7 +206,6 @@ namespace COST_SENSITIVE { num_current_features += (*ecc)->num_features; } - label& ld = ec.l.cs; char label_buf[32]; if (is_test) strcpy(label_buf," unknown"); @@ -229,7 +228,7 @@ namespace COST_SENSITIVE { (long int)all.sd->example_number, all.sd->weighted_examples, label_buf, - (long unsigned int)ld.prediction, + (long unsigned int)ec.pred.multiclass, (long unsigned int)num_current_features); all.sd->weighted_holdout_examples_since_last_dump = 0; @@ -242,7 +241,7 @@ namespace COST_SENSITIVE { (long int)all.sd->example_number, all.sd->weighted_examples, label_buf, - (long unsigned int)ld.prediction, + (long unsigned int)ec.pred.multiclass, (long unsigned int)num_current_features); all.sd->sum_loss_since_last_dump = 0.0; @@ -258,7 +257,7 @@ namespace COST_SENSITIVE { float loss = 0.; if (!is_test_label(ld)) {//need to compute exact loss - size_t pred = (size_t)ld.prediction; + size_t pred = (size_t)ec.pred.multiclass; float chosen_loss = FLT_MAX; float min = FLT_MAX; @@ -293,7 +292,7 @@ namespace COST_SENSITIVE { } for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) - all.print((int)*sink, (float)ld.prediction, 0, ec.tag); + all.print((int)*sink, (float)ec.pred.multiclass, 0, ec.tag); if (all.raw_prediction > 0) { string outputString; diff --git a/vowpalwabbit/cost_sensitive.h b/vowpalwabbit/cost_sensitive.h index f6934d4c..a5f56a61 100644 --- a/vowpalwabbit/cost_sensitive.h +++ b/vowpalwabbit/cost_sensitive.h @@ -25,7 +25,6 @@ namespace COST_SENSITIVE { struct label { v_array<wclass> costs; - uint32_t prediction; }; void output_example(vw& all, example& ec); diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc index 73139b31..33fdbae2 100644 --- a/vowpalwabbit/csoaa.cc +++ b/vowpalwabbit/csoaa.cc @@ -57,7 +57,7 @@ namespace CSOAA { ec.partial_prediction = 0.; } - ld.prediction = prediction; + ec.pred.multiclass = prediction; ec.l.cs = ld; } @@ -415,7 +415,7 @@ namespace LabelDict { if (prediction == costs1[j1].class_index) prediction_is_me = true; } - ld1.prediction = prediction_is_me ? prediction : 0; + ec1->pred.multiclass = prediction_is_me ? prediction : 0; ec1->l.cs = ld1; ec1->example_t = example_t1; } @@ -493,7 +493,7 @@ namespace LabelDict { ec->partial_prediction = costs[j].partial_prediction; if (prediction == costs[j].class_index) prediction_is_me = true; } - ld.prediction = prediction_is_me ? prediction : 0; + ec->pred.multiclass = prediction_is_me ? prediction : 0; // restore label ec->l.cs = ld; @@ -555,12 +555,11 @@ namespace LabelDict { all.sd->total_features += ec.num_features; float loss = 0.; - size_t final_pred = ld.prediction; if (!COST_SENSITIVE::example_is_test(ec)) { for (size_t j=0; j<costs.size(); j++) { if (hit_loss) break; - if (final_pred == costs[j].class_index) { + if (ec.pred.multiclass == costs[j].class_index) { loss = costs[j].x; hit_loss = true; } @@ -572,7 +571,7 @@ namespace LabelDict { } for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) - all.print(*sink, (float)ld.prediction, 0, ec.tag); + all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag); if (all.raw_prediction > 0) { string outputString; diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc index 210b2733..3eab22c9 100644 --- a/vowpalwabbit/ect.cc +++ b/vowpalwabbit/ect.cc @@ -203,7 +203,7 @@ namespace ECT base.learn(ec, problem_number); - if (ec.l.simple.prediction > 0.) + if (ec.pred.scalar > 0.) finals_winner = finals_winner | (((size_t)1) << i); } } @@ -213,7 +213,7 @@ namespace ECT { base.learn(ec, id - e.k); - if (ec.l.simple.prediction > 0.) + if (ec.pred.scalar > 0.) id = e.directions[id].right; else id = e.directions[id].left; @@ -256,7 +256,7 @@ namespace ECT ec.l.simple.weight = 0.; base.learn(ec, id-e.k);//inefficient, we should extract final prediction exactly. - bool won = ec.l.simple.prediction * simple_temp.label > 0; + bool won = ec.pred.scalar * simple_temp.label > 0; if (won) { @@ -306,7 +306,7 @@ namespace ECT base.learn(ec, problem_number); - if (ec.l.simple.prediction > 0.) + if (ec.pred.scalar > 0.) e.tournaments_won[j] = right; else e.tournaments_won[j] = left; @@ -324,7 +324,7 @@ namespace ECT MULTICLASS::multiclass mc = ec.l.multi; if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1)) cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl; - mc.prediction = ect_predict(*all, e, base, ec); + ec.pred.multiclass = ect_predict(*all, e, base, ec); ec.l.multi = mc; } @@ -334,11 +334,12 @@ namespace ECT MULTICLASS::multiclass mc = ec.l.multi; predict(e, base, ec); - mc.prediction = ec.l.multi.prediction; + uint32_t pred = ec.pred.multiclass; if (mc.label != (uint32_t)-1 && all->training) ect_train(*all, e, base, ec); ec.l.multi = mc; + ec.pred.multiclass = pred; } void finish(ect& e) diff --git a/vowpalwabbit/example.h b/vowpalwabbit/example.h index b8915b4a..8abbade7 100644 --- a/vowpalwabbit/example.h +++ b/vowpalwabbit/example.h @@ -41,11 +41,12 @@ typedef union { MULTICLASS::multiclass multi; COST_SENSITIVE::label cs; CB::label cb; + CB_EVAL::label cb_eval; } polylabel; typedef union { - float simple; - uint32_t multi; + float scalar; + uint32_t multiclass; } polyprediction; struct example // core example datatype. diff --git a/vowpalwabbit/ezexample.h b/vowpalwabbit/ezexample.h index 44f0d5a7..8720bb62 100644 --- a/vowpalwabbit/ezexample.h +++ b/vowpalwabbit/ezexample.h @@ -222,7 +222,7 @@ class ezexample { float predict() { setup_for_predict(); - return ec->l.simple.prediction; + return ec->pred.scalar; } float predict_partial() { diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc index 1579cb57..589d39c2 100644 --- a/vowpalwabbit/gd.cc +++ b/vowpalwabbit/gd.cc @@ -310,7 +310,7 @@ void print_features(vw& all, example& ec) void print_audit_features(vw& all, example& ec) { if(all.audit) - print_result(all.stdout_fileno,ec.l.simple.prediction,-1,ec.tag); + print_result(all.stdout_fileno,ec.pred.scalar,-1,ec.tag); fflush(stdout); print_features(all, ec); } @@ -356,7 +356,7 @@ void predict(gd& g, learner& base, example& ec) ec.partial_prediction = inline_predict(all, ec); ec.partial_prediction *= (float)all.sd->contraction; - ec.l.simple.prediction = finalize_prediction(all.sd, ec.partial_prediction); + ec.pred.scalar = finalize_prediction(all.sd, ec.partial_prediction); if (audit) print_audit_features(all, ec); @@ -443,7 +443,7 @@ template<bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normaliz {//We must traverse the features in _precisely_ the same order as during training. label_data& ld = ec.l.simple; vw& all = *g.all; - float grad_squared = all.loss->getSquareGrad(ld.prediction, ld.label) * ld.weight; + float grad_squared = all.loss->getSquareGrad(ec.pred.scalar, ld.label) * ld.weight; if (grad_squared == 0) return 1.; norm_data nd = {grad_squared, 0., 0., {g.neg_power_t, g.neg_norm_power}}; @@ -468,8 +468,8 @@ float compute_update(gd& g, example& ec) vw& all = *g.all; float ret = 0.; - ec.updated_prediction = ld.prediction; - if (all.loss->getLoss(all.sd, ld.prediction, ld.label) > 0.) + ec.updated_prediction = ec.pred.scalar; + if (all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) > 0.) { float pred_per_update; if(adaptive || normalized) @@ -486,15 +486,15 @@ float compute_update(gd& g, example& ec) float update; if(invariant) - update = all.loss->getUpdate(ld.prediction, ld.label, delta_pred, pred_per_update); + update = all.loss->getUpdate(ec.pred.scalar, ld.label, delta_pred, pred_per_update); else - update = all.loss->getUnsafeUpdate(ld.prediction, ld.label, delta_pred, pred_per_update); + update = all.loss->getUnsafeUpdate(ec.pred.scalar, ld.label, delta_pred, pred_per_update); // changed from ec.partial_prediction to ld.prediction ec.updated_prediction += pred_per_update * update; if (all.reg_mode && fabs(update) > 1e-8) { - double dev1 = all.loss->first_derivative(all.sd, ld.prediction, ld.label); + double dev1 = all.loss->first_derivative(all.sd, ec.pred.scalar, ld.label); double eta_bar = (fabs(dev1) > 1e-8) ? (-update / dev1) : 0.0; if (fabs(dev1) > 1e-8) all.sd->contraction *= (1. - all.l2_lambda * eta_bar); diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc index 1261a07a..df48c934 100644 --- a/vowpalwabbit/gd_mf.cc +++ b/vowpalwabbit/gd_mf.cc @@ -79,7 +79,7 @@ void mf_print_offset_features(vw& all, example& ec, size_t offset) void mf_print_audit_features(vw& all, example& ec, size_t offset) { - print_result(all.stdout_fileno,ec.l.simple.prediction,-1,ec.tag); + print_result(all.stdout_fileno,ec.pred.scalar,-1,ec.tag); mf_print_offset_features(all, ec, offset); } @@ -140,17 +140,15 @@ float mf_predict(vw& all, example& ec) all.set_minmax(all.sd, ld.label); - ld.prediction = GD::finalize_prediction(all.sd, ec.partial_prediction); - + ec.pred.scalar = GD::finalize_prediction(all.sd, ec.partial_prediction); + if (ld.label != FLT_MAX) - { - ec.loss = all.loss->getLoss(all.sd, ld.prediction, ld.label) * ld.weight; - } - + ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight; + if (all.audit) mf_print_audit_features(all, ec, 0); - - return ld.prediction; + + return ec.pred.scalar; } @@ -169,7 +167,7 @@ void mf_train(vw& all, example& ec) // use final prediction to get update size // update = eta_t*(y-y_hat) where eta_t = eta/(3*t^p) * importance weight float eta_t = all.eta/pow(ec.example_t,all.power_t) / 3.f * ld.weight; - float update = all.loss->getUpdate(ld.prediction, ld.label, eta_t, 1.); //ec.total_sum_feat_sq); + float update = all.loss->getUpdate(ec.pred.scalar, ld.label, eta_t, 1.); //ec.total_sum_feat_sq); float regularization = eta_t * all.l2_lambda; diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc index a06b03be..1667fdd1 100644 --- a/vowpalwabbit/kernel_svm.cc +++ b/vowpalwabbit/kernel_svm.cc @@ -408,7 +408,7 @@ namespace KSVM sec->init_svm_example(fec); float score; predict(params, &sec, &score, 1); - ec.l.simple.prediction = score; + ec.pred.scalar = score; } } @@ -749,7 +749,7 @@ namespace KSVM free_flatten_example(fec); float score = 0; predict(params, &sec, &score, 1); - ec.l.simple.prediction = score; + ec.pred.scalar = score; ec.loss = max(0.f, 1.f - score*ec.l.simple.label); params.loss_sum += ec.loss; if(params.all->training && ec.example_counter % 100 == 0) diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc index 27dda73d..e2859972 100644 --- a/vowpalwabbit/log_multi.cc +++ b/vowpalwabbit/log_multi.cc @@ -309,9 +309,9 @@ namespace LOG_MULTI while(b.nodes[cn].internal)
{
base.predict(ec, b.nodes[cn].base_predictor);
- cn = descend(b.nodes[cn], simple_temp.prediction);
+ cn = descend(b.nodes[cn], ec.pred.scalar);
}
- mc.prediction = b.nodes[cn].max_count_label;
+ ec.pred.multiclass = b.nodes[cn].max_count_label;
ec.l.multi = mc;
}
@@ -337,7 +337,7 @@ namespace LOG_MULTI while(children(b, cn, class_index, mc.label))
{
train_node(b, base, ec, cn, class_index);
- cn = descend(b.nodes[cn], simple_temp.prediction);
+ cn = descend(b.nodes[cn], ec.pred.scalar);
}
b.nodes[cn].min_count++;
diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc index 8f1db3ff..02272bc0 100644 --- a/vowpalwabbit/lrq.cc +++ b/vowpalwabbit/lrq.cc @@ -162,12 +162,12 @@ namespace LRQ { // Restore example if (iter == 0) { - first_prediction = ec.l.simple.prediction; + first_prediction = ec.pred.scalar; first_loss = ec.loss; } else { - ec.l.simple.prediction = first_prediction; + ec.pred.scalar = first_prediction; ec.loss = first_loss; } diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc index 33c02944..2cce6272 100644 --- a/vowpalwabbit/mf.cc +++ b/vowpalwabbit/mf.cc @@ -99,17 +99,17 @@ void predict(mf& data, learner& base, example& ec) { // finalize prediction ec.partial_prediction = prediction; - ec.l.simple.prediction = GD::finalize_prediction(data.all->sd, ec.partial_prediction); + ec.pred.scalar = GD::finalize_prediction(data.all->sd, ec.partial_prediction); } void learn(mf& data, learner& base, example& ec) { // predict with current weights predict<true>(data, base, ec); - float predicted = ec.l.simple.prediction; + float predicted = ec.pred.scalar; // update linear weights base.update(ec); - ec.l.simple.prediction = ec.updated_prediction; + ec.pred.scalar = ec.updated_prediction; // store namespace indices copy_array(data.indices, ec.indices); @@ -148,7 +148,7 @@ void learn(mf& data, learner& base, example& ec) { // compute new l_k * x_l scaling factors // base.predict(ec, k); // data.sub_predictions[2*k-1] = ec.partial_prediction; - // ec.l.simple.prediction = ec.updated_prediction; + // ec.pred.scalar = ec.updated_prediction; } // set example to right namespace only @@ -165,7 +165,7 @@ void learn(mf& data, learner& base, example& ec) { // update r^k using base learner base.update(ec, k + data.rank); - ec.l.simple.prediction = ec.updated_prediction; + ec.pred.scalar = ec.updated_prediction; // restore right namespace features copy_array(ec.atomics[right_ns], data.temp_features); @@ -176,7 +176,7 @@ void learn(mf& data, learner& base, example& ec) { copy_array(ec.indices, data.indices); // restore original prediction - ec.l.simple.prediction = predicted; + ec.pred.scalar = predicted; } void finish(mf& o) { diff --git a/vowpalwabbit/multiclass.cc b/vowpalwabbit/multiclass.cc index 9e485aa7..8736bc6e 100644 --- a/vowpalwabbit/multiclass.cc +++ b/vowpalwabbit/multiclass.cc @@ -119,7 +119,7 @@ namespace MULTICLASS { (long int)all.sd->example_number, all.sd->weighted_examples, label_buf, - (long int)ld.prediction, + (long int)ec.pred.multiclass, (long unsigned int)ec.num_features); all.sd->weighted_holdout_examples_since_last_dump = 0; @@ -132,7 +132,7 @@ namespace MULTICLASS { (long int)all.sd->example_number, all.sd->weighted_examples, label_buf, - (long int)ld.prediction, + (long int)ec.pred.multiclass, (long unsigned int)ec.num_features); all.sd->sum_loss_since_last_dump = 0.0; @@ -147,7 +147,7 @@ namespace MULTICLASS { multiclass ld = ec.l.multi; size_t loss = 1; - if (ld.label == (uint32_t)ld.prediction) + if (ld.label == (uint32_t)ec.pred.multiclass) loss = 0; if(ec.test_only) @@ -169,7 +169,7 @@ namespace MULTICLASS { } for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++) - all.print(*sink, (float)ld.prediction, 0, ec.tag); + all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag); MULTICLASS::print_update(all, ec); } diff --git a/vowpalwabbit/multiclass.h b/vowpalwabbit/multiclass.h index 1addf83e..aca34075 100644 --- a/vowpalwabbit/multiclass.h +++ b/vowpalwabbit/multiclass.h @@ -14,7 +14,6 @@ namespace MULTICLASS struct multiclass { uint32_t label; float weight; - uint32_t prediction; }; extern label_parser mc_label; diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index b422c754..b6d3a292 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -143,7 +143,7 @@ namespace NN { base.predict(ec, i); - hidden_units[i] = ec.l.simple.prediction; + hidden_units[i] = ec.pred.scalar; dropped_out[i] = (n.dropout && merand48 (n.xsubi) < 0.5); @@ -285,7 +285,7 @@ CONVERSE: // That's right, I'm using goto. So sue me. } ec.partial_prediction = save_partial_prediction; - ec.l.simple.prediction = save_final_prediction; + ec.pred.scalar = save_final_prediction; ec.loss = save_ec_loss; n.all->sd = save_sd; diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index 6349520c..c1bb6757 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -32,7 +32,7 @@ namespace OAA { if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1)) cout << "label " << mc_label_data.label << " is not in {1,"<< o.k << "} This won't work right." << endl; - ec.l.simple = {0.f, mc_label_data.weight, 0.f, 0.f}; + ec.l.simple = {0.f, mc_label_data.weight, 0.f}; string outputString; stringstream outputStringStream(outputString); @@ -64,7 +64,7 @@ namespace OAA { outputStringStream << i << ':' << ec.partial_prediction; } } - mc_label_data.prediction = prediction; + ec.pred.multiclass = prediction; ec.l.multi = mc_label_data; if (o.shouldOutput) diff --git a/vowpalwabbit/parser.cc b/vowpalwabbit/parser.cc index f78283dc..765b9f26 100644 --- a/vowpalwabbit/parser.cc +++ b/vowpalwabbit/parser.cc @@ -1136,12 +1136,12 @@ float get_initial(example* ec) float get_prediction(example* ec) { - return ec->l.simple.prediction; + return ec->pred.scalar; } float get_cost_sensitive_prediction(example* ec) { - return ec->l.cs.prediction; + return ec->pred.multiclass; } size_t get_tag_length(example* ec) diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index db04dfee..fda5c271 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -20,9 +20,9 @@ namespace Scorer { base.predict(ec); if(ec.l.simple.weight > 0 && ec.l.simple.label != FLT_MAX) - ec.loss = s.all->loss->getLoss(s.all->sd, ec.l.simple.prediction, ec.l.simple.label) * ec.l.simple.weight; + ec.loss = s.all->loss->getLoss(s.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.l.simple.weight; - ec.l.simple.prediction = link(ec.l.simple.prediction); + ec.pred.scalar = link(ec.pred.scalar); } // y = f(x) -> [0, 1] diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index 3a7c66b0..20433ae6 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -578,11 +578,6 @@ namespace Search { del_features_in_top_namespace(priv, ec, conditioning_namespace); } - uint32_t cs_get_prediction(bool isCB, polylabel& ld) { - return isCB ? ld.cb.prediction - : ld.cs.prediction; - } - size_t cs_get_costs_size(bool isCB, polylabel& ld) { return isCB ? ld.cb.costs.size() : ld.cs.costs.size(); @@ -671,7 +666,7 @@ namespace Search { polylabel old_label = ec.l; ec.l = allowed_actions_to_ld(priv, 1, allowed_actions, allowed_actions_cnt); priv.base_learner->predict(ec, policy); - uint32_t act = cs_get_prediction(priv.cb_learner, ec.l); + uint32_t act = ec.pred.multiclass; // in beam search mode, go through alternatives and add them as back-ups if (priv.beam) { @@ -1720,7 +1715,7 @@ namespace Search { costs.push_back(c); } - CS::label ld = { costs, 0 }; + CS::label ld = { costs }; allowed.push_back(ld); } free(bg); diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index ec53d1c4..8bff73e7 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -62,9 +62,9 @@ void receive_result(sender& s) example* ec=s.delay_ring[s.received_index++ % s.all->p->ring_size]; label_data& ld = ec->l.simple; - ld.prediction = res; + ec->pred.scalar = res; - ec->loss = s.all->loss->getLoss(s.all->sd, ld.prediction, ld.label) * ld.weight; + ec->loss = s.all->loss->getLoss(s.all->sd, ec->pred.scalar, ld.label) * ld.weight; return_simple_example(*(s.all), NULL, *ec); } diff --git a/vowpalwabbit/simple_label.cc b/vowpalwabbit/simple_label.cc index 0d5a32be..9e220cde 100644 --- a/vowpalwabbit/simple_label.cc +++ b/vowpalwabbit/simple_label.cc @@ -136,7 +136,7 @@ void print_update(vw& all, example& ec) (long int)all.sd->example_number, all.sd->weighted_examples, label_buf, - ld.prediction, + ec.pred.scalar, (long unsigned int)ec.num_features); all.sd->weighted_holdout_examples_since_last_dump = 0.; @@ -149,7 +149,7 @@ void print_update(vw& all, example& ec) (long int)all.sd->example_number, all.sd->weighted_examples, label_buf, - ld.prediction, + ec.pred.scalar, (long unsigned int)ec.num_features); all.sd->sum_loss_since_last_dump = 0.0; @@ -192,7 +192,7 @@ void output_and_account_example(vw& all, example& ec) if (all.lda > 0) print_lda_result(all, f,ec.topic_predictions.begin,0.,ec.tag); else - all.print(f, ld.prediction, 0, ec.tag); + all.print(f, ec.pred.scalar, 0, ec.tag); } print_update(all, ec); diff --git a/vowpalwabbit/simple_label.h b/vowpalwabbit/simple_label.h index 365d2f52..3c026483 100644 --- a/vowpalwabbit/simple_label.h +++ b/vowpalwabbit/simple_label.h @@ -13,7 +13,6 @@ struct label_data { float label; float weight; float initial; - float prediction; }; void return_simple_example(vw& all, void*, example& ec); diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc index c1f4fada..ff20561d 100644 --- a/vowpalwabbit/stagewise_poly.cc +++ b/vowpalwabbit/stagewise_poly.cc @@ -523,7 +523,7 @@ namespace StagewisePoly base.learn(poly.synth_ec); ec.partial_prediction = poly.synth_ec.partial_prediction; ec.updated_prediction = poly.synth_ec.updated_prediction; - ec.l.simple.prediction = poly.synth_ec.l.simple.prediction; + ec.pred.scalar = poly.synth_ec.pred.scalar; if (ec.example_counter //following line is to avoid repeats when multiple reductions on same example. diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc index a1baae6b..f3e8be8e 100644 --- a/vowpalwabbit/topk.cc +++ b/vowpalwabbit/topk.cc @@ -94,14 +94,13 @@ namespace TOPK { else base.predict(ec); - label_data& ld = ec.l.simple; if(d.pr_queue.size() < d.B) - d.pr_queue.push(make_pair(ld.prediction, ec.tag)); + d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); - else if(d.pr_queue.top().first < ld.prediction) + else if(d.pr_queue.top().first < ec.pred.scalar) { d.pr_queue.pop(); - d.pr_queue.push(make_pair(ld.prediction, ec.tag)); + d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag)); } } |