Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/searn.cc')
-rw-r--r--vowpalwabbit/searn.cc190
1 files changed, 95 insertions, 95 deletions
diff --git a/vowpalwabbit/searn.cc b/vowpalwabbit/searn.cc
index 1533aeca..808860fe 100644
--- a/vowpalwabbit/searn.cc
+++ b/vowpalwabbit/searn.cc
@@ -362,13 +362,13 @@ namespace Searn {
//cerr << "predict: action=" << action << endl;
void* old_label = ecs[action].ld;
ecs[action].ld = &test_label;
- base.predict(&ecs[action], pol);
+ base.predict(ecs[action], pol);
srn->total_predictions_made++;
srn->num_features += ecs[action].num_features;
srn->empty_example->in_use = true;
//cerr << "predict: empty_example" << endl;
- base.predict(srn->empty_example);
+ base.predict(*(srn->empty_example));
ecs[action].ld = old_label;
if ((action == 0) ||
@@ -429,37 +429,37 @@ namespace Searn {
return ld->costs[0].action;
}
- uint32_t single_prediction_notLDF(vw& all, searn& srn, learner& base, example* ec, void*valid_labels, uint32_t pol, bool allow_exploration)
+ uint32_t single_prediction_notLDF(vw& all, searn& srn, learner& base, example& ec, void*valid_labels, uint32_t pol, bool allow_exploration)
{
assert(pol >= 0);
- void* old_label = ec->ld;
- ec->ld = valid_labels;
+ void* old_label = ec.ld;
+ ec.ld = valid_labels;
base.predict(ec, pol);
srn.total_predictions_made++;
- srn.num_features += ec->num_features;
- uint32_t final_prediction = (uint32_t)ec->final_prediction;
+ srn.num_features += ec.num_features;
+ uint32_t final_prediction = (uint32_t)ec.final_prediction;
if (allow_exploration && (srn.exploration_temperature > 0.)) {
if (srn.rollout_all_actions)
- final_prediction = sample_with_temperature_csoaa((CSOAA::label*)ec->ld, srn.exploration_temperature);
+ final_prediction = sample_with_temperature_csoaa((CSOAA::label*)ec.ld, srn.exploration_temperature);
else
- final_prediction = sample_with_temperature_cb( (CB::label *)ec->ld, srn.exploration_temperature);
+ final_prediction = sample_with_temperature_cb( (CB::label *)ec.ld, srn.exploration_temperature);
}
if ((srn.state == INIT_TEST) && (all.raw_prediction > 0) && (srn.rollout_all_actions)) { // srn.rollout_all_actions ==> this is not CB, so we have CSOAA::labels
string outputString;
stringstream outputStringStream(outputString);
- CSOAA::label *ld = (CSOAA::label*)ec->ld;
+ CSOAA::label *ld = (CSOAA::label*)ec.ld;
for (CSOAA::wclass* c = ld->costs.begin; c != ld->costs.end; ++c) {
if (c != ld->costs.begin) outputStringStream << ' ';
outputStringStream << c->weight_index << ':' << c->partial_prediction;
}
- all.print_text(all.raw_prediction, outputStringStream.str(), ec->tag);
+ all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag);
}
- ec->ld = old_label;
+ ec.ld = old_label;
return final_prediction;
}
@@ -485,7 +485,7 @@ namespace Searn {
} else { // learned policy
if (!srn.is_ldf) { // single example
if (srn.auto_history) add_history_to_example(all, srn.hinfo, ecs, srn.rollout_action.begin+srn.t);
- size_t action = single_prediction_notLDF(all, srn, base, ecs, valid_labels, pol, allow_exploration);
+ size_t action = single_prediction_notLDF(all, srn, base, *ecs, valid_labels, pol, allow_exploration);
if (srn.auto_history) remove_history_from_example(all, srn.hinfo, ecs);
return (uint32_t)action;
} else {
@@ -985,7 +985,7 @@ namespace Searn {
void* old_label = ec[0].ld;
ec[0].ld = labels;
if (srn.auto_history) add_history_to_example(all, srn.hinfo, ec, srn.rollout_action.begin+srn.learn_t);
- base.learn(&ec[0], srn.current_policy);
+ base.learn(ec[0], srn.current_policy);
if (srn.auto_history) remove_history_from_example(all, srn.hinfo, ec);
ec[0].ld = old_label;
srn.total_examples_generated++;
@@ -998,10 +998,10 @@ namespace Searn {
//clog << endl << "this_example = "; GD::print_audit_features(all, &ec[a]);
add_history_to_example(all, srn.hinfo, &ec[a], srn.rollout_action.begin+srn.learn_t,
((CSOAA::label*)ec[a].ld)->costs[0].weight_index);
- base.learn(&ec[a], srn.current_policy);
+ base.learn(ec[a], srn.current_policy);
}
//clog << "learn: generate empty example" << endl;
- base.learn(srn.empty_example);
+ base.learn(*srn.empty_example);
//clog << "learn done " << repeat << endl;
for (size_t a=0; a<len; a++)
remove_history_from_example(all, srn.hinfo, &ec[a]);
@@ -1207,23 +1207,23 @@ namespace Searn {
out[max_len] = 0;
}
-void print_update(vw& all, searn* srn)
+void print_update(vw& all, searn& srn)
{
- if (!srn->printed_output_header && !all.quiet) {
+ if (!srn.printed_output_header && !all.quiet) {
const char * header_fmt = "%-10s %-10s %8s %15s %24s %22s %8s %5s %5s %15s %15s\n";
fprintf(stderr, header_fmt, "average", "since", "sequence", "example", "current label", "current predicted", "current", "cur", "cur", "predic.", "examples");
fprintf(stderr, header_fmt, "loss", "last", "counter", "weight", "sequence prefix", "sequence prefix", "features", "pass", "pol", "made", "gener.");
cerr.precision(5);
- srn->printed_output_header = true;
+ srn.printed_output_header = true;
}
- if (!should_print_update(all, srn->hit_new_pass))
+ if (!should_print_update(all, srn.hit_new_pass))
return;
char true_label[21];
char pred_label[21];
- to_short_string(srn->truth_string->str(), 20, true_label);
- to_short_string(srn->pred_string->str() , 20, pred_label);
+ to_short_string(srn.truth_string->str(), 20, true_label);
+ to_short_string(srn.pred_string->str() , 20, pred_label);
float avg_loss = 0.;
float avg_loss_since = 0.;
@@ -1245,14 +1245,14 @@ void print_update(vw& all, searn* srn)
all.sd->weighted_examples,
true_label,
pred_label,
- (long unsigned int)srn->num_features,
- (int)srn->read_example_last_pass,
- (int)srn->current_policy,
- (long unsigned int)srn->total_predictions_made,
- (long unsigned int)srn->total_examples_generated);
+ (long unsigned int)srn.num_features,
+ (int)srn.read_example_last_pass,
+ (int)srn.current_policy,
+ (long unsigned int)srn.total_predictions_made,
+ (long unsigned int)srn.total_examples_generated);
if (PRINT_CLOCK_TIME) {
- size_t num_sec = (size_t)(((float)(clock() - srn->start_clock_time)) / CLOCKS_PER_SEC);
+ size_t num_sec = (size_t)(((float)(clock() - srn.start_clock_time)) / CLOCKS_PER_SEC);
fprintf(stderr, " %15lusec", num_sec);
}
@@ -1370,69 +1370,69 @@ void print_update(vw& all, searn* srn)
}
template <bool is_learn>
- void searn_predict_or_learn(searn* srn, learner& base, example*ec) {
- vw* all = srn->all;
- srn->base_learner = &base;
+ void searn_predict_or_learn(searn& srn, learner& base, example& ec) {
+ vw* all = srn.all;
+ srn.base_learner = &base;
bool is_real_example = true;
- if (example_is_newline(ec) || srn->ec_seq.size() >= all->p->ring_size - 2) {
- if (srn->ec_seq.size() >= all->p->ring_size - 2) { // give some wiggle room
- std::cerr << "warning: length of sequence at " << ec->example_counter << " exceeds ring size; breaking apart" << std::endl;
+ if (example_is_newline(ec) || srn.ec_seq.size() >= all->p->ring_size - 2) {
+ if (srn.ec_seq.size() >= all->p->ring_size - 2) { // give some wiggle room
+ std::cerr << "warning: length of sequence at " << ec.example_counter << " exceeds ring size; breaking apart" << std::endl;
}
- do_actual_learning<is_learn>(*all, *srn);
- clear_seq(*all, *srn);
- srn->hit_new_pass = false;
+ do_actual_learning<is_learn>(*all, srn);
+ clear_seq(*all, srn);
+ srn.hit_new_pass = false;
//VW::finish_example(*all, ec);
is_real_example = false;
} else {
- srn->ec_seq.push_back(ec);
+ srn.ec_seq.push_back(&ec);
}
if (is_real_example) {
- srn->read_example_last_id = ec->example_counter;
+ srn.read_example_last_id = ec.example_counter;
}
}
- void end_pass(searn* srn) {
- vw* all = srn->all;
- srn->hit_new_pass = true;
- srn->read_example_last_pass++;
- srn->passes_since_new_policy++;
- if (srn->passes_since_new_policy >= srn->passes_per_policy) {
- srn->passes_since_new_policy = 0;
+ void end_pass(searn& srn) {
+ vw* all = srn.all;
+ srn.hit_new_pass = true;
+ srn.read_example_last_pass++;
+ srn.passes_since_new_policy++;
+ if (srn.passes_since_new_policy >= srn.passes_per_policy) {
+ srn.passes_since_new_policy = 0;
if(all->training)
- srn->current_policy++;
- if (srn->current_policy > srn->total_number_of_policies) {
+ srn.current_policy++;
+ if (srn.current_policy > srn.total_number_of_policies) {
std::cerr << "internal error (bug): too many policies; not advancing" << std::endl;
- srn->current_policy = srn->total_number_of_policies;
+ srn.current_policy = srn.total_number_of_policies;
}
//reset searn_trained_nb_policies in options_from_file so it is saved to regressor file later
std::stringstream ss;
- ss << srn->current_policy;
+ ss << srn.current_policy;
VW::cmd_string_replace_value(all->options_from_file,"--searn_trained_nb_policies", ss.str());
}
}
- void finish_example(vw& all, searn* srn, example* ec) {
- if (ec->end_pass || example_is_newline(ec) || srn->ec_seq.size() >= all.p->ring_size - 2) {
+ void finish_example(vw& all, searn& srn, example& ec) {
+ if (ec.end_pass || example_is_newline(ec) || srn.ec_seq.size() >= all.p->ring_size - 2) {
print_update(all, srn);
- VW::finish_example(all, ec);
+ VW::finish_example(all, &ec);
}
}
- void end_examples(searn* srn) {
- vw* all = srn->all;
+ void end_examples(searn& srn) {
+ vw* all = srn.all;
- do_actual_learning<true>(*all, *srn);
+ do_actual_learning<true>(*all, srn);
if( all->training ) {
std::stringstream ss1;
std::stringstream ss2;
- ss1 << ((srn->passes_since_new_policy == 0) ? srn->current_policy : (srn->current_policy+1));
+ ss1 << ((srn.passes_since_new_policy == 0) ? srn.current_policy : (srn.current_policy+1));
//use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searn_trained_nb_policies
VW::cmd_string_replace_value(all->options_from_file,"--searn_trained_nb_policies", ss1.str());
- ss2 << srn->total_number_of_policies;
+ ss2 << srn.total_number_of_policies;
//use cmd_string_replace_value in case we already loaded a predictor which had a value stored for --searn_total_nb_policies
VW::cmd_string_replace_value(all->options_from_file,"--searn_total_nb_policies", ss2.str());
}
@@ -1492,58 +1492,58 @@ void print_update(vw& all, searn* srn)
srn.empty_example->in_use = true;
}
- void searn_finish(searn* srn)
+ void searn_finish(searn& srn)
{
- vw* all = srn->all;
+ vw* all = srn.all;
//cerr << "searn_finish" << endl;
- delete srn->truth_string;
- delete srn->pred_string;
- delete srn->neighbor_features_string;
- srn->neighbor_features.erase();
- srn->neighbor_features.delete_v();
+ delete srn.truth_string;
+ delete srn.pred_string;
+ delete srn.neighbor_features_string;
+ srn.neighbor_features.erase();
+ srn.neighbor_features.delete_v();
- if (srn->rollout_all_actions) { // dst should be a CSOAA::label*
- ((CSOAA::label*)srn->valid_labels)->costs.erase();
- ((CSOAA::label*)srn->valid_labels)->costs.delete_v();
+ if (srn.rollout_all_actions) { // dst should be a CSOAA::label*
+ ((CSOAA::label*)srn.valid_labels)->costs.erase();
+ ((CSOAA::label*)srn.valid_labels)->costs.delete_v();
} else {
- ((CB::label*)srn->valid_labels)->costs.erase();
- ((CB::label*)srn->valid_labels)->costs.delete_v();
+ ((CB::label*)srn.valid_labels)->costs.erase();
+ ((CB::label*)srn.valid_labels)->costs.delete_v();
}
- if (srn->rollout_all_actions) // labels are CSOAA
- delete (CSOAA::label*)srn->valid_labels;
+ if (srn.rollout_all_actions) // labels are CSOAA
+ delete (CSOAA::label*)srn.valid_labels;
else // labels are CB
- delete (CB::label*)srn->valid_labels;
+ delete (CB::label*)srn.valid_labels;
- dealloc_example(CSOAA::delete_label, *(srn->empty_example));
- free(srn->empty_example);
+ dealloc_example(CSOAA::delete_label, *(srn.empty_example));
+ free(srn.empty_example);
- srn->ec_seq.delete_v();
+ srn.ec_seq.delete_v();
- clear_snapshot(*all, *srn);
- srn->snapshot_data.delete_v();
+ clear_snapshot(*all, srn);
+ srn.snapshot_data.delete_v();
- for (size_t i=0; i<srn->train_labels.size(); i++) {
- if (srn->rollout_all_actions) {
- ((CSOAA::label*)srn->train_labels[i])->costs.erase();
- ((CSOAA::label*)srn->train_labels[i])->costs.delete_v();
- delete ((CSOAA::label*)srn->train_labels[i]);
+ for (size_t i=0; i<srn.train_labels.size(); i++) {
+ if (srn.rollout_all_actions) {
+ ((CSOAA::label*)srn.train_labels[i])->costs.erase();
+ ((CSOAA::label*)srn.train_labels[i])->costs.delete_v();
+ delete ((CSOAA::label*)srn.train_labels[i]);
} else {
- ((CB::label*)srn->train_labels[i])->costs.erase();
- ((CB::label*)srn->train_labels[i])->costs.delete_v();
- delete ((CB::label*)srn->train_labels[i]);
+ ((CB::label*)srn.train_labels[i])->costs.erase();
+ ((CB::label*)srn.train_labels[i])->costs.delete_v();
+ delete ((CB::label*)srn.train_labels[i]);
}
}
- srn->train_labels.delete_v();
- srn->train_action.delete_v();
- srn->train_action_ids.delete_v();
- srn->rollout_action.delete_v();
- srn->learn_losses.delete_v();
-
- if (srn->task->finish != NULL) {
- srn->task->finish(*srn);
- free(srn->task);
+ srn.train_labels.delete_v();
+ srn.train_action.delete_v();
+ srn.train_action_ids.delete_v();
+ srn.rollout_action.delete_v();
+ srn.learn_losses.delete_v();
+
+ if (srn.task->finish != NULL) {
+ srn.task->finish(srn);
+ free(srn.task);
}
}