Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJake Hofman <jhofman@gmail.com>2014-01-07 02:01:15 +0400
committerJake Hofman <jhofman@gmail.com>2014-01-07 02:01:15 +0400
commitad6ac6854cb482bc4ab27f1cdc4a5276b7bf27b8 (patch)
tree2b8267ee7c7a125c64f51038eac027cc3628354c /vowpalwabbit/csoaa.cc
parentef823fba0220b7693f21b17f71261db1bcccf431 (diff)
reductions now have predict functions, but tests break
Diffstat (limited to 'vowpalwabbit/csoaa.cc')
-rw-r--r--vowpalwabbit/csoaa.cc77
1 files changed, 45 insertions, 32 deletions
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index 8b207f0b..d8781ba2 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -311,33 +311,39 @@ namespace CSOAA {
print_update(all, is_test_label((label*)ec->ld), ec);
}
- void learn(void* d, learner& base, example* ec) {
+ template <bool is_learn>
+ void predict_or_learn(void* d, learner& base, example* ec) {
csoaa* c = (csoaa*)d;
vw* all = c->all;
label* ld = (label*)ec->ld;
size_t prediction = 1;
float score = FLT_MAX;
+ label_data simple_temp;
+ simple_temp.initial = 0.;
+ ec->ld = &simple_temp;
for (wclass *cl = ld->costs.begin; cl != ld->costs.end; cl ++)
{
uint32_t i = cl->weight_index;
- label_data simple_temp;
- simple_temp.initial = 0.;
-
- if (cl->x == FLT_MAX || !all->training)
+
+ if (is_learn)
{
- simple_temp.label = FLT_MAX;
- simple_temp.weight = 0.;
+ if (cl->x == FLT_MAX || !all->training)
+ {
+ simple_temp.label = FLT_MAX;
+ simple_temp.weight = 0.;
+ }
+ else
+ {
+ simple_temp.label = cl->x;
+ simple_temp.weight = 1.;
+ }
+
+ base.learn(ec);
}
else
- {
- simple_temp.label = cl->x;
- simple_temp.weight = 1.;
- }
-
- ec->ld = &simple_temp;
+ base.predict(ec);
- base.learn(ec, i - 1);
cl->partial_prediction = ec->partial_prediction;
if (ec->partial_prediction < score || (ec->partial_prediction == score && i < prediction)) {
score = ec->partial_prediction;
@@ -378,7 +384,7 @@ namespace CSOAA {
*(all.p->lp) = cs_label_parser;
all.sd->k = nb_actions;
- learner* l = new learner(c, learn, all.l, nb_actions);
+ learner* l = new learner(c, predict_or_learn<true>, predict_or_learn<false>, all.l, nb_actions);
l->set_finish_example(finish_example);
return l;
}
@@ -623,16 +629,15 @@ namespace LabelDict {
v_array<CSOAA::wclass> costs = ld->costs;
label_data simple_label;
+ simple_label.initial = 0.;
+ simple_label.label = FLT_MAX;
+ simple_label.weight = 0.;
for (size_t j=0; j<costs.size(); j++) {
- simple_label.initial = 0.;
- simple_label.label = FLT_MAX;
- simple_label.weight = 0.;
- ec->partial_prediction = 0.;
LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index);
-
+
ec->ld = &simple_label;
- base.learn(ec); // make a prediction
+ base.predict(ec); // make a prediction
costs[j].partial_prediction = ec->partial_prediction;
if (ec->partial_prediction < *min_score) {
@@ -650,7 +655,7 @@ namespace LabelDict {
}
-
+ template <bool is_learn>
void do_actual_learning_wap(vw& all, ldf& l, learner& base, size_t start_K)
{
size_t K = l.ec_seq.size();
@@ -721,7 +726,10 @@ namespace LabelDict {
simple_label.weight = value_diff;
ec1->partial_prediction = 0.;
subtract_example(all, ec1, ec2);
- base.learn(ec1);
+ if (is_learn)
+ base.learn(ec1);
+ else
+ base.predict(ec1);
unsubtract_example(all, ec1);
LabelDict::del_example_namespace_from_memory(l, ec2, costs2[j2].weight_index);
@@ -738,6 +746,7 @@ namespace LabelDict {
}
}
+ template <bool is_learn>
void do_actual_learning_oaa(vw& all, ldf& l, learner& base, size_t start_K)
{
size_t K = l.ec_seq.size();
@@ -796,7 +805,10 @@ namespace LabelDict {
//cerr << "[" << ec->partial_prediction << "," << ec->done << "]";
//ec->done = false;
LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index);
- base.learn(ec);
+ if (is_learn)
+ base.learn(ec);
+ else
+ base.predict(ec);
LabelDict::del_example_namespace_from_memory(l, ec, costs[j].weight_index);
ec->example_t = example_t;
}
@@ -816,7 +828,7 @@ namespace LabelDict {
}
}
-
+ template <bool is_learn>
void do_actual_learning(vw& all, ldf& l, learner& base)
{
if (l.ec_seq.size() <= 0) return; // nothing to do
@@ -848,8 +860,8 @@ namespace LabelDict {
}
/////////////////////// learn
- if (l.is_wap) do_actual_learning_wap(all, l, base, start_K);
- else do_actual_learning_oaa(all, l, base, start_K);
+ if (l.is_wap) do_actual_learning_wap<is_learn>(all, l, base, start_K);
+ else do_actual_learning_oaa<is_learn>(all, l, base, start_K);
/////////////////////// remove header
if (start_K > 0)
@@ -934,7 +946,8 @@ namespace LabelDict {
l->first_pass = false;
}
- void learn(void* data, learner& base, example *ec)
+ template <bool is_learn>
+ void predict_or_learn(void* data, learner& base, example *ec)
{
ldf* l=(ldf*)data;
vw* all = l->all;
@@ -951,7 +964,7 @@ namespace LabelDict {
if (l->ec_seq.size() >= all->p->ring_size - 2 && l->first_pass)
cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << endl;
- do_actual_learning(*all, *l, base);
+ do_actual_learning<is_learn>(*all, *l, base);
if (!LabelDict::ec_seq_is_label_definition(*l, l->ec_seq) && l->ec_seq.size() > 0)
global_print_newline(*all);
@@ -966,7 +979,7 @@ namespace LabelDict {
if (!((!all->training) || CSOAA::example_is_test(ec))) {
l->ec_seq.erase();
l->ec_seq.push_back(ec);
- do_actual_learning(*all, *l, base);
+ do_actual_learning<is_learn>(*all, *l, base);
l->ec_seq.erase();
}
@@ -1018,7 +1031,7 @@ namespace LabelDict {
{
ldf* l=(ldf*)data;
vw* all = l->all;
- do_actual_learning(*all, *l, *(l->base));
+ do_actual_learning<true>(*all, *l, *(l->base));
output_example_seq(*all, *l);
clear_seq(*all, *l);
l->ec_seq.delete_v();
@@ -1096,7 +1109,7 @@ namespace LabelDict {
ld->read_example_this_loop = 0;
ld->need_to_clear = false;
- learner* l = new learner(ld, learn, all.l);
+ learner* l = new learner(ld, predict_or_learn<true>, predict_or_learn<false>, all.l);
if (ld->is_singleline)
l->set_finish_example(finish_example);
else