diff options
author | John Langford <jl@hunch.net> | 2014-01-25 02:26:16 +0400 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-01-25 02:26:16 +0400 |
commit | 9fd03e09f186516b2cd80f13be74a3645de93d24 (patch) | |
tree | 196090272086429589321480763e81519b42f052 /vowpalwabbit/cbify.cc | |
parent | 48ec8c0dae97cfdb3d48788386e6fdadf6febb40 (diff) |
various cbify improvements
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r-- | vowpalwabbit/cbify.cc | 66 |
1 files changed, 28 insertions, 38 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index 3f62a3ef..6818152a 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -17,22 +17,17 @@ namespace CBIFY { CB::label cb_label; }; - void do_uniform(cbify& data, example& ec) - { - //Draw an action - uint32_t action = (uint32_t)ceil(frand48() * data.k); - - ec.final_prediction = (float)action; + uint32_t do_uniform(cbify& data) + { //Draw an action + return (uint32_t)ceil(frand48() * data.k); } - void do_loss(example& ec) + float loss(uint32_t label, float final_prediction) { - OAA::mc_label* ld = (OAA::mc_label*)ec.ld;//New loss - - if (ld->label != ec.final_prediction) - ec.loss = 1.; + if (label != final_prediction) + return 1.; else - ec.loss = 0.; + return 0.; } template <bool is_learn> @@ -40,21 +35,17 @@ namespace CBIFY { {//Explore tau times, then act according to optimal. OAA::mc_label* ld = (OAA::mc_label*)ec.ld; //Use CB to find current prediction for remaining rounds. - if (data.tau > 0) + if (data.tau && is_learn > 0) { - do_uniform(data, ec); - do_loss(ec); + ec.final_prediction = do_uniform(data); + ec.loss = loss(ld->label, ec.final_prediction); data.tau--; - cout << "tau--" << endl; uint32_t action = (uint32_t)ec.final_prediction; CB::cb_class l = {ec.loss, action, 1.f / data.k}; data.cb_label.costs.erase(); data.cb_label.costs.push_back(l); ec.ld = &(data.cb_label); - if (is_learn) - base.learn(ec); - else - base.predict(ec); + base.learn(ec); ec.final_prediction = (float)action; ec.loss = l.cost; } @@ -62,11 +53,8 @@ namespace CBIFY { { data.cb_label.costs.erase(); ec.ld = &(data.cb_label); - if (is_learn) - base.learn(ec); - else - base.predict(ec); - do_loss(ec); + base.predict(ec); + ec.loss = loss(ld->label, ec.final_prediction); } ec.ld = ld; } @@ -75,33 +63,34 @@ namespace CBIFY { void predict_or_learn_greedy(cbify& data, learner& base, example& ec) {//Explore uniform random an epsilon fraction of the time. OAA::mc_label* ld = (OAA::mc_label*)ec.ld; - - data.cb_label.costs.erase(); ec.ld = &(data.cb_label); + data.cb_label.costs.erase(); + base.predict(ec); - do_loss(ec); uint32_t action = (uint32_t)ec.final_prediction; float base_prob = data.epsilon / data.k; if (frand48() < 1. - data.epsilon) { - CB::cb_class l = {ec.loss, action, 1.f - data.epsilon + base_prob}; + CB::cb_class l = {loss(ld->label, ec.final_prediction), + action, 1.f - data.epsilon + base_prob}; data.cb_label.costs.push_back(l); } else { - do_uniform(data, ec); - do_loss(ec); - action = (uint32_t)ec.final_prediction; - CB::cb_class l = {ec.loss, (uint32_t)ec.final_prediction, base_prob}; + action = do_uniform(data); + CB::cb_class l = {loss(ld->label, action), + action, base_prob}; + if (action == ec.final_prediction) + l.probability = 1.f - data.epsilon + base_prob; data.cb_label.costs.push_back(l); } if (is_learn) - base.learn(ec); - + base.learn(ec); + ec.final_prediction = (float)action; - ec.loss = data.cb_label.costs[0].cost; + ec.loss = loss(ld->label, ec.final_prediction); ec.ld = ld; } @@ -119,6 +108,8 @@ namespace CBIFY { //Use cost sensitive oracle to cover actions to form distribution. } + void init_driver(cbify&) {} + void finish_example(vw& all, cbify&, example& ec) { OAA::output_example(all, ec); @@ -179,8 +170,7 @@ namespace CBIFY { l->set_predict<cbify, predict_or_learn_greedy<false> >(); } l->set_finish_example<cbify,finish_example>(); - - cout << "epsilon = " << data->epsilon << endl; + l->set_init_driver<cbify,init_driver>(); return l; } |