Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r--vowpalwabbit/cbify.cc78
1 files changed, 39 insertions, 39 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc
index d41975ae..3f62a3ef 100644
--- a/vowpalwabbit/cbify.cc
+++ b/vowpalwabbit/cbify.cc
@@ -17,92 +17,92 @@ namespace CBIFY {
CB::label cb_label;
};
- void do_uniform(cbify* data, example* ec)
+ void do_uniform(cbify& data, example& ec)
{
//Draw an action
- uint32_t action = (uint32_t)ceil(frand48() * data->k);
+ uint32_t action = (uint32_t)ceil(frand48() * data.k);
- ec->final_prediction = (float)action;
+ ec.final_prediction = (float)action;
}
- void do_loss(example* ec)
+ void do_loss(example& ec)
{
- OAA::mc_label* ld = (OAA::mc_label*)ec->ld;//New loss
+ OAA::mc_label* ld = (OAA::mc_label*)ec.ld;//New loss
- if (ld->label != ec->final_prediction)
- ec->loss = 1.;
+ if (ld->label != ec.final_prediction)
+ ec.loss = 1.;
else
- ec->loss = 0.;
+ ec.loss = 0.;
}
template <bool is_learn>
- void predict_or_learn_first(cbify* data, learner& base, example* ec)
+ void predict_or_learn_first(cbify& data, learner& base, example& ec)
{//Explore tau times, then act according to optimal.
- OAA::mc_label* ld = (OAA::mc_label*)ec->ld;
+ OAA::mc_label* ld = (OAA::mc_label*)ec.ld;
//Use CB to find current prediction for remaining rounds.
- if (data->tau > 0)
+ if (data.tau > 0)
{
do_uniform(data, ec);
do_loss(ec);
- data->tau--;
+ data.tau--;
cout << "tau--" << endl;
- uint32_t action = (uint32_t)ec->final_prediction;
- CB::cb_class l = {ec->loss, action, 1.f / data->k};
- data->cb_label.costs.erase();
- data->cb_label.costs.push_back(l);
- ec->ld = &(data->cb_label);
+ uint32_t action = (uint32_t)ec.final_prediction;
+ CB::cb_class l = {ec.loss, action, 1.f / data.k};
+ data.cb_label.costs.erase();
+ data.cb_label.costs.push_back(l);
+ ec.ld = &(data.cb_label);
if (is_learn)
base.learn(ec);
else
base.predict(ec);
- ec->final_prediction = (float)action;
- ec->loss = l.cost;
+ ec.final_prediction = (float)action;
+ ec.loss = l.cost;
}
else
{
- data->cb_label.costs.erase();
- ec->ld = &(data->cb_label);
+ data.cb_label.costs.erase();
+ ec.ld = &(data.cb_label);
if (is_learn)
base.learn(ec);
else
base.predict(ec);
do_loss(ec);
}
- ec->ld = ld;
+ ec.ld = ld;
}
template <bool is_learn>
- void predict_or_learn_greedy(cbify* data, learner& base, example* ec)
+ void predict_or_learn_greedy(cbify& data, learner& base, example& ec)
{//Explore uniform random an epsilon fraction of the time.
- OAA::mc_label* ld = (OAA::mc_label*)ec->ld;
+ OAA::mc_label* ld = (OAA::mc_label*)ec.ld;
- data->cb_label.costs.erase();
- ec->ld = &(data->cb_label);
+ data.cb_label.costs.erase();
+ ec.ld = &(data.cb_label);
base.predict(ec);
do_loss(ec);
- uint32_t action = (uint32_t)ec->final_prediction;
+ uint32_t action = (uint32_t)ec.final_prediction;
- float base_prob = data->epsilon / data->k;
- if (frand48() < 1. - data->epsilon)
+ float base_prob = data.epsilon / data.k;
+ if (frand48() < 1. - data.epsilon)
{
- CB::cb_class l = {ec->loss, action, 1.f - data->epsilon + base_prob};
- data->cb_label.costs.push_back(l);
+ CB::cb_class l = {ec.loss, action, 1.f - data.epsilon + base_prob};
+ data.cb_label.costs.push_back(l);
}
else
{
do_uniform(data, ec);
do_loss(ec);
- action = (uint32_t)ec->final_prediction;
- CB::cb_class l = {ec->loss, (uint32_t)ec->final_prediction, base_prob};
- data->cb_label.costs.push_back(l);
+ action = (uint32_t)ec.final_prediction;
+ CB::cb_class l = {ec.loss, (uint32_t)ec.final_prediction, base_prob};
+ data.cb_label.costs.push_back(l);
}
if (is_learn)
base.learn(ec);
- ec->final_prediction = (float)action;
- ec->loss = data->cb_label.costs[0].cost;
- ec->ld = ld;
+ ec.final_prediction = (float)action;
+ ec.loss = data.cb_label.costs[0].cost;
+ ec.ld = ld;
}
void learn_bagging(void* d, learner& base, example* ec)
@@ -119,10 +119,10 @@ namespace CBIFY {
//Use cost sensitive oracle to cover actions to form distribution.
}
- void finish_example(vw& all, cbify*, example* ec)
+ void finish_example(vw& all, cbify&, example& ec)
{
OAA::output_example(all, ec);
- VW::finish_example(all, ec);
+ VW::finish_example(all, &ec);
}
learner* setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)