diff options
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r-- | vowpalwabbit/cbify.cc | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index 9b2e3147..05eda7d3 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -390,7 +390,12 @@ namespace CBIFY { data.k = (uint32_t)vm["cbify"].as<size_t>(); *all.file_options << " --cbify " << data.k; + if (!vm.count("cb")) + vm.insert(pair<string,po::variable_value>(string("cb"),vm["cbify"])); + base_learner* base = setup_base(all,vm); + all.p->lp = MULTICLASS::mc_label; + learner<cbify>* l; data.recorder.reset(new vw_recorder()); data.mwt_explorer.reset(new MwtExplorer<vw_context>("vw", *data.recorder.get())); @@ -405,7 +410,7 @@ namespace CBIFY { epsilon = vm["epsilon"].as<float>(); data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k)); data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k)); - l = &init_learner(&data, all.l, predict_or_learn_cover<true>, + l = &init_learner(&data, base, predict_or_learn_cover<true>, predict_or_learn_cover<false>, cover + 1); } else if (vm.count("bag")) @@ -416,7 +421,7 @@ namespace CBIFY { data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i))); } data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k)); - l = &init_learner(&data, all.l, predict_or_learn_bag<true>, + l = &init_learner(&data, base, predict_or_learn_bag<true>, predict_or_learn_bag<false>, bags); } else if (vm.count("first") ) @@ -424,7 +429,7 @@ namespace CBIFY { uint32_t tau = (uint32_t)vm["first"].as<size_t>(); data.policy.reset(new vw_policy()); data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k)); - l = &init_learner(&data, all.l, predict_or_learn_first<true>, + l = &init_learner(&data, base, predict_or_learn_first<true>, predict_or_learn_first<false>, 1); } else @@ -434,7 +439,7 @@ namespace CBIFY { epsilon = vm["epsilon"].as<float>(); data.policy.reset(new vw_policy()); data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k)); - l = &init_learner(&data, all.l, predict_or_learn_greedy<true>, + l = &init_learner(&data, base, predict_or_learn_greedy<true>, predict_or_learn_greedy<false>, 1); } |