Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn <jl@hunch.net>2014-12-21 18:25:27 +0300
committerJohn <jl@hunch.net>2014-12-21 18:25:27 +0300
commit3995d24c7ed2903ad790e96ecc80cb86ff13f3b9 (patch)
treefabd676d74e745b5da256e0740cc5ae86905004d
parent0f327c9fcdb3b493eae890044be823be31cd19bf (diff)
parent878a24accb06f24ff3fc4ffbcecd536fbc8389cd (diff)
Merge pull request #460 from grafke/master
A new algorithm: FTRL http://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf
-rwxr-xr-xtest/RunTests9
-rw-r--r--test/pred-sets/ref/0001_ftrl.predict200
-rw-r--r--test/test-sets/ref/0001_ftrl.stderr29
-rw-r--r--test/train-sets/ref/0001_ftrl.stderr31
-rw-r--r--vowpalwabbit/Makefile.am2
-rw-r--r--vowpalwabbit/ftrl_proximal.cc229
-rw-r--r--vowpalwabbit/ftrl_proximal.h12
-rw-r--r--vowpalwabbit/gd.cc28
-rw-r--r--vowpalwabbit/gd.h1
-rw-r--r--vowpalwabbit/global_data.h1
-rw-r--r--vowpalwabbit/parse_args.cc4
11 files changed, 532 insertions, 14 deletions
diff --git a/test/RunTests b/test/RunTests
index 518f9df1..2d04a6f5 100755
--- a/test/RunTests
+++ b/test/RunTests
@@ -1077,3 +1077,12 @@ __DATA__
{VW} -d train-sets/0002.dat --autolink 1 --examples 100 -p 0002.autolink.predict
train-sets/ref/0002.autolink.stderr
train-sets/ref/0002.autolink.predict
+
+# Test 72: train FTRL-Proximal
+{VW} -k -d train-sets/0001.dat -f models/0001_ftrl.model --passes 1 --ftrl --ftrl_alpha 0.01 --ftrl_beta 0 --l1 2
+ train-sets/ref/0001_ftrl.stderr
+
+# Test 73: test FTRL-Proximal
+{VW} -k -t train-sets/0001.dat -i models/0001_ftrl.model -p ftrl_001.predict.tmp
+ test-sets/ref/0001_ftrl.stderr
+ pred-sets/ref/0001_ftrl.predict
diff --git a/test/pred-sets/ref/0001_ftrl.predict b/test/pred-sets/ref/0001_ftrl.predict
new file mode 100644
index 00000000..7b0f7232
--- /dev/null
+++ b/test/pred-sets/ref/0001_ftrl.predict
@@ -0,0 +1,200 @@
+0.148139
+0.149886
+0.146102
+0.146261
+0.146004
+0.150632
+0.146453
+0.149597
+0.146937
+0.151828
+0.145943
+0.149081
+0.146763
+0.149353
+0.151105
+0.151723
+0.146879
+0.147548
+0.149135
+0.146180
+0.147393
+0.148680
+0.147577
+0.147491
+0.145806
+0.146659
+0.146162
+0.146864
+0.148203
+0.146874
+0.154894
+0.145806
+0.146263
+0.146565
+0.147723
+0.149017
+0.148744
+0.148413
+0.147824
+0.164194
+0.145806
+0.147537
+0.146497
+0.147713
+0.146387
+0.145806
+0.147941
+0.145994
+0.147120
+0.145839
+0.146758
+0.146780
+0.146082
+0.148553
+0.158495
+0.149754
+0.148530
+0.149789
+0.147992
+0.146164
+0.147383
+0.147015
+0.151542
+0.145806
+0.150415
+0.146394
+0.145806
+0.146673
+0.148820
+0.148958
+0.147902
+0.149351
+0.146609
+0.147084
+0.153529
+0.147889
+0.147304
+0.147790
+0.145806
+0.146484
+0.145951
+0.146190
+0.157696
+0.145881
+0.145916
+0.145806
+0.145841
+0.148447
+0.151769
+0.147781
+0.145867
+0.149931
+0.160551
+0.146357
+0.148946
+0.145857
+0.148735
+0.145806
+0.146633
+0.147460
+0.146732
+0.147819
+0.146551
+0.147912
+0.147477
+0.149064
+0.145985
+0.146388
+0.146095
+0.146254
+0.150747
+0.145985
+0.147522
+0.147671
+0.145806
+0.149411
+0.146786
+0.147408
+0.146400
+0.147492
+0.146148
+0.145990
+0.147976
+0.145955
+0.149059
+0.157779
+0.146221
+0.146270
+0.146986
+0.149931
+0.146057
+0.149236
+0.145806
+0.145806
+0.146095
+0.146229
+0.147008
+0.147346
+0.156727
+0.146704
+0.146059
+0.145983
+0.150402
+0.152494
+0.149853
+0.147685
+0.149237
+0.164478
+0.146583
+0.146943
+0.148574
+0.145806
+0.147213
+0.146807
+0.147313
+0.147939
+0.158719
+0.147603
+0.167189
+0.145957
+0.161452
+0.146334
+0.146107
+0.156393
+0.146060
+0.147053
+0.158706
+0.147079
+0.146229
+0.149644
+0.150721
+0.146385
+0.146222
+0.147244
+0.145906
+0.150127
+0.148533
+0.156075
+0.148192
+0.153242
+0.149490
+0.150472
+0.147509
+0.145806
+0.146288
+0.146351
+0.148459
+0.149844
+0.147283
+0.151779
+0.158779
+0.145806
+0.145806
+0.145913
+0.145806
+0.147957
+0.149762
+0.146178
+0.145888
+0.146106
diff --git a/test/test-sets/ref/0001_ftrl.stderr b/test/test-sets/ref/0001_ftrl.stderr
new file mode 100644
index 00000000..4d9c1c54
--- /dev/null
+++ b/test/test-sets/ref/0001_ftrl.stderr
@@ -0,0 +1,29 @@
+only testing
+Num weight bits = 18
+learning rate = 10
+initial_t = 1
+power_t = 0.5
+predictions = ftrl_001.predict.tmp
+using no cache
+Reading datafile = train-sets/0001.dat
+num sources = 1
+average since example example current current current
+loss last counter weight label predict features
+0.725668 0.725668 1 1.0 1.0000 0.1481 51
+0.374067 0.022466 2 2.0 0.0000 0.1499 104
+0.197718 0.021369 4 4.0 0.0000 0.1463 135
+0.197180 0.196643 8 8.0 0.0000 0.1496 146
+0.240375 0.283569 16 16.0 1.0000 0.1517 24
+0.262741 0.285108 32 32.0 0.0000 0.1458 32
+0.273560 0.284380 64 64.0 0.0000 0.1458 61
+0.312129 0.350697 128 128.0 1.0000 0.1463 106
+
+finished run
+number of examples per pass = 200
+passes used = 1
+weighted example sum = 200
+weighted label sum = 91
+average loss = 0.3403
+best constant = 0.455
+best constant's loss = 0.247975
+total feature number = 15482
diff --git a/test/train-sets/ref/0001_ftrl.stderr b/test/train-sets/ref/0001_ftrl.stderr
new file mode 100644
index 00000000..a1d2a211
--- /dev/null
+++ b/test/train-sets/ref/0001_ftrl.stderr
@@ -0,0 +1,31 @@
+using l1 regularization = 2
+final_regressor = models/0001_ftrl.model
+Enabling FTRL-Proximal based optimization
+ftrl_alpha = 0.01
+ftrl_beta = 0
+Num weight bits = 18
+learning rate = 0.5
+initial_t = 0
+power_t = 0.5
+using no cache
+Reading datafile = train-sets/0001.dat
+num sources = 1
+average since example example current current current
+loss last counter weight label predict features
+1.000000 1.000000 1 1.0 1.0000 0.0000 51
+0.500000 0.000000 2 2.0 0.0000 0.0000 104
+0.250000 0.000000 4 4.0 0.0000 0.0000 135
+0.250000 0.250000 8 8.0 0.0000 0.0000 146
+0.312500 0.375000 16 16.0 1.0000 0.0000 24
+0.343750 0.375000 32 32.0 0.0000 0.0000 32
+0.359375 0.375000 64 64.0 0.0000 0.0000 61
+0.414062 0.468750 128 128.0 1.0000 0.0000 106
+
+finished run
+number of examples = 200
+weighted example sum = 200
+weighted label sum = 91
+average loss = 0.455
+best constant = 0.455
+best constant's loss = 0.247975
+total feature number = 15482
diff --git a/vowpalwabbit/Makefile.am b/vowpalwabbit/Makefile.am
index dd8a9787..676d6f4e 100644
--- a/vowpalwabbit/Makefile.am
+++ b/vowpalwabbit/Makefile.am
@@ -4,7 +4,7 @@ liballreduce_la_SOURCES = allreduce.cc
bin_PROGRAMS = vw active_interactor
-libvw_la_SOURCES = hash.cc memory.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc multiclass.cc oaa.cc ect.cc autolink.cc binary.cc lrq.cc cost_sensitive.cc csoaa.cc cb.cc cb_algs.cc search.cc search_sequencetask.cc search_dep_parser.cc search_hooktask.cc search_multiclasstask.cc search_entityrelationtask.cc parse_example.cc scorer.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc mf.cc bfgs.cc noop.cc print.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc stagewise_poly.cc log_multi.cc active.cc kernel_svm.cc best_constant.cc
+libvw_la_SOURCES = hash.cc memory.cc global_data.cc io_buf.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc rand48.cc simple_label.cc multiclass.cc oaa.cc ect.cc autolink.cc binary.cc lrq.cc cost_sensitive.cc csoaa.cc cb.cc cb_algs.cc search.cc search_sequencetask.cc search_dep_parser.cc search_hooktask.cc search_multiclasstask.cc search_entityrelationtask.cc parse_example.cc scorer.cc network.cc parse_args.cc accumulate.cc gd.cc learner.cc lda_core.cc gd_mf.cc mf.cc bfgs.cc noop.cc print.cc example.cc parser.cc loss_functions.cc sender.cc nn.cc bs.cc cbify.cc topk.cc stagewise_poly.cc log_multi.cc active.cc kernel_svm.cc best_constant.cc ftrl_proximal.cc
libvw_c_wrapper_la_SOURCES = vwdll.cpp
diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc
new file mode 100644
index 00000000..fd613367
--- /dev/null
+++ b/vowpalwabbit/ftrl_proximal.cc
@@ -0,0 +1,229 @@
+/*
+ Copyright (c) by respective owners including Yahoo!, Microsoft, and
+ individual contributors. All rights reserved. Released under a BSD (revised)
+ license as described in the file LICENSE.
+ */
+#include <fstream>
+#include <float.h>
+#ifndef _WIN32
+#include <netdb.h>
+#endif
+#include <string.h>
+#include <stdio.h>
+#include <assert.h>
+#include <sys/timeb.h>
+#include "parse_example.h"
+#include "constant.h"
+#include "cache.h"
+#include "simple_label.h"
+#include "vw.h"
+#include "gd.h"
+#include "accumulate.h"
+#include "memory.h"
+#include <exception>
+
+using namespace std;
+using namespace LEARNER;
+
+
+#define W_XT 0 // current parameter w(XT)
+#define W_GT 1 // current gradient g(GT)
+#define W_ZT 2 // accumulated z(t) = z(t-1) + g(t) + sigma*w(t)
+#define W_G2 3 // accumulated gradient squre n(t) = n(t-1) + g(t)*g(t)
+
+/********************************************************************/
+/* mem & w definition ***********************************************/
+/********************************************************************/
+// w[0] = current weight
+// w[1] = current first derivative
+// w[2] = accumulated zt
+// w[3] = accumulated g2
+
+namespace FTRL {
+
+ //nonrentrant
+ struct ftrl {
+
+ vw* all;
+ // set by initializer
+ float ftrl_alpha;
+ float ftrl_beta;
+
+ // evaluation file pointer
+ FILE* fo;
+ bool progressive_validation;
+ };
+
+ void update_accumulated_state(weight* w, float ftrl_alpha) {
+ double ng2 = w[W_G2] + w[W_GT]*w[W_GT];
+ double sigma = (sqrt(ng2) - sqrt(w[W_G2]))/ ftrl_alpha;
+ w[W_ZT] += w[W_GT] - sigma * w[W_XT];
+ w[W_G2] = ng2;
+ }
+
+ struct update_data {
+ float update;
+ float ftrl_alpha;
+ float ftrl_beta;
+ float l1_lambda;
+ float l2_lambda;
+ };
+
+ //void update_grad(weight* weights, size_t mask, float loss_grad)
+ void update_grad(update_data& d, float x, float& wref) {
+ float* w = &wref;
+ w[W_GT] = d.update * x;
+ update_accumulated_state(w, d.ftrl_alpha);
+ }
+
+ float ftrl_predict(vw& all, example& ec) {
+ ec.partial_prediction = GD::inline_predict(all, ec);
+ return GD::finalize_prediction(all.sd, ec.partial_prediction);
+ }
+
+ float predict_and_gradient(vw& all, ftrl &b, example& ec) {
+ float fp = ftrl_predict(all, ec);
+ ec.updated_prediction = fp;
+
+ label_data& ld = ec.l.simple;
+ all.set_minmax(all.sd, ld.label);
+
+ struct update_data data;
+
+ data.update = all.loss->first_derivative(all.sd, fp, ld.label) * ld.weight;
+ data.ftrl_alpha = b.ftrl_alpha;
+
+ GD::foreach_feature<update_data,update_grad>(all, ec, data);
+
+ return fp;
+ }
+
+ inline float sign(float w){ if (w < 0.) return -1.; else return 1.;}
+
+ void update_w(update_data& d, float x, float& wref) {
+ float* w = &wref;
+ float flag = sign(w[W_ZT]);
+ float fabs_zt = w[W_ZT] * flag;
+ if (fabs_zt <= d.l1_lambda) {
+ w[W_XT] = 0.;
+ } else {
+ double step = 1/(d.l2_lambda + (d.ftrl_beta + sqrt(w[W_G2]))/d.ftrl_alpha);
+ w[W_XT] = step * flag * (d.l1_lambda - fabs_zt);
+ }
+ }
+
+ void update_weight(vw& all, ftrl &b, example& ec) {
+
+ struct update_data data;
+
+ data.ftrl_alpha = b.ftrl_alpha;
+ data.ftrl_beta = b.ftrl_beta;
+ data.l1_lambda = all.l1_lambda;
+ data.l2_lambda = all.l2_lambda;
+
+ GD::foreach_feature<update_data, update_w>(all, ec, data);
+
+ }
+
+ void evaluate_example(vw& all, ftrl& b , example& ec) {
+ label_data& ld = ec.l.simple;
+ ec.loss = all.loss->getLoss(all.sd, ec.updated_prediction, ld.label) * ld.weight;
+ if (b.progressive_validation) {
+ float v = 1./(1 + exp(-ec.updated_prediction));
+ fprintf(b.fo, "%.6f\t%d\n", v, (int)(ld.label * ld.weight));
+ }
+ }
+
+ //void learn(void* a, void* d, example* ec) {
+ void learn(ftrl& a, learner& base, example& ec) {
+ vw* all = a.all;
+ assert(ec.in_use);
+
+ // predict w*x, compute gradient, update accumulate state
+ predict_and_gradient(*all, a, ec);
+ // evaluate, statistic
+ evaluate_example(*all, a, ec);
+ // update weight
+ update_weight(*all, a, ec);
+ }
+
+ void save_load(ftrl& b, io_buf& model_file, bool read, bool text) {
+ vw* all = b.all;
+ if (read) {
+ initialize_regressor(*all);
+ }
+
+ if (model_file.files.size() > 0) {
+ bool resume = all->save_resume;
+ char buff[512];
+ uint32_t text_len = sprintf(buff, ":%d\n", resume);
+ bin_text_read_write_fixed(model_file,(char *)&resume, sizeof (resume), "", read, buff, text_len, text);
+
+ if (resume) {
+ GD::save_load_online_state(*all, model_file, read, text);
+ //save_load_online_state(*all, model_file, read, text);
+ } else {
+ GD::save_load_regressor(*all, model_file, read, text);
+ }
+ }
+
+ }
+
+ // placeholder
+ void predict(ftrl& b, learner& base, example& ec)
+ {
+ vw* all = b.all;
+ //ec.l.simple.prediction = ftrl_predict(*all,ec);
+ ec.pred.scalar = ftrl_predict(*all,ec);
+ }
+
+ learner* setup(vw& all, po::variables_map& vm) {
+
+ ftrl* b = (ftrl*)calloc_or_die(1, sizeof(ftrl));
+ b->all = &all;
+ b->ftrl_beta = 0.0;
+ b->ftrl_alpha = 0.1;
+
+ po::options_description ftrl_opts("FTRL options");
+
+ ftrl_opts.add_options()
+ ("ftrl_alpha", po::value<float>(&(b->ftrl_alpha)), "Learning rate for FTRL-proximal optimization")
+ ("ftrl_beta", po::value<float>(&(b->ftrl_beta)), "FTRL beta")
+ ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal");
+
+ vm = add_options(all, ftrl_opts);
+
+ if (vm.count("ftrl_alpha")) {
+ b->ftrl_alpha = vm["ftrl_alpha"].as<float>();
+ }
+
+ if (vm.count("ftrl_beta")) {
+ b->ftrl_beta = vm["ftrl_beta"].as<float>();
+ }
+
+ all.reg.stride_shift = 2; // NOTE: for more parameter storage
+
+ b->progressive_validation = false;
+ if (vm.count("progressive_validation")) {
+ std::string filename = vm["progressive_validation"].as<string>();
+ b->fo = fopen(filename.c_str(), "w");
+ assert(b->fo != NULL);
+ b->progressive_validation = true;
+ }
+
+ if (!all.quiet) {
+ cerr << "Enabling FTRL-Proximal based optimization" << endl;
+ cerr << "ftrl_alpha = " << b->ftrl_alpha << endl;
+ cerr << "ftrl_beta = " << b->ftrl_beta << endl;
+ }
+
+ learner* l = new learner(b, 1 << all.reg.stride_shift);
+ l->set_learn<ftrl, learn>();
+ l->set_predict<ftrl, predict>();
+ l->set_save_load<ftrl,save_load>();
+
+ return l;
+ }
+
+
+} // end namespace
diff --git a/vowpalwabbit/ftrl_proximal.h b/vowpalwabbit/ftrl_proximal.h
new file mode 100644
index 00000000..934d91c8
--- /dev/null
+++ b/vowpalwabbit/ftrl_proximal.h
@@ -0,0 +1,12 @@
+/*
+Copyright (c) by respective owners including Yahoo!, Microsoft, and
+individual contributors. All rights reserved. Released under a BSD
+license as described in the file LICENSE.
+ */
+#ifndef FTRL_PROXIMAL_H
+#define FTRL_PROXIMAL_H
+
+namespace FTRL {
+ LEARNER::learner* setup(vw& all, po::variables_map& vm);
+}
+#endif
diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc
index 589d39c2..1e74da0f 100644
--- a/vowpalwabbit/gd.cc
+++ b/vowpalwabbit/gd.cc
@@ -32,7 +32,7 @@ using namespace LEARNER;
namespace GD
{
struct gd{
- double normalized_sum_norm_x;
+ //double normalized_sum_norm_x;
double total_weight;
size_t no_win_counter;
size_t early_stop_thres;
@@ -75,14 +75,14 @@ namespace GD
if (normalized) {
if (sqrt_rate)
{
- float avg_norm = (float) g.total_weight / (float) g.normalized_sum_norm_x;
+ float avg_norm = (float) g.total_weight / (float) g.all->normalized_sum_norm_x;
if (adaptive)
return sqrt(avg_norm);
else
return avg_norm;
}
else
- return powf( (float) g.normalized_sum_norm_x / (float) g.total_weight, g.neg_norm_power);
+ return powf( (float) g.all->normalized_sum_norm_x / (float) g.total_weight, g.neg_norm_power);
}
return 1.f;
}
@@ -451,7 +451,7 @@ template<bool sqrt_rate, bool feature_mask_off, size_t adaptive, size_t normaliz
foreach_feature<norm_data,pred_per_update_feature<sqrt_rate, feature_mask_off, adaptive, normalized, spare> >(all, ec, nd);
if(normalized) {
- g.normalized_sum_norm_x += ld.weight * nd.norm_x;
+ g.all->normalized_sum_norm_x += ld.weight * nd.norm_x;
g.total_weight += ld.weight;
g.update_multiplier = average_update<sqrt_rate, adaptive, normalized>(g, nd.pred_per_update);
@@ -609,9 +609,10 @@ void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text)
while ((!read && i < length) || (read && brw >0));
}
-void save_load_online_state(gd& g, io_buf& model_file, bool read, bool text)
+//void save_load_online_state(gd& g, io_buf& model_file, bool read, bool text)
+void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text)
{
- vw& all = *g.all;
+ //vw& all = *g.all;
char buff[512];
@@ -620,10 +621,10 @@ void save_load_online_state(gd& g, io_buf& model_file, bool read, bool text)
"", read,
buff, text_len, text);
- text_len = sprintf(buff, "norm normalizer %f\n", g.normalized_sum_norm_x);
- bin_text_read_write_fixed(model_file,(char*)&g.normalized_sum_norm_x, sizeof(g.normalized_sum_norm_x),
- "", read,
- buff, text_len, text);
+ text_len = sprintf(buff, "norm normalizer %f\n", all.normalized_sum_norm_x);
+ bin_text_read_write_fixed(model_file,(char*)&all.normalized_sum_norm_x, sizeof(all.normalized_sum_norm_x),
+ "", read,
+ buff, text_len, text);
text_len = sprintf(buff, "t %f\n", all.sd->t);
bin_text_read_write_fixed(model_file,(char*)&all.sd->t, sizeof(all.sd->t),
@@ -780,7 +781,8 @@ void save_load(gd& g, io_buf& model_file, bool read, bool text)
"", read,
buff, text_len, text);
if (resume)
- save_load_online_state(g, model_file, read, text);
+ //save_load_online_state(g, model_file, read, text);
+ save_load_online_state(all, model_file, read, text);
else
save_load_regressor(all, model_file, read, text);
}
@@ -844,7 +846,7 @@ learner* setup(vw& all, po::variables_map& vm)
{
gd* g = (gd*)calloc_or_die(1, sizeof(gd));
g->all = &all;
- g->normalized_sum_norm_x = 0;
+ g->all->normalized_sum_norm_x = 0;
g->no_win_counter = 0;
g->total_weight = 0.;
g->early_stop_thres = 3;
@@ -853,7 +855,7 @@ learner* setup(vw& all, po::variables_map& vm)
if(all.initial_t > 0)//for the normalized update: if initial_t is bigger than 1 we interpret this as if we had seen (all.initial_t) previous fake datapoints all with norm 1
{
- g->normalized_sum_norm_x = all.initial_t;
+ g->all->normalized_sum_norm_x = all.initial_t;
g->total_weight = all.initial_t;
}
diff --git a/vowpalwabbit/gd.h b/vowpalwabbit/gd.h
index 05bd5b5d..be52e62f 100644
--- a/vowpalwabbit/gd.h
+++ b/vowpalwabbit/gd.h
@@ -26,6 +26,7 @@ namespace GD{
void train_one_example_single_thread(regressor& r, example* ex);
LEARNER::learner* setup(vw& all, po::variables_map& vm);
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text);
+ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text);
void output_and_account_example(example* ec);
// iterate through one namespace (or its part), callback function T(some_data_R, feature_value_x, feature_weight)
diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h
index dd8567a8..c90ae05b 100644
--- a/vowpalwabbit/global_data.h
+++ b/vowpalwabbit/global_data.h
@@ -197,6 +197,7 @@ struct vw {
int m;
bool save_resume;
+ double normalized_sum_norm_x;
po::options_description opts;
std::string file_options;
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index efa0961e..d417ca0c 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -35,6 +35,7 @@ license as described in the file LICENSE.
#include "gd_mf.h"
#include "mf.h"
#include "vw.h"
+#include "ftrl_proximal.h"
#include "rand48.h"
#include "parse_args.h"
#include "binary.h"
@@ -722,6 +723,7 @@ void parse_base_algorithm(vw& all, po::variables_map& vm)
base_opt.add_options()
("sgd", "use regular stochastic gradient descent update.")
+ ("ftrl", "use ftrl-proximal optimization")
("adaptive", "use adaptive, individual learning rates.")
("invariant", "use safe/importance aware updates.")
("normalized", "use per feature normalized updates")
@@ -740,6 +742,8 @@ void parse_base_algorithm(vw& all, po::variables_map& vm)
all.l = BFGS::setup(all, vm);
else if (vm.count("lda"))
all.l = LDA::setup(all, vm);
+ else if (vm.count("ftrl"))
+ all.l = FTRL::setup(all, vm);
else if (vm.count("noop"))
all.l = NOOP::setup(all);
else if (vm.count("print"))