Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-12-31 20:34:33 +0300
committerJohn Langford <jl@hunch.net>2014-12-31 20:34:33 +0300
commitfc8034d53385a87d34cad922cfbb238a512fe664 (patch)
treea9e04cca8fa80794b37d2e6e88b7b84417f25e13
parent1149962d20dac0b65c07576aeb2f35c75af69193 (diff)
fixing irregularities in parsing order
-rw-r--r--vowpalwabbit/ftrl_proximal.cc32
-rw-r--r--vowpalwabbit/parse_args.cc12
-rw-r--r--vowpalwabbit/parse_args.h2
-rw-r--r--vowpalwabbit/reductions.h1
-rw-r--r--vowpalwabbit/scorer.cc4
-rw-r--r--vowpalwabbit/search.cc6
6 files changed, 22 insertions, 35 deletions
diff --git a/vowpalwabbit/ftrl_proximal.cc b/vowpalwabbit/ftrl_proximal.cc
index 6216058b..9c5410d7 100644
--- a/vowpalwabbit/ftrl_proximal.cc
+++ b/vowpalwabbit/ftrl_proximal.cc
@@ -179,27 +179,21 @@ namespace FTRL {
base_learner* setup(vw& all, po::variables_map& vm)
{
+ po::options_description opts("FTRL options");
+ opts.add_options()
+ ("ftrl", "Follow the Regularized Leader")
+ ("ftrl_alpha", po::value<float>()->default_value(0.0), "Learning rate for FTRL-proximal optimization")
+ ("ftrl_beta", po::value<float>()->default_value(0.1), "FTRL beta")
+ ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal");
+ vm = add_options(all, opts);
+
+ if (!vm.count("ftrl"))
+ return NULL;
+
ftrl& b = calloc_or_die<ftrl>();
b.all = &all;
- b.ftrl_beta = 0.0;
- b.ftrl_alpha = 0.1;
-
- po::options_description ftrl_opts("FTRL options");
-
- ftrl_opts.add_options()
- ("ftrl_alpha", po::value<float>(&(b.ftrl_alpha)), "Learning rate for FTRL-proximal optimization")
- ("ftrl_beta", po::value<float>(&(b.ftrl_beta)), "FTRL beta")
- ("progressive_validation", po::value<string>()->default_value("ftrl.evl"), "File to record progressive validation for ftrl-proximal");
-
- vm = add_options(all, ftrl_opts);
-
- if (vm.count("ftrl_alpha")) {
- b.ftrl_alpha = vm["ftrl_alpha"].as<float>();
- }
-
- if (vm.count("ftrl_beta")) {
- b.ftrl_beta = vm["ftrl_beta"].as<float>();
- }
+ b.ftrl_beta = vm["ftrl_beta"].as<float>();
+ b.ftrl_alpha = vm["ftrl_alpha"].as<float>();
all.reg.stride_shift = 2; // NOTE: for more parameter storage
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index 715c9fca..2ff0fd4b 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -715,14 +715,6 @@ void parse_output_model(vw& all, po::variables_map& vm)
all.save_resume = true;
}
-void parse_base_algorithm(vw& all, po::variables_map& vm)
-{
- // all.l = GD::setup(all, vm);
- all.scorer = all.l;
- if (vm.count("ftrl"))
- all.l = FTRL::setup(all, vm);
-}
-
void load_input_model(vw& all, po::variables_map& vm, io_buf& io_temp)
{
// Need to see if we have to load feature mask first or second.
@@ -750,7 +742,7 @@ LEARNER::base_learner* setup_base(vw& all, po::variables_map& vm)
{
LEARNER::base_learner* ret = all.reduction_stack.pop()(all,vm);
if (ret == NULL)
- return setup_next(all,vm);
+ return setup_base(all,vm);
else
return ret;
}
@@ -919,8 +911,6 @@ vw* parse_args(int argc, char *argv[])
parse_output_model(*all, vm);
- parse_base_algorithm(*all, vm);
-
if (!all->quiet)
{
cerr << "Num weight bits = " << all->num_bits << endl;
diff --git a/vowpalwabbit/parse_args.h b/vowpalwabbit/parse_args.h
index 23531050..9be0b2c5 100644
--- a/vowpalwabbit/parse_args.h
+++ b/vowpalwabbit/parse_args.h
@@ -7,4 +7,4 @@ license as described in the file LICENSE.
#include "global_data.h"
vw* parse_args(int argc, char *argv[]);
-LEARNER::base_learner* setup_next(vw& all, po::variables_map& vm);
+LEARNER::base_learner* setup_base(vw& all, po::variables_map& vm);
diff --git a/vowpalwabbit/reductions.h b/vowpalwabbit/reductions.h
index a1c18ecf..35571cb2 100644
--- a/vowpalwabbit/reductions.h
+++ b/vowpalwabbit/reductions.h
@@ -13,3 +13,4 @@ namespace po = boost::program_options;
#include "learner.h" // for core reduction definition
#include "global_data.h" // for vw datastructure
#include "memory.h"
+#include "parse_args.h"
diff --git a/vowpalwabbit/scorer.cc b/vowpalwabbit/scorer.cc
index 4d396841..76089bb8 100644
--- a/vowpalwabbit/scorer.cc
+++ b/vowpalwabbit/scorer.cc
@@ -63,6 +63,8 @@ namespace Scorer {
cerr << "Unknown link function: " << link << endl;
throw exception();
}
- return make_base(*l);
+ all.scorer = make_base(*l);
+
+ return all.scorer;
}
}
diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc
index 71f00e89..70efde7e 100644
--- a/vowpalwabbit/search.cc
+++ b/vowpalwabbit/search.cc
@@ -12,9 +12,8 @@ license as described in the file LICENSE.
#include "rand48.h"
#include "cost_sensitive.h"
#include "multiclass.h"
-#include "memory.h"
#include "constant.h"
-#include "example.h"
+#include "reductions.h"
#include "cb.h"
#include "gd.h" // for GD::foreach_feature
#include <math.h>
@@ -1987,7 +1986,8 @@ namespace Search {
vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["search"]));
base_learner* base = setup_base(all,vm);
- learner<search>& l = init_learner(&sch, all.l, search_predict_or_learn<true>,
+ learner<search>& l = init_learner(&sch, base,
+ search_predict_or_learn<true>,
search_predict_or_learn<false>,
priv.total_number_of_policies);
l.set_finish_example(finish_example);