Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/oaa.cc')
-rw-r--r--vowpalwabbit/oaa.cc60
1 files changed, 30 insertions, 30 deletions
diff --git a/vowpalwabbit/oaa.cc b/vowpalwabbit/oaa.cc
index 0a7f6649..01849199 100644
--- a/vowpalwabbit/oaa.cc
+++ b/vowpalwabbit/oaa.cc
@@ -110,11 +110,11 @@ namespace OAA {
}
}
- void print_update(vw& all, example *ec)
+ void print_update(vw& all, example &ec)
{
if (all.sd->weighted_examples >= all.sd->dump_interval && !all.quiet && !all.bfgs)
{
- mc_label* ld = (mc_label*) ec->ld;
+ mc_label* ld = (mc_label*) ec.ld;
char label_buf[32];
if (ld->label == INT_MAX)
strcpy(label_buf," unknown");
@@ -137,8 +137,8 @@ namespace OAA {
(long int)all.sd->example_number,
all.sd->weighted_examples,
label_buf,
- (long int)ec->final_prediction,
- (long unsigned int)ec->num_features);
+ (long int)ec.final_prediction,
+ (long unsigned int)ec.num_features);
all.sd->weighted_holdout_examples_since_last_dump = 0;
all.sd->holdout_sum_loss_since_last_dump = 0.0;
@@ -150,8 +150,8 @@ namespace OAA {
(long int)all.sd->example_number,
all.sd->weighted_examples,
label_buf,
- (long int)ec->final_prediction,
- (long unsigned int)ec->num_features);
+ (long int)ec.final_prediction,
+ (long unsigned int)ec.num_features);
all.sd->sum_loss_since_last_dump = 0.0;
all.sd->old_weighted_examples = all.sd->weighted_examples;
@@ -159,19 +159,19 @@ namespace OAA {
}
}
- void output_example(vw& all, example* ec)
+ void output_example(vw& all, example& ec)
{
- mc_label* ld = (mc_label*)ec->ld;
+ mc_label* ld = (mc_label*)ec.ld;
size_t loss = 1;
- if (ld->label == (uint32_t)ec->final_prediction)
+ if (ld->label == (uint32_t)ec.final_prediction)
loss = 0;
- if(ec->test_only)
+ if(ec.test_only)
{
- all.sd->weighted_holdout_examples += ec->global_weight;//test weight seen
- all.sd->weighted_holdout_examples_since_last_dump += ec->global_weight;
- all.sd->weighted_holdout_examples_since_last_pass += ec->global_weight;
+ all.sd->weighted_holdout_examples += ec.global_weight;//test weight seen
+ all.sd->weighted_holdout_examples_since_last_dump += ec.global_weight;
+ all.sd->weighted_holdout_examples_since_last_pass += ec.global_weight;
all.sd->holdout_sum_loss += loss;
all.sd->holdout_sum_loss_since_last_dump += loss;
all.sd->holdout_sum_loss_since_last_pass += loss;//since last pass
@@ -179,36 +179,36 @@ namespace OAA {
else
{
all.sd->weighted_examples += ld->weight;
- all.sd->total_features += ec->num_features;
+ all.sd->total_features += ec.num_features;
all.sd->sum_loss += loss;
all.sd->sum_loss_since_last_dump += loss;
all.sd->example_number++;
}
for (int* sink = all.final_prediction_sink.begin; sink != all.final_prediction_sink.end; sink++)
- all.print(*sink, ec->final_prediction, 0, ec->tag);
+ all.print(*sink, ec.final_prediction, 0, ec.tag);
OAA::print_update(all, ec);
}
- void finish_example(vw& all, oaa*, example* ec)
+ void finish_example(vw& all, oaa&, example& ec)
{
output_example(all, ec);
- VW::finish_example(all, ec);
+ VW::finish_example(all, &ec);
}
template <bool is_learn>
- void predict_or_learn(oaa* o, learner& base, example* ec) {
- vw* all = o->all;
+ void predict_or_learn(oaa& o, learner& base, example& ec) {
+ vw* all = o.all;
bool shouldOutput = all->raw_prediction > 0;
- mc_label* mc_label_data = (mc_label*)ec->ld;
+ mc_label* mc_label_data = (mc_label*)ec.ld;
float prediction = 1;
float score = INT_MIN;
- if (mc_label_data->label == 0 || (mc_label_data->label > o->k && mc_label_data->label != (uint32_t)-1))
- cout << "label " << mc_label_data->label << " is not in {1,"<< o->k << "} This won't work right." << endl;
+ if (mc_label_data->label == 0 || (mc_label_data->label > o.k && mc_label_data->label != (uint32_t)-1))
+ cout << "label " << mc_label_data->label << " is not in {1,"<< o.k << "} This won't work right." << endl;
string outputString;
stringstream outputStringStream(outputString);
@@ -216,9 +216,9 @@ namespace OAA {
label_data simple_temp;
simple_temp.initial = 0.;
simple_temp.weight = mc_label_data->weight;
- ec->ld = &simple_temp;
+ ec.ld = &simple_temp;
- for (size_t i = 1; i <= o->k; i++)
+ for (size_t i = 1; i <= o.k; i++)
{
if (is_learn)
{
@@ -232,22 +232,22 @@ namespace OAA {
else
base.predict(ec, i-1);
- if (ec->partial_prediction > score)
+ if (ec.partial_prediction > score)
{
- score = ec->partial_prediction;
+ score = ec.partial_prediction;
prediction = (float)i;
}
if (shouldOutput) {
if (i > 1) outputStringStream << ' ';
- outputStringStream << i << ':' << ec->partial_prediction;
+ outputStringStream << i << ':' << ec.partial_prediction;
}
}
- ec->ld = mc_label_data;
- ec->final_prediction = prediction;
+ ec.ld = mc_label_data;
+ ec.final_prediction = prediction;
if (shouldOutput)
- all->print_text(all->raw_prediction, outputStringStream.str(), ec->tag);
+ all->print_text(all->raw_prediction, outputStringStream.str(), ec.tag);
}
learner* setup(vw& all, std::vector<std::string>&opts, po::variables_map& vm, po::variables_map& vm_file)