diff options
Diffstat (limited to 'vowpalwabbit/cb_algs.cc')
-rw-r--r-- | vowpalwabbit/cb_algs.cc | 42 |
1 files changed, 25 insertions, 17 deletions
diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc index bf9a6e5b..b14328d3 100644 --- a/vowpalwabbit/cb_algs.cc +++ b/vowpalwabbit/cb_algs.cc @@ -436,36 +436,35 @@ namespace CB_ALGS VW::finish_example(all, &ec); } - base_learner* setup(vw& all, po::variables_map& vm) + base_learner* setup(vw& all) { + new_options(all, "CB options") + ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs"); + if (missing_required(all)) return NULL; + new_options(all) + ("cb_type", po::value<string>(), "contextual bandit method to use in {ips,dm,dr}") + ("eval", "Evaluate a policy rather than optimizing."); + add_options(all); + cb& c = calloc_or_die<cb>(); c.all = &all; - uint32_t nb_actions = (uint32_t)vm["cb"].as<size_t>(); - //append cb with nb_actions to file_options so it is saved to regressor later - - po::options_description cb_opts("CB options"); - cb_opts.add_options() - ("cb_type", po::value<string>(), "contextual bandit method to use in {ips,dm,dr}") - ("eval", "Evaluate a policy rather than optimizing.") - ; - - vm = add_options(all, cb_opts); + uint32_t nb_actions = (uint32_t)all.vm["cb"].as<size_t>(); *all.file_options << " --cb " << nb_actions; all.sd->k = nb_actions; bool eval = false; - if (vm.count("eval")) + if (all.vm.count("eval")) eval = true; size_t problem_multiplier = 2;//default for DR - if (vm.count("cb_type")) + if (all.vm.count("cb_type")) { std::string type_string; - type_string = vm["cb_type"].as<std::string>(); + type_string = all.vm["cb_type"].as<std::string>(); *all.file_options << " --cb_type " << type_string; if (type_string.compare("dr") == 0) @@ -496,6 +495,15 @@ namespace CB_ALGS *all.file_options << " --cb_type dr"; } + if (count(all.args.begin(), all.args.end(),"--csoaa") == 0) + { + all.args.push_back("--csoaa"); + stringstream ss; + ss << all.vm["cb"].as<size_t>(); + all.args.push_back(ss.str()); + } + + base_learner* base = setup_base(all); if (eval) all.p->lp = CB_EVAL::cb_eval; else @@ -504,18 +512,18 @@ namespace CB_ALGS learner<cb>* l; if (eval) { - l = &init_learner(&c, all.l, learn_eval, predict_eval, problem_multiplier); + l = &init_learner(&c, base, learn_eval, predict_eval, problem_multiplier); l->set_finish_example(eval_finish_example); } else { - l = &init_learner(&c, all.l, predict_or_learn<true>, predict_or_learn<false>, + l = &init_learner(&c, base, predict_or_learn<true>, predict_or_learn<false>, problem_multiplier); l->set_finish_example(finish_example); } // preserve the increment of the base learner since we are // _adding_ to the number of problems rather than multiplying. - l->increment = all.l->increment; + l->increment = base->increment; l->set_init_driver(init_driver); l->set_finish(finish); |