Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHal Daume III <me@hal3.name>2014-01-09 17:45:34 +0400
committerHal Daume III <me@hal3.name>2014-01-09 17:45:34 +0400
commit90b37b4996ce88e4f41a5353f9ffc284387b14de (patch)
tree588da572019224cdbe83bfbb615006d812f6bc3c /vowpalwabbit/csoaa.cc
parent05dcce8b1f673c0a083fd75a36016c51a945af4d (diff)
parentd73d8f6aef1c02dd05554f91ca57e1d304336130 (diff)
pulled john's changes, tests are ok, but some other things broke :( -- namely due to offsetting
Diffstat (limited to 'vowpalwabbit/csoaa.cc')
-rw-r--r--vowpalwabbit/csoaa.cc121
1 files changed, 67 insertions, 54 deletions
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index 068cf5d9..6e9fbad0 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -17,6 +17,8 @@ license as described in the file LICENSE.
using namespace std;
+using namespace LEARNER;
+
namespace CSOAA {
struct csoaa{
vw* all;
@@ -206,7 +208,7 @@ namespace CSOAA {
void print_update(vw& all, bool is_test, example *ec)
{
- if ( /* (all.sd->weighted_examples > all.sd->old_weighted_examples) || */ (all.sd->weighted_examples > all.sd->dump_interval && !all.quiet && !all.bfgs))
+ if (all.sd->weighted_examples >= all.sd->dump_interval && !all.quiet && !all.bfgs)
{
char label_buf[32];
if (is_test)
@@ -248,7 +250,7 @@ namespace CSOAA {
all.sd->sum_loss_since_last_dump = 0.0;
all.sd->old_weighted_examples = all.sd->weighted_examples;
- all.sd->dump_interval *= 2;
+ VW::update_dump_interval(all);
}
}
@@ -311,33 +313,38 @@ namespace CSOAA {
print_update(all, is_test_label((label*)ec->ld), ec);
}
- void learn(void* d, learner& base, example* ec) {
- csoaa* c = (csoaa*)d;
+ template <bool is_learn>
+ void predict_or_learn(csoaa* c, learner& base, example* ec) {
vw* all = c->all;
label* ld = (label*)ec->ld;
size_t prediction = 1;
float score = FLT_MAX;
+ label_data simple_temp;
+ simple_temp.initial = 0.;
+ ec->ld = &simple_temp;
for (wclass *cl = ld->costs.begin; cl != ld->costs.end; cl ++)
{
uint32_t i = cl->weight_index;
- label_data simple_temp;
- simple_temp.initial = 0.;
-
- if (cl->x == FLT_MAX || !all->training)
+
+ if (is_learn)
{
- simple_temp.label = FLT_MAX;
- simple_temp.weight = 0.;
+ if (cl->x == FLT_MAX || !all->training)
+ {
+ simple_temp.label = FLT_MAX;
+ simple_temp.weight = 0.;
+ }
+ else
+ {
+ simple_temp.label = cl->x;
+ simple_temp.weight = 1.;
+ }
+
+ base.learn(ec, i);
}
else
- {
- simple_temp.label = cl->x;
- simple_temp.weight = 1.;
- }
+ base.predict(ec, i);
- ec->ld = &simple_temp;
-
- base.learn(ec, i - 1);
cl->partial_prediction = ec->partial_prediction;
if (ec->partial_prediction < score || (ec->partial_prediction == score && i < prediction)) {
score = ec->partial_prediction;
@@ -349,7 +356,7 @@ namespace CSOAA {
ec->final_prediction = (float)prediction;
}
- void finish_example(vw& all, void*, example* ec)
+ void finish_example(vw& all, csoaa*, example* ec)
{
output_example(all, ec);
VW::finish_example(all, ec);
@@ -378,8 +385,10 @@ namespace CSOAA {
*(all.p->lp) = cs_label_parser;
all.sd->k = nb_actions;
- learner* l = new learner(c, learn, all.l, nb_actions);
- l->set_finish_example(finish_example);
+ learner* l = new learner(c, all.l, nb_actions);
+ l->set_learn<csoaa, predict_or_learn<true> >();
+ l->set_predict<csoaa, predict_or_learn<false> >();
+ l->set_finish_example<csoaa,finish_example>();
return l;
}
@@ -643,7 +652,7 @@ namespace LabelDict {
LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index);
ec->ld = &simple_label;
- base.learn(ec); // make a prediction
+ base.predict(ec); // make a prediction
costs[j].partial_prediction = ec->partial_prediction;
if (min_score && prediction && (ec->partial_prediction < *min_score)) {
@@ -662,7 +671,7 @@ namespace LabelDict {
}
-
+ template <bool is_learn>
void do_actual_learning_wap(vw& all, ldf& l, learner& base, size_t start_K)
{
size_t K = l.ec_seq.size();
@@ -733,7 +742,10 @@ namespace LabelDict {
simple_label.weight = value_diff;
ec1->partial_prediction = 0.;
subtract_example(all, ec1, ec2);
- base.learn(ec1);
+ if (is_learn)
+ base.learn(ec1);
+ else
+ base.predict(ec1);
unsubtract_example(all, ec1);
LabelDict::del_example_namespace_from_memory(l, ec2, costs2[j2].weight_index);
@@ -750,6 +762,7 @@ namespace LabelDict {
}
}
+ template <bool is_learn>
void do_actual_learning_oaa(vw& all, ldf& l, learner& base, size_t start_K)
{
size_t K = l.ec_seq.size();
@@ -812,7 +825,10 @@ namespace LabelDict {
//cerr << "[" << ec->partial_prediction << "," << ec->done << "]";
//ec->done = false;
LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index);
- base.learn(ec);
+ if (is_learn)
+ base.learn(ec);
+ else
+ base.predict(ec);
LabelDict::del_example_namespace_from_memory(l, ec, costs[j].weight_index);
ec->example_t = example_t;
}
@@ -832,7 +848,7 @@ namespace LabelDict {
}
}
-
+ template <bool is_learn>
void do_actual_learning(vw& all, ldf& l, learner& base)
{
//clog << "do_actual_learning size=" << l.ec_seq.size() << endl;
@@ -865,8 +881,8 @@ namespace LabelDict {
}
/////////////////////// learn
- if (l.is_wap) do_actual_learning_wap(all, l, base, start_K);
- else do_actual_learning_oaa(all, l, base, start_K);
+ if (l.is_wap) do_actual_learning_wap<is_learn>(all, l, base, start_K);
+ else do_actual_learning_oaa<is_learn>(all, l, base, start_K);
/////////////////////// remove header
if (start_K > 0)
@@ -945,16 +961,18 @@ namespace LabelDict {
l.ec_seq.erase();
}
- void end_pass(void* data)
+ void end_pass(ldf* l)
{
- ldf* l=(ldf*)data;
l->first_pass = false;
}
/*
void learn(void* data, learner& base, example *ec)
+=======
+ template <bool is_learn>
+ void predict_or_learn(ldf* l, learner& base, example *ec)
+>>>>>>> d73d8f6aef1c02dd05554f91ca57e1d304336130
{
- ldf* l=(ldf*)data;
vw* all = l->all;
l->base = &base;
@@ -970,7 +988,7 @@ namespace LabelDict {
if (l->ec_seq.size() >= all->p->ring_size - 2 && l->first_pass)
cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << endl;
- do_actual_learning(*all, *l, base);
+ do_actual_learning<is_learn>(*all, *l, base);
if (!LabelDict::ec_seq_is_label_definition(*l, l->ec_seq) && l->ec_seq.size() > 0)
global_print_newline(*all);
@@ -985,7 +1003,7 @@ namespace LabelDict {
if (!((!all->training) || CSOAA::example_is_test(ec))) {
l->ec_seq.erase();
l->ec_seq.push_back(ec);
- do_actual_learning(*all, *l, base);
+ do_actual_learning<is_learn>(*all, *l, base);
l->ec_seq.erase();
}
@@ -1004,7 +1022,7 @@ namespace LabelDict {
}
*/
- void finish_singleline_example(vw& all, void*, example* ec)
+ void finish_singleline_example(vw& all, ldf*, example* ec)
{
if (! LabelDict::ec_is_label_definition(ec)) {
all.sd->weighted_examples += 1;
@@ -1015,9 +1033,8 @@ namespace LabelDict {
VW::finish_example(all, ec);
}
- void finish_multiline_example(vw& all, void* data, example* ec)
+ void finish_multiline_example(vw& all, ldf* l, example* ec)
{
- ldf* l=(ldf*)data;
if (l->need_to_clear) {
if (l->ec_seq.size() > 0) {
output_example_seq(all, *l);
@@ -1029,28 +1046,22 @@ namespace LabelDict {
}
}
- void end_examples(void* data)
+ void end_examples(ldf* l)
{
- ldf* l=(ldf*)data;
- //vw* all = l->all;
if (l->need_to_clear)
l->ec_seq.erase();
}
- void finish(void* d)
+ void finish(ldf* l)
{
- ldf* l=(ldf*)d;
- /*vw* all = l->all;
- clear_seq(*all, *l);
- l->ec_seq.delete_v();*/
- l->ec_seq.erase();
+ //vw* all = l->all;
l->ec_seq.delete_v();
LabelDict::free_label_features(*l);
}
- void learn(void* data, learner& base, example* ec) {
- ldf* l = (ldf*)data;
+ template <bool is_learn>
+ void predict_or_learn(ldf* l, learner& base, example *ec) {
vw* all = l->all;
l->base = &base;
@@ -1067,7 +1078,7 @@ namespace LabelDict {
if (need_to_break && l->first_pass)
cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << endl;
- do_actual_learning(*all, *l, base);
+ do_actual_learning<is_learn>(*all, *l, base);
l->need_to_clear = true;
} else if (LabelDict::ec_is_label_definition(ec)) {
if (l->ec_seq.size() > 0) {
@@ -1077,7 +1088,7 @@ namespace LabelDict {
if (! is_test) {
l->ec_seq.push_back(ec);
- do_actual_learning(*all, *l, base);
+ do_actual_learning<is_learn>(*all, *l, base);
l->need_to_clear = true;
}
} else {
@@ -1161,14 +1172,16 @@ namespace LabelDict {
ld->read_example_this_loop = 0;
ld->need_to_clear = false;
- learner* l = new learner(ld, learn, all.l);
+ learner* l = new learner(ld, all.l);
+ l->set_learn<ldf, predict_or_learn<true> >();
+ l->set_predict<ldf, predict_or_learn<false> >();
if (ld->is_singleline)
- l->set_finish_example(finish_singleline_example);
+ l->set_finish_example<ldf,finish_singleline_example>();
else
- l->set_finish_example(finish_multiline_example);
- l->set_finish(finish);
- l->set_end_examples(end_examples);
- l->set_end_pass(end_pass);
+ l->set_finish_example<ldf,finish_multiline_example>();
+ l->set_finish<ldf,finish>();
+ l->set_end_examples<ldf,end_examples>();
+ l->set_end_pass<ldf,end_pass>();
return l;
}