Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-01-29 22:28:43 +0400
committerJohn Langford <jl@hunch.net>2014-01-29 22:28:43 +0400
commit09df65e84bcceca5c5127a730eaf6b6f0c2e00d6 (patch)
treeae2eefab3a25aa905f784dd8a868d6ac002b8776 /vowpalwabbit/cbify.cc
parent2fb8eca0a0d2a8ba5e8438145b5639486cc4c933 (diff)
improved more compact cbify
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r--vowpalwabbit/cbify.cc53
1 files changed, 41 insertions, 12 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc
index 2bfcbef5..8061f383 100644
--- a/vowpalwabbit/cbify.cc
+++ b/vowpalwabbit/cbify.cc
@@ -165,6 +165,26 @@ namespace CBIFY {
return 1;
}
+ void safety(v_array<float>& distribution, float min_prob)
+ {
+ float added_mass = 0.;
+ for (uint32_t i = 0; i < distribution.size();i++)
+ if (distribution[i] > 0 && distribution[i] <= min_prob)
+ {
+ added_mass += min_prob - distribution[i];
+ distribution[i] = min_prob;
+ }
+
+ float ratio = 1. / (1. + added_mass);
+ if (ratio < 0.999)
+ {
+ for (uint32_t i = 0; i < distribution.size(); i++)
+ if (distribution[i] > min_prob)
+ distribution[i] = distribution[i] * ratio;
+ safety(distribution, min_prob);
+ }
+ }
+
void gen_cs_label(vw& all, CB::cb_class& known_cost, example& ec, CSOAA::label& cs_ld, uint32_t label)
{
CSOAA::wclass wc;
@@ -188,14 +208,12 @@ namespace CBIFY {
//Use cost sensitive oracle to cover actions to form distribution.
OAA::mc_label* ld = (OAA::mc_label*)ec.ld;
data.counter++;
- float round_epsilon = data.epsilon / sqrt(data.counter);
- float base_prob = round_epsilon / data.k;
data.count.erase();
data.cs_label.costs.erase();
for (size_t j = 0; j < data.k; j++)
{
- data.count.push_back(base_prob);
+ data.count.push_back(0);
CSOAA::wclass wc;
@@ -207,18 +225,26 @@ namespace CBIFY {
data.cs_label.costs.push_back(wc);
}
- float additive_probability = (1. - round_epsilon) / (float)data.bags;
+ float additive_probability = 1. / (float)data.bags;
ec.ld = &data.cs_label;
for (size_t i = 0; i < data.bags; i++)
{ //get predicted cost-sensitive predictions
- data.cs->predict(ec,i+2);
+ if (i == 0)
+ data.cs->predict(ec, i);
+ else
+ data.cs->predict(ec,i+1);
data.count[ec.final_prediction-1] += additive_probability;
data.predictions[i] = ec.final_prediction;
}
+
+ float min_prob = data.epsilon * min (1. / data.k, 1. / sqrt(data.counter * data.k));
+
+ safety(data.count, min_prob);
+
//compute random action
uint32_t action = choose_action(data.count);
-
+
if (is_learn)
{
data.cb_label.costs.erase();
@@ -233,12 +259,11 @@ namespace CBIFY {
//1. Compute loss vector
data.cs_label.costs.erase();
- float norm = 0;
+ float norm = min_prob * data.k;
for (uint32_t j = 0; j < data.k; j++)
{ //data.cs_label now contains an unbiased estimate of cost of each class.
gen_cs_label(*data.all, l, ec, data.cs_label, j+1);
- data.count[j] = base_prob;
- norm += base_prob;
+ data.count[j] = 0;
}
ec.ld = &data.second_cs_label;
@@ -247,13 +272,17 @@ namespace CBIFY {
{ //get predicted cost-sensitive predictions
for (size_t j = 0; j < data.k; j++)
{
- float pseudo_cost = data.cs_label.costs[j].x - 0.125 * (base_prob / (data.count[j] / norm) + 1);
+ float pseudo_cost = data.cs_label.costs[j].x - data.epsilon * min_prob / (max(data.count[j], min_prob) / norm) + 1;
data.second_cs_label.costs[j].weight_index = j+1;
data.second_cs_label.costs[j].x = pseudo_cost;
}
- data.cs->learn(ec,i+2);
+ if (i != 0)
+ data.cs->learn(ec,i+1);
+ if (data.count[data.predictions[i]-1] < min_prob)
+ norm += max(0, additive_probability - (min_prob - data.count[data.predictions[i]-1]));
+ else
+ norm += additive_probability;
data.count[data.predictions[i]-1] += additive_probability;
- norm += additive_probability;
}
}