Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'vowpalwabbit/cbify.cc')
-rw-r--r--vowpalwabbit/cbify.cc13
1 files changed, 9 insertions, 4 deletions
diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc
index 9b2e3147..05eda7d3 100644
--- a/vowpalwabbit/cbify.cc
+++ b/vowpalwabbit/cbify.cc
@@ -390,7 +390,12 @@ namespace CBIFY {
data.k = (uint32_t)vm["cbify"].as<size_t>();
*all.file_options << " --cbify " << data.k;
+ if (!vm.count("cb"))
+ vm.insert(pair<string,po::variable_value>(string("cb"),vm["cbify"]));
+ base_learner* base = setup_base(all,vm);
+
all.p->lp = MULTICLASS::mc_label;
+
learner<cbify>* l;
data.recorder.reset(new vw_recorder());
data.mwt_explorer.reset(new MwtExplorer<vw_context>("vw", *data.recorder.get()));
@@ -405,7 +410,7 @@ namespace CBIFY {
epsilon = vm["epsilon"].as<float>();
data.scorer.reset(new vw_cover_scorer(epsilon, cover, (u32)data.k));
data.generic_explorer.reset(new GenericExplorer<vw_context>(*data.scorer.get(), (u32)data.k));
- l = &init_learner(&data, all.l, predict_or_learn_cover<true>,
+ l = &init_learner(&data, base, predict_or_learn_cover<true>,
predict_or_learn_cover<false>, cover + 1);
}
else if (vm.count("bag"))
@@ -416,7 +421,7 @@ namespace CBIFY {
data.policies.push_back(unique_ptr<IPolicy<vw_context>>(new vw_policy(i)));
}
data.bootstrap_explorer.reset(new BootstrapExplorer<vw_context>(data.policies, (u32)data.k));
- l = &init_learner(&data, all.l, predict_or_learn_bag<true>,
+ l = &init_learner(&data, base, predict_or_learn_bag<true>,
predict_or_learn_bag<false>, bags);
}
else if (vm.count("first") )
@@ -424,7 +429,7 @@ namespace CBIFY {
uint32_t tau = (uint32_t)vm["first"].as<size_t>();
data.policy.reset(new vw_policy());
data.tau_explorer.reset(new TauFirstExplorer<vw_context>(*data.policy.get(), (u32)tau, (u32)data.k));
- l = &init_learner(&data, all.l, predict_or_learn_first<true>,
+ l = &init_learner(&data, base, predict_or_learn_first<true>,
predict_or_learn_first<false>, 1);
}
else
@@ -434,7 +439,7 @@ namespace CBIFY {
epsilon = vm["epsilon"].as<float>();
data.policy.reset(new vw_policy());
data.greedy_explorer.reset(new EpsilonGreedyExplorer<vw_context>(*data.policy.get(), epsilon, (u32)data.k));
- l = &init_learner(&data, all.l, predict_or_learn_greedy<true>,
+ l = &init_learner(&data, base, predict_or_learn_greedy<true>,
predict_or_learn_greedy<false>, 1);
}