diff options
Diffstat (limited to 'gd.cc')
-rw-r--r-- | gd.cc | 102 |
1 files changed, 49 insertions, 53 deletions
@@ -21,7 +21,6 @@ embodied in the content of this file are licensed under the BSD #include "gd.h" #include "cache.h" #include "simple_label.h" -#include "delay_ring.h" #include "allreduce.h" #include "accumulate.h" @@ -40,43 +39,42 @@ void* gd_thread(void *in) while ( true ) {//this is a poor man's select operation. - if ((ec = get_delay_example()) != NULL)//nonblocking + if ((ec = get_example()) != NULL)//semiblocking operation. { + assert(ec->in_use); if (ec->pass != current_pass) { + if(global.span_server != "") { + if(global.adaptive) + accumulate_weighted_avg(global.span_server, params->reg); + else + accumulate_avg(global.span_server, params->reg, 0); + } + if (global.save_per_pass) sync_weights(®); global.eta *= global.eta_decay_rate; save_predictor(*(params->final_regressor_name), current_pass); current_pass = ec->pass; } - if (global.adaptive) - adaptive_inline_train(reg,ec,ec->eta_round); - else - inline_train(reg, ec, ec->eta_round); - finish_example(ec); - if (global.sd->contraction < 1e-10) // updating weights now to avoid numerical instability - sync_weights(®); - } - else if ((ec = get_example()) != NULL)//semiblocking operation. - { - assert(ec->in_use); - if (ec->pass != current_pass && global.span_server != "") + + if (!command_example(ec, params)) { - if(global.span_server != "") { - if(global.adaptive) - accumulate_weighted_avg(global.span_server, params->reg); - else - accumulate_avg(global.span_server, params->reg, 0); - } + predict(reg,ec,*(params->vars)); + if (ec->eta_round != 0.) + { + if (global.adaptive) + adaptive_inline_train(reg,ec,ec->eta_round); + else + inline_train(reg, ec, ec->eta_round); + if (global.sd->contraction < 1e-10) // updating weights now to avoid numerical instability + sync_weights(®); + + } } - - if (command_example(ec, params)) - delay_example(ec,0); - else - predict(reg,ec,*(params->vars)); + finish_example(ec); } - else if (thread_done()) + else if (parser_done()) { sync_weights(®); if(global.span_server != "") { @@ -615,31 +613,35 @@ void local_predict(example* ec, gd_vars& vars, regressor& reg) t = global.sd->weighted_unlabeled_examples; else t = ec->example_t; - + + ec->eta_round = 0; if (ld->label != FLT_MAX) { ec->loss = reg.loss->getLoss(ec->final_prediction, ld->label) * ld->weight; - double eta_t; - float norm; - if (global.adaptive && global.exact_adaptive_norm) { - float magx = 0.; - norm = compute_xGx(reg, ec, magx); - eta_t = global.eta * norm / magx; - } else { - eta_t = global.eta / pow(t,vars.power_t) * ld->weight; - norm = global.nonormalize ? 1. : ec->total_sum_feat_sq; - } - - ec->eta_round = reg.loss->getUpdate(ec->final_prediction, ld->label, eta_t, norm) / global.sd->contraction; - - if (global.training && global.reg_mode && fabs(ec->eta_round) > 1e-8) { - double dev1 = reg.loss->first_derivative(ec->final_prediction, ld->label); - double eta_bar = (fabs(dev1) > 1e-8) ? (-ec->eta_round / dev1) : 0.0; - if (fabs(dev1) > 1e-8) - global.sd->contraction /= (1. + global.l2_lambda * eta_bar * norm); - global.sd->gravity += eta_bar * sqrt(norm) * global.l1_lambda; - } + if (global.training) + { + double eta_t; + float norm; + if (global.adaptive && global.exact_adaptive_norm) { + float magx = 0.; + norm = compute_xGx(reg, ec, magx); + eta_t = global.eta * norm / magx; + } else { + eta_t = global.eta / pow(t,vars.power_t) * ld->weight; + norm = global.nonormalize ? 1. : ec->total_sum_feat_sq; + } + + ec->eta_round = reg.loss->getUpdate(ec->final_prediction, ld->label, eta_t, norm) / global.sd->contraction; + + if (global.reg_mode && fabs(ec->eta_round) > 1e-8) { + double dev1 = reg.loss->first_derivative(ec->final_prediction, ld->label); + double eta_bar = (fabs(dev1) > 1e-8) ? (-ec->eta_round / dev1) : 0.0; + if (fabs(dev1) > 1e-8) + global.sd->contraction /= (1. + global.l2_lambda * eta_bar * norm); + global.sd->gravity += eta_bar * sqrt(norm) * global.l1_lambda; + } + } } else if(global.active) ec->revert_weight = reg.loss->getRevertingWeight(ec->final_prediction, global.eta/pow(t,vars.power_t)); @@ -676,12 +678,6 @@ void predict(regressor& r, example* ex, gd_vars& vars) local_predict(ex, vars,r); ex->done = true; - - if (global.training && ((label_data*)(ex->ld))->label != FLT_MAX) - delay_example(ex,1); - else - delay_example(ex,0); - } // trains regressor r on one example ex. |