diff options
Diffstat (limited to 'vowpalwabbit/ftrl_proximal.cc')
-rw-r--r-- | vowpalwabbit/ftrl_proximal.cc | 225 |
1 files changed, 225 insertions, 0 deletions
diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc new file mode 100644 index 00000000..6216058b --- /dev/null +++ b/vowpalwabbit/ftrl_proximal.cc @@ -0,0 +1,225 @@ +/* + Copyright (c) by respective owners including Yahoo!, Microsoft, and + individual contributors. All rights reserved. Released under a BSD (revised) + license as described in the file LICENSE. + */ +#include <fstream> +#include <float.h> +#ifndef _WIN32 +#include <netdb.h> +#endif +#include <string.h> +#include <stdio.h> +#include <assert.h> +#include <sys/timeb.h> +#include "parse_example.h" +#include "constant.h" +#include "cache.h" +#include "simple_label.h" +#include "vw.h" +#include "gd.h" +#include "accumulate.h" +#include "memory.h" +#include <exception> + +using namespace std; +using namespace LEARNER; + + +#define W_XT 0 // current parameter w(XT) +#define W_GT 1 // current gradient g(GT) +#define W_ZT 2 // accumulated z(t) = z(t-1) + g(t) + sigma*w(t) +#define W_G2 3 // accumulated gradient squre n(t) = n(t-1) + g(t)*g(t) + +/********************************************************************/ +/* mem & w definition ***********************************************/ +/********************************************************************/ +// w[0] = current weight +// w[1] = current first derivative +// w[2] = accumulated zt +// w[3] = accumulated g2 + +namespace FTRL { + + //nonrentrant + struct ftrl { + + vw* all; + // set by initializer + float ftrl_alpha; + float ftrl_beta; + + // evaluation file pointer + FILE* fo; + bool progressive_validation; + }; + + void update_accumulated_state(weight* w, float ftrl_alpha) { + double ng2 = w[W_G2] + w[W_GT]*w[W_GT]; + double sigma = (sqrt(ng2) - sqrt(w[W_G2]))/ ftrl_alpha; + w[W_ZT] += w[W_GT] - sigma * w[W_XT]; + w[W_G2] = ng2; + } + + struct update_data { + float update; + float ftrl_alpha; + float ftrl_beta; + float l1_lambda; + float l2_lambda; + }; + + //void update_grad(weight* weights, size_t mask, float loss_grad) + void update_grad(update_data& d, float x, float& wref) { + float* w = &wref; + w[W_GT] = d.update * x; + update_accumulated_state(w, d.ftrl_alpha); + } + + float ftrl_predict(vw& all, example& ec) { + ec.partial_prediction = GD::inline_predict(all, ec); + return GD::finalize_prediction(all.sd, ec.partial_prediction); + } + + float predict_and_gradient(vw& all, ftrl &b, example& ec) { + float fp = ftrl_predict(all, ec); + ec.updated_prediction = fp; + + label_data& ld = ec.l.simple; + all.set_minmax(all.sd, ld.label); + + struct update_data data; + + data.update = all.loss->first_derivative(all.sd, fp, ld.label) * ld.weight; + data.ftrl_alpha = b.ftrl_alpha; + + GD::foreach_feature<update_data,update_grad>(all, ec, data); + + return fp; + } + + inline float sign(float w){ if (w < 0.) return -1.; else return 1.;} + + void update_w(update_data& d, float x, float& wref) { + float* w = &wref; + float flag = sign(w[W_ZT]); + float fabs_zt = w[W_ZT] * flag; + if (fabs_zt <= d.l1_lambda) { + w[W_XT] = 0.; + } else { + double step = 1/(d.l2_lambda + (d.ftrl_beta + sqrt(w[W_G2]))/d.ftrl_alpha); + w[W_XT] = step * flag * (d.l1_lambda - fabs_zt); + } + } + + void update_weight(vw& all, ftrl &b, example& ec) { + + struct update_data data; + + data.ftrl_alpha = b.ftrl_alpha; + data.ftrl_beta = b.ftrl_beta; + data.l1_lambda = all.l1_lambda; + data.l2_lambda = all.l2_lambda; + + GD::foreach_feature<update_data, update_w>(all, ec, data); + + } + + void evaluate_example(vw& all, ftrl& b , example& ec) { + label_data& ld = ec.l.simple; + ec.loss = all.loss->getLoss(all.sd, ec.updated_prediction, ld.label) * ld.weight; + if (b.progressive_validation) { + float v = 1./(1 + exp(-ec.updated_prediction)); + fprintf(b.fo, "%.6f\t%d\n", v, (int)(ld.label * ld.weight)); + } + } + + //void learn(void* a, void* d, example* ec) { + void learn(ftrl& a, base_learner& base, example& ec) { + vw* all = a.all; + assert(ec.in_use); + + // predict w*x, compute gradient, update accumulate state + ec.pred.scalar = predict_and_gradient(*all, a, ec); + // evaluate, statistic + evaluate_example(*all, a, ec); + // update weight + update_weight(*all, a, ec); + } + + void save_load(ftrl& b, io_buf& model_file, bool read, bool text) { + vw* all = b.all; + if (read) { + initialize_regressor(*all); + } + + if (model_file.files.size() > 0) { + bool resume = all->save_resume; + char buff[512]; + uint32_t text_len = sprintf(buff, ":%d\n", resume); + bin_text_read_write_fixed(model_file,(char *)&resume, sizeof (resume), "", read, buff, text_len, text); + + if (resume) { + GD::save_load_online_state(*all, model_file, read, text); + //save_load_online_state(*all, model_file, read, text); + } else { + GD::save_load_regressor(*all, model_file, read, text); + } + } + + } + + // placeholder + void predict(ftrl& b, base_learner& base, example& ec) + { + vw* all = b.all; + //ec.l.simple.prediction = ftrl_predict(*all,ec); + ec.pred.scalar = ftrl_predict(*all,ec); + } + + base_learner* setup(vw& all, po::variables_map& vm) + { + ftrl& b = calloc_or_die<ftrl>(); + b.all = &all; + b.ftrl_beta = 0.0; + b.ftrl_alpha = 0.1; + + po::options_description ftrl_opts("FTRL options"); + + ftrl_opts.add_options() + ("ftrl_alpha", po::value<float>(&(b.ftrl_alpha)), "Learning rate for FTRL-proximal optimization") + ("ftrl_beta", po::value<float>(&(b.ftrl_beta)), "FTRL beta") + ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal"); + + vm = add_options(all, ftrl_opts); + + if (vm.count("ftrl_alpha")) { + b.ftrl_alpha = vm["ftrl_alpha"].as<float>(); + } + + if (vm.count("ftrl_beta")) { + b.ftrl_beta = vm["ftrl_beta"].as<float>(); + } + + all.reg.stride_shift = 2; // NOTE: for more parameter storage + + b.progressive_validation = false; + if (vm.count("progressive_validation")) { + std::string filename = vm["progressive_validation"].as<string>(); + b.fo = fopen(filename.c_str(), "w"); + assert(b.fo != NULL); + b.progressive_validation = true; + } + + if (!all.quiet) { + cerr << "Enabling FTRL-Proximal based optimization" << endl; + cerr << "ftrl_alpha = " << b.ftrl_alpha << endl; + cerr << "ftrl_beta = " << b.ftrl_beta << endl; + } + + learner<ftrl>& l = init_learner(&b, learn, 1 << all.reg.stride_shift); + l.set_predict(predict); + l.set_save_load(save_load); + return make_base(l); + } +} // end namespace |