Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHal Daume III <me@hal3.name>2014-05-05 01:34:36 +0400
committerHal Daume III <me@hal3.name>2014-05-05 01:34:36 +0400
commit8c79df51fc26e01294c01fead79255e2c35223c1 (patch)
treee279f353298a20f338b56b46d03848fe465a0c9f /vowpalwabbit
parent754e70470d77b4e1217ca1ffd521f8763f3f58b4 (diff)
debugging beam train
Diffstat (limited to 'vowpalwabbit')
-rw-r--r--vowpalwabbit/searn.cc136
1 files changed, 76 insertions, 60 deletions
diff --git a/vowpalwabbit/searn.cc b/vowpalwabbit/searn.cc
index dfe8486b..5a201097 100644
--- a/vowpalwabbit/searn.cc
+++ b/vowpalwabbit/searn.cc
@@ -906,7 +906,7 @@ namespace Searn
int pol = -1;
if (!srn->priv->trajectory_oracle)
pol = choose_policy(*srn, allow_current, allow_optimal);
- cdbg << "BEAM_INIT: pol = " << pol << endl;
+ cdbg << "BEAM_INIT: pol = " << pol << ", beta = " << srn->priv->beta << endl;
size_t num_actions = get_all_labels(srn->priv->valid_labels, *srn, num_ec, yallowed);
single_action<T>(all, *srn, base, ecs, num_ec, (T*)srn->priv->valid_labels, pol, ystar, ystar_is_uint32t, false, true);
// fill in relevant information
@@ -917,7 +917,7 @@ namespace Searn
srn->priv->t++;
uint32_t this_a = get_any_label(*srn, yallowed);
uint32_t a_name = (! srn->priv->is_ldf) ? (uint32_t)this_a : ((COST_SENSITIVE::label*)ecs[this_a].ld)->costs[0].class_index;
- if (srn->priv->auto_history) push_at(srn->priv->rollout_action, a_name, srn->priv->t);
+ if (srn->priv->auto_history) push_at(srn->priv->rollout_action, a_name, srn->priv->t + srn->priv->hinfo.length - 1);
cdbg << "A rollout_action.push_back(" << a_name << ", @ " << (srn->priv->t) << ")" << endl;
if (srn->priv->hinfo.length>0) {cdbg << " rollout_action = ["; for (size_t i=0; i<srn->priv->t+1; i++) cdbg << " " << srn->priv->rollout_action.begin[i]; cdbg << " ], len=" << srn->priv->rollout_action.size() << endl;}
return this_a;
@@ -931,17 +931,17 @@ namespace Searn
cdbg << "valid_labels = ["; for (COST_SENSITIVE::wclass*wc=((COST_SENSITIVE::label*)srn->priv->valid_labels)->costs.begin; wc!= ((COST_SENSITIVE::label*)srn->priv->valid_labels)->costs.end; ++wc) cdbg << " " << wc->class_index; cdbg << " ]" << endl;
}
uint32_t a_name = (! srn->priv->is_ldf) ? (uint32_t)this_a : ((COST_SENSITIVE::label*)ecs[this_a].ld)->costs[0].class_index;
- if (srn->priv->auto_history) push_at(srn->priv->rollout_action, a_name, srn->priv->t);
+ if (srn->priv->auto_history) push_at(srn->priv->rollout_action, a_name, srn->priv->t + srn->priv->hinfo.length - 1);
cdbg << "B rollout_action.push_back(" << a_name << ", @ " << (srn->priv->t) << ")" << endl;
if (srn->priv->hinfo.length>0) {cdbg << " rollout_action = ["; for (size_t i=0; i<srn->priv->t+1; i++) cdbg << " " << srn->priv->rollout_action.begin[i]; cdbg << " ], len=" << srn->priv->rollout_action.size() << endl;}
return this_a;
} else if (srn->priv->t == srn->priv->cur_beam_hyp->t) {
- bool allow_optimal = srn->priv->beam_is_training;
bool allow_current = (! srn->priv->beam_is_training) || srn->priv->allow_current_policy;
+ bool allow_optimal = srn->priv->beam_is_training;
int pol = -1;
if (!srn->priv->trajectory_oracle)
pol = choose_policy(*srn, allow_current, allow_optimal);
- cdbg << "BEAM_ADVANCE: pol = " << pol << endl;
+ cdbg << "BEAM_ADVANCE: pol = " << pol << ", beta = " << srn->priv->beta << endl;
size_t num_actions = get_all_labels(srn->priv->valid_labels, *srn, num_ec, yallowed);
single_action<T>(all, *srn, base, ecs, num_ec, (T*)srn->priv->valid_labels, pol, ystar, ystar_is_uint32t, false, true);
// fill in relevant information
@@ -950,7 +950,7 @@ namespace Searn
srn->priv->t++;
uint32_t this_a = get_any_label(*srn, yallowed);
uint32_t a_name = (! srn->priv->is_ldf) ? (uint32_t)this_a : ((COST_SENSITIVE::label*)ecs[this_a].ld)->costs[0].class_index;
- if (srn->priv->auto_history) push_at(srn->priv->rollout_action, a_name, srn->priv->t);
+ if (srn->priv->auto_history) push_at(srn->priv->rollout_action, a_name, srn->priv->t + srn->priv->hinfo.length - 1);
cdbg << "C rollout_action.push_back(" << a_name << ", @ " << (srn->priv->t) << ")" << endl;
if (srn->priv->hinfo.length>0) {cdbg << " rollout_action = ["; for (size_t i=0; i<srn->priv->t+1; i++) cdbg << " " << srn->priv->rollout_action.begin[i]; cdbg << " ], len=" << srn->priv->rollout_action.size() << endl;}
return this_a;
@@ -959,7 +959,7 @@ namespace Searn
srn->priv->t++;
uint32_t this_a = get_any_label(*srn, yallowed);
uint32_t a_name = (! srn->priv->is_ldf) ? (uint32_t)this_a : ((COST_SENSITIVE::label*)ecs[this_a].ld)->costs[0].class_index;
- if (srn->priv->auto_history) push_at(srn->priv->rollout_action, a_name, srn->priv->t);
+ if (srn->priv->auto_history) push_at(srn->priv->rollout_action, a_name, srn->priv->t + srn->priv->hinfo.length - 1);
cdbg << "D rollout_action.push_back(" << a_name << ", @ " << (srn->priv->t) << ")" << endl;
//cdbg << " rollout_action = ["; for (size_t i=0; i<srn->priv->t+1; i++) cdbg << " " << srn->priv->rollout_action.begin[i]; cdbg << " ], len=" << srn->priv->rollout_action.size() << endl;
return this_a;
@@ -973,7 +973,7 @@ namespace Searn
if (!srn->priv->is_ldf)
this_a = ((COST_SENSITIVE::label*)srn->priv->valid_labels)->costs[this_a].class_index;
uint32_t a_name = (! srn->priv->is_ldf) ? (uint32_t)this_a : ((COST_SENSITIVE::label*)ecs[this_a].ld)->costs[0].class_index;
- if (srn->priv->auto_history) push_at(srn->priv->rollout_action, a_name, srn->priv->t - 1);
+ if (srn->priv->auto_history) push_at(srn->priv->rollout_action, a_name, srn->priv->t - 1 + srn->priv->hinfo.length);
return this_a;
} else {
throw exception();
@@ -1520,27 +1520,19 @@ namespace Searn
free(hyp_pool[i].snapshot[j].data_ptr);
hyp_pool[i].snapshot.delete_v();
}
-
- hyp_pool.delete_v();
+ hyp_pool.erase();
+ hyp_pool_id = 0;
}
void beam_predict(vw&all, searn&srn, vector<example*>ec, v_array<beam_hyp> &hyp_pool, size_t &hyp_pool_id, bool is_learn) {
using namespace Beam;
- if (might_print_update(all)) {
- reset_searn_structure(srn);
- srn.priv->state = GET_TRUTH_STRING;
- srn.priv->should_produce_string = true;
- srn.priv->truth_string->str("");
- srn.task->structured_predict(srn, ec);
- }
-
beam* final_beam = new beam(max(1, min(srn.priv->beam_size, srn.priv->kbest))); // at least 1, but otherwise the min of beam_size and kbest
compute_full_beam(all, srn, ec, hyp_pool, hyp_pool_id, final_beam);
- { // TODO: check if this is going to be used at all!!!
+ if (srn.priv->should_produce_string && !is_learn) { // TODO: check if this is going to be used at all!!!
/*UNDOME*/cdbg << "========== FINAL ROLLOUT(S) ==" <<endl;
assert(final_beam->size() > 0);
stringstream spred;
@@ -1548,7 +1540,6 @@ namespace Searn
srn.priv->pred_string->str("");
bool is_first = true;
- srn.priv->should_produce_string = true;
for (beam_element * be = final_beam->begin(); be != final_beam->end(); ++be) {
beam_hyp* hyp = (beam_hyp*)be->data;
assert(hyp);
@@ -1587,37 +1578,29 @@ namespace Searn
delete final_beam;
}
-template <bool is_learn>
-void train_single_example(vw& all, searn& srn, vector<example*>ec)
-{
- // if we're going to have to print to the screen, generate the "truth" string
- cdbg << "======================================== GET TRUTH STRING (" << srn.priv->current_policy << "," << srn.priv->read_example_last_pass << ") ========================================" << endl;
- if (might_print_update(all)) {
- reset_searn_structure(srn);
- srn.priv->state = GET_TRUTH_STRING;
- srn.priv->should_produce_string = true;
- srn.priv->truth_string->str("");
- srn.task->structured_predict(srn, ec);
- }
-
- // do an initial test pass to compute output (and loss)
- cdbg << "======================================== INIT TEST (" << srn.priv->current_policy << "," << srn.priv->read_example_last_pass << ") ========================================" << endl;
-
- reset_searn_structure(srn);
- srn.priv->state = INIT_TEST;
-
- if ((all.final_prediction_sink.size() > 0) || // if we have to produce output, we need to run this
- might_print_update(all) || // if we have to print and update to stderr
- (!all.training) || // if we're just testing
- (all.current_pass == 0) || // we need error rates for progressive cost
+ bool must_run_test(vw&all, vector<example*>ec) {
+ return
+ (all.final_prediction_sink.size() > 0) || // if we have to produce output, we need to run this
+ might_print_update(all) || // if we have to print and update to stderr
+ (!all.training) || // if we're just testing
+ (all.current_pass == 0) || // we need error rates for progressive cost
(all.holdout_set_off) || // no holdout
(ec[0]->test_only) || // it's a holdout example
(all.raw_prediction > 0) // we need raw predictions
- ) {
+ ;
+ }
+
+ template <bool is_learn>
+ void train_single_example(vw& all, searn& srn, vector<example*>ec) {
+ // do an initial test pass to compute output (and loss)
+ cdbg << "======================================== INIT TEST (" << srn.priv->current_policy << "," << srn.priv->read_example_last_pass << ") ========================================" << endl;
+ reset_searn_structure(srn);
+ srn.priv->state = INIT_TEST;
+
+ if (must_run_test(all, ec)) {
srn.priv->should_produce_string = might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction > 0);
- //if (srn.priv->should_produce_string)
- srn.priv->pred_string->str("");
+ srn.priv->pred_string->str("");
assert(srn.priv->truth_string != NULL);
srn.task->structured_predict(srn, ec);
@@ -1748,7 +1731,7 @@ void train_single_example(vw& all, searn& srn, vector<example*>ec)
tset.erase(); tset.delete_v();
cdbg << endl;
}
-
+
clear_snapshot(all, srn, true);
srn.priv->train_action.delete_v();
srn.priv->train_action_ids.delete_v();
@@ -2021,15 +2004,14 @@ void print_update(vw& all, searn& srn)
void train_single_example_beam(vw&all, searn&srn, v_array<beam_hyp> &hyp_pool, size_t hyp_pool_size) {
searn_private* priv = srn.priv;
- if (priv->adaptive_beta)
- priv->beta = 1.f - powf(1.f - priv->alpha, (float)priv->total_examples_generated);
-
+
cdbg << "======================================== BEAM LEARN (" << srn.priv->current_policy << "," << srn.priv->read_example_last_pass << ") ========================================" << endl;
priv->state = LEARN;
priv->should_produce_string = false;
for (size_t hyp_id = 0; hyp_id < hyp_pool_size; hyp_id++) {
beam_hyp me = hyp_pool[hyp_id];
if (me.pruned) continue;
+ if (me.num_actions == 0) continue;
cdbg << "=== hyp_id = " << hyp_id << " ===" << endl;
cdbg << "me = {" << endl << " t = " << me.t << endl << " action_taken = " << me.action_taken << endl << " parent = " << me.parent << endl << " parent.action_taken = " << ((me.parent == NULL) ? -1 : me.parent->action_taken) << endl << "}" << endl;
@@ -2081,6 +2063,8 @@ void print_update(vw& all, searn& srn)
example * ptr = priv->examples_dont_change ? priv->learn_example_ref : priv->learn_example_copy;
cdbg << "generate_training_example on " << priv->learn_example_len << " learn_example_copy items" << endl;
+ if (priv->adaptive_beta)
+ priv->beta = 1.f - powf(1.f - priv->alpha, (float)priv->total_examples_generated);
generate_training_example(all, srn, *priv->base_learner, ptr, priv->learn_example_len, &aset, priv->learn_losses);
if (!priv->examples_dont_change) {
@@ -2103,19 +2087,51 @@ void print_update(vw& all, searn& srn)
return; // nothing to do :)
add_neighbor_features(srn);
+
+ // if we're going to have to print to the screen, generate the "truth" string
+ cdbg << "======================================== GET TRUTH STRING (" << srn.priv->current_policy << "," << srn.priv->read_example_last_pass << ") ========================================" << endl;
+ if (might_print_update(all)) {
+ reset_searn_structure(srn);
+ srn.priv->state = GET_TRUTH_STRING;
+ srn.priv->should_produce_string = true;
+ srn.priv->truth_string->str("");
+ srn.task->structured_predict(srn, srn.priv->ec_seq);
+ }
+
if (srn.priv->beam_size == 0)
train_single_example<is_learn>(all, srn, srn.priv->ec_seq);
else {
v_array<beam_hyp> hyp_pool;
size_t hyp_pool_id = 0;
- hyp_pool.resize(10000, true);
+ float cached_test_loss = 0.;
+
+ if (srn.priv->adaptive_beta)
+ srn.priv->beta = 1.f - powf(1.f - srn.priv->alpha, (float)srn.priv->total_examples_generated);
+
+ if (must_run_test(all, srn.priv->ec_seq)) {
+ cdbg << "======================================== BEAM TEST (" << srn.priv->current_policy << "," << srn.priv->read_example_last_pass << ") ========================================" << endl;
+ srn.priv->beam_is_training = false;
+ srn.priv->should_produce_string = might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction > 0);
+ srn.priv->pred_string->str("");
+ hyp_pool.resize(10000, true);
+ beam_predict(all, srn, srn.priv->ec_seq, hyp_pool, hyp_pool_id, false);
+ cached_test_loss = srn.priv->test_loss;
+ free_hyp_pool(hyp_pool, hyp_pool_id);
+ }
+
srn.priv->beam_is_training = is_learn && all.training && !srn.priv->ec_seq[0]->test_only;
- beam_predict(all, srn, srn.priv->ec_seq, hyp_pool, hyp_pool_id, srn.priv->beam_is_training);
- if (srn.priv->beam_is_training)
+ if (srn.priv->beam_is_training) {
+ cdbg << "======================================== BEAM TRAIN (" << srn.priv->current_policy << "," << srn.priv->read_example_last_pass << ") ========================================" << endl;
+ hyp_pool.resize(10000, true);
+ srn.priv->should_produce_string = false;
+ beam_predict(all, srn, srn.priv->ec_seq, hyp_pool, hyp_pool_id, true);
train_single_example_beam(all, srn, hyp_pool, hyp_pool_id);
-
- free_hyp_pool(hyp_pool, hyp_pool_id);
+ free_hyp_pool(hyp_pool, hyp_pool_id);
+ }
+
+ hyp_pool.delete_v();
+ srn.priv->test_loss = cached_test_loss;
}
del_neighbor_features(srn);
@@ -2916,11 +2932,11 @@ void print_update(vw& all, searn& srn)
* write documentation
* pull munge/unmunge out of structured_predict
- * make searn tasks classes
- * hide stuff in the searn class (HOW?)
- * add --search_dont_rollout option
- * allow loss to also adjust count --> put in documentation
- * confusion matrix for faster errors?
+ * allow optional functions in searn tasks
+ * label constraints
+ * hypothesis recombination
+ * beam at train
+ * coreference
time ./vw -k -c -d pos.gz --search_as_dagger 1e-8 --search_task sequence --search 45 --holdout_off