diff options
author | ariel faigon <github.2009@yendor.com> | 2013-12-26 10:31:14 +0400 |
---|---|---|
committer | ariel faigon <github.2009@yendor.com> | 2013-12-26 10:31:14 +0400 |
commit | 1097df5bffa4f9cf809c5cfc2f7637eb88275b0b (patch) | |
tree | 8ff8c5f8462dd1eb9911fb4606fdaa7b63ed7326 /vowpalwabbit/parse_args.cc | |
parent | 706b550c88d4bc902ca4e01bab42e15178e4ec0b (diff) | |
parent | 7c89547c50cdb9a4d876debef2aee7fa4828e3c6 (diff) |
Merge branch 'master' of git://github.com/JohnLangford/vowpal_wabbit
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r-- | vowpalwabbit/parse_args.cc | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index 3cec507a..aad2eae6 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -15,6 +15,7 @@ license as described in the file LICENSE. #include "network.h" #include "global_data.h" #include "nn.h" +#include "cbify.h" #include "oaa.h" #include "bs.h" #include "topk.h" @@ -245,6 +246,7 @@ vw* parse_args(int argc, char *argv[]) ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs") ("lda", po::value<size_t>(&(all->lda)), "Run lda with <int> topics") ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units") + ("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve") ("searn", po::value<size_t>(), "use searn, argument=maximum action id or 0 for LDF") ; @@ -863,6 +865,26 @@ vw* parse_args(int argc, char *argv[]) got_cb = true; } + if (vm.count("cbify") || vm_file.count("cbify")) + { + if(!got_cs) { + if( vm_file.count("cbify") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["cbify"])); + else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cbify"])); + + all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified + got_cs = true; + } + + if (!got_cb) { + if( vm_file.count("cbify") ) vm.insert(pair<string,po::variable_value>(string("cb"),vm_file["cbify"])); + else vm.insert(pair<string,po::variable_value>(string("cb"),vm["cbify"])); + all->l = CB::setup(*all, to_pass_further, vm, vm_file); + got_cb = true; + } + + all->l = CBIFY::setup(*all, to_pass_further, vm, vm_file); + } + all->searnstr = NULL; if (vm.count("searn") || vm_file.count("searn") ) { if (!got_cs && !got_cb) { |