Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorariel faigon <github.2009@yendor.com>2013-12-26 10:31:14 +0400
committerariel faigon <github.2009@yendor.com>2013-12-26 10:31:14 +0400
commit1097df5bffa4f9cf809c5cfc2f7637eb88275b0b (patch)
tree8ff8c5f8462dd1eb9911fb4606fdaa7b63ed7326 /vowpalwabbit/parse_args.cc
parent706b550c88d4bc902ca4e01bab42e15178e4ec0b (diff)
parent7c89547c50cdb9a4d876debef2aee7fa4828e3c6 (diff)
Merge branch 'master' of git://github.com/JohnLangford/vowpal_wabbit
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r--vowpalwabbit/parse_args.cc22
1 files changed, 22 insertions, 0 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index 3cec507a..aad2eae6 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -15,6 +15,7 @@ license as described in the file LICENSE.
#include "network.h"
#include "global_data.h"
#include "nn.h"
+#include "cbify.h"
#include "oaa.h"
#include "bs.h"
#include "topk.h"
@@ -245,6 +246,7 @@ vw* parse_args(int argc, char *argv[])
("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs")
("lda", po::value<size_t>(&(all->lda)), "Run lda with <int> topics")
("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units")
+ ("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve")
("searn", po::value<size_t>(), "use searn, argument=maximum action id or 0 for LDF")
;
@@ -863,6 +865,26 @@ vw* parse_args(int argc, char *argv[])
got_cb = true;
}
+ if (vm.count("cbify") || vm_file.count("cbify"))
+ {
+ if(!got_cs) {
+ if( vm_file.count("cbify") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["cbify"]));
+ else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cbify"]));
+
+ all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified
+ got_cs = true;
+ }
+
+ if (!got_cb) {
+ if( vm_file.count("cbify") ) vm.insert(pair<string,po::variable_value>(string("cb"),vm_file["cbify"]));
+ else vm.insert(pair<string,po::variable_value>(string("cb"),vm["cbify"]));
+ all->l = CB::setup(*all, to_pass_further, vm, vm_file);
+ got_cb = true;
+ }
+
+ all->l = CBIFY::setup(*all, to_pass_further, vm, vm_file);
+ }
+
all->searnstr = NULL;
if (vm.count("searn") || vm_file.count("searn") ) {
if (!got_cs && !got_cb) {