Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-01-29 02:28:32 +0400
committerJohn Langford <jl@hunch.net>2014-01-29 02:28:32 +0400
commit2fb8eca0a0d2a8ba5e8438145b5639486cc4c933 (patch)
tree6ebf73e8d06d551776f5235857eb7b5fe48f3609 /vowpalwabbit/cbify.cc
parent1a9eae25b5fc861af1f6714b3ad61822bcf4726f (diff)
baseline bag and cover approaches working
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r--vowpalwabbit/cbify.cc218
1 files changed, 200 insertions, 18 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc
index 6818152a..2bfcbef5 100644
--- a/vowpalwabbit/cbify.cc
+++ b/vowpalwabbit/cbify.cc
@@ -1,7 +1,10 @@
+#include <float.h>
#include "oaa.h"
#include "vw.h"
+#include "csoaa.h"
#include "cb.h"
#include "rand48.h"
+#include "bs.h"
using namespace LEARNER;
@@ -13,15 +16,31 @@ namespace CBIFY {
size_t tau;
float epsilon;
+
+ size_t counter;
+
+ size_t bags;
+ v_array<float> count;
+ v_array<uint32_t> predictions;
CB::label cb_label;
+ CSOAA::label cs_label;
+ CSOAA::label second_cs_label;
+
+ learner* cs;
+ vw* all;
};
uint32_t do_uniform(cbify& data)
{ //Draw an action
return (uint32_t)ceil(frand48() * data.k);
}
-
+
+ uint32_t choose_bag(cbify& data)
+ { //Draw an action
+ return (uint32_t)floor(frand48() * data.bags);
+ }
+
float loss(uint32_t label, float final_prediction)
{
if (label != final_prediction)
@@ -94,20 +113,154 @@ namespace CBIFY {
ec.ld = ld;
}
- void learn_bagging(void* d, learner& base, example* ec)
- {//Randomize over predictions from a base set of predictors
- // cbify* data = (cbify*)d;
-
- //Use CB to find current predictions.
- }
+ template <bool is_learn>
+ void predict_or_learn_bag(cbify& data, learner& base, example& ec)
+ {//Randomize over predictions from a base set of predictors
+ //Use CB to find current predictions.
+ OAA::mc_label* ld = (OAA::mc_label*)ec.ld;
+ ec.ld = &(data.cb_label);
+ data.cb_label.costs.erase();
- void learn_cover(void* d, learner& base, example* ec)
- {//Randomize over predictions from a base set of predictors
- //cbify* data = (cbify*)d;
-
- //Use cost sensitive oracle to cover actions to form distribution.
- }
+ for (size_t j = 1; j <= data.bags; j++)
+ data.count[j] = 0;
+
+ size_t bag = choose_bag(data);
+ size_t action = 0;
+ for (size_t i = 0; i < data.bags; i++)
+ {
+ base.predict(ec,i);
+ data.count[ec.final_prediction]++;
+ if (i == bag)
+ action = ec.final_prediction;
+ }
+ assert(action != 0);
+ if (is_learn)
+ {
+ float probability = (float)data.count[action] / (float)data.bags;
+ CB::cb_class l = {loss(ld->label, action),
+ action, probability};
+ data.cb_label.costs.push_back(l);
+ for (size_t i = 0; i < data.bags; i++)
+ {
+ uint32_t count = BS::weight_gen();
+ for (uint32_t j = 0; j < count; j++)
+ base.learn(ec,i);
+ }
+ }
+ ec.ld = ld;
+ ec.final_prediction = action;
+ }
+
+ uint32_t choose_action(v_array<float>& distribution)
+ {
+ float value = frand48();
+ for (uint32_t i = 0; i < distribution.size();i++)
+ {
+ if (value <= distribution[i])
+ return i+1;
+ else
+ value -= distribution[i];
+ }
+ //some rounding problem presumably.
+ return 1;
+ }
+
+ void gen_cs_label(vw& all, CB::cb_class& known_cost, example& ec, CSOAA::label& cs_ld, uint32_t label)
+ {
+ CSOAA::wclass wc;
+
+ //get cost prediction for this label
+ wc.x = CB::get_cost_pred<false>(all, &known_cost, ec, label, all.sd->k);
+ wc.weight_index = label;
+ wc.partial_prediction = 0.;
+ wc.wap_value = 0.;
+
+ //add correction if we observed cost for this action and regressor is wrong
+ if( known_cost.action == label )
+ wc.x += (known_cost.cost - wc.x) / known_cost.probability;
+
+ cs_ld.costs.push_back( wc );
+ }
+
+ template <bool is_learn>
+ void predict_or_learn_cover(cbify& data, learner& base, example& ec)
+ {//Randomize over predictions from a base set of predictors
+ //Use cost sensitive oracle to cover actions to form distribution.
+ OAA::mc_label* ld = (OAA::mc_label*)ec.ld;
+ data.counter++;
+ float round_epsilon = data.epsilon / sqrt(data.counter);
+
+ float base_prob = round_epsilon / data.k;
+ data.count.erase();
+ data.cs_label.costs.erase();
+ for (size_t j = 0; j < data.k; j++)
+ {
+ data.count.push_back(base_prob);
+
+ CSOAA::wclass wc;
+
+ //get cost prediction for this label
+ wc.x = FLT_MAX;
+ wc.weight_index = j+1;
+ wc.partial_prediction = 0.;
+ wc.wap_value = 0.;
+ data.cs_label.costs.push_back(wc);
+ }
+
+ float additive_probability = (1. - round_epsilon) / (float)data.bags;
+
+ ec.ld = &data.cs_label;
+ for (size_t i = 0; i < data.bags; i++)
+ { //get predicted cost-sensitive predictions
+ data.cs->predict(ec,i+2);
+ data.count[ec.final_prediction-1] += additive_probability;
+ data.predictions[i] = ec.final_prediction;
+ }
+ //compute random action
+ uint32_t action = choose_action(data.count);
+
+ if (is_learn)
+ {
+ data.cb_label.costs.erase();
+ float probability = (float)data.count[action-1];
+ CB::cb_class l = {loss(ld->label, action),
+ action, probability};
+ data.cb_label.costs.push_back(l);
+ ec.ld = &(data.cb_label);
+ base.learn(ec);
+
+ //Now update oracles
+
+ //1. Compute loss vector
+ data.cs_label.costs.erase();
+ float norm = 0;
+ for (uint32_t j = 0; j < data.k; j++)
+ { //data.cs_label now contains an unbiased estimate of cost of each class.
+ gen_cs_label(*data.all, l, ec, data.cs_label, j+1);
+ data.count[j] = base_prob;
+ norm += base_prob;
+ }
+
+ ec.ld = &data.second_cs_label;
+ //2. Update functions
+ for (size_t i = 0; i < data.bags; i++)
+ { //get predicted cost-sensitive predictions
+ for (size_t j = 0; j < data.k; j++)
+ {
+ float pseudo_cost = data.cs_label.costs[j].x - 0.125 * (base_prob / (data.count[j] / norm) + 1);
+ data.second_cs_label.costs[j].weight_index = j+1;
+ data.second_cs_label.costs[j].x = pseudo_cost;
+ }
+ data.cs->learn(ec,i+2);
+ data.count[data.predictions[i]-1] += additive_probability;
+ norm += additive_probability;
+ }
+ }
+ ec.ld = ld;
+ ec.final_prediction = action;
+ }
+
void init_driver(cbify&) {}
void finish_example(vw& all, cbify&, example& ec)
@@ -121,11 +274,15 @@ namespace CBIFY {
cbify* data = (cbify*)calloc(1, sizeof(cbify));
data->epsilon = 0.05f;
+ data->counter = 0;
data->tau = 1000;
+ data->all = &all;
po::options_description desc("CBIFY options");
desc.add_options()
("first", po::value<size_t>(), "tau-first exploration")
- ("greedy",po::value<float>() ,"epsilon-greedy exploration");
+ ("epsilon",po::value<float>() ,"epsilon-greedy exploration")
+ ("bag",po::value<size_t>() ,"bagging-based exploration")
+ ("cover",po::value<size_t>() ,"bagging-based exploration");
po::parsed_options parsed = po::command_line_parser(opts).
style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing).
@@ -155,20 +312,45 @@ namespace CBIFY {
}
all.p->lp = OAA::mc_label_parser;
- learner* l = new learner(data, all.l, 1);
- if (vm.count("first") )
+ learner* l;
+ if (vm.count("cover"))
+ {
+ data->bags = (uint32_t)vm["cover"].as<size_t>();
+ data->cs = all.cost_sensitive;
+ data->count.resize(data->k);
+ data->predictions.resize(data->bags);
+ data->second_cs_label.costs.resize(data->k);
+ data->second_cs_label.costs.end = data->second_cs_label.costs.begin+data->k;
+ if ( vm.count("epsilon") )
+ data->epsilon = vm["epsilon"].as<float>();
+ l = new learner(data, all.l, data->bags + 1);
+ l->set_learn<cbify, predict_or_learn_cover<true> >();
+ l->set_predict<cbify, predict_or_learn_cover<false> >();
+ }
+ else if (vm.count("bag"))
+ {
+ data->bags = (uint32_t)vm["bag"].as<size_t>();
+ data->count.resize(data->bags+1);
+ l = new learner(data, all.l, data->bags);
+ l->set_learn<cbify, predict_or_learn_bag<true> >();
+ l->set_predict<cbify, predict_or_learn_bag<false> >();
+ }
+ else if (vm.count("first") )
{
data->tau = (uint32_t)vm["first"].as<size_t>();
+ l = new learner(data, all.l, 1);
l->set_learn<cbify, predict_or_learn_first<true> >();
l->set_predict<cbify, predict_or_learn_first<false> >();
}
else
{
- if ( vm.count("greedy") )
- data->epsilon = vm["greedy"].as<float>();
+ if ( vm.count("epsilon") )
+ data->epsilon = vm["epsilon"].as<float>();
+ l = new learner(data, all.l, 1);
l->set_learn<cbify, predict_or_learn_greedy<true> >();
l->set_predict<cbify, predict_or_learn_greedy<false> >();
}
+
l->set_finish_example<cbify,finish_example>();
l->set_init_driver<cbify,init_driver>();