diff options
author | John Langford <jl@hunch.net> | 2014-01-29 22:28:43 +0400 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-01-29 22:28:43 +0400 |
commit | 09df65e84bcceca5c5127a730eaf6b6f0c2e00d6 (patch) | |
tree | ae2eefab3a25aa905f784dd8a868d6ac002b8776 /vowpalwabbit/cbify.cc | |
parent | 2fb8eca0a0d2a8ba5e8438145b5639486cc4c933 (diff) |
improved more compact cbify
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r-- | vowpalwabbit/cbify.cc | 53 |
1 files changed, 41 insertions, 12 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index 2bfcbef5..8061f383 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -165,6 +165,26 @@ namespace CBIFY { return 1; } + void safety(v_array<float>& distribution, float min_prob) + { + float added_mass = 0.; + for (uint32_t i = 0; i < distribution.size();i++) + if (distribution[i] > 0 && distribution[i] <= min_prob) + { + added_mass += min_prob - distribution[i]; + distribution[i] = min_prob; + } + + float ratio = 1. / (1. + added_mass); + if (ratio < 0.999) + { + for (uint32_t i = 0; i < distribution.size(); i++) + if (distribution[i] > min_prob) + distribution[i] = distribution[i] * ratio; + safety(distribution, min_prob); + } + } + void gen_cs_label(vw& all, CB::cb_class& known_cost, example& ec, CSOAA::label& cs_ld, uint32_t label) { CSOAA::wclass wc; @@ -188,14 +208,12 @@ namespace CBIFY { //Use cost sensitive oracle to cover actions to form distribution. OAA::mc_label* ld = (OAA::mc_label*)ec.ld; data.counter++; - float round_epsilon = data.epsilon / sqrt(data.counter); - float base_prob = round_epsilon / data.k; data.count.erase(); data.cs_label.costs.erase(); for (size_t j = 0; j < data.k; j++) { - data.count.push_back(base_prob); + data.count.push_back(0); CSOAA::wclass wc; @@ -207,18 +225,26 @@ namespace CBIFY { data.cs_label.costs.push_back(wc); } - float additive_probability = (1. - round_epsilon) / (float)data.bags; + float additive_probability = 1. / (float)data.bags; ec.ld = &data.cs_label; for (size_t i = 0; i < data.bags; i++) { //get predicted cost-sensitive predictions - data.cs->predict(ec,i+2); + if (i == 0) + data.cs->predict(ec, i); + else + data.cs->predict(ec,i+1); data.count[ec.final_prediction-1] += additive_probability; data.predictions[i] = ec.final_prediction; } + + float min_prob = data.epsilon * min (1. / data.k, 1. / sqrt(data.counter * data.k)); + + safety(data.count, min_prob); + //compute random action uint32_t action = choose_action(data.count); - + if (is_learn) { data.cb_label.costs.erase(); @@ -233,12 +259,11 @@ namespace CBIFY { //1. Compute loss vector data.cs_label.costs.erase(); - float norm = 0; + float norm = min_prob * data.k; for (uint32_t j = 0; j < data.k; j++) { //data.cs_label now contains an unbiased estimate of cost of each class. gen_cs_label(*data.all, l, ec, data.cs_label, j+1); - data.count[j] = base_prob; - norm += base_prob; + data.count[j] = 0; } ec.ld = &data.second_cs_label; @@ -247,13 +272,17 @@ namespace CBIFY { { //get predicted cost-sensitive predictions for (size_t j = 0; j < data.k; j++) { - float pseudo_cost = data.cs_label.costs[j].x - 0.125 * (base_prob / (data.count[j] / norm) + 1); + float pseudo_cost = data.cs_label.costs[j].x - data.epsilon * min_prob / (max(data.count[j], min_prob) / norm) + 1; data.second_cs_label.costs[j].weight_index = j+1; data.second_cs_label.costs[j].x = pseudo_cost; } - data.cs->learn(ec,i+2); + if (i != 0) + data.cs->learn(ec,i+1); + if (data.count[data.predictions[i]-1] < min_prob) + norm += max(0, additive_probability - (min_prob - data.count[data.predictions[i]-1])); + else + norm += additive_probability; data.count[data.predictions[i]-1] += additive_probability; - norm += additive_probability; } } |