diff options
author | John Langford <jl@hunch.net> | 2014-12-31 20:34:33 +0300 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-12-31 20:34:33 +0300 |
commit | fc8034d53385a87d34cad922cfbb238a512fe664 (patch) | |
tree | a9e04cca8fa80794b37d2e6e88b7b84417f25e13 | |
parent | 1149962d20dac0b65c07576aeb2f35c75af69193 (diff) |
fixing irregularities in parsing order
-rw-r--r-- | vowpalwabbit/ftrl_proximal.cc | 32 | ||||
-rw-r--r-- | vowpalwabbit/parse_args.cc | 12 | ||||
-rw-r--r-- | vowpalwabbit/parse_args.h | 2 | ||||
-rw-r--r-- | vowpalwabbit/reductions.h | 1 | ||||
-rw-r--r-- | vowpalwabbit/scorer.cc | 4 | ||||
-rw-r--r-- | vowpalwabbit/search.cc | 6 |
6 files changed, 22 insertions, 35 deletions
diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc index 6216058b..9c5410d7 100644 --- a/vowpalwabbit/ftrl_proximal.cc +++ b/vowpalwabbit/ftrl_proximal.cc @@ -179,27 +179,21 @@ namespace FTRL { base_learner* setup(vw& all, po::variables_map& vm) { + po::options_description opts("FTRL options"); + opts.add_options() + ("ftrl", "Follow the Regularized Leader") + ("ftrl_alpha", po::value<float>()->default_value(0.0), "Learning rate for FTRL-proximal optimization") + ("ftrl_beta", po::value<float>()->default_value(0.1), "FTRL beta") + ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal"); + vm = add_options(all, opts); + + if (!vm.count("ftrl")) + return NULL; + ftrl& b = calloc_or_die<ftrl>(); b.all = &all; - b.ftrl_beta = 0.0; - b.ftrl_alpha = 0.1; - - po::options_description ftrl_opts("FTRL options"); - - ftrl_opts.add_options() - ("ftrl_alpha", po::value<float>(&(b.ftrl_alpha)), "Learning rate for FTRL-proximal optimization") - ("ftrl_beta", po::value<float>(&(b.ftrl_beta)), "FTRL beta") - ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal"); - - vm = add_options(all, ftrl_opts); - - if (vm.count("ftrl_alpha")) { - b.ftrl_alpha = vm["ftrl_alpha"].as<float>(); - } - - if (vm.count("ftrl_beta")) { - b.ftrl_beta = vm["ftrl_beta"].as<float>(); - } + b.ftrl_beta = vm["ftrl_beta"].as<float>(); + b.ftrl_alpha = vm["ftrl_alpha"].as<float>(); all.reg.stride_shift = 2; // NOTE: for more parameter storage diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index 715c9fca..2ff0fd4b 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -715,14 +715,6 @@ void parse_output_model(vw& all, po::variables_map& vm) all.save_resume = true; } -void parse_base_algorithm(vw& all, po::variables_map& vm) -{ - // all.l = GD::setup(all, vm); - all.scorer = all.l; - if (vm.count("ftrl")) - all.l = FTRL::setup(all, vm); -} - void load_input_model(vw& all, po::variables_map& vm, io_buf& io_temp) { // Need to see if we have to load feature mask first or second. @@ -750,7 +742,7 @@ LEARNER::base_learner* setup_base(vw& all, po::variables_map& vm) { LEARNER::base_learner* ret = all.reduction_stack.pop()(all,vm); if (ret == NULL) - return setup_next(all,vm); + return setup_base(all,vm); else return ret; } @@ -919,8 +911,6 @@ vw* parse_args(int argc, char *argv[]) parse_output_model(*all, vm); - parse_base_algorithm(*all, vm); - if (!all->quiet) { cerr << "Num weight bits = " << all->num_bits << endl; diff --git a/vowpalwabbit/parse_args.h b/vowpalwabbit/parse_args.h index 23531050..9be0b2c5 100644 --- a/vowpalwabbit/parse_args.h +++ b/vowpalwabbit/parse_args.h @@ -7,4 +7,4 @@ license as described in the file LICENSE. #include "global_data.h" vw* parse_args(int argc, char *argv[]); -LEARNER::base_learner* setup_next(vw& all, po::variables_map& vm); +LEARNER::base_learner* setup_base(vw& all, po::variables_map& vm); diff --git a/vowpalwabbit/reductions.h b/vowpalwabbit/reductions.h index a1c18ecf..35571cb2 100644 --- a/vowpalwabbit/reductions.h +++ b/vowpalwabbit/reductions.h @@ -13,3 +13,4 @@ namespace po = boost::program_options; #include "learner.h" // for core reduction definition #include "global_data.h" // for vw datastructure #include "memory.h" +#include "parse_args.h" diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc index 4d396841..76089bb8 100644 --- a/vowpalwabbit/scorer.cc +++ b/vowpalwabbit/scorer.cc @@ -63,6 +63,8 @@ namespace Scorer { cerr << "Unknown link function: " << link << endl; throw exception(); } - return make_base(*l); + all.scorer = make_base(*l); + + return all.scorer; } } diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index 71f00e89..70efde7e 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -12,9 +12,8 @@ license as described in the file LICENSE. #include "rand48.h" #include "cost_sensitive.h" #include "multiclass.h" -#include "memory.h" #include "constant.h" -#include "example.h" +#include "reductions.h" #include "cb.h" #include "gd.h" // for GD::foreach_feature #include <math.h> @@ -1987,7 +1986,8 @@ namespace Search { vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["search"])); base_learner* base = setup_base(all,vm); - learner<search>& l = init_learner(&sch, all.l, search_predict_or_learn<true>, + learner<search>& l = init_learner(&sch, base, + search_predict_or_learn<true>, search_predict_or_learn<false>, priv.total_number_of_policies); l.set_finish_example(finish_example); |