Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHal Daume III <me@hal3.name>2014-01-08 20:17:47 +0400
committerHal Daume III <me@hal3.name>2014-01-08 20:17:47 +0400
commit45196928d5f0c0cc0b07912dfce00837957073cf (patch)
tree939bd7ff501a02e32405164dd95b480ec9561d7e /vowpalwabbit/csoaa.cc
parente22b1febf4830f7b54eb16827718515ce390d844 (diff)
searn ldf looks to be working!
Diffstat (limited to 'vowpalwabbit/csoaa.cc')
-rw-r--r--vowpalwabbit/csoaa.cc117
1 files changed, 88 insertions, 29 deletions
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index 01d563ac..0eb07f33 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -623,29 +623,39 @@ namespace LabelDict {
v_array<CSOAA::wclass> costs = ld->costs;
label_data simple_label;
- for (size_t j=0; j<costs.size(); j++) {
+ if (costs.size() == 0) {
simple_label.initial = 0.;
simple_label.label = FLT_MAX;
simple_label.weight = 0.;
ec->partial_prediction = 0.;
-
- LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index);
ec->ld = &simple_label;
base.learn(ec); // make a prediction
- costs[j].partial_prediction = ec->partial_prediction;
+ } else {
+ for (size_t j=0; j<costs.size(); j++) {
+ simple_label.initial = 0.;
+ simple_label.label = FLT_MAX;
+ simple_label.weight = 0.;
+ ec->partial_prediction = 0.;
- if (ec->partial_prediction < *min_score) {
- *min_score = ec->partial_prediction;
- *prediction = costs[j].weight_index;
- }
+ LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index);
+
+ ec->ld = &simple_label;
+ base.learn(ec); // make a prediction
+ costs[j].partial_prediction = ec->partial_prediction;
- if (min_cost && (costs[j].x < *min_cost)) *min_cost = costs[j].x;
- if (max_cost && (costs[j].x > *max_cost)) *max_cost = costs[j].x;
+ if (min_score && prediction && (ec->partial_prediction < *min_score)) {
+ *min_score = ec->partial_prediction;
+ *prediction = costs[j].weight_index;
+ }
- LabelDict::del_example_namespace_from_memory(l, ec, costs[j].weight_index);
- }
+ if (min_cost && (costs[j].x < *min_cost)) *min_cost = costs[j].x;
+ if (max_cost && (costs[j].x > *max_cost)) *max_cost = costs[j].x;
+ LabelDict::del_example_namespace_from_memory(l, ec, costs[j].weight_index);
+ }
+ }
+
ec->ld = ld;
}
@@ -746,6 +756,8 @@ namespace LabelDict {
float min_score = FLT_MAX;
float min_cost = FLT_MAX;
float max_cost = -FLT_MAX;
+
+ cerr << "isTest=" << isTest << " start_K=" << start_K << " K=" << K << endl;
for (size_t k=start_K; k<K; k++) {
example *ec = l.ec_seq.begin[k];
@@ -757,6 +769,7 @@ namespace LabelDict {
cerr << "warning: example headers at position " << k << ": can only have in initial position!" << endl;
throw exception();
}
+ cerr << "msp k=" << k << endl;
make_single_prediction(all, l, base, ec, &prediction, &min_score, &min_cost, &max_cost);
}
@@ -772,6 +785,7 @@ namespace LabelDict {
label_data simple_label;
bool prediction_is_me = false;
for (size_t j=0; j<costs.size(); j++) {
+ cerr << "j=" << j << " costs.size=" << costs.size() << endl;
if (all.training && !isTest) {
float example_t = ec->example_t;
ec->example_t = l.csoaa_example_t;
@@ -790,7 +804,7 @@ namespace LabelDict {
}
}
// TODO: check the example->done and ec->partial_prediction = costs[j].partial_prediciton here
-
+ cerr << "k=" << k << " j=" << j << " label=" << simple_label.label << " cost=" << simple_label.weight << endl;
ec->ld = &simple_label;
//ec->partial_prediction = costs[j].partial_prediction;
//cerr << "[" << ec->partial_prediction << "," << ec->done << "]";
@@ -819,6 +833,7 @@ namespace LabelDict {
void do_actual_learning(vw& all, ldf& l, learner& base)
{
+ cerr << "do_actual_learning size=" << l.ec_seq.size() << endl;
if (l.ec_seq.size() <= 0) return; // nothing to do
/////////////////////// handle label definitions
@@ -919,7 +934,7 @@ namespace LabelDict {
}
}
- void clear_seq(vw& all, ldf& l)
+ void clear_seq_and_finish_examples(vw& all, ldf& l)
{
if (l.ec_seq.size() > 0)
for (example** ecc=l.ec_seq.begin; ecc!=l.ec_seq.end; ecc++)
@@ -934,6 +949,7 @@ namespace LabelDict {
l->first_pass = false;
}
+/*
void learn(void* data, learner& base, example *ec)
{
ldf* l=(ldf*)data;
@@ -984,17 +1000,9 @@ namespace LabelDict {
l->need_to_clear = false;
}
}
+*/
- void finish(void* d)
- {
- ldf* l=(ldf*)d;
- vw* all = l->all;
- clear_seq(*all, *l);
- l->ec_seq.delete_v();
- LabelDict::free_label_features(*l);
- }
-
- void finish_example(vw& all, void*, example* ec)
+ void finish_singleline_example(vw& all, void*, example* ec)
{
if (! LabelDict::ec_is_label_definition(ec)) {
all.sd->weighted_examples += 1;
@@ -1011,19 +1019,70 @@ namespace LabelDict {
if (l->need_to_clear) {
if (l->ec_seq.size() > 0)
output_example_seq(all, *l);
- clear_seq(all, *l);
- l->need_to_clear = false;
+ clear_seq_and_finish_examples(all, *l);
+ l->need_to_clear = false;
}
+ if (ec->in_use) VW::finish_example(all, ec);
}
void end_examples(void* data)
{
ldf* l=(ldf*)data;
vw* all = l->all;
- do_actual_learning(*all, *l, *(l->base));
- output_example_seq(*all, *l);
+ if (l->need_to_clear)
+ l->ec_seq.erase();
+ }
+
+
+ void finish(void* d)
+ {
+ ldf* l=(ldf*)d;
+ /*vw* all = l->all;
clear_seq(*all, *l);
+ l->ec_seq.delete_v();*/
+ l->ec_seq.erase();
l->ec_seq.delete_v();
+ LabelDict::free_label_features(*l);
+ }
+
+ void learn(void* data, learner& base, example* ec) {
+ ldf* l = (ldf*)data;
+ vw* all = l->all;
+ l->base = &base;
+
+ bool is_test = CSOAA::example_is_test(ec) || !all->training;
+
+ if (is_test)
+ make_single_prediction(*all, *l, base, ec, NULL, NULL, NULL, NULL);
+
+ bool need_to_break = l->ec_seq.size() >= all->p->ring_size - 2;
+
+ if (l->is_singleline)
+ assert(is_test);
+ else if (example_is_newline(ec) || need_to_break) {
+ if (need_to_break && l->first_pass)
+ cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << endl;
+
+ do_actual_learning(*all, *l, base);
+ l->need_to_clear = true;
+ } else if (LabelDict::ec_is_label_definition(ec)) {
+ if (l->ec_seq.size() > 0) {
+ cerr << "error: label definition encountered in data block" << endl;
+ throw exception();
+ }
+
+ if (! is_test) {
+ l->ec_seq.push_back(ec);
+ do_actual_learning(*all, *l, base);
+ l->need_to_clear = true;
+ }
+ } else {
+ if (l->need_to_clear) { // should only happen if we're NOT driving
+ l->ec_seq.erase();
+ l->need_to_clear = false;
+ }
+ l->ec_seq.push_back(ec);
+ }
}
learner* setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
@@ -1100,7 +1159,7 @@ namespace LabelDict {
ld->need_to_clear = false;
learner* l = new learner(ld, learn, all.l);
if (ld->is_singleline)
- l->set_finish_example(finish_example);
+ l->set_finish_example(finish_singleline_example);
else
l->set_finish_example(finish_multiline_example);
l->set_finish(finish);