Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn <jl@hunch.net>2014-01-10 22:04:00 +0400
committerJohn <jl@hunch.net>2014-01-10 22:04:00 +0400
commit2ee1516beec670c6275aaf08de7cff03164c73d4 (patch)
tree27fb755b34bcf4765b71983e4de807c0454a731c /vowpalwabbit/csoaa.cc
parent18dd4ea36dc6eacac98049f99c27ce50a81121c2 (diff)
parent861c48b2f682370d5fc42846adb70ed4a98d1db0 (diff)
Merge pull request #228 from hal3/master
searn ldf is working
Diffstat (limited to 'vowpalwabbit/csoaa.cc')
-rw-r--r--vowpalwabbit/csoaa.cc150
1 files changed, 107 insertions, 43 deletions
diff --git a/vowpalwabbit/csoaa.cc b/vowpalwabbit/csoaa.cc
index 9da05810..0e80b32f 100644
--- a/vowpalwabbit/csoaa.cc
+++ b/vowpalwabbit/csoaa.cc
@@ -320,8 +320,7 @@ namespace CSOAA {
size_t prediction = 1;
float score = FLT_MAX;
- label_data simple_temp;
- simple_temp.initial = 0.;
+ label_data simple_temp = { 0., 0., 0. };
ec->ld = &simple_temp;
for (wclass *cl = ld->costs.begin; cl != ld->costs.end; cl ++)
{
@@ -340,10 +339,10 @@ namespace CSOAA {
simple_temp.weight = 1.;
}
- base.learn(ec);
+ base.learn(ec, i);
}
else
- base.predict(ec);
+ base.predict(ec, i);
cl->partial_prediction = ec->partial_prediction;
if (ec->partial_prediction < score || (ec->partial_prediction == score && i < prediction)) {
@@ -547,12 +546,14 @@ namespace LabelDict {
void free_label_features(ldf& l) {
void* label_iter = l.label_features.iterator();
while (label_iter != NULL) {
- v_array<feature> features = l.label_features.iterator_get_value(label_iter);
- features.erase();
- features.delete_v();
+ v_array<feature> *features = l.label_features.iterator_get_value(label_iter);
+ features->erase();
+ features->delete_v();
label_iter = l.label_features.iterator_next(label_iter);
}
+ l.label_features.clear();
+ l.label_features.delete_v();
}
}
@@ -632,28 +633,39 @@ namespace LabelDict {
v_array<CSOAA::wclass> costs = ld->costs;
label_data simple_label;
- simple_label.initial = 0.;
- simple_label.label = FLT_MAX;
- simple_label.weight = 0.;
- for (size_t j=0; j<costs.size(); j++) {
-
- LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index);
-
+ if (costs.size() == 0) {
+ simple_label.initial = 0.;
+ simple_label.label = FLT_MAX;
+ simple_label.weight = 0.;
+ ec->partial_prediction = 0.;
+
ec->ld = &simple_label;
base.predict(ec); // make a prediction
- costs[j].partial_prediction = ec->partial_prediction;
+ } else {
+ for (size_t j=0; j<costs.size(); j++) {
+ simple_label.initial = 0.;
+ simple_label.label = FLT_MAX;
+ simple_label.weight = 0.;
+ ec->partial_prediction = 0.;
- if (ec->partial_prediction < *min_score) {
- *min_score = ec->partial_prediction;
- *prediction = costs[j].weight_index;
- }
+ LabelDict::add_example_namespace_from_memory(l, ec, costs[j].weight_index);
+
+ ec->ld = &simple_label;
+ base.predict(ec); // make a prediction
+ costs[j].partial_prediction = ec->partial_prediction;
- if (min_cost && (costs[j].x < *min_cost)) *min_cost = costs[j].x;
- if (max_cost && (costs[j].x > *max_cost)) *max_cost = costs[j].x;
+ if (min_score && prediction && (ec->partial_prediction < *min_score)) {
+ *min_score = ec->partial_prediction;
+ *prediction = costs[j].weight_index;
+ }
- LabelDict::del_example_namespace_from_memory(l, ec, costs[j].weight_index);
- }
+ if (min_cost && (costs[j].x < *min_cost)) *min_cost = costs[j].x;
+ if (max_cost && (costs[j].x > *max_cost)) *max_cost = costs[j].x;
+ LabelDict::del_example_namespace_from_memory(l, ec, costs[j].weight_index);
+ }
+ }
+
ec->ld = ld;
}
@@ -758,6 +770,8 @@ namespace LabelDict {
float min_score = FLT_MAX;
float min_cost = FLT_MAX;
float max_cost = -FLT_MAX;
+
+ //clog << "isTest=" << isTest << " start_K=" << start_K << " K=" << K << endl;
for (size_t k=start_K; k<K; k++) {
example *ec = l.ec_seq.begin[k];
@@ -769,6 +783,7 @@ namespace LabelDict {
cerr << "warning: example headers at position " << k << ": can only have in initial position!" << endl;
throw exception();
}
+ //clog << "msp k=" << k << endl;
make_single_prediction(all, l, base, ec, &prediction, &min_score, &min_cost, &max_cost);
}
@@ -784,6 +799,7 @@ namespace LabelDict {
label_data simple_label;
bool prediction_is_me = false;
for (size_t j=0; j<costs.size(); j++) {
+ //clog << "j=" << j << " costs.size=" << costs.size() << endl;
if (all.training && !isTest) {
float example_t = ec->example_t;
ec->example_t = l.csoaa_example_t;
@@ -802,7 +818,7 @@ namespace LabelDict {
}
}
// TODO: check the example->done and ec->partial_prediction = costs[j].partial_prediciton here
-
+ //clog << "k=" << k << " j=" << j << " label=" << simple_label.label << " cost=" << simple_label.weight << endl;
ec->ld = &simple_label;
//ec->partial_prediction = costs[j].partial_prediction;
//cerr << "[" << ec->partial_prediction << "," << ec->done << "]";
@@ -834,6 +850,7 @@ namespace LabelDict {
template <bool is_learn>
void do_actual_learning(vw& all, ldf& l, learner& base)
{
+ //clog << "do_actual_learning size=" << l.ec_seq.size() << endl;
if (l.ec_seq.size() <= 0) return; // nothing to do
/////////////////////// handle label definitions
@@ -934,7 +951,7 @@ namespace LabelDict {
}
}
- void clear_seq(vw& all, ldf& l)
+ void clear_seq_and_finish_examples(vw& all, ldf& l)
{
if (l.ec_seq.size() > 0)
for (example** ecc=l.ec_seq.begin; ecc!=l.ec_seq.end; ecc++)
@@ -948,8 +965,12 @@ namespace LabelDict {
l->first_pass = false;
}
+/*
+ void learn(void* data, learner& base, example *ec)
+=======
template <bool is_learn>
void predict_or_learn(ldf* l, learner& base, example *ec)
+>>>>>>> d73d8f6aef1c02dd05554f91ca57e1d304336130
{
vw* all = l->all;
l->base = &base;
@@ -962,6 +983,7 @@ namespace LabelDict {
if (l->is_singleline) {
// must be test mode
} else if (example_is_newline(ec) || l->ec_seq.size() >= all->p->ring_size - 2) {
+ cerr << "newline, example_is_newline=" << example_is_newline(ec) << ", size=" << l->ec_seq.size() << ", indices.size=" << ec->indices.size() << endl;
if (l->ec_seq.size() >= all->p->ring_size - 2 && l->first_pass)
cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << endl;
@@ -987,6 +1009,7 @@ namespace LabelDict {
if (ec->in_use)
VW::finish_example(*all, ec);
} else {
+ cerr << "push_back" << endl;
l->ec_seq.push_back(ec);
}
@@ -996,16 +1019,9 @@ namespace LabelDict {
l->need_to_clear = false;
}
}
+*/
- void finish(ldf* l)
- {
- vw* all = l->all;
- clear_seq(*all, *l);
- l->ec_seq.delete_v();
- LabelDict::free_label_features(*l);
- }
-
- void finish_example(vw& all, ldf*, example* ec)
+ void finish_singleline_example(vw& all, ldf*, example* ec)
{
if (! LabelDict::ec_is_label_definition(ec)) {
all.sd->weighted_examples += 1;
@@ -1019,20 +1035,68 @@ namespace LabelDict {
void finish_multiline_example(vw& all, ldf* l, example* ec)
{
if (l->need_to_clear) {
- if (l->ec_seq.size() > 0)
+ if (l->ec_seq.size() > 0) {
output_example_seq(all, *l);
- clear_seq(all, *l);
- l->need_to_clear = false;
+ global_print_newline(all);
+ }
+ clear_seq_and_finish_examples(all, *l);
+ l->need_to_clear = false;
+ if (ec->in_use) VW::finish_example(all, ec);
}
}
void end_examples(ldf* l)
{
- vw* all = l->all;
- do_actual_learning<true>(*all, *l, *(l->base));
- output_example_seq(*all, *l);
- clear_seq(*all, *l);
+ if (l->need_to_clear)
+ l->ec_seq.erase();
+ }
+
+
+ void finish(ldf* l)
+ {
+ //vw* all = l->all;
l->ec_seq.delete_v();
+ LabelDict::free_label_features(*l);
+ }
+
+ template <bool is_learn>
+ void predict_or_learn(ldf* l, learner& base, example *ec) {
+ vw* all = l->all;
+ l->base = &base;
+
+ bool is_test = CSOAA::example_is_test(ec) || !all->training;
+
+ if (is_test)
+ make_single_prediction(*all, *l, base, ec, NULL, NULL, NULL, NULL);
+
+ bool need_to_break = l->ec_seq.size() >= all->p->ring_size - 2;
+
+ if (l->is_singleline)
+ assert(is_test);
+ else if (example_is_newline(ec) || need_to_break) {
+ if (need_to_break && l->first_pass)
+ cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << endl;
+
+ do_actual_learning<is_learn>(*all, *l, base);
+ l->need_to_clear = true;
+ } else if (LabelDict::ec_is_label_definition(ec)) {
+ if (l->ec_seq.size() > 0) {
+ cerr << "error: label definition encountered in data block" << endl;
+ throw exception();
+ }
+
+ if (! is_test) {
+ l->ec_seq.push_back(ec);
+ do_actual_learning<is_learn>(*all, *l, base);
+ l->need_to_clear = true;
+ }
+ } else {
+ if (l->need_to_clear) { // should only happen if we're NOT driving
+ l->ec_seq.erase();
+ l->need_to_clear = false;
+ }
+ l->ec_seq.push_back(ec);
+ }
}
learner* setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)
@@ -1103,7 +1167,7 @@ namespace LabelDict {
all.add_constant = false;
}
ld->label_features.init(256, v_array<feature>(), LabelDict::size_t_eq);
- ld->label_features.get(1, 94717244);
+ ld->label_features.get(1, 94717244); // TODO: figure this out
ld->read_example_this_loop = 0;
ld->need_to_clear = false;
@@ -1111,7 +1175,7 @@ namespace LabelDict {
l->set_learn<ldf, predict_or_learn<true> >();
l->set_predict<ldf, predict_or_learn<false> >();
if (ld->is_singleline)
- l->set_finish_example<ldf,finish_example>();
+ l->set_finish_example<ldf,finish_singleline_example>();
else
l->set_finish_example<ldf,finish_multiline_example>();
l->set_finish<ldf,finish>();