diff options
author | John <jl@hunch.net> | 2014-12-21 18:25:27 +0300 |
---|---|---|
committer | John <jl@hunch.net> | 2014-12-21 18:25:27 +0300 |
commit | 3995d24c7ed2903ad790e96ecc80cb86ff13f3b9 (patch) | |
tree | fabd676d74e745b5da256e0740cc5ae86905004d /vowpalwabbit | |
parent | 0f327c9fcdb3b493eae890044be823be31cd19bf (diff) | |
parent | 878a24accb06f24ff3fc4ffbcecd536fbc8389cd (diff) |
Merge pull request #460 from grafke/master
A new algorithm: FTRL http://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf
Diffstat (limited to 'vowpalwabbit')
-rw-r--r-- | vowpalwabbit/Makefile.am | 2 | ||||
-rw-r--r-- | vowpalwabbit/ftrl_proximal.cc | 229 | ||||
-rw-r--r-- | vowpalwabbit/ftrl_proximal.h | 12 | ||||
-rw-r--r-- | vowpalwabbit/gd.cc | 28 | ||||
-rw-r--r-- | vowpalwabbit/gd.h | 1 | ||||
-rw-r--r-- | vowpalwabbit/global_data.h | 1 | ||||
-rw-r--r-- | vowpalwabbit/parse_args.cc | 4 |
7 files changed, 263 insertions, 14 deletions
diff --git a/vowpalwabbit/Makefile.am b/vowpalwabbit/Makefile.am index dd8a9787..676d6f4e 100644 --- a/vowpalwabbit/Makefile.am +++ b/vowpalwabbit/Makefile.am @@ -4,7 +4,7 @@ liballreduce_la_SOURCES = allreduce.cc bin_PROGRAMS = vw active_interactor -libvw_la_SOURCES = hash.cc memory.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc multiclass.cc oaa.cc ect.cc autolink.cc binary.cc lrq.cc cost_sensitive.cc csoaa.cc cb.cc cb_algs.cc search.cc search_sequencetask.cc search_dep_parser.cc search_hooktask.cc search_multiclasstask.cc search_entityrelationtask.cc parse_example.cc scorer.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc mf.cc bfgs.cc noop.cc print.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc stagewise_poly.cc log_multi.cc active.cc kernel_svm.cc best_constant.cc +libvw_la_SOURCES = hash.cc memory.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc multiclass.cc oaa.cc ect.cc autolink.cc binary.cc lrq.cc cost_sensitive.cc csoaa.cc cb.cc cb_algs.cc search.cc search_sequencetask.cc search_dep_parser.cc search_hooktask.cc search_multiclasstask.cc search_entityrelationtask.cc parse_example.cc scorer.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc mf.cc bfgs.cc noop.cc print.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc stagewise_poly.cc log_multi.cc active.cc kernel_svm.cc best_constant.cc ftrl_proximal.cc libvw_c_wrapper_la_SOURCES = vwdll.cpp diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc new file mode 100644 index 00000000..fd613367 --- /dev/null +++ b/vowpalwabbit/ftrl_proximal.cc @@ -0,0 +1,229 @@ +/* + Copyright (c) by respective owners including Yahoo!, Microsoft, and + individual contributors. All rights reserved. Released under a BSD (revised) + license as described in the file LICENSE. + */ +#include <fstream> +#include <float.h> +#ifndef _WIN32 +#include <netdb.h> +#endif +#include <string.h> +#include <stdio.h> +#include <assert.h> +#include <sys/timeb.h> +#include "parse_example.h" +#include "constant.h" +#include "cache.h" +#include "simple_label.h" +#include "vw.h" +#include "gd.h" +#include "accumulate.h" +#include "memory.h" +#include <exception> + +using namespace std; +using namespace LEARNER; + + +#define W_XT 0 // current parameter w(XT) +#define W_GT 1 // current gradient g(GT) +#define W_ZT 2 // accumulated z(t) = z(t-1) + g(t) + sigma*w(t) +#define W_G2 3 // accumulated gradient squre n(t) = n(t-1) + g(t)*g(t) + +/********************************************************************/ +/* mem & w definition ***********************************************/ +/********************************************************************/ +// w[0] = current weight +// w[1] = current first derivative +// w[2] = accumulated zt +// w[3] = accumulated g2 + +namespace FTRL { + + //nonrentrant + struct ftrl { + + vw* all; + // set by initializer + float ftrl_alpha; + float ftrl_beta; + + // evaluation file pointer + FILE* fo; + bool progressive_validation; + }; + + void update_accumulated_state(weight* w, float ftrl_alpha) { + double ng2 = w[W_G2] + w[W_GT]*w[W_GT]; + double sigma = (sqrt(ng2) - sqrt(w[W_G2]))/ ftrl_alpha; + w[W_ZT] += w[W_GT] - sigma * w[W_XT]; + w[W_G2] = ng2; + } + + struct update_data { + float update; + float ftrl_alpha; + float ftrl_beta; + float l1_lambda; + float l2_lambda; + }; + + //void update_grad(weight* weights, size_t mask, float loss_grad) + void update_grad(update_data& d, float x, float& wref) { + float* w = &wref; + w[W_GT] = d.update * x; + update_accumulated_state(w, d.ftrl_alpha); + } + + float ftrl_predict(vw& all, example& ec) { + ec.partial_prediction = GD::inline_predict(all, ec); + return GD::finalize_prediction(all.sd, ec.partial_prediction); + } + + float predict_and_gradient(vw& all, ftrl &b, example& ec) { + float fp = ftrl_predict(all, ec); + ec.updated_prediction = fp; + + label_data& ld = ec.l.simple; + all.set_minmax(all.sd, ld.label); + + struct update_data data; + + data.update = all.loss->first_derivative(all.sd, fp, ld.label) * ld.weight; + data.ftrl_alpha = b.ftrl_alpha; + + GD::foreach_feature<update_data,update_grad>(all, ec, data); + + return fp; + } + + inline float sign(float w){ if (w < 0.) return -1.; else return 1.;} + + void update_w(update_data& d, float x, float& wref) { + float* w = &wref; + float flag = sign(w[W_ZT]); + float fabs_zt = w[W_ZT] * flag; + if (fabs_zt <= d.l1_lambda) { + w[W_XT] = 0.; + } else { + double step = 1/(d.l2_lambda + (d.ftrl_beta + sqrt(w[W_G2]))/d.ftrl_alpha); + w[W_XT] = step * flag * (d.l1_lambda - fabs_zt); + } + } + + void update_weight(vw& all, ftrl &b, example& ec) { + + struct update_data data; + + data.ftrl_alpha = b.ftrl_alpha; + data.ftrl_beta = b.ftrl_beta; + data.l1_lambda = all.l1_lambda; + data.l2_lambda = all.l2_lambda; + + GD::foreach_feature<update_data, update_w>(all, ec, data); + + } + + void evaluate_example(vw& all, ftrl& b , example& ec) { + label_data& ld = ec.l.simple; + ec.loss = all.loss->getLoss(all.sd, ec.updated_prediction, ld.label) * ld.weight; + if (b.progressive_validation) { + float v = 1./(1 + exp(-ec.updated_prediction)); + fprintf(b.fo, "%.6f\t%d\n", v, (int)(ld.label * ld.weight)); + } + } + + //void learn(void* a, void* d, example* ec) { + void learn(ftrl& a, learner& base, example& ec) { + vw* all = a.all; + assert(ec.in_use); + + // predict w*x, compute gradient, update accumulate state + predict_and_gradient(*all, a, ec); + // evaluate, statistic + evaluate_example(*all, a, ec); + // update weight + update_weight(*all, a, ec); + } + + void save_load(ftrl& b, io_buf& model_file, bool read, bool text) { + vw* all = b.all; + if (read) { + initialize_regressor(*all); + } + + if (model_file.files.size() > 0) { + bool resume = all->save_resume; + char buff[512]; + uint32_t text_len = sprintf(buff, ":%d\n", resume); + bin_text_read_write_fixed(model_file,(char *)&resume, sizeof (resume), "", read, buff, text_len, text); + + if (resume) { + GD::save_load_online_state(*all, model_file, read, text); + //save_load_online_state(*all, model_file, read, text); + } else { + GD::save_load_regressor(*all, model_file, read, text); + } + } + + } + + // placeholder + void predict(ftrl& b, learner& base, example& ec) + { + vw* all = b.all; + //ec.l.simple.prediction = ftrl_predict(*all,ec); + ec.pred.scalar = ftrl_predict(*all,ec); + } + + learner* setup(vw& all, po::variables_map& vm) { + + ftrl* b = (ftrl*)calloc_or_die(1, sizeof(ftrl)); + b->all = &all; + b->ftrl_beta = 0.0; + b->ftrl_alpha = 0.1; + + po::options_description ftrl_opts("FTRL options"); + + ftrl_opts.add_options() + ("ftrl_alpha", po::value<float>(&(b->ftrl_alpha)), "Learning rate for FTRL-proximal optimization") + ("ftrl_beta", po::value<float>(&(b->ftrl_beta)), "FTRL beta") + ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal"); + + vm = add_options(all, ftrl_opts); + + if (vm.count("ftrl_alpha")) { + b->ftrl_alpha = vm["ftrl_alpha"].as<float>(); + } + + if (vm.count("ftrl_beta")) { + b->ftrl_beta = vm["ftrl_beta"].as<float>(); + } + + all.reg.stride_shift = 2; // NOTE: for more parameter storage + + b->progressive_validation = false; + if (vm.count("progressive_validation")) { + std::string filename = vm["progressive_validation"].as<string>(); + b->fo = fopen(filename.c_str(), "w"); + assert(b->fo != NULL); + b->progressive_validation = true; + } + + if (!all.quiet) { + cerr << "Enabling FTRL-Proximal based optimization" << endl; + cerr << "ftrl_alpha = " << b->ftrl_alpha << endl; + cerr << "ftrl_beta = " << b->ftrl_beta << endl; + } + + learner* l = new learner(b, 1 << all.reg.stride_shift); + l->set_learn<ftrl, learn>(); + l->set_predict<ftrl, predict>(); + l->set_save_load<ftrl,save_load>(); + + return l; + } + + +} // end namespace diff --git a/vowpalwabbit/ftrl_proximal.h b/vowpalwabbit/ftrl_proximal.h new file mode 100644 index 00000000..934d91c8 --- /dev/null +++ b/vowpalwabbit/ftrl_proximal.h @@ -0,0 +1,12 @@ +/* +Copyright (c) by respective owners including Yahoo!, Microsoft, and +individual contributors. All rights reserved. Released under a BSD +license as described in the file LICENSE. + */ +#ifndef FTRL_PROXIMAL_H +#define FTRL_PROXIMAL_H + +namespace FTRL { + LEARNER::learner* setup(vw& all, po::variables_map& vm); +} +#endif diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc index 589d39c2..1e74da0f 100644 --- a/vowpalwabbit/gd.cc +++ b/vowpalwabbit/gd.cc @@ -32,7 +32,7 @@ using namespace LEARNER; namespace GD { struct gd{ - double normalized_sum_norm_x; + //double normalized_sum_norm_x; double total_weight; size_t no_win_counter; size_t early_stop_thres; @@ -75,14 +75,14 @@ namespace GD if (normalized) { if (sqrt_rate) { - float avg_norm = (float) g.total_weight / (float) g.normalized_sum_norm_x; + float avg_norm = (float) g.total_weight / (float) g.all->normalized_sum_norm_x; if (adaptive) return sqrt(avg_norm); else return avg_norm; } else - return powf( (float) g.normalized_sum_norm_x / (float) g.total_weight, g.neg_norm_power); + return powf( (float) g.all->normalized_sum_norm_x / (float) g.total_weight, g.neg_norm_power); } return 1.f; } @@ -451,7 +451,7 @@ template<bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normaliz foreach_feature<norm_data,pred_per_update_feature<sqrt_rate, feature_mask_off, adaptive, normalized, spare> >(all, ec, nd); if(normalized) { - g.normalized_sum_norm_x += ld.weight * nd.norm_x; + g.all->normalized_sum_norm_x += ld.weight * nd.norm_x; g.total_weight += ld.weight; g.update_multiplier = average_update<sqrt_rate, adaptive, normalized>(g, nd.pred_per_update); @@ -609,9 +609,10 @@ void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text) while ((!read && i < length) || (read && brw >0)); } -void save_load_online_state(gd& g, io_buf& model_file, bool read, bool text) +//void save_load_online_state(gd& g, io_buf& model_file, bool read, bool text) +void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text) { - vw& all = *g.all; + //vw& all = *g.all; char buff[512]; @@ -620,10 +621,10 @@ void save_load_online_state(gd& g, io_buf& model_file, bool read, bool text) "", read, buff, text_len, text); - text_len = sprintf(buff, "norm normalizer %f\n", g.normalized_sum_norm_x); - bin_text_read_write_fixed(model_file,(char*)&g.normalized_sum_norm_x, sizeof(g.normalized_sum_norm_x), - "", read, - buff, text_len, text); + text_len = sprintf(buff, "norm normalizer %f\n", all.normalized_sum_norm_x); + bin_text_read_write_fixed(model_file,(char*)&all.normalized_sum_norm_x, sizeof(all.normalized_sum_norm_x), + "", read, + buff, text_len, text); text_len = sprintf(buff, "t %f\n", all.sd->t); bin_text_read_write_fixed(model_file,(char*)&all.sd->t, sizeof(all.sd->t), @@ -780,7 +781,8 @@ void save_load(gd& g, io_buf& model_file, bool read, bool text) "", read, buff, text_len, text); if (resume) - save_load_online_state(g, model_file, read, text); + //save_load_online_state(g, model_file, read, text); + save_load_online_state(all, model_file, read, text); else save_load_regressor(all, model_file, read, text); } @@ -844,7 +846,7 @@ learner* setup(vw& all, po::variables_map& vm) { gd* g = (gd*)calloc_or_die(1, sizeof(gd)); g->all = &all; - g->normalized_sum_norm_x = 0; + g->all->normalized_sum_norm_x = 0; g->no_win_counter = 0; g->total_weight = 0.; g->early_stop_thres = 3; @@ -853,7 +855,7 @@ learner* setup(vw& all, po::variables_map& vm) if(all.initial_t > 0)//for the normalized update: if initial_t is bigger than 1 we interpret this as if we had seen (all.initial_t) previous fake datapoints all with norm 1 { - g->normalized_sum_norm_x = all.initial_t; + g->all->normalized_sum_norm_x = all.initial_t; g->total_weight = all.initial_t; } diff --git a/vowpalwabbit/gd.h b/vowpalwabbit/gd.h index 05bd5b5d..be52e62f 100644 --- a/vowpalwabbit/gd.h +++ b/vowpalwabbit/gd.h @@ -26,6 +26,7 @@ namespace GD{ void train_one_example_single_thread(regressor& r, example* ex); LEARNER::learner* setup(vw& all, po::variables_map& vm); void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text); + void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text); void output_and_account_example(example* ec); // iterate through one namespace (or its part), callback function T(some_data_R, feature_value_x, feature_weight) diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h index dd8567a8..c90ae05b 100644 --- a/vowpalwabbit/global_data.h +++ b/vowpalwabbit/global_data.h @@ -197,6 +197,7 @@ struct vw { int m; bool save_resume; + double normalized_sum_norm_x; po::options_description opts; std::string file_options; diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index efa0961e..d417ca0c 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -35,6 +35,7 @@ license as described in the file LICENSE. #include "gd_mf.h" #include "mf.h" #include "vw.h" +#include "ftrl_proximal.h" #include "rand48.h" #include "parse_args.h" #include "binary.h" @@ -722,6 +723,7 @@ void parse_base_algorithm(vw& all, po::variables_map& vm) base_opt.add_options() ("sgd", "use regular stochastic gradient descent update.") + ("ftrl", "use ftrl-proximal optimization") ("adaptive", "use adaptive, individual learning rates.") ("invariant", "use safe/importance aware updates.") ("normalized", "use per feature normalized updates") @@ -740,6 +742,8 @@ void parse_base_algorithm(vw& all, po::variables_map& vm) all.l = BFGS::setup(all, vm); else if (vm.count("lda")) all.l = LDA::setup(all, vm); + else if (vm.count("ftrl")) + all.l = FTRL::setup(all, vm); else if (vm.count("noop")) all.l = NOOP::setup(all); else if (vm.count("print")) |