diff options
author | John Langford <jl@hunch.net> | 2014-01-29 02:28:32 +0400 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-01-29 02:28:32 +0400 |
commit | 2fb8eca0a0d2a8ba5e8438145b5639486cc4c933 (patch) | |
tree | 6ebf73e8d06d551776f5235857eb7b5fe48f3609 /vowpalwabbit/cbify.cc | |
parent | 1a9eae25b5fc861af1f6714b3ad61822bcf4726f (diff) |
baseline bag and cover approaches working
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r-- | vowpalwabbit/cbify.cc | 218 |
1 files changed, 200 insertions, 18 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index 6818152a..2bfcbef5 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -1,7 +1,10 @@ +#include <float.h> #include "oaa.h" #include "vw.h" +#include "csoaa.h" #include "cb.h" #include "rand48.h" +#include "bs.h" using namespace LEARNER; @@ -13,15 +16,31 @@ namespace CBIFY { size_t tau; float epsilon; + + size_t counter; + + size_t bags; + v_array<float> count; + v_array<uint32_t> predictions; CB::label cb_label; + CSOAA::label cs_label; + CSOAA::label second_cs_label; + + learner* cs; + vw* all; }; uint32_t do_uniform(cbify& data) { //Draw an action return (uint32_t)ceil(frand48() * data.k); } - + + uint32_t choose_bag(cbify& data) + { //Draw an action + return (uint32_t)floor(frand48() * data.bags); + } + float loss(uint32_t label, float final_prediction) { if (label != final_prediction) @@ -94,20 +113,154 @@ namespace CBIFY { ec.ld = ld; } - void learn_bagging(void* d, learner& base, example* ec) - {//Randomize over predictions from a base set of predictors - // cbify* data = (cbify*)d; - - //Use CB to find current predictions. - } + template <bool is_learn> + void predict_or_learn_bag(cbify& data, learner& base, example& ec) + {//Randomize over predictions from a base set of predictors + //Use CB to find current predictions. + OAA::mc_label* ld = (OAA::mc_label*)ec.ld; + ec.ld = &(data.cb_label); + data.cb_label.costs.erase(); - void learn_cover(void* d, learner& base, example* ec) - {//Randomize over predictions from a base set of predictors - //cbify* data = (cbify*)d; - - //Use cost sensitive oracle to cover actions to form distribution. - } + for (size_t j = 1; j <= data.bags; j++) + data.count[j] = 0; + + size_t bag = choose_bag(data); + size_t action = 0; + for (size_t i = 0; i < data.bags; i++) + { + base.predict(ec,i); + data.count[ec.final_prediction]++; + if (i == bag) + action = ec.final_prediction; + } + assert(action != 0); + if (is_learn) + { + float probability = (float)data.count[action] / (float)data.bags; + CB::cb_class l = {loss(ld->label, action), + action, probability}; + data.cb_label.costs.push_back(l); + for (size_t i = 0; i < data.bags; i++) + { + uint32_t count = BS::weight_gen(); + for (uint32_t j = 0; j < count; j++) + base.learn(ec,i); + } + } + ec.ld = ld; + ec.final_prediction = action; + } + + uint32_t choose_action(v_array<float>& distribution) + { + float value = frand48(); + for (uint32_t i = 0; i < distribution.size();i++) + { + if (value <= distribution[i]) + return i+1; + else + value -= distribution[i]; + } + //some rounding problem presumably. + return 1; + } + + void gen_cs_label(vw& all, CB::cb_class& known_cost, example& ec, CSOAA::label& cs_ld, uint32_t label) + { + CSOAA::wclass wc; + + //get cost prediction for this label + wc.x = CB::get_cost_pred<false>(all, &known_cost, ec, label, all.sd->k); + wc.weight_index = label; + wc.partial_prediction = 0.; + wc.wap_value = 0.; + + //add correction if we observed cost for this action and regressor is wrong + if( known_cost.action == label ) + wc.x += (known_cost.cost - wc.x) / known_cost.probability; + + cs_ld.costs.push_back( wc ); + } + + template <bool is_learn> + void predict_or_learn_cover(cbify& data, learner& base, example& ec) + {//Randomize over predictions from a base set of predictors + //Use cost sensitive oracle to cover actions to form distribution. + OAA::mc_label* ld = (OAA::mc_label*)ec.ld; + data.counter++; + float round_epsilon = data.epsilon / sqrt(data.counter); + + float base_prob = round_epsilon / data.k; + data.count.erase(); + data.cs_label.costs.erase(); + for (size_t j = 0; j < data.k; j++) + { + data.count.push_back(base_prob); + + CSOAA::wclass wc; + + //get cost prediction for this label + wc.x = FLT_MAX; + wc.weight_index = j+1; + wc.partial_prediction = 0.; + wc.wap_value = 0.; + data.cs_label.costs.push_back(wc); + } + + float additive_probability = (1. - round_epsilon) / (float)data.bags; + + ec.ld = &data.cs_label; + for (size_t i = 0; i < data.bags; i++) + { //get predicted cost-sensitive predictions + data.cs->predict(ec,i+2); + data.count[ec.final_prediction-1] += additive_probability; + data.predictions[i] = ec.final_prediction; + } + //compute random action + uint32_t action = choose_action(data.count); + + if (is_learn) + { + data.cb_label.costs.erase(); + float probability = (float)data.count[action-1]; + CB::cb_class l = {loss(ld->label, action), + action, probability}; + data.cb_label.costs.push_back(l); + ec.ld = &(data.cb_label); + base.learn(ec); + + //Now update oracles + + //1. Compute loss vector + data.cs_label.costs.erase(); + float norm = 0; + for (uint32_t j = 0; j < data.k; j++) + { //data.cs_label now contains an unbiased estimate of cost of each class. + gen_cs_label(*data.all, l, ec, data.cs_label, j+1); + data.count[j] = base_prob; + norm += base_prob; + } + + ec.ld = &data.second_cs_label; + //2. Update functions + for (size_t i = 0; i < data.bags; i++) + { //get predicted cost-sensitive predictions + for (size_t j = 0; j < data.k; j++) + { + float pseudo_cost = data.cs_label.costs[j].x - 0.125 * (base_prob / (data.count[j] / norm) + 1); + data.second_cs_label.costs[j].weight_index = j+1; + data.second_cs_label.costs[j].x = pseudo_cost; + } + data.cs->learn(ec,i+2); + data.count[data.predictions[i]-1] += additive_probability; + norm += additive_probability; + } + } + ec.ld = ld; + ec.final_prediction = action; + } + void init_driver(cbify&) {} void finish_example(vw& all, cbify&, example& ec) @@ -121,11 +274,15 @@ namespace CBIFY { cbify* data = (cbify*)calloc(1, sizeof(cbify)); data->epsilon = 0.05f; + data->counter = 0; data->tau = 1000; + data->all = &all; po::options_description desc("CBIFY options"); desc.add_options() ("first", po::value<size_t>(), "tau-first exploration") - ("greedy",po::value<float>() ,"epsilon-greedy exploration"); + ("epsilon",po::value<float>() ,"epsilon-greedy exploration") + ("bag",po::value<size_t>() ,"bagging-based exploration") + ("cover",po::value<size_t>() ,"bagging-based exploration"); po::parsed_options parsed = po::command_line_parser(opts). style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing). @@ -155,20 +312,45 @@ namespace CBIFY { } all.p->lp = OAA::mc_label_parser; - learner* l = new learner(data, all.l, 1); - if (vm.count("first") ) + learner* l; + if (vm.count("cover")) + { + data->bags = (uint32_t)vm["cover"].as<size_t>(); + data->cs = all.cost_sensitive; + data->count.resize(data->k); + data->predictions.resize(data->bags); + data->second_cs_label.costs.resize(data->k); + data->second_cs_label.costs.end = data->second_cs_label.costs.begin+data->k; + if ( vm.count("epsilon") ) + data->epsilon = vm["epsilon"].as<float>(); + l = new learner(data, all.l, data->bags + 1); + l->set_learn<cbify, predict_or_learn_cover<true> >(); + l->set_predict<cbify, predict_or_learn_cover<false> >(); + } + else if (vm.count("bag")) + { + data->bags = (uint32_t)vm["bag"].as<size_t>(); + data->count.resize(data->bags+1); + l = new learner(data, all.l, data->bags); + l->set_learn<cbify, predict_or_learn_bag<true> >(); + l->set_predict<cbify, predict_or_learn_bag<false> >(); + } + else if (vm.count("first") ) { data->tau = (uint32_t)vm["first"].as<size_t>(); + l = new learner(data, all.l, 1); l->set_learn<cbify, predict_or_learn_first<true> >(); l->set_predict<cbify, predict_or_learn_first<false> >(); } else { - if ( vm.count("greedy") ) - data->epsilon = vm["greedy"].as<float>(); + if ( vm.count("epsilon") ) + data->epsilon = vm["epsilon"].as<float>(); + l = new learner(data, all.l, 1); l->set_learn<cbify, predict_or_learn_greedy<true> >(); l->set_predict<cbify, predict_or_learn_greedy<false> >(); } + l->set_finish_example<cbify,finish_example>(); l->set_init_driver<cbify,init_driver>(); |