diff options
author | John Langford <jl@hunch.net> | 2013-12-25 03:04:13 +0400 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2013-12-25 03:04:13 +0400 |
commit | 7c89547c50cdb9a4d876debef2aee7fa4828e3c6 (patch) | |
tree | 85b375c2ceb340f6f8ca4c19aebf404faf9fc846 /vowpalwabbit/cbify.cc | |
parent | 1ff5894ad74f034473c44c8f82bca692e3e0d8a4 (diff) |
added cbify
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r-- | vowpalwabbit/cbify.cc | 172 |
1 files changed, 172 insertions, 0 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc new file mode 100644 index 00000000..c739fc76 --- /dev/null +++ b/vowpalwabbit/cbify.cc @@ -0,0 +1,172 @@ +#include "oaa.h" +#include "vw.h" +#include "cb.h" +#include "rand48.h" + +namespace CBIFY { + + struct cbify { + size_t k; + + size_t tau; + + float epsilon; + + CB::label cb_label; + }; + + void do_uniform(cbify* data, example* ec) + { + //Draw an action + size_t action = (size_t)ceil(frand48() * data->k); + + ec->final_prediction = action; + } + + void do_loss(example* ec) + { + OAA::mc_label* ld = (OAA::mc_label*)ec->ld;//New loss + + if (ld->label != ec->final_prediction) + ec->loss = 1.; + else + ec->loss = 0.; + } + + void learn_first(void* d, learner& base, example* ec) + {//Explore tau times, then act according to optimal. + cbify* data = (cbify*)d; + + OAA::mc_label* ld = (OAA::mc_label*)ec->ld; + //Use CB to find current prediction for remaining rounds. + if (data->tau > 0) + { + do_uniform(data, ec); + do_loss(ec); + data->tau--; + cout << "tau--" << endl; + size_t action = ec->final_prediction; + CB::cb_class l = {ec->loss, action, 1. / data->k}; + data->cb_label.costs.erase(); + data->cb_label.costs.push_back(l); + ec->ld = &(data->cb_label); + base.learn(ec); + ec->final_prediction = action; + ec->loss = l.cost; + } + else + { + data->cb_label.costs.erase(); + ec->ld = &(data->cb_label); + base.learn(ec); + do_loss(ec); + } + ec->ld = ld; + } + + void learn_greedy(void* d, learner& base, example* ec) + {//Explore uniform random an epsilon fraction of the time. + cbify* data = (cbify*)d; + + //Use CB to find current prediction. + OAA::mc_label* ld = (OAA::mc_label*)ec->ld; + + data->cb_label.costs.erase(); + ec->ld = &(data->cb_label); + base.learn(ec); + do_loss(ec); + float action = ec->final_prediction; + + float base_prob = data->epsilon / data->k; + if (frand48() < 1. - data->epsilon) + { + CB::cb_class l = {ec->loss, action, 1. - data->epsilon + base_prob}; + data->cb_label.costs.push_back(l); + } + else + { + do_uniform(data, ec); + do_loss(ec); + action = ec->final_prediction; + CB::cb_class l = {ec->loss, ec->final_prediction, base_prob}; + data->cb_label.costs.push_back(l); + } + base.learn(ec); + + ec->final_prediction = action; + ec->loss = data->cb_label.costs[0].cost; + ec->ld = ld; + } + + void learn_bagging(void* d, learner& base, example* ec) + {//Randomize over predictions from a base set of predictors + // cbify* data = (cbify*)d; + + //Use CB to find current predictions. + } + + void learn_cover(void* d, learner& base, example* ec) + {//Randomize over predictions from a base set of predictors + //cbify* data = (cbify*)d; + + //Use cost sensitive oracle to cover actions to form distribution. + } + + learner* setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) + {//parse and set arguments + cbify* data = (cbify*)calloc(1, sizeof(cbify)); + + data->epsilon = 0.05; + data->tau = 1000; + po::options_description desc("CBIFY options"); + desc.add_options() + ("first", po::value<size_t>(), "tau-first exploration") + ("greedy",po::value<float>() ,"epsilon-greedy exploration"); + + po::parsed_options parsed = po::command_line_parser(opts). + style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing). + options(desc).allow_unregistered().run(); + opts = po::collect_unrecognized(parsed.options, po::include_positional); + po::store(parsed, vm); + po::notify(vm); + + po::parsed_options parsed_file = po::command_line_parser(all.options_from_file_argc,all.options_from_file_argv). + style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing). + options(desc).allow_unregistered().run(); + po::store(parsed_file, vm_file); + po::notify(vm_file); + + if( vm_file.count("cbify") ) { + data->k = (uint32_t)vm_file["cbify"].as<size_t>(); + if( vm.count("cbify") && (uint32_t)vm["cbify"].as<size_t>() != data->k ) + std::cerr << "warning: you specified a different number of actions through --cbify than the one loaded from predictor. Pursuing with loaded value of: " << data->k << endl; + } + else { + data->k = (uint32_t)vm["cbify"].as<size_t>(); + + //appends nb_actions to options_from_file so it is saved to regressor later + std::stringstream ss; + ss << " --cbify " << data->k; + all.options_from_file.append(ss.str()); + } + + *(all.p->lp) = OAA::mc_label_parser; + learner* l; + if (vm.count("first") ) + { + data->tau = (uint32_t)vm["first"].as<size_t>(); + l = new learner(data, learn_first, all.l, 1); + } + else + { + if ( vm.count("greedy") ) + data->epsilon = vm["greedy"].as<float>(); + l = new learner(data, learn_greedy, all.l, 1); + } + l->set_finish_example(OAA::finish_example); + + cout << "epsilon = " << data->epsilon << endl; + + return l; + } +} |