diff options
Diffstat (limited to 'vowpalwabbit/cb.cc')
-rw-r--r-- | vowpalwabbit/cb.cc | 110 |
1 files changed, 55 insertions, 55 deletions
diff --git a/vowpalwabbit/cb.cc b/vowpalwabbit/cb.cc index 9d2437ac..a3214ba1 100644 --- a/vowpalwabbit/cb.cc +++ b/vowpalwabbit/cb.cc @@ -223,9 +223,9 @@ namespace CB return NULL; } - void gen_cs_example_ips(vw& all, cb& c, example* ec, CSOAA::label& cs_ld) + void gen_cs_example_ips(vw& all, cb& c, example& ec, CSOAA::label& cs_ld) {//this implements the inverse propensity score method, where cost are importance weighted by the probability of the chosen action - CB::label* ld = (CB::label*)ec->ld; + CB::label* ld = (CB::label*)ec.ld; //generate cost-sensitive example cs_ld.costs.erase(); @@ -282,7 +282,7 @@ namespace CB } template <bool is_learn> - void call_scorer(vw& all, cb& c, example* ec, uint32_t index) + void call_scorer(vw& all, cb& c, example& ec, uint32_t index) { float old_min = all.sd->min_label; //all.sd->min_label = c.min_cost; @@ -297,9 +297,9 @@ namespace CB } template <bool is_learn> - float get_cost_pred(vw& all, cb& c, example* ec, uint32_t index) + float get_cost_pred(vw& all, cb& c, example& ec, uint32_t index) { - CB::label* ld = (CB::label*)ec->ld; + CB::label* ld = (CB::label*)ec.ld; label_data simple_temp; simple_temp.initial = 0.; @@ -314,21 +314,21 @@ namespace CB simple_temp.weight = 0.; } - ec->ld = &simple_temp; + ec.ld = &simple_temp; call_scorer<is_learn>(all, c, ec, index); - ec->ld = ld; + ec.ld = ld; - float cost = ec->final_prediction; + float cost = ec.final_prediction; return cost; } template <bool is_learn> - void gen_cs_example_dm(vw& all, cb& c, example* ec, CSOAA::label& cs_ld) + void gen_cs_example_dm(vw& all, cb& c, example& ec, CSOAA::label& cs_ld) { //this implements the direct estimation method, where costs are directly specified by the learned regressor. - CB::label* ld = (CB::label*)ec->ld; + CB::label* ld = (CB::label*)ec.ld; float min = FLT_MAX; size_t argmin = 1; @@ -393,13 +393,13 @@ namespace CB } } - ec->final_prediction = (float)argmin; + ec.final_prediction = (float)argmin; } template <bool is_learn> - void gen_cs_example_dr(vw& all, cb& c, example* ec, CSOAA::label& cs_ld) + void gen_cs_example_dr(vw& all, cb& c, example& ec, CSOAA::label& cs_ld) {//this implements the doubly robust method - CB::label* ld = (CB::label*)ec->ld; + CB::label* ld = (CB::label*)ec.ld; //generate cost sensitive example cs_ld.costs.erase(); @@ -455,9 +455,9 @@ namespace CB } } - void cb_test_to_cs_test_label(vw& all, example* ec, CSOAA::label& cs_ld) + void cb_test_to_cs_test_label(vw& all, example& ec, CSOAA::label& cs_ld) { - CB::label* ld = (CB::label*)ec->ld; + CB::label* ld = (CB::label*)ec.ld; cs_ld.costs.erase(); if(ld->costs.size() > 0) @@ -479,67 +479,67 @@ namespace CB } template <bool is_learn> - void predict_or_learn(cb* c, learner& base, example* ec) { - vw* all = c->all; - CB::label* ld = (CB::label*)ec->ld; + void predict_or_learn(cb& c, learner& base, example& ec) { + vw* all = c.all; + CB::label* ld = (CB::label*)ec.ld; //check if this is a test example where we just want a prediction if( CB::is_test_label(ld) ) { //if so just query base cost-sensitive learner - cb_test_to_cs_test_label(*all,ec,c->cb_cs_ld); + cb_test_to_cs_test_label(*all,ec,c.cb_cs_ld); - ec->ld = &c->cb_cs_ld; + ec.ld = &c.cb_cs_ld; base.predict(ec); - ec->ld = ld; + ec.ld = ld; for (size_t i=0; i<ld->costs.size(); i++) - ld->costs[i].partial_prediction = c->cb_cs_ld.costs[i].partial_prediction; + ld->costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction; return; } //now this is a training example - c->known_cost = get_observed_cost(ld); - c->min_cost = min (c->min_cost, c->known_cost->cost); - c->max_cost = max (c->max_cost, c->known_cost->cost); + c.known_cost = get_observed_cost(ld); + c.min_cost = min (c.min_cost, c.known_cost->cost); + c.max_cost = max (c.max_cost, c.known_cost->cost); //generate a cost-sensitive example to update classifiers - switch(c->cb_type) + switch(c.cb_type) { case CB_TYPE_IPS: - gen_cs_example_ips(*all,*c,ec,c->cb_cs_ld); + gen_cs_example_ips(*all,c,ec,c.cb_cs_ld); break; case CB_TYPE_DM: - gen_cs_example_dm<is_learn>(*all,*c,ec,c->cb_cs_ld); + gen_cs_example_dm<is_learn>(*all,c,ec,c.cb_cs_ld); break; case CB_TYPE_DR: - gen_cs_example_dr<is_learn>(*all,*c,ec,c->cb_cs_ld); + gen_cs_example_dr<is_learn>(*all,c,ec,c.cb_cs_ld); break; default: - std::cerr << "Unknown cb_type specified for contextual bandit learning: " << c->cb_type << ". Exiting." << endl; + std::cerr << "Unknown cb_type specified for contextual bandit learning: " << c.cb_type << ". Exiting." << endl; throw exception(); } - if (c->cb_type != CB_TYPE_DM) + if (c.cb_type != CB_TYPE_DM) { - ec->ld = &c->cb_cs_ld; + ec.ld = &c.cb_cs_ld; if (is_learn) base.learn(ec); else base.predict(ec); - ec->ld = ld; + ec.ld = ld; for (size_t i=0; i<ld->costs.size(); i++) - ld->costs[i].partial_prediction = c->cb_cs_ld.costs[i].partial_prediction; + ld->costs[i].partial_prediction = c.cb_cs_ld.costs[i].partial_prediction; } } - void init_driver(cb*) + void init_driver(cb&) { fprintf(stderr, "*estimate* *estimate* avglossreg last pred last correct\n"); } - void print_update(vw& all, cb& c, bool is_test, example *ec) + void print_update(vw& all, cb& c, bool is_test, example& ec) { if (all.sd->weighted_examples >= all.sd->dump_interval && !all.quiet && !all.bfgs) { @@ -565,8 +565,8 @@ namespace CB (long int)all.sd->example_number, all.sd->weighted_examples, label_buf, - (long unsigned int)ec->final_prediction, - (long unsigned int)ec->num_features, + (long unsigned int)ec.final_prediction, + (long unsigned int)ec.num_features, c.avg_loss_regressors, c.last_pred_reg, c.last_correct_cost); @@ -581,8 +581,8 @@ namespace CB (long int)all.sd->example_number, all.sd->weighted_examples, label_buf, - (long unsigned int)ec->final_prediction, - (long unsigned int)ec->num_features, + (long unsigned int)ec.final_prediction, + (long unsigned int)ec.num_features, c.avg_loss_regressors, c.last_pred_reg, c.last_correct_cost); @@ -593,14 +593,14 @@ namespace CB } } - void output_example(vw& all, cb& c, example* ec) + void output_example(vw& all, cb& c, example& ec) { - CB::label* ld = (CB::label*)ec->ld; + CB::label* ld = (CB::label*)ec.ld; float loss = 0.; if (!CB::is_test_label(ld)) {//need to compute exact loss - size_t pred = (size_t)ec->final_prediction; + size_t pred = (size_t)ec.final_prediction; float chosen_loss = FLT_MAX; if( know_all_cost_example(ld) ) { @@ -626,11 +626,11 @@ namespace CB loss = chosen_loss; } - if(ec->test_only) + if(ec.test_only) { - all.sd->weighted_holdout_examples += ec->global_weight;//test weight seen - all.sd->weighted_holdout_examples_since_last_dump += ec->global_weight; - all.sd->weighted_holdout_examples_since_last_pass += ec->global_weight; + all.sd->weighted_holdout_examples += ec.global_weight;//test weight seen + all.sd->weighted_holdout_examples_since_last_dump += ec.global_weight; + all.sd->weighted_holdout_examples_since_last_pass += ec.global_weight; all.sd->holdout_sum_loss += loss; all.sd->holdout_sum_loss_since_last_dump += loss; all.sd->holdout_sum_loss_since_last_pass += loss;//since last pass @@ -640,30 +640,30 @@ namespace CB all.sd->sum_loss += loss; all.sd->sum_loss_since_last_dump += loss; all.sd->weighted_examples += 1.; - all.sd->total_features += ec->num_features; + all.sd->total_features += ec.num_features; all.sd->example_number++; } for (size_t i = 0; i<all.final_prediction_sink.size(); i++) { int f = all.final_prediction_sink[i]; - all.print(f, ec->final_prediction, 0, ec->tag); + all.print(f, ec.final_prediction, 0, ec.tag); } - print_update(all, c, CB::is_test_label((CB::label*)ec->ld), ec); + print_update(all, c, CB::is_test_label((CB::label*)ec.ld), ec); } - void finish(cb* c) + void finish(cb& c) { - c->cb_cs_ld.costs.delete_v(); + c.cb_cs_ld.costs.delete_v(); } - void finish_example(vw& all, cb* c, example* ec) + void finish_example(vw& all, cb& c, example& ec) { - output_example(all, *c, ec); - VW::finish_example(all, ec); + output_example(all, c, ec); + VW::finish_example(all, &ec); } learner* setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) |