diff options
Diffstat (limited to 'vowpalwabbit/search.cc')
-rw-r--r-- | vowpalwabbit/search.cc | 108 |
1 files changed, 62 insertions, 46 deletions
diff --git a/vowpalwabbit/search.cc b/vowpalwabbit/search.cc index fb14e65c..74444206 100644 --- a/vowpalwabbit/search.cc +++ b/vowpalwabbit/search.cc @@ -12,9 +12,8 @@ license as described in the file LICENSE. #include "rand48.h" #include "cost_sensitive.h" #include "multiclass.h" -#include "memory.h" #include "constant.h" -#include "example.h" +#include "reductions.h" #include "cb.h" #include "gd.h" // for GD::foreach_feature #include <math.h> @@ -1669,14 +1668,14 @@ namespace Search { ret = false; } - void handle_condition_options(vw& vw, auto_condition_settings& acset, po::variables_map& vm) { - po::options_description condition_options("Search Auto-conditioning Options"); - condition_options.add_options() - ("search_max_bias_ngram_length", po::value<size_t>(), "add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 (default), then you get a single feature for each conditional") - ("search_max_quad_ngram_length", po::value<size_t>(), "add bias *times* input features for each ngram up to and including this length (def: 0)") - ("search_condition_feature_value", po::value<float> (), "how much weight should the conditional features get? (def: 1.)"); + void handle_condition_options(vw& vw, auto_condition_settings& acset) { + new_options(vw, "Search Auto-conditioning Options") + ("search_max_bias_ngram_length", po::value<size_t>(), "add a \"bias\" feature for each ngram up to and including this length. eg., if it's 1 (default), then you get a single feature for each conditional") + ("search_max_quad_ngram_length", po::value<size_t>(), "add bias *times* input features for each ngram up to and including this length (def: 0)") + ("search_condition_feature_value", po::value<float> (), "how much weight should the conditional features get? (def: 1.)"); + add_options(vw); - vm = add_options(vw, condition_options); + po::variables_map& vm = vw.vm; check_option<size_t>(acset.max_bias_ngram_length, vw, vm, "search_max_bias_ngram_length", false, size_equal, "warning: you specified a different value for --search_max_bias_ngram_length than the one loaded from regressor. proceeding with loaded value: ", ""); @@ -1764,38 +1763,39 @@ namespace Search { delete[] cstr; } - base_learner* setup(vw&all, po::variables_map& vm) { - search& sch = calloc_or_die<search>(); - sch.priv = new search_private(); - search_initialize(&all, sch); - search_private& priv = *sch.priv; - - po::options_description search_opts("Search Options"); - search_opts.add_options() - ("search_task", po::value<string>(), "the search task (use \"--search_task list\" to get a list of available tasks)") - ("search_interpolation", po::value<string>(), "at what level should interpolation happen? [*data|policy]") - ("search_rollout", po::value<string>(), "how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]") - ("search_rollin", po::value<string>(), "how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]") - - ("search_passes_per_policy", po::value<size_t>(), "number of passes per policy (only valid for search_interpolation=policy) [def=1]") - ("search_beta", po::value<float>(), "interpolation rate for policies (only valid for search_interpolation=policy) [def=0.5]") - - ("search_alpha", po::value<float>(), "annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data) [def=1e-10]") - - ("search_total_nb_policies", po::value<size_t>(), "if we are going to train the policies through multiple separate calls to vw, we need to specify this parameter and tell vw how many policies are eventually going to be trained") - - ("search_trained_nb_policies", po::value<size_t>(), "the number of trained policies in a file") - - ("search_allowed_transitions",po::value<string>(),"read file of allowed transitions [def: all transitions are allowed]") - ("search_subsample_time", po::value<float>(), "instead of training at all timesteps, use a subset. if value in (0,1), train on a random v%. if v>=1, train on precisely v steps per example") - ("search_neighbor_features", po::value<string>(), "copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line namespace a and next next line from namespace _unnamed_, where ',' separates them") - ("search_rollout_num_steps", po::value<size_t>(), "how many calls of \"loss\" before we stop really predicting on rollouts and switch to oracle (def: 0 means \"infinite\")") - ("search_history_length", po::value<size_t>(), "some tasks allow you to specify how much history their depend on; specify that here [def: 1]") - - ("search_no_caching", "turn off the built-in caching ability (makes things slower, but technically more safe)") - ("search_beam", po::value<size_t>(), "use beam search (arg = beam size, default 0 = no beam)") - ("search_kbest", po::value<size_t>(), "size of k-best list to produce (must be <= beam size)") - ; + base_learner* setup(vw&all) { + new_options(all,"Search Options") + ("search", po::value<size_t>(), "use search-based structured prediction, argument=maximum action id or 0 for LDF"); + if (missing_required(all)) return NULL; + new_options(all) + ("search_task", po::value<string>(), "the search task (use \"--search_task list\" to get a list of available tasks)") + ("search_interpolation", po::value<string>(), "at what level should interpolation happen? [*data|policy]") + ("search_rollout", po::value<string>(), "how should rollouts be executed? [policy|oracle|*mix_per_state|mix_per_roll|none]") + ("search_rollin", po::value<string>(), "how should past trajectories be generated? [policy|oracle|*mix_per_state|mix_per_roll]") + + ("search_passes_per_policy", po::value<size_t>(), "number of passes per policy (only valid for search_interpolation=policy) [def=1]") + ("search_beta", po::value<float>(), "interpolation rate for policies (only valid for search_interpolation=policy) [def=0.5]") + + ("search_alpha", po::value<float>(), "annealed beta = 1-(1-alpha)^t (only valid for search_interpolation=data) [def=1e-10]") + + ("search_total_nb_policies", po::value<size_t>(), "if we are going to train the policies through multiple separate calls to vw, we need to specify this parameter and tell vw how many policies are eventually going to be trained") + + ("search_trained_nb_policies", po::value<size_t>(), "the number of trained policies in a file") + + ("search_allowed_transitions",po::value<string>(),"read file of allowed transitions [def: all transitions are allowed]") + ("search_subsample_time", po::value<float>(), "instead of training at all timesteps, use a subset. if value in (0,1), train on a random v%. if v>=1, train on precisely v steps per example") + ("search_neighbor_features", po::value<string>(), "copy features from neighboring lines. argument looks like: '-1:a,+2' meaning copy previous line namespace a and next next line from namespace _unnamed_, where ',' separates them") + ("search_rollout_num_steps", po::value<size_t>(), "how many calls of \"loss\" before we stop really predicting on rollouts and switch to oracle (def: 0 means \"infinite\")") + ("search_history_length", po::value<size_t>(), "some tasks allow you to specify how much history their depend on; specify that here [def: 1]") + + ("search_no_caching", "turn off the built-in caching ability (makes things slower, but technically more safe)") + ("search_beam", po::value<size_t>(), "use beam search (arg = beam size, default 0 = no beam)") + ("search_kbest", po::value<size_t>(), "size of k-best list to produce (must be <= beam size)") + ; + add_options(all); + po::variables_map& vm = all.vm; + if (!vm.count("search")) + return NULL; bool has_hook_task = false; for (size_t i=0; i<all.args.size()-1; i++) @@ -1805,9 +1805,12 @@ namespace Search { for (int i = (int)all.args.size()-2; i >= 0; i--) if (all.args[i] == "--search_task" && all.args[i+1] != "hook") all.args.erase(all.args.begin() + i, all.args.begin() + i + 2); + + search& sch = calloc_or_die<search>(); + sch.priv = new search_private(); + search_initialize(&all, sch); + search_private& priv = *sch.priv; - vm = add_options(all, search_opts); - std::string task_string; std::string interpolation_string = "data"; std::string rollout_string = "mix_per_state"; @@ -1955,6 +1958,18 @@ namespace Search { } all.p->emptylines_separate_examples = true; + if (count(all.args.begin(), all.args.end(),"--csoaa") == 0 + && count(all.args.begin(), all.args.end(),"--csoaa_ldf") == 0 + && count(all.args.begin(), all.args.end(),"--wap_ldf") == 0 + && count(all.args.begin(), all.args.end(),"--cb") == 0) + { + all.args.push_back("--csoaa"); + stringstream ss; + ss << vm["search"].as<size_t>(); + all.args.push_back(ss.str()); + } + base_learner* base = setup_base(all); + // default to OAA labels unless the task wants to override this (which they can do in initialize) all.p->lp = MC::mc_label; if (priv.task) @@ -1964,7 +1979,7 @@ namespace Search { // set up auto-history if they want it if (priv.auto_condition_features) { - handle_condition_options(all, priv.acset, vm); + handle_condition_options(all, priv.acset); // turn off auto-condition if it's irrelevant if (((priv.acset.max_bias_ngram_length == 0) && (priv.acset.max_quad_ngram_length == 0)) || @@ -1981,7 +1996,8 @@ namespace Search { priv.start_clock_time = clock(); - learner<search>& l = init_learner(&sch, all.l, search_predict_or_learn<true>, + learner<search>& l = init_learner(&sch, base, + search_predict_or_learn<true>, search_predict_or_learn<false>, priv.total_number_of_policies); l.set_finish_example(finish_example); @@ -2063,7 +2079,7 @@ namespace Search { void search::set_num_learners(size_t num_learners) { this->priv->num_learners = num_learners; } - void search::add_program_options(po::variables_map& vm, po::options_description& opts) { vm = add_options( *this->priv->all, opts ); } + void search::add_program_options(po::variables_map& vw, po::options_description& opts) { add_options( *this->priv->all, opts ); } size_t search::get_mask() { return this->priv->all->reg.weight_mask;} size_t search::get_stride_shift() { return this->priv->all->reg.stride_shift;} |