Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-01-25 02:26:16 +0400
committerJohn Langford <jl@hunch.net>2014-01-25 02:26:16 +0400
commit9fd03e09f186516b2cd80f13be74a3645de93d24 (patch)
tree196090272086429589321480763e81519b42f052 /vowpalwabbit/cbify.cc
parent48ec8c0dae97cfdb3d48788386e6fdadf6febb40 (diff)
various cbify improvements
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r--vowpalwabbit/cbify.cc66
1 files changed, 28 insertions, 38 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc
index 3f62a3ef..6818152a 100644
--- a/vowpalwabbit/cbify.cc
+++ b/vowpalwabbit/cbify.cc
@@ -17,22 +17,17 @@ namespace CBIFY {
CB::label cb_label;
};
- void do_uniform(cbify& data, example& ec)
- {
- //Draw an action
- uint32_t action = (uint32_t)ceil(frand48() * data.k);
-
- ec.final_prediction = (float)action;
+ uint32_t do_uniform(cbify& data)
+ { //Draw an action
+ return (uint32_t)ceil(frand48() * data.k);
}
- void do_loss(example& ec)
+ float loss(uint32_t label, float final_prediction)
{
- OAA::mc_label* ld = (OAA::mc_label*)ec.ld;//New loss
-
- if (ld->label != ec.final_prediction)
- ec.loss = 1.;
+ if (label != final_prediction)
+ return 1.;
else
- ec.loss = 0.;
+ return 0.;
}
template <bool is_learn>
@@ -40,21 +35,17 @@ namespace CBIFY {
{//Explore tau times, then act according to optimal.
OAA::mc_label* ld = (OAA::mc_label*)ec.ld;
//Use CB to find current prediction for remaining rounds.
- if (data.tau > 0)
+ if (data.tau && is_learn > 0)
{
- do_uniform(data, ec);
- do_loss(ec);
+ ec.final_prediction = do_uniform(data);
+ ec.loss = loss(ld->label, ec.final_prediction);
data.tau--;
- cout << "tau--" << endl;
uint32_t action = (uint32_t)ec.final_prediction;
CB::cb_class l = {ec.loss, action, 1.f / data.k};
data.cb_label.costs.erase();
data.cb_label.costs.push_back(l);
ec.ld = &(data.cb_label);
- if (is_learn)
- base.learn(ec);
- else
- base.predict(ec);
+ base.learn(ec);
ec.final_prediction = (float)action;
ec.loss = l.cost;
}
@@ -62,11 +53,8 @@ namespace CBIFY {
{
data.cb_label.costs.erase();
ec.ld = &(data.cb_label);
- if (is_learn)
- base.learn(ec);
- else
- base.predict(ec);
- do_loss(ec);
+ base.predict(ec);
+ ec.loss = loss(ld->label, ec.final_prediction);
}
ec.ld = ld;
}
@@ -75,33 +63,34 @@ namespace CBIFY {
void predict_or_learn_greedy(cbify& data, learner& base, example& ec)
{//Explore uniform random an epsilon fraction of the time.
OAA::mc_label* ld = (OAA::mc_label*)ec.ld;
-
- data.cb_label.costs.erase();
ec.ld = &(data.cb_label);
+ data.cb_label.costs.erase();
+
base.predict(ec);
- do_loss(ec);
uint32_t action = (uint32_t)ec.final_prediction;
float base_prob = data.epsilon / data.k;
if (frand48() < 1. - data.epsilon)
{
- CB::cb_class l = {ec.loss, action, 1.f - data.epsilon + base_prob};
+ CB::cb_class l = {loss(ld->label, ec.final_prediction),
+ action, 1.f - data.epsilon + base_prob};
data.cb_label.costs.push_back(l);
}
else
{
- do_uniform(data, ec);
- do_loss(ec);
- action = (uint32_t)ec.final_prediction;
- CB::cb_class l = {ec.loss, (uint32_t)ec.final_prediction, base_prob};
+ action = do_uniform(data);
+ CB::cb_class l = {loss(ld->label, action),
+ action, base_prob};
+ if (action == ec.final_prediction)
+ l.probability = 1.f - data.epsilon + base_prob;
data.cb_label.costs.push_back(l);
}
if (is_learn)
- base.learn(ec);
-
+ base.learn(ec);
+
ec.final_prediction = (float)action;
- ec.loss = data.cb_label.costs[0].cost;
+ ec.loss = loss(ld->label, ec.final_prediction);
ec.ld = ld;
}
@@ -119,6 +108,8 @@ namespace CBIFY {
//Use cost sensitive oracle to cover actions to form distribution.
}
+ void init_driver(cbify&) {}
+
void finish_example(vw& all, cbify&, example& ec)
{
OAA::output_example(all, ec);
@@ -179,8 +170,7 @@ namespace CBIFY {
l->set_predict<cbify, predict_or_learn_greedy<false> >();
}
l->set_finish_example<cbify,finish_example>();
-
- cout << "epsilon = " << data->epsilon << endl;
+ l->set_init_driver<cbify,init_driver>();
return l;
}