diff options
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r-- | vowpalwabbit/cbify.cc | 78 |
1 files changed, 39 insertions, 39 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index d41975ae..3f62a3ef 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -17,92 +17,92 @@ namespace CBIFY { CB::label cb_label; }; - void do_uniform(cbify* data, example* ec) + void do_uniform(cbify& data, example& ec) { //Draw an action - uint32_t action = (uint32_t)ceil(frand48() * data->k); + uint32_t action = (uint32_t)ceil(frand48() * data.k); - ec->final_prediction = (float)action; + ec.final_prediction = (float)action; } - void do_loss(example* ec) + void do_loss(example& ec) { - OAA::mc_label* ld = (OAA::mc_label*)ec->ld;//New loss + OAA::mc_label* ld = (OAA::mc_label*)ec.ld;//New loss - if (ld->label != ec->final_prediction) - ec->loss = 1.; + if (ld->label != ec.final_prediction) + ec.loss = 1.; else - ec->loss = 0.; + ec.loss = 0.; } template <bool is_learn> - void predict_or_learn_first(cbify* data, learner& base, example* ec) + void predict_or_learn_first(cbify& data, learner& base, example& ec) {//Explore tau times, then act according to optimal. - OAA::mc_label* ld = (OAA::mc_label*)ec->ld; + OAA::mc_label* ld = (OAA::mc_label*)ec.ld; //Use CB to find current prediction for remaining rounds. - if (data->tau > 0) + if (data.tau > 0) { do_uniform(data, ec); do_loss(ec); - data->tau--; + data.tau--; cout << "tau--" << endl; - uint32_t action = (uint32_t)ec->final_prediction; - CB::cb_class l = {ec->loss, action, 1.f / data->k}; - data->cb_label.costs.erase(); - data->cb_label.costs.push_back(l); - ec->ld = &(data->cb_label); + uint32_t action = (uint32_t)ec.final_prediction; + CB::cb_class l = {ec.loss, action, 1.f / data.k}; + data.cb_label.costs.erase(); + data.cb_label.costs.push_back(l); + ec.ld = &(data.cb_label); if (is_learn) base.learn(ec); else base.predict(ec); - ec->final_prediction = (float)action; - ec->loss = l.cost; + ec.final_prediction = (float)action; + ec.loss = l.cost; } else { - data->cb_label.costs.erase(); - ec->ld = &(data->cb_label); + data.cb_label.costs.erase(); + ec.ld = &(data.cb_label); if (is_learn) base.learn(ec); else base.predict(ec); do_loss(ec); } - ec->ld = ld; + ec.ld = ld; } template <bool is_learn> - void predict_or_learn_greedy(cbify* data, learner& base, example* ec) + void predict_or_learn_greedy(cbify& data, learner& base, example& ec) {//Explore uniform random an epsilon fraction of the time. - OAA::mc_label* ld = (OAA::mc_label*)ec->ld; + OAA::mc_label* ld = (OAA::mc_label*)ec.ld; - data->cb_label.costs.erase(); - ec->ld = &(data->cb_label); + data.cb_label.costs.erase(); + ec.ld = &(data.cb_label); base.predict(ec); do_loss(ec); - uint32_t action = (uint32_t)ec->final_prediction; + uint32_t action = (uint32_t)ec.final_prediction; - float base_prob = data->epsilon / data->k; - if (frand48() < 1. - data->epsilon) + float base_prob = data.epsilon / data.k; + if (frand48() < 1. - data.epsilon) { - CB::cb_class l = {ec->loss, action, 1.f - data->epsilon + base_prob}; - data->cb_label.costs.push_back(l); + CB::cb_class l = {ec.loss, action, 1.f - data.epsilon + base_prob}; + data.cb_label.costs.push_back(l); } else { do_uniform(data, ec); do_loss(ec); - action = (uint32_t)ec->final_prediction; - CB::cb_class l = {ec->loss, (uint32_t)ec->final_prediction, base_prob}; - data->cb_label.costs.push_back(l); + action = (uint32_t)ec.final_prediction; + CB::cb_class l = {ec.loss, (uint32_t)ec.final_prediction, base_prob}; + data.cb_label.costs.push_back(l); } if (is_learn) base.learn(ec); - ec->final_prediction = (float)action; - ec->loss = data->cb_label.costs[0].cost; - ec->ld = ld; + ec.final_prediction = (float)action; + ec.loss = data.cb_label.costs[0].cost; + ec.ld = ld; } void learn_bagging(void* d, learner& base, example* ec) @@ -119,10 +119,10 @@ namespace CBIFY { //Use cost sensitive oracle to cover actions to form distribution. } - void finish_example(vw& all, cbify*, example* ec) + void finish_example(vw& all, cbify&, example& ec) { OAA::output_example(all, ec); - VW::finish_example(all, ec); + VW::finish_example(all, &ec); } learner* setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) |