Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@nyclamp.(none)>2013-10-29 18:55:22 +0400
committerJohn Langford <jl@nyclamp.(none)>2013-10-29 18:55:22 +0400
commit2080aa9aaccf6bb5709f223c2f779b565cc4d104 (patch)
tree804732d81522c2651973875ab38496b2c301a6bf /vowpalwabbit
parent39d2d61df606929569b9d6736bfd25cbe373a06e (diff)
more standard drivers
Diffstat (limited to 'vowpalwabbit')
-rw-r--r--vowpalwabbit/bs.cc35
-rw-r--r--vowpalwabbit/cb.cc29
-rw-r--r--vowpalwabbit/csoaa.cc2
-rw-r--r--vowpalwabbit/csoaa.h2
-rw-r--r--vowpalwabbit/lda_core.cc4
-rw-r--r--vowpalwabbit/learner.h8
-rw-r--r--vowpalwabbit/nn.cc8
-rw-r--r--vowpalwabbit/oaa.cc2
-rw-r--r--vowpalwabbit/oaa.h2
-rw-r--r--vowpalwabbit/sender.cc4
-rw-r--r--vowpalwabbit/simple_label.cc2
-rw-r--r--vowpalwabbit/simple_label.h2
12 files changed, 37 insertions, 63 deletions
diff --git a/vowpalwabbit/bs.cc b/vowpalwabbit/bs.cc
index 639b442b..2e533481 100644
--- a/vowpalwabbit/bs.cc
+++ b/vowpalwabbit/bs.cc
@@ -127,7 +127,7 @@ namespace BS {
}
}
- void output_example(vw& all, example* ec, bs* d)
+ void output_example(vw& all, bs* d, example* ec)
{
if (command_example(&all,ec))
return;
@@ -172,9 +172,12 @@ namespace BS {
print_update(all, ec);
}
- void learn_with_output(bs* d, example* ec, bool shouldOutput)
+ void learn(void* data, example* ec)
{
+ bs* d = (bs*)data;
vw* all = d->all;
+ bool shouldOutput = all->raw_prediction > 0;
+
if (command_example(all,ec))
{
d->base.learn(ec);
@@ -226,27 +229,11 @@ namespace BS {
}
- void learn(void* d, example* ec) {
- learn_with_output((bs*)d, ec, false);
- }
-
- void drive(vw* all, void* d)
+ void finish_example(vw& all, void* d, example* ec)
{
- example* ec = NULL;
- while ( true )
- {
- if ((ec = VW::get_example(all->p)) != NULL)//semiblocking operation.
- {
- learn_with_output((bs*)d, ec, all->raw_prediction > 0);
- if (!command_example(all, ec))
- BS::output_example(*all, ec, (bs*)d);
- VW::finish_example(*all, ec);
- }
- else if (parser_done(all->p))
- return;
- else
- ;
- }
+ if (!command_example(&all, ec))
+ BS::output_example(all, (bs*)d, ec);
+ VW::finish_example(all, ec);
}
void finish(void* data)
@@ -332,7 +319,9 @@ namespace BS {
all.weights_per_problem *= data->B;
data->total_increment = data->increment*(data->B-1);
data->base = all.l;
- learner l(data, drive, learn, finish, all.l.sl);
+ learner l(data, LEARNER::generic_driver, learn, finish, all.l.sl);
+
+ l.set_finish_example(finish_example);
return l;
}
}
diff --git a/vowpalwabbit/cb.cc b/vowpalwabbit/cb.cc
index d3a4c5d1..8d204450 100644
--- a/vowpalwabbit/cb.cc
+++ b/vowpalwabbit/cb.cc
@@ -658,27 +658,12 @@ namespace CB
free(c);
}
- void drive(vw* all, void* d)
+ void finish_example(vw& all, void* data, example* ec)
{
- cb* c = (cb*)d;
- example* ec = NULL;
- while ( true )
- {
- if(all-> early_terminate)
- {
- all->p->done = true;
- return;
- }
- if ((ec = VW::get_example(all->p)) != NULL)//semiblocking operation.
- {
- learn(d, ec);
- if (!command_example(&all, ec))
- output_example(*all, *c, ec);
- VW::finish_example(*all, ec);
- }
- else if (parser_done(all->p))
- return;
- }
+ cb* c = (cb*)data;
+ if (!command_example(&all, ec))
+ output_example(all, *c, ec);
+ VW::finish_example(all, ec);
}
learner setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
@@ -766,8 +751,10 @@ namespace CB
all.sd->k = nb_actions;
- learner l(c, drive, learn, finish, all.l.sl);
+ learner l(c, LEARNER::generic_driver, learn, finish, all.l.sl);
c->base = all.l;
+ l.set_finish_example(finish_example);
+
return l;
}
}
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index fa773725..85190934 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -368,7 +368,7 @@ namespace CSOAA {
update_example_indicies(all->audit, ec, -current_increment);
}
- void finish_example(vw& all, example* ec)
+ void finish_example(vw& all, void*, example* ec)
{
if (!command_example(&all, ec))
output_example(all, ec);
diff --git a/vowpalwabbit/csoaa.h b/vowpalwabbit/csoaa.h
index b26d4ed1..ac978eb4 100644
--- a/vowpalwabbit/csoaa.h
+++ b/vowpalwabbit/csoaa.h
@@ -30,7 +30,7 @@ namespace CSOAA {
learner setup(vw& all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file);
- void finish_example(vw& all, example* ec);
+ void finish_example(vw& all, void*, example* ec);
size_t read_cached_label(shared_data* sd, void* v, io_buf& cache);
void cache_label(void* v, io_buf& cache);
void default_label(void* v);
diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc
index 12756e73..be5e8ccc 100644
--- a/vowpalwabbit/lda_core.cc
+++ b/vowpalwabbit/lda_core.cc
@@ -644,7 +644,7 @@ void save_load(void* d, io_buf& model_file, bool read, bool text)
l.all->sd->sum_loss -= score;
l.all->sd->sum_loss_since_last_dump -= score;
}
- return_simple_example(*l.all, l.examples[d]);
+ return_simple_example(*l.all, NULL, l.examples[d]);
}
for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back();)
@@ -717,7 +717,7 @@ void end_examples(void* d)
free(d);
}
-void finish_example(vw& all, example*ec)
+ void finish_example(vw& all, void*, example*ec)
{}
learner setup(vw&all, std::vector<std::string>&opts, po::variables_map& vm)
diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h
index 2f9a5ff8..9ad1efe4 100644
--- a/vowpalwabbit/learner.h
+++ b/vowpalwabbit/learner.h
@@ -13,7 +13,7 @@ struct sl_t {
void (*save_loader)(void* sldata, io_buf&, bool read, bool text);
};
-void return_simple_example(vw& all, example* ec);
+void return_simple_example(vw& all, void*, example* ec);
#include<iostream>
using namespace std;
@@ -40,7 +40,7 @@ private:
void* data;
void (*driver)(vw* all, void* data);
void (*learn_f)(void* data, example*);
- void (*finish_example_f)(vw&, example*);
+ void (*finish_example_f)(vw&, void* data, example*);
void (*finisher)(void* data);
void (*end_pass)(void* data);
void (*end_examples_f)(void* data);
@@ -50,11 +50,11 @@ public:
inline void learn(example* ec) { learn_f(data,ec); }
inline void finish() { finisher(data); }
- inline void finish_example(vw& all, example* ec) { finish_example_f(all,ec);}
+ inline void finish_example(vw& all, example* ec) { finish_example_f(all, data, ec);}
inline void drive(vw* all) { driver(all, data); }
inline void save_load(io_buf& io, bool read, bool text) { sl.save_loader(sl.sldata, io, read, text); }
- void set_finish_example(void (*ef)(vw& all, example*))
+ void set_finish_example(void (*ef)(vw& all, void*, example*))
{finish_example_f = ef;}
void set_end_examples(void (*ee)(void*))
diff --git a/vowpalwabbit/nn.cc b/vowpalwabbit/nn.cc
index d16cadad..afff4854 100644
--- a/vowpalwabbit/nn.cc
+++ b/vowpalwabbit/nn.cc
@@ -69,9 +69,7 @@ namespace NN {
uint32_t offset)
{
for (feature* x = f.begin; x != f.end; ++x)
- {
- x->weight_index += offset;
- }
+ x->weight_index += offset;
}
void finish_setup (nn* n, vw& all);
@@ -255,11 +253,11 @@ CONVERSE: // That's right, I'm using goto. So sue me.
ec->loss = save_ec_loss;
}
- void finish_example(vw& all, example* ec)
+ void finish_example(vw& all, void*, example* ec)
{
int save_raw_prediction = all.raw_prediction;
all.raw_prediction = -1;
- return_simple_example(all, ec);
+ return_simple_example(all, NULL, ec);
all.raw_prediction = save_raw_prediction;
}
diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc
index b8a7a3d6..1bf492bb 100644
--- a/vowpalwabbit/oaa.cc
+++ b/vowpalwabbit/oaa.cc
@@ -191,7 +191,7 @@ namespace OAA {
OAA::print_update(all, ec);
}
- void finish_example(vw& all, example* ec)
+ void finish_example(vw& all, void*, example* ec)
{
output_example(all, ec);
VW::finish_example(all, ec);
diff --git a/vowpalwabbit/oaa.h b/vowpalwabbit/oaa.h
index f6b7e492..2abb0d30 100644
--- a/vowpalwabbit/oaa.h
+++ b/vowpalwabbit/oaa.h
@@ -36,7 +36,7 @@ namespace OAA
NULL,
sizeof(mc_label)};
- void finish_example(vw& all, example* ec);
+ void finish_example(vw& all, void*, example* ec);
inline int example_is_test(example* ec)
{
diff --git a/vowpalwabbit/sender.cc b/vowpalwabbit/sender.cc
index a966acda..87b4b3b6 100644
--- a/vowpalwabbit/sender.cc
+++ b/vowpalwabbit/sender.cc
@@ -68,7 +68,7 @@ void receive_result(sender& s)
ec->loss = s.all->loss->getLoss(s.all->sd, ec->final_prediction, ld->label) * ld->weight;
- return_simple_example(*(s.all), ec);
+ return_simple_example(*(s.all), NULL, ec);
}
void learn(void* d, example* ec)
@@ -85,7 +85,7 @@ void receive_result(sender& s)
s->delay_ring[s->sent_index++ % s->all->p->ring_size] = ec;
}
-void finish_example(vw& all, example*ec)
+ void finish_example(vw& all, void*, example*ec)
{}
void end_examples(void* d)
diff --git a/vowpalwabbit/simple_label.cc b/vowpalwabbit/simple_label.cc
index d3605d68..c56d6921 100644
--- a/vowpalwabbit/simple_label.cc
+++ b/vowpalwabbit/simple_label.cc
@@ -225,7 +225,7 @@ void output_and_account_example(vw& all, example* ec)
print_update(all, ec);
}
-void return_simple_example(vw& all, example* ec)
+void return_simple_example(vw& all, void*, example* ec)
{
if (!command_example(&all, ec))
output_and_account_example(all, ec);
diff --git a/vowpalwabbit/simple_label.h b/vowpalwabbit/simple_label.h
index 1384a45c..898df9be 100644
--- a/vowpalwabbit/simple_label.h
+++ b/vowpalwabbit/simple_label.h
@@ -17,7 +17,7 @@ struct label_data {
float initial;
};
-void return_simple_example(vw& all, example* ec);
+void return_simple_example(vw& all, void*, example* ec);
size_t read_cached_simple_label(shared_data* sd, void* v, io_buf& cache);
void cache_simple_label(void* v, io_buf& cache);