Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-11-28 22:43:58 +0300
committerJohn Langford <jl@hunch.net>2014-11-28 22:43:58 +0300
commit4bf53b154796eee7501598f1edf5debb5c8e49df (patch)
treed124f9e8e992bca47b38aebd30342b1b5a91cc14
parent76c9c77631308ef612ec9f9580e49aeebab8172c (diff)
separated label from prediction
-rw-r--r--library/library_example.cc6
-rw-r--r--library/recommend.cc6
-rw-r--r--vowpalwabbit/active.cc6
-rw-r--r--vowpalwabbit/autolink.cc4
-rw-r--r--vowpalwabbit/bfgs.cc16
-rw-r--r--vowpalwabbit/binary.cc8
-rw-r--r--vowpalwabbit/bs.cc12
-rw-r--r--vowpalwabbit/cb.cc47
-rw-r--r--vowpalwabbit/cb.h6
-rw-r--r--vowpalwabbit/cb_algs.cc60
-rw-r--r--vowpalwabbit/cb_algs.h4
-rw-r--r--vowpalwabbit/cbify.cc15
-rw-r--r--vowpalwabbit/cost_sensitive.cc9
-rw-r--r--vowpalwabbit/cost_sensitive.h1
-rw-r--r--vowpalwabbit/csoaa.cc11
-rw-r--r--vowpalwabbit/ect.cc13
-rw-r--r--vowpalwabbit/example.h5
-rw-r--r--vowpalwabbit/ezexample.h2
-rw-r--r--vowpalwabbit/gd.cc16
-rw-r--r--vowpalwabbit/gd_mf.cc18
-rw-r--r--vowpalwabbit/kernel_svm.cc4
-rw-r--r--vowpalwabbit/log_multi.cc6
-rw-r--r--vowpalwabbit/lrq.cc4
-rw-r--r--vowpalwabbit/mf.cc12
-rw-r--r--vowpalwabbit/multiclass.cc8
-rw-r--r--vowpalwabbit/multiclass.h1
-rw-r--r--vowpalwabbit/nn.cc4
-rw-r--r--vowpalwabbit/oaa.cc4
-rw-r--r--vowpalwabbit/parser.cc4
-rw-r--r--vowpalwabbit/scorer.cc4
-rw-r--r--vowpalwabbit/search.cc9
-rw-r--r--vowpalwabbit/sender.cc4
-rw-r--r--vowpalwabbit/simple_label.cc6
-rw-r--r--vowpalwabbit/simple_label.h1
-rw-r--r--vowpalwabbit/stagewise_poly.cc2
-rw-r--r--vowpalwabbit/topk.cc7
36 files changed, 178 insertions, 167 deletions
diff --git a/library/library_example.cc b/library/library_example.cc
index 7a3a9367..8abfc1f7 100644
--- a/library/library_example.cc
+++ b/library/library_example.cc
@@ -18,7 +18,7 @@ int main(int argc, char *argv[])
example *vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme");
model->learn(vec2);
- cerr << "p2 = " << vec2->l.simple.prediction << endl;
+ cerr << "p2 = " << vec2->pred.scalar << endl;
VW::finish_example(*model, vec2);
vector< VW::feature_space > ec_info;
@@ -36,7 +36,7 @@ int main(int argc, char *argv[])
example* vec3 = VW::import_example(*model, ec_info);
model->learn(vec3);
- cerr << "p3 = " << vec3->l.simple.prediction << endl;
+ cerr << "p3 = " << vec3->pred.scalar << endl;
VW::finish_example(*model, vec3);
VW::finish(*model);
@@ -44,7 +44,7 @@ int main(int argc, char *argv[])
vw* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw");
vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme");
model2->learn(vec2);
- cerr << "p4 = " << vec2->l.simple.prediction << endl;
+ cerr << "p4 = " << vec2->pred.scalar << endl;
size_t len=0;
VW::primitive_feature_space* pfs = VW::export_example(*model2, vec2, len);
diff --git a/library/recommend.cc b/library/recommend.cc
index b78d6fa5..3d970d76 100644
--- a/library/recommend.cc
+++ b/library/recommend.cc
@@ -213,12 +213,12 @@ int main(int argc, char *argv[])
if(pr_queue.size() < (size_t)topk)
{
- pr_queue.push(make_pair(ex->l.simple.prediction, str));
+ pr_queue.push(make_pair(ex->pred.scalar, str));
}
- else if(pr_queue.top().first < ex->l.simple.prediction)
+ else if(pr_queue.top().first < ex->pred.scalar)
{
pr_queue.pop();
- pr_queue.push(make_pair(ex->l.simple.prediction, str));
+ pr_queue.push(make_pair(ex->pred.scalar, str));
}
VW::finish_example(*model, ex);
diff --git a/vowpalwabbit/active.cc b/vowpalwabbit/active.cc
index 19ce9778..78e8e0e1 100644
--- a/vowpalwabbit/active.cc
+++ b/vowpalwabbit/active.cc
@@ -53,7 +53,7 @@ namespace ACTIVE {
vw& all = *a.all;
float k = ec.example_t - ec.l.simple.weight;
- ec.revert_weight = all.loss->getRevertingWeight(all.sd, ec.l.simple.prediction, all.eta/powf(k,all.power_t));
+ ec.revert_weight = all.loss->getRevertingWeight(all.sd, ec.pred.scalar, all.eta/powf(k,all.power_t));
float importance = query_decision(a, ec, k);
if(importance > 0){
@@ -77,7 +77,7 @@ namespace ACTIVE {
vw& all = *a.all;
float t = (float)(ec.example_t - all.sd->weighted_holdout_examples);
- ec.revert_weight = all.loss->getRevertingWeight(all.sd, ec.l.simple.prediction,
+ ec.revert_weight = all.loss->getRevertingWeight(all.sd, ec.pred.scalar,
all.eta/powf(t,all.power_t));
}
}
@@ -139,7 +139,7 @@ namespace ACTIVE {
for (size_t i = 0; i<all.final_prediction_sink.size(); i++)
{
int f = (int)all.final_prediction_sink[i];
- active_print_result(f, ld.prediction, ai, ec.tag);
+ active_print_result(f, ec.pred.scalar, ai, ec.tag);
}
print_update(all, ec);
diff --git a/vowpalwabbit/autolink.cc b/vowpalwabbit/autolink.cc
index 60766443..4eebe6d4 100644
--- a/vowpalwabbit/autolink.cc
+++ b/vowpalwabbit/autolink.cc
@@ -15,7 +15,7 @@ namespace ALINK {
void predict_or_learn(autolink& b, learner& base, example& ec)
{
base.predict(ec);
- float base_pred = ec.l.simple.prediction;
+ float base_pred = ec.pred.scalar;
// add features of label
ec.indices.push_back(autolink_namespace);
@@ -26,7 +26,7 @@ namespace ALINK {
feature f = { base_pred, (uint32_t) (autoconstant + (i << b.stride_shift)) };
ec.atomics[autolink_namespace].push_back(f);
sum_sq += base_pred*base_pred;
- base_pred *= ec.l.simple.prediction;
+ base_pred *= ec.pred.scalar;
}
ec.total_sum_feat_sq += sum_sq;
diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc
index bb22e393..b964c218 100644
--- a/vowpalwabbit/bfgs.cc
+++ b/vowpalwabbit/bfgs.cc
@@ -187,7 +187,7 @@ inline void add_precond(float& d, float f, float& fw)
void update_preconditioner(vw& all, example& ec)
{
label_data& ld = ec.l.simple;
- float curvature = all.loss->second_derivative(all.sd, ld.prediction,ld.label) * ld.weight;
+ float curvature = all.loss->second_derivative(all.sd, ec.pred.scalar, ld.label) * ld.weight;
ec.ft_offset += W_COND;
GD::foreach_feature<float,add_precond>(all, ec, curvature);
@@ -734,10 +734,10 @@ void process_example(vw& all, bfgs& b, example& ec)
/********************************************************************/
if (b.gradient_pass)
{
- ld.prediction = predict_and_gradient(all, ec);//w[0] & w[1]
- ec.loss = all.loss->getLoss(all.sd, ld.prediction, ld.label) * ld.weight;
+ ec.pred.scalar = predict_and_gradient(all, ec);//w[0] & w[1]
+ ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight;
b.loss_sum += ec.loss;
- b.predictions.push_back(ld.prediction);
+ b.predictions.push_back(ec.pred.scalar);
}
/********************************************************************/
/* II) CURVATURE CALCULATION ****************************************/
@@ -747,13 +747,13 @@ void process_example(vw& all, bfgs& b, example& ec)
float d_dot_x = dot_with_direction(all, ec);//w[2]
if (b.example_number >= b.predictions.size())//Make things safe in case example source is strange.
b.example_number = b.predictions.size()-1;
- ld.prediction = b.predictions[b.example_number];
+ ec.pred.scalar = b.predictions[b.example_number];
ec.partial_prediction = b.predictions[b.example_number];
- ec.loss = all.loss->getLoss(all.sd, ld.prediction, ld.label) * ld.weight;
+ ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight;
float sd = all.loss->second_derivative(all.sd, b.predictions[b.example_number++],ld.label);
b.curvature += d_dot_x*d_dot_x*sd*ld.weight;
}
- ec.updated_prediction = ld.prediction;
+ ec.updated_prediction = ec.pred.scalar;
if (b.preconditioner_pass)
update_preconditioner(all, ec);//w[3]
@@ -821,7 +821,7 @@ void end_pass(bfgs& b)
void predict(bfgs& b, learner& base, example& ec)
{
vw* all = b.all;
- ec.l.simple.prediction = bfgs_predict(*all,ec);
+ ec.pred.scalar = bfgs_predict(*all,ec);
}
void learn(bfgs& b, learner& base, example& ec)
diff --git a/vowpalwabbit/binary.cc b/vowpalwabbit/binary.cc
index 68b3a341..3f78e789 100644
--- a/vowpalwabbit/binary.cc
+++ b/vowpalwabbit/binary.cc
@@ -13,12 +13,12 @@ namespace BINARY {
else
base.predict(ec);
- if ( ec.l.simple.prediction > 0)
- ec.l.simple.prediction = 1;
+ if ( ec.pred.scalar > 0)
+ ec.pred.scalar = 1;
else
- ec.l.simple.prediction = -1;
+ ec.pred.scalar = -1;
- if (ec.l.simple.label == ec.l.simple.prediction)
+ if (ec.l.simple.label == ec.pred.scalar)
ec.loss = 0.;
else
ec.loss = ec.l.simple.weight;
diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc
index dd95b6c8..8c2ca4b9 100644
--- a/vowpalwabbit/bs.cc
+++ b/vowpalwabbit/bs.cc
@@ -31,8 +31,8 @@ namespace BS {
void bs_predict_mean(vw& all, example& ec, vector<double> &pred_vec)
{
- ec.l.simple.prediction = (float)accumulate(pred_vec.begin(), pred_vec.end(), 0.0)/pred_vec.size();
- ec.loss = all.loss->getLoss(all.sd, ec.l.simple.prediction, ec.l.simple.label) * ec.l.simple.weight;
+ ec.pred.scalar = (float)accumulate(pred_vec.begin(), pred_vec.end(), 0.0)/pred_vec.size();
+ ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ec.l.simple.label) * ec.l.simple.weight;
}
void bs_predict_vote(vw& all, example& ec, vector<double> &pred_vec)
@@ -107,10 +107,10 @@ namespace BS {
}
// ld.prediction = sum_labels/(float)counter; //replace line below for: "avg on votes" and getLoss()
- ec.l.simple.prediction = (float)current_label;
+ ec.pred.scalar = (float)current_label;
// ec.loss = all.loss->getLoss(all.sd, ld.prediction, ld.label) * ld.weight; //replace line below for: "avg on votes" and getLoss()
- ec.loss = ((ec.l.simple.prediction == ec.l.simple.label) ? 0.f : 1.f) * ec.l.simple.weight;
+ ec.loss = ((ec.pred.scalar == ec.l.simple.label) ? 0.f : 1.f) * ec.l.simple.weight;
}
void print_result(int f, float res, float weight, v_array<char> tag, float lb, float ub)
@@ -174,7 +174,7 @@ namespace BS {
}
for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++)
- BS::print_result(*sink, ld.prediction, 0, ec.tag, d.lb, d.ub);
+ BS::print_result(*sink, ec.pred.scalar, 0, ec.tag, d.lb, d.ub);
print_update(all, ec);
}
@@ -200,7 +200,7 @@ namespace BS {
else
base.predict(ec, i-1);
- d.pred_vec.push_back(ec.l.simple.prediction);
+ d.pred_vec.push_back(ec.pred.scalar);
if (shouldOutput) {
if (i > 1) outputStringStream << ' ';
diff --git a/vowpalwabbit/cb.cc b/vowpalwabbit/cb.cc
index eb8d1386..ba65febe 100644
--- a/vowpalwabbit/cb.cc
+++ b/vowpalwabbit/cb.cc
@@ -172,31 +172,52 @@ namespace CB_EVAL
{
size_t read_cached_label(shared_data*sd, void* v, io_buf& cache)
{
- CB::label* ld = (CB::label*) v;
+ CB_EVAL::label* ld = (CB_EVAL::label*) v;
char* c;
size_t total = sizeof(uint32_t);
if (buf_read(cache, c, total) < total)
return 0;
- ld->prediction = *(uint32_t*)c;
+ ld->action = *(uint32_t*)c;
c += sizeof(uint32_t);
- return total + CB::read_cached_label(sd, ld, cache);
+ return total + CB::read_cached_label(sd, &(ld->event), cache);
}
void cache_label(void* v, io_buf& cache)
{
char *c;
- CB::label* ld = (CB::label*) v;
+ CB_EVAL::label* ld = (CB_EVAL::label*) v;
buf_write(cache, c, sizeof(uint32_t));
- *(uint32_t *)c = ld->prediction;
+ *(uint32_t *)c = ld->action;
c+= sizeof(uint32_t);
- CB::cache_label(ld, cache);
+ CB::cache_label(&(ld->event), cache);
+ }
+
+ void default_label(void* v)
+ {
+ CB_EVAL::label* ld = (CB_EVAL::label*) v;
+ CB::default_label(&(ld->event));
+ ld->action = 0;
+ }
+
+ void delete_label(void* v)
+ {
+ CB_EVAL::label* ld = (CB_EVAL::label*)v;
+ CB::delete_label(&(ld->event));
+ }
+
+ void copy_label(void*dst, void*src)
+ {
+ CB_EVAL::label* ldD = (CB_EVAL::label*)dst;
+ CB_EVAL::label* ldS = (CB_EVAL::label*)src;
+ CB::copy_label(&(ldD->event), &(ldS)->event);
+ ldD->action = ldS->action;
}
void parse_label(parser* p, shared_data* sd, void* v, v_array<substring>& words)
{
- CB::label* ld = (CB::label*)v;
+ CB_EVAL::label* ld = (CB_EVAL::label*)v;
if (words.size() < 2)
{
@@ -204,18 +225,18 @@ namespace CB_EVAL
throw exception();
}
- ld->prediction = (uint32_t)hashstring(words[0], 0);
+ ld->action = (uint32_t)hashstring(words[0], 0);
words.begin++;
- CB::parse_label(p, sd, ld, words);
+ CB::parse_label(p, sd, &(ld->event), words);
words.begin--;
}
- label_parser cb_eval = {CB::default_label, parse_label,
+ label_parser cb_eval = {default_label, parse_label,
cache_label, read_cached_label,
- CB::delete_label, CB::weight,
- CB::copy_label,
- sizeof(CB::label)};
+ delete_label, CB::weight,
+ copy_label,
+ sizeof(CB_EVAL::label)};
}
diff --git a/vowpalwabbit/cb.h b/vowpalwabbit/cb.h
index cc803d71..1ce2cf9e 100644
--- a/vowpalwabbit/cb.h
+++ b/vowpalwabbit/cb.h
@@ -19,12 +19,16 @@ namespace CB {
struct label {
v_array<cb_class> costs;
- uint32_t prediction;
};
extern label_parser cb_label;//for learning
}
namespace CB_EVAL {
+ struct label {
+ uint32_t action;
+ CB::label event;
+ };
+
extern label_parser cb_eval;//for evaluation of an arbitrary policy.
}
diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc
index e2dfaee1..4550efaa 100644
--- a/vowpalwabbit/cb_algs.cc
+++ b/vowpalwabbit/cb_algs.cc
@@ -79,10 +79,8 @@ namespace CB_ALGS
return NULL;
}
- void gen_cs_example_ips(vw& all, cb& c, example& ec, COST_SENSITIVE::label& cs_ld)
+ void gen_cs_example_ips(vw& all, cb& c, example& ec, CB::label& ld, COST_SENSITIVE::label& cs_ld)
{//this implements the inverse propensity score method, where cost are importance weighted by the probability of the chosen action
- CB::label ld = ec.l.cb;
-
//generate cost-sensitive example
cs_ld.costs.erase();
if( ld.costs.size() == 1) { //this is a typical example where we can perform all actions
@@ -206,7 +204,7 @@ namespace CB_ALGS
}
}
- ec.l.cb.prediction = argmin;
+ ec.pred.multiclass = argmin;
}
template <bool is_learn>
@@ -233,10 +231,8 @@ namespace CB_ALGS
}
template <bool is_learn>
- void gen_cs_example_dr(vw& all, cb& c, example& ec, COST_SENSITIVE::label& cs_ld)
+ void gen_cs_example_dr(vw& all, cb& c, example& ec, CB::label& ld, COST_SENSITIVE::label& cs_ld)
{//this implements the doubly robust method
- CB::label ld = ec.l.cb;
-
//generate cost sensitive example
cs_ld.costs.erase();
if( ld.costs.size() == 1) //this is a typical example where we can perform all actions
@@ -300,7 +296,6 @@ namespace CB_ALGS
ec.l.cs = c.cb_cs_ld;
base.predict(ec);
- ld.prediction = c.cb_cs_ld.prediction;
for (size_t i=0; i<ld.costs.size(); i++)
ld.costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction;
@@ -318,13 +313,13 @@ namespace CB_ALGS
switch(c.cb_type)
{
case CB_TYPE_IPS:
- gen_cs_example_ips(*all,c,ec,c.cb_cs_ld);
+ gen_cs_example_ips(*all,c,ec,ld,c.cb_cs_ld);
break;
case CB_TYPE_DM:
gen_cs_example_dm<is_learn>(*all,c,ec,c.cb_cs_ld);
break;
case CB_TYPE_DR:
- gen_cs_example_dr<is_learn>(*all,c,ec,c.cb_cs_ld);
+ gen_cs_example_dr<is_learn>(*all,c,ec,ld,c.cb_cs_ld);
break;
default:
std::cerr << "Unknown cb_type specified for contextual bandit learning: " << c.cb_type << ". Exiting." << endl;
@@ -340,7 +335,6 @@ namespace CB_ALGS
else
base.predict(ec);
- ld.prediction = ec.l.cs.prediction;
for (size_t i=0; i<ld.costs.size(); i++)
ld.costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction;
ec.l.cb = ld;
@@ -354,19 +348,19 @@ namespace CB_ALGS
void learn_eval(cb& c, learner& base, example& ec) {
vw* all = c.all;
- CB::label ld = ec.l.cb;
+ CB_EVAL::label ld = ec.l.cb_eval;
- c.known_cost = get_observed_cost(ld);
+ c.known_cost = get_observed_cost(ld.event);
if (c.cb_type == CB_TYPE_DR)
- gen_cs_example_dr<true>(*all,c,ec,c.cb_cs_ld);
+ gen_cs_example_dr<true>(*all, c, ec, ld.event, c.cb_cs_ld);
else //c.cb_type == CB_TYPE_IPS
- gen_cs_example_ips(*all,c,ec,c.cb_cs_ld);
+ gen_cs_example_ips(*all, c, ec, ld.event, c.cb_cs_ld);
- for (size_t i=0; i<ld.costs.size(); i++)
- ld.costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction;
+ for (size_t i=0; i<ld.event.costs.size(); i++)
+ ld.event.costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction;
- ec.l.cb = ld;
+ ec.l.cb_eval = ld;
}
void init_driver(cb&)
@@ -384,7 +378,6 @@ namespace CB_ALGS
else
sprintf(label_buf," known");
- CB::label& ld = ec.l.cb;
if(!all.holdout_set_off && all.current_pass >= 1)
{
if(all.sd->holdout_sum_loss == 0. && all.sd->weighted_holdout_examples == 0.)
@@ -401,7 +394,7 @@ namespace CB_ALGS
(long int)all.sd->example_number,
all.sd->weighted_examples,
label_buf,
- (long unsigned int)ld.prediction,
+ (long unsigned int)ec.pred.multiclass,
(long unsigned int)ec.num_features,
c.avg_loss_regressors,
c.last_pred_reg,
@@ -417,7 +410,7 @@ namespace CB_ALGS
(long int)all.sd->example_number,
all.sd->weighted_examples,
label_buf,
- (long unsigned int)ld.prediction,
+ (long unsigned int)ec.pred.multiclass,
(long unsigned int)ec.num_features,
c.avg_loss_regressors,
c.last_pred_reg,
@@ -430,29 +423,25 @@ namespace CB_ALGS
}
}
- void output_example(vw& all, cb& c, example& ec)
+ void output_example(vw& all, cb& c, example& ec, CB::label& ld)
{
- CB::label& ld = ec.l.cb;
-
float loss = 0.;
if (!is_test_label(ld))
{//need to compute exact loss
- size_t pred = (size_t)ld.prediction;
-
float chosen_loss = FLT_MAX;
if( know_all_cost_example(ld) ) {
for (cb_class *cl = ld.costs.begin; cl != ld.costs.end; cl ++) {
- if (cl->action == pred)
+ if (cl->action == ec.pred.multiclass)
chosen_loss = cl->cost;
}
}
else {
//we do not know exact cost of each action, so evaluate on generated cost-sensitive example currently stored in cb_cs_ld
for (COST_SENSITIVE::wclass *cl = c.cb_cs_ld.costs.begin; cl != c.cb_cs_ld.costs.end; cl ++) {
- if (cl->class_index == pred)
+ if (cl->class_index == ec.pred.multiclass)
{
chosen_loss = cl->x;
- if (c.known_cost->action == pred && c.cb_type == CB_TYPE_DM)
+ if (c.known_cost->action == ec.pred.multiclass && c.cb_type == CB_TYPE_DM)
chosen_loss += (c.known_cost->cost - chosen_loss) / c.known_cost->probability;
}
}
@@ -484,7 +473,7 @@ namespace CB_ALGS
for (size_t i = 0; i<all.final_prediction_sink.size(); i++)
{
int f = all.final_prediction_sink[i];
- all.print(f, (float)ld.prediction, 0, ec.tag);
+ all.print(f, (float)ec.pred.multiclass, 0, ec.tag);
}
print_update(all, c, is_test_label(ec.l.cb), ec);
@@ -497,7 +486,13 @@ namespace CB_ALGS
void finish_example(vw& all, cb& c, example& ec)
{
- output_example(all, c, ec);
+ output_example(all, c, ec, ec.l.cb);
+ VW::finish_example(all, &ec);
+ }
+
+ void eval_finish_example(vw& all, cb& c, example& ec)
+ {
+ output_example(all, c, ec, ec.l.cb_eval.event);
VW::finish_example(all, &ec);
}
@@ -577,13 +572,14 @@ namespace CB_ALGS
{
l->set_learn<cb, learn_eval>();
l->set_predict<cb, predict_eval>();
+ l->set_finish_example<cb,eval_finish_example>();
}
else
{
l->set_learn<cb, predict_or_learn<true> >();
l->set_predict<cb, predict_or_learn<false> >();
+ l->set_finish_example<cb,finish_example>();
}
- l->set_finish_example<cb,finish_example>();
l->set_init_driver<cb,init_driver>();
l->set_finish<cb,finish>();
// preserve the increment of the base learner since we are
diff --git a/vowpalwabbit/cb_algs.h b/vowpalwabbit/cb_algs.h
index 85d21591..68da7610 100644
--- a/vowpalwabbit/cb_algs.h
+++ b/vowpalwabbit/cb_algs.h
@@ -34,10 +34,10 @@ namespace CB_ALGS {
else
all.scorer->predict(ec, index-1+base);
- simple_temp.prediction = ec.l.simple.prediction;
+ float pred = ec.pred.scalar;
ec.l.cb = ld;
- return simple_temp.prediction;
+ return pred;
}
}
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc
index 0bd64ecb..c36e345f 100644
--- a/vowpalwabbit/cbify.cc
+++ b/vowpalwabbit/cbify.cc
@@ -117,7 +117,7 @@ namespace CBIFY {
else
ctx.l->predict(*ctx.e, (size_t)m_index);
ctx.recorded = false;
- return (u32)(ctx.e->l.cb.prediction);
+ return (u32)(ctx.e->pred.multiclass);
}
void vw_recorder::Record(vw_context& context, u32 action, float probability, string unique_key)
@@ -136,8 +136,9 @@ namespace CBIFY {
ctx.data->cs->predict(*ctx.e, i);
else
ctx.data->cs->predict(*ctx.e, i + 1);
- m_scores[ctx.data->cs_label.prediction - 1] += additive_probability;
- m_predictions[i] = (uint32_t)ctx.data->cs_label.prediction;
+ uint32_t pred = ctx.e->pred.multiclass;
+ m_scores[pred - 1] += additive_probability;
+ m_predictions[i] = (uint32_t)pred;
}
float min_prob = m_epsilon * min(1.f / ctx.data->k, 1.f / (float)sqrt(m_counter * ctx.data->k));
@@ -178,7 +179,7 @@ namespace CBIFY {
ec.loss = l.cost;
}
- ld.prediction = action;
+ ec.pred.multiclass = action;
ec.l.multi = ld;
}
@@ -204,7 +205,7 @@ namespace CBIFY {
if (is_learn)
base.learn(ec);
- ld.prediction = action;
+ ec.pred.multiclass = action;
ec.l.multi = ld;
ec.loss = loss(ld.label, action);
}
@@ -239,7 +240,7 @@ namespace CBIFY {
base.learn(ec,i);
}
}
- ld.prediction = action;
+ ec.pred.multiclass = action;
ec.l.multi = ld;
}
@@ -357,7 +358,7 @@ namespace CBIFY {
}
}
- ld.prediction = action;
+ ec.pred.multiclass = action;
ec.l.multi = ld;
}
diff --git a/vowpalwabbit/cost_sensitive.cc b/vowpalwabbit/cost_sensitive.cc
index e0fe0ca3..11bc7231 100644
--- a/vowpalwabbit/cost_sensitive.cc
+++ b/vowpalwabbit/cost_sensitive.cc
@@ -206,7 +206,6 @@ namespace COST_SENSITIVE {
num_current_features += (*ecc)->num_features;
}
- label& ld = ec.l.cs;
char label_buf[32];
if (is_test)
strcpy(label_buf," unknown");
@@ -229,7 +228,7 @@ namespace COST_SENSITIVE {
(long int)all.sd->example_number,
all.sd->weighted_examples,
label_buf,
- (long unsigned int)ld.prediction,
+ (long unsigned int)ec.pred.multiclass,
(long unsigned int)num_current_features);
all.sd->weighted_holdout_examples_since_last_dump = 0;
@@ -242,7 +241,7 @@ namespace COST_SENSITIVE {
(long int)all.sd->example_number,
all.sd->weighted_examples,
label_buf,
- (long unsigned int)ld.prediction,
+ (long unsigned int)ec.pred.multiclass,
(long unsigned int)num_current_features);
all.sd->sum_loss_since_last_dump = 0.0;
@@ -258,7 +257,7 @@ namespace COST_SENSITIVE {
float loss = 0.;
if (!is_test_label(ld))
{//need to compute exact loss
- size_t pred = (size_t)ld.prediction;
+ size_t pred = (size_t)ec.pred.multiclass;
float chosen_loss = FLT_MAX;
float min = FLT_MAX;
@@ -293,7 +292,7 @@ namespace COST_SENSITIVE {
}
for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++)
- all.print((int)*sink, (float)ld.prediction, 0, ec.tag);
+ all.print((int)*sink, (float)ec.pred.multiclass, 0, ec.tag);
if (all.raw_prediction > 0) {
string outputString;
diff --git a/vowpalwabbit/cost_sensitive.h b/vowpalwabbit/cost_sensitive.h
index f6934d4c..a5f56a61 100644
--- a/vowpalwabbit/cost_sensitive.h
+++ b/vowpalwabbit/cost_sensitive.h
@@ -25,7 +25,6 @@ namespace COST_SENSITIVE {
struct label {
v_array<wclass> costs;
- uint32_t prediction;
};
void output_example(vw& all, example& ec);
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index 73139b31..33fdbae2 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -57,7 +57,7 @@ namespace CSOAA {
ec.partial_prediction = 0.;
}
- ld.prediction = prediction;
+ ec.pred.multiclass = prediction;
ec.l.cs = ld;
}
@@ -415,7 +415,7 @@ namespace LabelDict {
if (prediction == costs1[j1].class_index) prediction_is_me = true;
}
- ld1.prediction = prediction_is_me ? prediction : 0;
+ ec1->pred.multiclass = prediction_is_me ? prediction : 0;
ec1->l.cs = ld1;
ec1->example_t = example_t1;
}
@@ -493,7 +493,7 @@ namespace LabelDict {
ec->partial_prediction = costs[j].partial_prediction;
if (prediction == costs[j].class_index) prediction_is_me = true;
}
- ld.prediction = prediction_is_me ? prediction : 0;
+ ec->pred.multiclass = prediction_is_me ? prediction : 0;
// restore label
ec->l.cs = ld;
@@ -555,12 +555,11 @@ namespace LabelDict {
all.sd->total_features += ec.num_features;
float loss = 0.;
- size_t final_pred = ld.prediction;
if (!COST_SENSITIVE::example_is_test(ec)) {
for (size_t j=0; j<costs.size(); j++) {
if (hit_loss) break;
- if (final_pred == costs[j].class_index) {
+ if (ec.pred.multiclass == costs[j].class_index) {
loss = costs[j].x;
hit_loss = true;
}
@@ -572,7 +571,7 @@ namespace LabelDict {
}
for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++)
- all.print(*sink, (float)ld.prediction, 0, ec.tag);
+ all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag);
if (all.raw_prediction > 0) {
string outputString;
diff --git a/vowpalwabbit/ect.cc b/vowpalwabbit/ect.cc
index 210b2733..3eab22c9 100644
--- a/vowpalwabbit/ect.cc
+++ b/vowpalwabbit/ect.cc
@@ -203,7 +203,7 @@ namespace ECT
base.learn(ec, problem_number);
- if (ec.l.simple.prediction > 0.)
+ if (ec.pred.scalar > 0.)
finals_winner = finals_winner | (((size_t)1) << i);
}
}
@@ -213,7 +213,7 @@ namespace ECT
{
base.learn(ec, id - e.k);
- if (ec.l.simple.prediction > 0.)
+ if (ec.pred.scalar > 0.)
id = e.directions[id].right;
else
id = e.directions[id].left;
@@ -256,7 +256,7 @@ namespace ECT
ec.l.simple.weight = 0.;
base.learn(ec, id-e.k);//inefficient, we should extract final prediction exactly.
- bool won = ec.l.simple.prediction * simple_temp.label > 0;
+ bool won = ec.pred.scalar * simple_temp.label > 0;
if (won)
{
@@ -306,7 +306,7 @@ namespace ECT
base.learn(ec, problem_number);
- if (ec.l.simple.prediction > 0.)
+ if (ec.pred.scalar > 0.)
e.tournaments_won[j] = right;
else
e.tournaments_won[j] = left;
@@ -324,7 +324,7 @@ namespace ECT
MULTICLASS::multiclass mc = ec.l.multi;
if (mc.label == 0 || (mc.label > e.k && mc.label != (uint32_t)-1))
cout << "label " << mc.label << " is not in {1,"<< e.k << "} This won't work right." << endl;
- mc.prediction = ect_predict(*all, e, base, ec);
+ ec.pred.multiclass = ect_predict(*all, e, base, ec);
ec.l.multi = mc;
}
@@ -334,11 +334,12 @@ namespace ECT
MULTICLASS::multiclass mc = ec.l.multi;
predict(e, base, ec);
- mc.prediction = ec.l.multi.prediction;
+ uint32_t pred = ec.pred.multiclass;
if (mc.label != (uint32_t)-1 && all->training)
ect_train(*all, e, base, ec);
ec.l.multi = mc;
+ ec.pred.multiclass = pred;
}
void finish(ect& e)
diff --git a/vowpalwabbit/example.h b/vowpalwabbit/example.h
index b8915b4a..8abbade7 100644
--- a/vowpalwabbit/example.h
+++ b/vowpalwabbit/example.h
@@ -41,11 +41,12 @@ typedef union {
MULTICLASS::multiclass multi;
COST_SENSITIVE::label cs;
CB::label cb;
+ CB_EVAL::label cb_eval;
} polylabel;
typedef union {
- float simple;
- uint32_t multi;
+ float scalar;
+ uint32_t multiclass;
} polyprediction;
struct example // core example datatype.
diff --git a/vowpalwabbit/ezexample.h b/vowpalwabbit/ezexample.h
index 44f0d5a7..8720bb62 100644
--- a/vowpalwabbit/ezexample.h
+++ b/vowpalwabbit/ezexample.h
@@ -222,7 +222,7 @@ class ezexample {
float predict() {
setup_for_predict();
- return ec->l.simple.prediction;
+ return ec->pred.scalar;
}
float predict_partial() {
diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc
index 1579cb57..589d39c2 100644
--- a/vowpalwabbit/gd.cc
+++ b/vowpalwabbit/gd.cc
@@ -310,7 +310,7 @@ void print_features(vw& all, example& ec)
void print_audit_features(vw& all, example& ec)
{
if(all.audit)
- print_result(all.stdout_fileno,ec.l.simple.prediction,-1,ec.tag);
+ print_result(all.stdout_fileno,ec.pred.scalar,-1,ec.tag);
fflush(stdout);
print_features(all, ec);
}
@@ -356,7 +356,7 @@ void predict(gd& g, learner& base, example& ec)
ec.partial_prediction = inline_predict(all, ec);
ec.partial_prediction *= (float)all.sd->contraction;
- ec.l.simple.prediction = finalize_prediction(all.sd, ec.partial_prediction);
+ ec.pred.scalar = finalize_prediction(all.sd, ec.partial_prediction);
if (audit)
print_audit_features(all, ec);
@@ -443,7 +443,7 @@ template<bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normaliz
{//We must traverse the features in _precisely_ the same order as during training.
label_data& ld = ec.l.simple;
vw& all = *g.all;
- float grad_squared = all.loss->getSquareGrad(ld.prediction, ld.label) * ld.weight;
+ float grad_squared = all.loss->getSquareGrad(ec.pred.scalar, ld.label) * ld.weight;
if (grad_squared == 0) return 1.;
norm_data nd = {grad_squared, 0., 0., {g.neg_power_t, g.neg_norm_power}};
@@ -468,8 +468,8 @@ float compute_update(gd& g, example& ec)
vw& all = *g.all;
float ret = 0.;
- ec.updated_prediction = ld.prediction;
- if (all.loss->getLoss(all.sd, ld.prediction, ld.label) > 0.)
+ ec.updated_prediction = ec.pred.scalar;
+ if (all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) > 0.)
{
float pred_per_update;
if(adaptive || normalized)
@@ -486,15 +486,15 @@ float compute_update(gd& g, example& ec)
float update;
if(invariant)
- update = all.loss->getUpdate(ld.prediction, ld.label, delta_pred, pred_per_update);
+ update = all.loss->getUpdate(ec.pred.scalar, ld.label, delta_pred, pred_per_update);
else
- update = all.loss->getUnsafeUpdate(ld.prediction, ld.label, delta_pred, pred_per_update);
+ update = all.loss->getUnsafeUpdate(ec.pred.scalar, ld.label, delta_pred, pred_per_update);
// changed from ec.partial_prediction to ld.prediction
ec.updated_prediction += pred_per_update * update;
if (all.reg_mode && fabs(update) > 1e-8) {
- double dev1 = all.loss->first_derivative(all.sd, ld.prediction, ld.label);
+ double dev1 = all.loss->first_derivative(all.sd, ec.pred.scalar, ld.label);
double eta_bar = (fabs(dev1) > 1e-8) ? (-update / dev1) : 0.0;
if (fabs(dev1) > 1e-8)
all.sd->contraction *= (1. - all.l2_lambda * eta_bar);
diff --git a/vowpalwabbit/gd_mf.cc b/vowpalwabbit/gd_mf.cc
index 1261a07a..df48c934 100644
--- a/vowpalwabbit/gd_mf.cc
+++ b/vowpalwabbit/gd_mf.cc
@@ -79,7 +79,7 @@ void mf_print_offset_features(vw& all, example& ec, size_t offset)
void mf_print_audit_features(vw& all, example& ec, size_t offset)
{
- print_result(all.stdout_fileno,ec.l.simple.prediction,-1,ec.tag);
+ print_result(all.stdout_fileno,ec.pred.scalar,-1,ec.tag);
mf_print_offset_features(all, ec, offset);
}
@@ -140,17 +140,15 @@ float mf_predict(vw& all, example& ec)
all.set_minmax(all.sd, ld.label);
- ld.prediction = GD::finalize_prediction(all.sd, ec.partial_prediction);
-
+ ec.pred.scalar = GD::finalize_prediction(all.sd, ec.partial_prediction);
+
if (ld.label != FLT_MAX)
- {
- ec.loss = all.loss->getLoss(all.sd, ld.prediction, ld.label) * ld.weight;
- }
-
+ ec.loss = all.loss->getLoss(all.sd, ec.pred.scalar, ld.label) * ld.weight;
+
if (all.audit)
mf_print_audit_features(all, ec, 0);
-
- return ld.prediction;
+
+ return ec.pred.scalar;
}
@@ -169,7 +167,7 @@ void mf_train(vw& all, example& ec)
// use final prediction to get update size
// update = eta_t*(y-y_hat) where eta_t = eta/(3*t^p) * importance weight
float eta_t = all.eta/pow(ec.example_t,all.power_t) / 3.f * ld.weight;
- float update = all.loss->getUpdate(ld.prediction, ld.label, eta_t, 1.); //ec.total_sum_feat_sq);
+ float update = all.loss->getUpdate(ec.pred.scalar, ld.label, eta_t, 1.); //ec.total_sum_feat_sq);
float regularization = eta_t * all.l2_lambda;
diff --git a/vowpalwabbit/kernel_svm.cc b/vowpalwabbit/kernel_svm.cc
index a06b03be..1667fdd1 100644
--- a/vowpalwabbit/kernel_svm.cc
+++ b/vowpalwabbit/kernel_svm.cc
@@ -408,7 +408,7 @@ namespace KSVM
sec->init_svm_example(fec);
float score;
predict(params, &sec, &score, 1);
- ec.l.simple.prediction = score;
+ ec.pred.scalar = score;
}
}
@@ -749,7 +749,7 @@ namespace KSVM
free_flatten_example(fec);
float score = 0;
predict(params, &sec, &score, 1);
- ec.l.simple.prediction = score;
+ ec.pred.scalar = score;
ec.loss = max(0.f, 1.f - score*ec.l.simple.label);
params.loss_sum += ec.loss;
if(params.all->training && ec.example_counter % 100 == 0)
diff --git a/vowpalwabbit/log_multi.cc b/vowpalwabbit/log_multi.cc
index 27dda73d..e2859972 100644
--- a/vowpalwabbit/log_multi.cc
+++ b/vowpalwabbit/log_multi.cc
@@ -309,9 +309,9 @@ namespace LOG_MULTI
while(b.nodes[cn].internal)
{
base.predict(ec, b.nodes[cn].base_predictor);
- cn = descend(b.nodes[cn], simple_temp.prediction);
+ cn = descend(b.nodes[cn], ec.pred.scalar);
}
- mc.prediction = b.nodes[cn].max_count_label;
+ ec.pred.multiclass = b.nodes[cn].max_count_label;
ec.l.multi = mc;
}
@@ -337,7 +337,7 @@ namespace LOG_MULTI
while(children(b, cn, class_index, mc.label))
{
train_node(b, base, ec, cn, class_index);
- cn = descend(b.nodes[cn], simple_temp.prediction);
+ cn = descend(b.nodes[cn], ec.pred.scalar);
}
b.nodes[cn].min_count++;
diff --git a/vowpalwabbit/lrq.cc b/vowpalwabbit/lrq.cc
index 8f1db3ff..02272bc0 100644
--- a/vowpalwabbit/lrq.cc
+++ b/vowpalwabbit/lrq.cc
@@ -162,12 +162,12 @@ namespace LRQ {
// Restore example
if (iter == 0)
{
- first_prediction = ec.l.simple.prediction;
+ first_prediction = ec.pred.scalar;
first_loss = ec.loss;
}
else
{
- ec.l.simple.prediction = first_prediction;
+ ec.pred.scalar = first_prediction;
ec.loss = first_loss;
}
diff --git a/vowpalwabbit/mf.cc b/vowpalwabbit/mf.cc
index 33c02944..2cce6272 100644
--- a/vowpalwabbit/mf.cc
+++ b/vowpalwabbit/mf.cc
@@ -99,17 +99,17 @@ void predict(mf& data, learner& base, example& ec) {
// finalize prediction
ec.partial_prediction = prediction;
- ec.l.simple.prediction = GD::finalize_prediction(data.all->sd, ec.partial_prediction);
+ ec.pred.scalar = GD::finalize_prediction(data.all->sd, ec.partial_prediction);
}
void learn(mf& data, learner& base, example& ec) {
// predict with current weights
predict<true>(data, base, ec);
- float predicted = ec.l.simple.prediction;
+ float predicted = ec.pred.scalar;
// update linear weights
base.update(ec);
- ec.l.simple.prediction = ec.updated_prediction;
+ ec.pred.scalar = ec.updated_prediction;
// store namespace indices
copy_array(data.indices, ec.indices);
@@ -148,7 +148,7 @@ void learn(mf& data, learner& base, example& ec) {
// compute new l_k * x_l scaling factors
// base.predict(ec, k);
// data.sub_predictions[2*k-1] = ec.partial_prediction;
- // ec.l.simple.prediction = ec.updated_prediction;
+ // ec.pred.scalar = ec.updated_prediction;
}
// set example to right namespace only
@@ -165,7 +165,7 @@ void learn(mf& data, learner& base, example& ec) {
// update r^k using base learner
base.update(ec, k + data.rank);
- ec.l.simple.prediction = ec.updated_prediction;
+ ec.pred.scalar = ec.updated_prediction;
// restore right namespace features
copy_array(ec.atomics[right_ns], data.temp_features);
@@ -176,7 +176,7 @@ void learn(mf& data, learner& base, example& ec) {
copy_array(ec.indices, data.indices);
// restore original prediction
- ec.l.simple.prediction = predicted;
+ ec.pred.scalar = predicted;
}
void finish(mf& o) {
diff --git a/vowpalwabbit/multiclass.cc b/vowpalwabbit/multiclass.cc
index 9e485aa7..8736bc6e 100644
--- a/vowpalwabbit/multiclass.cc
+++ b/vowpalwabbit/multiclass.cc
@@ -119,7 +119,7 @@ namespace MULTICLASS {
(long int)all.sd->example_number,
all.sd->weighted_examples,
label_buf,
- (long int)ld.prediction,
+ (long int)ec.pred.multiclass,
(long unsigned int)ec.num_features);
all.sd->weighted_holdout_examples_since_last_dump = 0;
@@ -132,7 +132,7 @@ namespace MULTICLASS {
(long int)all.sd->example_number,
all.sd->weighted_examples,
label_buf,
- (long int)ld.prediction,
+ (long int)ec.pred.multiclass,
(long unsigned int)ec.num_features);
all.sd->sum_loss_since_last_dump = 0.0;
@@ -147,7 +147,7 @@ namespace MULTICLASS {
multiclass ld = ec.l.multi;
size_t loss = 1;
- if (ld.label == (uint32_t)ld.prediction)
+ if (ld.label == (uint32_t)ec.pred.multiclass)
loss = 0;
if(ec.test_only)
@@ -169,7 +169,7 @@ namespace MULTICLASS {
}
for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++)
- all.print(*sink, (float)ld.prediction, 0, ec.tag);
+ all.print(*sink, (float)ec.pred.multiclass, 0, ec.tag);
MULTICLASS::print_update(all, ec);
}
diff --git a/vowpalwabbit/multiclass.h b/vowpalwabbit/multiclass.h
index 1addf83e..aca34075 100644
--- a/vowpalwabbit/multiclass.h
+++ b/vowpalwabbit/multiclass.h
@@ -14,7 +14,6 @@ namespace MULTICLASS
struct multiclass {
uint32_t label;
float weight;
- uint32_t prediction;
};
extern label_parser mc_label;
diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc
index b422c754..b6d3a292 100644
--- a/vowpalwabbit/nn.cc
+++ b/vowpalwabbit/nn.cc
@@ -143,7 +143,7 @@ namespace NN {
base.predict(ec, i);
- hidden_units[i] = ec.l.simple.prediction;
+ hidden_units[i] = ec.pred.scalar;
dropped_out[i] = (n.dropout && merand48 (n.xsubi) < 0.5);
@@ -285,7 +285,7 @@ CONVERSE: // That's right, I'm using goto. So sue me.
}
ec.partial_prediction = save_partial_prediction;
- ec.l.simple.prediction = save_final_prediction;
+ ec.pred.scalar = save_final_prediction;
ec.loss = save_ec_loss;
n.all->sd = save_sd;
diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc
index 6349520c..c1bb6757 100644
--- a/vowpalwabbit/oaa.cc
+++ b/vowpalwabbit/oaa.cc
@@ -32,7 +32,7 @@ namespace OAA {
if (mc_label_data.label == 0 || (mc_label_data.label > o.k && mc_label_data.label != (uint32_t)-1))
cout << "label " << mc_label_data.label << " is not in {1,"<< o.k << "} This won't work right." << endl;
- ec.l.simple = {0.f, mc_label_data.weight, 0.f, 0.f};
+ ec.l.simple = {0.f, mc_label_data.weight, 0.f};
string outputString;
stringstream outputStringStream(outputString);
@@ -64,7 +64,7 @@ namespace OAA {
outputStringStream << i << ':' << ec.partial_prediction;
}
}
- mc_label_data.prediction = prediction;
+ ec.pred.multiclass = prediction;
ec.l.multi = mc_label_data;
if (o.shouldOutput)
diff --git a/vowpalwabbit/parser.cc b/vowpalwabbit/parser.cc
index f78283dc..765b9f26 100644
--- a/vowpalwabbit/parser.cc
+++ b/vowpalwabbit/parser.cc
@@ -1136,12 +1136,12 @@ float get_initial(example* ec)
float get_prediction(example* ec)
{
- return ec->l.simple.prediction;
+ return ec->pred.scalar;
}
float get_cost_sensitive_prediction(example* ec)
{
- return ec->l.cs.prediction;
+ return ec->pred.multiclass;
}
size_t get_tag_length(example* ec)
diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc
index db04dfee..fda5c271 100644
--- a/vowpalwabbit/scorer.cc
+++ b/vowpalwabbit/scorer.cc
@@ -20,9 +20,9 @@ namespace Scorer {
base.predict(ec);
if(ec.l.simple.weight > 0 && ec.l.simple.label != FLT_MAX)
- ec.loss = s.all->loss->getLoss(s.all->sd, ec.l.simple.prediction, ec.l.simple.label) * ec.l.simple.weight;
+ ec.loss = s.all->loss->getLoss(s.all->sd, ec.pred.scalar, ec.l.simple.label) * ec.l.simple.weight;
- ec.l.simple.prediction = link(ec.l.simple.prediction);
+ ec.pred.scalar = link(ec.pred.scalar);
}
// y = f(x) -> [0, 1]
diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc
index 3a7c66b0..20433ae6 100644
--- a/vowpalwabbit/search.cc
+++ b/vowpalwabbit/search.cc
@@ -578,11 +578,6 @@ namespace Search {
del_features_in_top_namespace(priv, ec, conditioning_namespace);
}
- uint32_t cs_get_prediction(bool isCB, polylabel& ld) {
- return isCB ? ld.cb.prediction
- : ld.cs.prediction;
- }
-
size_t cs_get_costs_size(bool isCB, polylabel& ld) {
return isCB ? ld.cb.costs.size()
: ld.cs.costs.size();
@@ -671,7 +666,7 @@ namespace Search {
polylabel old_label = ec.l;
ec.l = allowed_actions_to_ld(priv, 1, allowed_actions, allowed_actions_cnt);
priv.base_learner->predict(ec, policy);
- uint32_t act = cs_get_prediction(priv.cb_learner, ec.l);
+ uint32_t act = ec.pred.multiclass;
// in beam search mode, go through alternatives and add them as back-ups
if (priv.beam) {
@@ -1720,7 +1715,7 @@ namespace Search {
costs.push_back(c);
}
- CS::label ld = { costs, 0 };
+ CS::label ld = { costs };
allowed.push_back(ld);
}
free(bg);
diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc
index ec53d1c4..8bff73e7 100644
--- a/vowpalwabbit/sender.cc
+++ b/vowpalwabbit/sender.cc
@@ -62,9 +62,9 @@ void receive_result(sender& s)
example* ec=s.delay_ring[s.received_index++ % s.all->p->ring_size];
label_data& ld = ec->l.simple;
- ld.prediction = res;
+ ec->pred.scalar = res;
- ec->loss = s.all->loss->getLoss(s.all->sd, ld.prediction, ld.label) * ld.weight;
+ ec->loss = s.all->loss->getLoss(s.all->sd, ec->pred.scalar, ld.label) * ld.weight;
return_simple_example(*(s.all), NULL, *ec);
}
diff --git a/vowpalwabbit/simple_label.cc b/vowpalwabbit/simple_label.cc
index 0d5a32be..9e220cde 100644
--- a/vowpalwabbit/simple_label.cc
+++ b/vowpalwabbit/simple_label.cc
@@ -136,7 +136,7 @@ void print_update(vw& all, example& ec)
(long int)all.sd->example_number,
all.sd->weighted_examples,
label_buf,
- ld.prediction,
+ ec.pred.scalar,
(long unsigned int)ec.num_features);
all.sd->weighted_holdout_examples_since_last_dump = 0.;
@@ -149,7 +149,7 @@ void print_update(vw& all, example& ec)
(long int)all.sd->example_number,
all.sd->weighted_examples,
label_buf,
- ld.prediction,
+ ec.pred.scalar,
(long unsigned int)ec.num_features);
all.sd->sum_loss_since_last_dump = 0.0;
@@ -192,7 +192,7 @@ void output_and_account_example(vw& all, example& ec)
if (all.lda > 0)
print_lda_result(all, f,ec.topic_predictions.begin,0.,ec.tag);
else
- all.print(f, ld.prediction, 0, ec.tag);
+ all.print(f, ec.pred.scalar, 0, ec.tag);
}
print_update(all, ec);
diff --git a/vowpalwabbit/simple_label.h b/vowpalwabbit/simple_label.h
index 365d2f52..3c026483 100644
--- a/vowpalwabbit/simple_label.h
+++ b/vowpalwabbit/simple_label.h
@@ -13,7 +13,6 @@ struct label_data {
float label;
float weight;
float initial;
- float prediction;
};
void return_simple_example(vw& all, void*, example& ec);
diff --git a/vowpalwabbit/stagewise_poly.cc b/vowpalwabbit/stagewise_poly.cc
index c1f4fada..ff20561d 100644
--- a/vowpalwabbit/stagewise_poly.cc
+++ b/vowpalwabbit/stagewise_poly.cc
@@ -523,7 +523,7 @@ namespace StagewisePoly
base.learn(poly.synth_ec);
ec.partial_prediction = poly.synth_ec.partial_prediction;
ec.updated_prediction = poly.synth_ec.updated_prediction;
- ec.l.simple.prediction = poly.synth_ec.l.simple.prediction;
+ ec.pred.scalar = poly.synth_ec.pred.scalar;
if (ec.example_counter
//following line is to avoid repeats when multiple reductions on same example.
diff --git a/vowpalwabbit/topk.cc b/vowpalwabbit/topk.cc
index a1baae6b..f3e8be8e 100644
--- a/vowpalwabbit/topk.cc
+++ b/vowpalwabbit/topk.cc
@@ -94,14 +94,13 @@ namespace TOPK {
else
base.predict(ec);
- label_data& ld = ec.l.simple;
if(d.pr_queue.size() < d.B)
- d.pr_queue.push(make_pair(ld.prediction, ec.tag));
+ d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
- else if(d.pr_queue.top().first < ld.prediction)
+ else if(d.pr_queue.top().first < ec.pred.scalar)
{
d.pr_queue.pop();
- d.pr_queue.push(make_pair(ld.prediction, ec.tag));
+ d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
}
}