Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/cb_algs.cc')
-rw-r--r--vowpalwabbit/cb_algs.cc42
1 files changed, 25 insertions, 17 deletions
diff --git a/vowpalwabbit/cb_algs.cc b/vowpalwabbit/cb_algs.cc
index bf9a6e5b..b14328d3 100644
--- a/vowpalwabbit/cb_algs.cc
+++ b/vowpalwabbit/cb_algs.cc
@@ -436,36 +436,35 @@ namespace CB_ALGS
VW::finish_example(all, &ec);
}
- base_learner* setup(vw& all, po::variables_map& vm)
+ base_learner* setup(vw& all)
{
+ new_options(all, "CB options")
+ ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs");
+ if (missing_required(all)) return NULL;
+ new_options(all)
+ ("cb_type", po::value<string>(), "contextual bandit method to use in {ips,dm,dr}")
+ ("eval", "Evaluate a policy rather than optimizing.");
+ add_options(all);
+
cb& c = calloc_or_die<cb>();
c.all = &all;
- uint32_t nb_actions = (uint32_t)vm["cb"].as<size_t>();
- //append cb with nb_actions to file_options so it is saved to regressor later
-
- po::options_description cb_opts("CB options");
- cb_opts.add_options()
- ("cb_type", po::value<string>(), "contextual bandit method to use in {ips,dm,dr}")
- ("eval", "Evaluate a policy rather than optimizing.")
- ;
-
- vm = add_options(all, cb_opts);
+ uint32_t nb_actions = (uint32_t)all.vm["cb"].as<size_t>();
*all.file_options << " --cb " << nb_actions;
all.sd->k = nb_actions;
bool eval = false;
- if (vm.count("eval"))
+ if (all.vm.count("eval"))
eval = true;
size_t problem_multiplier = 2;//default for DR
- if (vm.count("cb_type"))
+ if (all.vm.count("cb_type"))
{
std::string type_string;
- type_string = vm["cb_type"].as<std::string>();
+ type_string = all.vm["cb_type"].as<std::string>();
*all.file_options << " --cb_type " << type_string;
if (type_string.compare("dr") == 0)
@@ -496,6 +495,15 @@ namespace CB_ALGS
*all.file_options << " --cb_type dr";
}
+ if (count(all.args.begin(), all.args.end(),"--csoaa") == 0)
+ {
+ all.args.push_back("--csoaa");
+ stringstream ss;
+ ss << all.vm["cb"].as<size_t>();
+ all.args.push_back(ss.str());
+ }
+
+ base_learner* base = setup_base(all);
if (eval)
all.p->lp = CB_EVAL::cb_eval;
else
@@ -504,18 +512,18 @@ namespace CB_ALGS
learner<cb>* l;
if (eval)
{
- l = &init_learner(&c, all.l, learn_eval, predict_eval, problem_multiplier);
+ l = &init_learner(&c, base, learn_eval, predict_eval, problem_multiplier);
l->set_finish_example(eval_finish_example);
}
else
{
- l = &init_learner(&c, all.l, predict_or_learn<true>, predict_or_learn<false>,
+ l = &init_learner(&c, base, predict_or_learn<true>, predict_or_learn<false>,
problem_multiplier);
l->set_finish_example(finish_example);
}
// preserve the increment of the base learner since we are
// _adding_ to the number of problems rather than multiplying.
- l->increment = all.l->increment;
+ l->increment = base->increment;
l->set_init_driver(init_driver);
l->set_finish(finish);