diff options
author | John Langford <jl@nyclamp.(none)> | 2013-10-29 18:55:22 +0400 |
---|---|---|
committer | John Langford <jl@nyclamp.(none)> | 2013-10-29 18:55:22 +0400 |
commit | 2080aa9aaccf6bb5709f223c2f779b565cc4d104 (patch) | |
tree | 804732d81522c2651973875ab38496b2c301a6bf /vowpalwabbit | |
parent | 39d2d61df606929569b9d6736bfd25cbe373a06e (diff) |
more standard drivers
Diffstat (limited to 'vowpalwabbit')
-rw-r--r-- | vowpalwabbit/bs.cc | 35 | ||||
-rw-r--r-- | vowpalwabbit/cb.cc | 29 | ||||
-rw-r--r-- | vowpalwabbit/csoaa.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/csoaa.h | 2 | ||||
-rw-r--r-- | vowpalwabbit/lda_core.cc | 4 | ||||
-rw-r--r-- | vowpalwabbit/learner.h | 8 | ||||
-rw-r--r-- | vowpalwabbit/nn.cc | 8 | ||||
-rw-r--r-- | vowpalwabbit/oaa.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/oaa.h | 2 | ||||
-rw-r--r-- | vowpalwabbit/sender.cc | 4 | ||||
-rw-r--r-- | vowpalwabbit/simple_label.cc | 2 | ||||
-rw-r--r-- | vowpalwabbit/simple_label.h | 2 |
12 files changed, 37 insertions, 63 deletions
diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc index 639b442b..2e533481 100644 --- a/vowpalwabbit/bs.cc +++ b/vowpalwabbit/bs.cc @@ -127,7 +127,7 @@ namespace BS { } } - void output_example(vw& all, example* ec, bs* d) + void output_example(vw& all, bs* d, example* ec) { if (command_example(&all,ec)) return; @@ -172,9 +172,12 @@ namespace BS { print_update(all, ec); } - void learn_with_output(bs* d, example* ec, bool shouldOutput) + void learn(void* data, example* ec) { + bs* d = (bs*)data; vw* all = d->all; + bool shouldOutput = all->raw_prediction > 0; + if (command_example(all,ec)) { d->base.learn(ec); @@ -226,27 +229,11 @@ namespace BS { } - void learn(void* d, example* ec) { - learn_with_output((bs*)d, ec, false); - } - - void drive(vw* all, void* d) + void finish_example(vw& all, void* d, example* ec) { - example* ec = NULL; - while ( true ) - { - if ((ec = VW::get_example(all->p)) != NULL)//semiblocking operation. - { - learn_with_output((bs*)d, ec, all->raw_prediction > 0); - if (!command_example(all, ec)) - BS::output_example(*all, ec, (bs*)d); - VW::finish_example(*all, ec); - } - else if (parser_done(all->p)) - return; - else - ; - } + if (!command_example(&all, ec)) + BS::output_example(all, (bs*)d, ec); + VW::finish_example(all, ec); } void finish(void* data) @@ -332,7 +319,9 @@ namespace BS { all.weights_per_problem *= data->B; data->total_increment = data->increment*(data->B-1); data->base = all.l; - learner l(data, drive, learn, finish, all.l.sl); + learner l(data, LEARNER::generic_driver, learn, finish, all.l.sl); + + l.set_finish_example(finish_example); return l; } } diff --git a/vowpalwabbit/cb.cc b/vowpalwabbit/cb.cc index d3a4c5d1..8d204450 100644 --- a/vowpalwabbit/cb.cc +++ b/vowpalwabbit/cb.cc @@ -658,27 +658,12 @@ namespace CB free(c); } - void drive(vw* all, void* d) + void finish_example(vw& all, void* data, example* ec) { - cb* c = (cb*)d; - example* ec = NULL; - while ( true ) - { - if(all-> early_terminate) - { - all->p->done = true; - return; - } - if ((ec = VW::get_example(all->p)) != NULL)//semiblocking operation. - { - learn(d, ec); - if (!command_example(&all, ec)) - output_example(*all, *c, ec); - VW::finish_example(*all, ec); - } - else if (parser_done(all->p)) - return; - } + cb* c = (cb*)data; + if (!command_example(&all, ec)) + output_example(all, *c, ec); + VW::finish_example(all, ec); } learner setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file) @@ -766,8 +751,10 @@ namespace CB all.sd->k = nb_actions; - learner l(c, drive, learn, finish, all.l.sl); + learner l(c, LEARNER::generic_driver, learn, finish, all.l.sl); c->base = all.l; + l.set_finish_example(finish_example); + return l; } } diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc index fa773725..85190934 100644 --- a/vowpalwabbit/csoaa.cc +++ b/vowpalwabbit/csoaa.cc @@ -368,7 +368,7 @@ namespace CSOAA { update_example_indicies(all->audit, ec, -current_increment); } - void finish_example(vw& all, example* ec) + void finish_example(vw& all, void*, example* ec) { if (!command_example(&all, ec)) output_example(all, ec); diff --git a/vowpalwabbit/csoaa.h b/vowpalwabbit/csoaa.h index b26d4ed1..ac978eb4 100644 --- a/vowpalwabbit/csoaa.h +++ b/vowpalwabbit/csoaa.h @@ -30,7 +30,7 @@ namespace CSOAA { learner setup(vw& all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file); - void finish_example(vw& all, example* ec); + void finish_example(vw& all, void*, example* ec); size_t read_cached_label(shared_data* sd, void* v, io_buf& cache); void cache_label(void* v, io_buf& cache); void default_label(void* v); diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index 12756e73..be5e8ccc 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -644,7 +644,7 @@ void save_load(void* d, io_buf& model_file, bool read, bool text) l.all->sd->sum_loss -= score;
l.all->sd->sum_loss_since_last_dump -= score;
}
- return_simple_example(*l.all, l.examples[d]);
+ return_simple_example(*l.all, NULL, l.examples[d]);
}
for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back();)
@@ -717,7 +717,7 @@ void end_examples(void* d) free(d);
}
-void finish_example(vw& all, example*ec)
+ void finish_example(vw& all, void*, example*ec)
{}
learner setup(vw&all, std::vector<std::string>&opts, po::variables_map& vm)
diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h index 2f9a5ff8..9ad1efe4 100644 --- a/vowpalwabbit/learner.h +++ b/vowpalwabbit/learner.h @@ -13,7 +13,7 @@ struct sl_t { void (*save_loader)(void* sldata, io_buf&, bool read, bool text); }; -void return_simple_example(vw& all, example* ec); +void return_simple_example(vw& all, void*, example* ec); #include<iostream> using namespace std; @@ -40,7 +40,7 @@ private: void* data; void (*driver)(vw* all, void* data); void (*learn_f)(void* data, example*); - void (*finish_example_f)(vw&, example*); + void (*finish_example_f)(vw&, void* data, example*); void (*finisher)(void* data); void (*end_pass)(void* data); void (*end_examples_f)(void* data); @@ -50,11 +50,11 @@ public: inline void learn(example* ec) { learn_f(data,ec); } inline void finish() { finisher(data); } - inline void finish_example(vw& all, example* ec) { finish_example_f(all,ec);} + inline void finish_example(vw& all, example* ec) { finish_example_f(all, data, ec);} inline void drive(vw* all) { driver(all, data); } inline void save_load(io_buf& io, bool read, bool text) { sl.save_loader(sl.sldata, io, read, text); } - void set_finish_example(void (*ef)(vw& all, example*)) + void set_finish_example(void (*ef)(vw& all, void*, example*)) {finish_example_f = ef;} void set_end_examples(void (*ee)(void*)) diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc index d16cadad..afff4854 100644 --- a/vowpalwabbit/nn.cc +++ b/vowpalwabbit/nn.cc @@ -69,9 +69,7 @@ namespace NN { uint32_t offset) { for (feature* x = f.begin; x != f.end; ++x) - { - x->weight_index += offset; - } + x->weight_index += offset; } void finish_setup (nn* n, vw& all); @@ -255,11 +253,11 @@ CONVERSE: // That's right, I'm using goto. So sue me. ec->loss = save_ec_loss; } - void finish_example(vw& all, example* ec) + void finish_example(vw& all, void*, example* ec) { int save_raw_prediction = all.raw_prediction; all.raw_prediction = -1; - return_simple_example(all, ec); + return_simple_example(all, NULL, ec); all.raw_prediction = save_raw_prediction; } diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc index b8a7a3d6..1bf492bb 100644 --- a/vowpalwabbit/oaa.cc +++ b/vowpalwabbit/oaa.cc @@ -191,7 +191,7 @@ namespace OAA { OAA::print_update(all, ec); } - void finish_example(vw& all, example* ec) + void finish_example(vw& all, void*, example* ec) { output_example(all, ec); VW::finish_example(all, ec); diff --git a/vowpalwabbit/oaa.h b/vowpalwabbit/oaa.h index f6b7e492..2abb0d30 100644 --- a/vowpalwabbit/oaa.h +++ b/vowpalwabbit/oaa.h @@ -36,7 +36,7 @@ namespace OAA NULL, sizeof(mc_label)}; - void finish_example(vw& all, example* ec); + void finish_example(vw& all, void*, example* ec); inline int example_is_test(example* ec) { diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc index a966acda..87b4b3b6 100644 --- a/vowpalwabbit/sender.cc +++ b/vowpalwabbit/sender.cc @@ -68,7 +68,7 @@ void receive_result(sender& s) ec->loss = s.all->loss->getLoss(s.all->sd, ec->final_prediction, ld->label) * ld->weight; - return_simple_example(*(s.all), ec); + return_simple_example(*(s.all), NULL, ec); } void learn(void* d, example* ec) @@ -85,7 +85,7 @@ void receive_result(sender& s) s->delay_ring[s->sent_index++ % s->all->p->ring_size] = ec; } -void finish_example(vw& all, example*ec) + void finish_example(vw& all, void*, example*ec) {} void end_examples(void* d) diff --git a/vowpalwabbit/simple_label.cc b/vowpalwabbit/simple_label.cc index d3605d68..c56d6921 100644 --- a/vowpalwabbit/simple_label.cc +++ b/vowpalwabbit/simple_label.cc @@ -225,7 +225,7 @@ void output_and_account_example(vw& all, example* ec) print_update(all, ec); } -void return_simple_example(vw& all, example* ec) +void return_simple_example(vw& all, void*, example* ec) { if (!command_example(&all, ec)) output_and_account_example(all, ec); diff --git a/vowpalwabbit/simple_label.h b/vowpalwabbit/simple_label.h index 1384a45c..898df9be 100644 --- a/vowpalwabbit/simple_label.h +++ b/vowpalwabbit/simple_label.h @@ -17,7 +17,7 @@ struct label_data { float initial; }; -void return_simple_example(vw& all, example* ec); +void return_simple_example(vw& all, void*, example* ec); size_t read_cached_simple_label(shared_data* sd, void* v, io_buf& cache); void cache_simple_label(void* v, io_buf& cache); |