Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/gd.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gd.cc')
-rw-r--r--gd.cc102
1 files changed, 49 insertions, 53 deletions
diff --git a/gd.cc b/gd.cc
index 2304a561..c245ff04 100644
--- a/gd.cc
+++ b/gd.cc
@@ -21,7 +21,6 @@ embodied in the content of this file are licensed under the BSD
#include "gd.h"
#include "cache.h"
#include "simple_label.h"
-#include "delay_ring.h"
#include "allreduce.h"
#include "accumulate.h"
@@ -40,43 +39,42 @@ void* gd_thread(void *in)
while ( true )
{//this is a poor man's select operation.
- if ((ec = get_delay_example()) != NULL)//nonblocking
+ if ((ec = get_example()) != NULL)//semiblocking operation.
{
+ assert(ec->in_use);
if (ec->pass != current_pass)
{
+ if(global.span_server != "") {
+ if(global.adaptive)
+ accumulate_weighted_avg(global.span_server, params->reg);
+ else
+ accumulate_avg(global.span_server, params->reg, 0);
+ }
+
if (global.save_per_pass)
sync_weights(&reg);
global.eta *= global.eta_decay_rate;
save_predictor(*(params->final_regressor_name), current_pass);
current_pass = ec->pass;
}
- if (global.adaptive)
- adaptive_inline_train(reg,ec,ec->eta_round);
- else
- inline_train(reg, ec, ec->eta_round);
- finish_example(ec);
- if (global.sd->contraction < 1e-10) // updating weights now to avoid numerical instability
- sync_weights(&reg);
- }
- else if ((ec = get_example()) != NULL)//semiblocking operation.
- {
- assert(ec->in_use);
- if (ec->pass != current_pass && global.span_server != "")
+
+ if (!command_example(ec, params))
{
- if(global.span_server != "") {
- if(global.adaptive)
- accumulate_weighted_avg(global.span_server, params->reg);
- else
- accumulate_avg(global.span_server, params->reg, 0);
- }
+ predict(reg,ec,*(params->vars));
+ if (ec->eta_round != 0.)
+ {
+ if (global.adaptive)
+ adaptive_inline_train(reg,ec,ec->eta_round);
+ else
+ inline_train(reg, ec, ec->eta_round);
+ if (global.sd->contraction < 1e-10) // updating weights now to avoid numerical instability
+ sync_weights(&reg);
+
+ }
}
-
- if (command_example(ec, params))
- delay_example(ec,0);
- else
- predict(reg,ec,*(params->vars));
+ finish_example(ec);
}
- else if (thread_done())
+ else if (parser_done())
{
sync_weights(&reg);
if(global.span_server != "") {
@@ -615,31 +613,35 @@ void local_predict(example* ec, gd_vars& vars, regressor& reg)
t = global.sd->weighted_unlabeled_examples;
else
t = ec->example_t;
-
+
+ ec->eta_round = 0;
if (ld->label != FLT_MAX)
{
ec->loss = reg.loss->getLoss(ec->final_prediction, ld->label) * ld->weight;
- double eta_t;
- float norm;
- if (global.adaptive && global.exact_adaptive_norm) {
- float magx = 0.;
- norm = compute_xGx(reg, ec, magx);
- eta_t = global.eta * norm / magx;
- } else {
- eta_t = global.eta / pow(t,vars.power_t) * ld->weight;
- norm = global.nonormalize ? 1. : ec->total_sum_feat_sq;
- }
-
- ec->eta_round = reg.loss->getUpdate(ec->final_prediction, ld->label, eta_t, norm) / global.sd->contraction;
-
- if (global.training && global.reg_mode && fabs(ec->eta_round) > 1e-8) {
- double dev1 = reg.loss->first_derivative(ec->final_prediction, ld->label);
- double eta_bar = (fabs(dev1) > 1e-8) ? (-ec->eta_round / dev1) : 0.0;
- if (fabs(dev1) > 1e-8)
- global.sd->contraction /= (1. + global.l2_lambda * eta_bar * norm);
- global.sd->gravity += eta_bar * sqrt(norm) * global.l1_lambda;
- }
+ if (global.training)
+ {
+ double eta_t;
+ float norm;
+ if (global.adaptive && global.exact_adaptive_norm) {
+ float magx = 0.;
+ norm = compute_xGx(reg, ec, magx);
+ eta_t = global.eta * norm / magx;
+ } else {
+ eta_t = global.eta / pow(t,vars.power_t) * ld->weight;
+ norm = global.nonormalize ? 1. : ec->total_sum_feat_sq;
+ }
+
+ ec->eta_round = reg.loss->getUpdate(ec->final_prediction, ld->label, eta_t, norm) / global.sd->contraction;
+
+ if (global.reg_mode && fabs(ec->eta_round) > 1e-8) {
+ double dev1 = reg.loss->first_derivative(ec->final_prediction, ld->label);
+ double eta_bar = (fabs(dev1) > 1e-8) ? (-ec->eta_round / dev1) : 0.0;
+ if (fabs(dev1) > 1e-8)
+ global.sd->contraction /= (1. + global.l2_lambda * eta_bar * norm);
+ global.sd->gravity += eta_bar * sqrt(norm) * global.l1_lambda;
+ }
+ }
}
else if(global.active)
ec->revert_weight = reg.loss->getRevertingWeight(ec->final_prediction, global.eta/pow(t,vars.power_t));
@@ -676,12 +678,6 @@ void predict(regressor& r, example* ex, gd_vars& vars)
local_predict(ex, vars,r);
ex->done = true;
-
- if (global.training && ((label_data*)(ex->ld))->label != FLT_MAX)
- delay_example(ex,1);
- else
- delay_example(ex,0);
-
}
// trains regressor r on one example ex.