Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/search.cc')
-rw-r--r--vowpalwabbit/search.cc108
1 files changed, 62 insertions, 46 deletions
diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc
index fb14e65c..74444206 100644
--- a/vowpalwabbit/search.cc
+++ b/vowpalwabbit/search.cc
@@ -12,9 +12,8 @@ license as described in the file LICENSE.
#include "rand48.h"
#include "cost_sensitive.h"
#include "multiclass.h"
-#include "memory.h"
#include "constant.h"
-#include "example.h"
+#include "reductions.h"
#include "cb.h"
#include "gd.h" // for GD::foreach_feature
#include <math.h>
@@ -1669,14 +1668,14 @@ namespace Search {
ret = false;
}
- void handle_condition_options(vw& vw, auto_condition_settings& acset, po::variables_map& vm) {
- po::options_description condition_options("Search Auto-conditioning Options");
- condition_options.add_options()
- ("search_max_bias_ngram_length", po::value<size_t>(), "add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 (default), then you get a single feature for each conditional")
- ("search_max_quad_ngram_length", po::value<size_t>(), "add bias *times* input features for each ngram up to and including this length (def: 0)")
- ("search_condition_feature_value", po::value<float> (), "how much weight should the conditional features get? (def: 1.)");
+ void handle_condition_options(vw& vw, auto_condition_settings& acset) {
+ new_options(vw, "Search Auto-conditioning Options")
+ ("search_max_bias_ngram_length", po::value<size_t>(), "add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 (default), then you get a single feature for each conditional")
+ ("search_max_quad_ngram_length", po::value<size_t>(), "add bias *times* input features for each ngram up to and including this length (def: 0)")
+ ("search_condition_feature_value", po::value<float> (), "how much weight should the conditional features get? (def: 1.)");
+ add_options(vw);
- vm = add_options(vw, condition_options);
+ po::variables_map& vm = vw.vm;
check_option<size_t>(acset.max_bias_ngram_length, vw, vm, "search_max_bias_ngram_length", false, size_equal,
"warning: you specified a different value for --search_max_bias_ngram_length than the one loaded from regressor. proceeding with loaded value: ", "");
@@ -1764,38 +1763,39 @@ namespace Search {
delete[] cstr;
}
- base_learner* setup(vw&all, po::variables_map& vm) {
- search& sch = calloc_or_die<search>();
- sch.priv = new search_private();
- search_initialize(&all, sch);
- search_private& priv = *sch.priv;
-
- po::options_description search_opts("Search Options");
- search_opts.add_options()
- ("search_task", po::value<string>(), "the search task (use \"--search_task list\" to get a list of available tasks)")
- ("search_interpolation", po::value<string>(), "at what level should interpolation happen? [*data|policy]")
- ("search_rollout", po::value<string>(), "how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]")
- ("search_rollin", po::value<string>(), "how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]")
-
- ("search_passes_per_policy", po::value<size_t>(), "number of passes per policy (only valid for search_interpolation=policy) [def=1]")
- ("search_beta", po::value<float>(), "interpolation rate for policies (only valid for search_interpolation=policy) [def=0.5]")
-
- ("search_alpha", po::value<float>(), "annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data) [def=1e-10]")
-
- ("search_total_nb_policies", po::value<size_t>(), "if we are going to train the policies through multiple separate calls to vw, we need to specify this parameter and tell vw how many policies are eventually going to be trained")
-
- ("search_trained_nb_policies", po::value<size_t>(), "the number of trained policies in a file")
-
- ("search_allowed_transitions",po::value<string>(),"read file of allowed transitions [def: all transitions are allowed]")
- ("search_subsample_time", po::value<float>(), "instead of training at all timesteps, use a subset. if value in (0,1), train on a random v%. if v>=1, train on precisely v steps per example")
- ("search_neighbor_features", po::value<string>(), "copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line namespace a and next next line from namespace _unnamed_, where ',' separates them")
- ("search_rollout_num_steps", po::value<size_t>(), "how many calls of \"loss\" before we stop really predicting on rollouts and switch to oracle (def: 0 means \"infinite\")")
- ("search_history_length", po::value<size_t>(), "some tasks allow you to specify how much history their depend on; specify that here [def: 1]")
-
- ("search_no_caching", "turn off the built-in caching ability (makes things slower, but technically more safe)")
- ("search_beam", po::value<size_t>(), "use beam search (arg = beam size, default 0 = no beam)")
- ("search_kbest", po::value<size_t>(), "size of k-best list to produce (must be <= beam size)")
- ;
+ base_learner* setup(vw&all) {
+ new_options(all,"Search Options")
+ ("search", po::value<size_t>(), "use search-based structured prediction, argument=maximum action id or 0 for LDF");
+ if (missing_required(all)) return NULL;
+ new_options(all)
+ ("search_task", po::value<string>(), "the search task (use \"--search_task list\" to get a list of available tasks)")
+ ("search_interpolation", po::value<string>(), "at what level should interpolation happen? [*data|policy]")
+ ("search_rollout", po::value<string>(), "how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]")
+ ("search_rollin", po::value<string>(), "how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]")
+
+ ("search_passes_per_policy", po::value<size_t>(), "number of passes per policy (only valid for search_interpolation=policy) [def=1]")
+ ("search_beta", po::value<float>(), "interpolation rate for policies (only valid for search_interpolation=policy) [def=0.5]")
+
+ ("search_alpha", po::value<float>(), "annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data) [def=1e-10]")
+
+ ("search_total_nb_policies", po::value<size_t>(), "if we are going to train the policies through multiple separate calls to vw, we need to specify this parameter and tell vw how many policies are eventually going to be trained")
+
+ ("search_trained_nb_policies", po::value<size_t>(), "the number of trained policies in a file")
+
+ ("search_allowed_transitions",po::value<string>(),"read file of allowed transitions [def: all transitions are allowed]")
+ ("search_subsample_time", po::value<float>(), "instead of training at all timesteps, use a subset. if value in (0,1), train on a random v%. if v>=1, train on precisely v steps per example")
+ ("search_neighbor_features", po::value<string>(), "copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line namespace a and next next line from namespace _unnamed_, where ',' separates them")
+ ("search_rollout_num_steps", po::value<size_t>(), "how many calls of \"loss\" before we stop really predicting on rollouts and switch to oracle (def: 0 means \"infinite\")")
+ ("search_history_length", po::value<size_t>(), "some tasks allow you to specify how much history their depend on; specify that here [def: 1]")
+
+ ("search_no_caching", "turn off the built-in caching ability (makes things slower, but technically more safe)")
+ ("search_beam", po::value<size_t>(), "use beam search (arg = beam size, default 0 = no beam)")
+ ("search_kbest", po::value<size_t>(), "size of k-best list to produce (must be <= beam size)")
+ ;
+ add_options(all);
+ po::variables_map& vm = all.vm;
+ if (!vm.count("search"))
+ return NULL;
bool has_hook_task = false;
for (size_t i=0; i<all.args.size()-1; i++)
@@ -1805,9 +1805,12 @@ namespace Search {
for (int i = (int)all.args.size()-2; i >= 0; i--)
if (all.args[i] == "--search_task" && all.args[i+1] != "hook")
all.args.erase(all.args.begin() + i, all.args.begin() + i + 2);
+
+ search& sch = calloc_or_die<search>();
+ sch.priv = new search_private();
+ search_initialize(&all, sch);
+ search_private& priv = *sch.priv;
- vm = add_options(all, search_opts);
-
std::string task_string;
std::string interpolation_string = "data";
std::string rollout_string = "mix_per_state";
@@ -1955,6 +1958,18 @@ namespace Search {
}
all.p->emptylines_separate_examples = true;
+ if (count(all.args.begin(), all.args.end(),"--csoaa") == 0
+ && count(all.args.begin(), all.args.end(),"--csoaa_ldf") == 0
+ && count(all.args.begin(), all.args.end(),"--wap_ldf") == 0
+ && count(all.args.begin(), all.args.end(),"--cb") == 0)
+ {
+ all.args.push_back("--csoaa");
+ stringstream ss;
+ ss << vm["search"].as<size_t>();
+ all.args.push_back(ss.str());
+ }
+ base_learner* base = setup_base(all);
+
// default to OAA labels unless the task wants to override this (which they can do in initialize)
all.p->lp = MC::mc_label;
if (priv.task)
@@ -1964,7 +1979,7 @@ namespace Search {
// set up auto-history if they want it
if (priv.auto_condition_features) {
- handle_condition_options(all, priv.acset, vm);
+ handle_condition_options(all, priv.acset);
// turn off auto-condition if it's irrelevant
if (((priv.acset.max_bias_ngram_length == 0) && (priv.acset.max_quad_ngram_length == 0)) ||
@@ -1981,7 +1996,8 @@ namespace Search {
priv.start_clock_time = clock();
- learner<search>& l = init_learner(&sch, all.l, search_predict_or_learn<true>,
+ learner<search>& l = init_learner(&sch, base,
+ search_predict_or_learn<true>,
search_predict_or_learn<false>,
priv.total_number_of_policies);
l.set_finish_example(finish_example);
@@ -2063,7 +2079,7 @@ namespace Search {
void search::set_num_learners(size_t num_learners) { this->priv->num_learners = num_learners; }
- void search::add_program_options(po::variables_map& vm, po::options_description& opts) { vm = add_options( *this->priv->all, opts ); }
+ void search::add_program_options(po::variables_map& vw, po::options_description& opts) { add_options( *this->priv->all, opts ); }
size_t search::get_mask() { return this->priv->all->reg.weight_mask;}
size_t search::get_stride_shift() { return this->priv->all->reg.stride_shift;}