Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHal Daume III <me@hal3.name>2014-03-01 21:20:01 +0400
committerHal Daume III <me@hal3.name>2014-03-01 21:20:01 +0400
commitf3b36073f693621d7858f660fa88a05f2571b8e4 (patch)
tree4f4a19b14ff113fa4a29e0b2799bb8311d09b571 /vowpalwabbit/parse_args.cc
parentef827509126c45e01f72d506a499f19dff945205 (diff)
parent3e368e9b2e6e3cf01d9ab8c59127b3d1806f43bd (diff)
fixed neighbor feature auditing, and affix feature storing in file
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r--vowpalwabbit/parse_args.cc63
1 files changed, 43 insertions, 20 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index c941928b..5ebde865 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -17,6 +17,7 @@ license as described in the file LICENSE.
#include "nn.h"
#include "cbify.h"
#include "oaa.h"
+#include "rand48.h"
#include "bs.h"
#include "topk.h"
#include "ect.h"
@@ -28,6 +29,7 @@ license as described in the file LICENSE.
#include "bfgs.h"
#include "lda_core.h"
#include "noop.h"
+#include "print.h"
#include "gd_mf.h"
#include "mf.h"
#include "vw.h"
@@ -254,7 +256,7 @@ vw* parse_args(int argc, char *argv[])
("bs_type", po::value<string>(), "bootstrap mode - currently 'mean' or 'vote'")
("autolink", po::value<size_t>(), "create link function with polynomial d")
("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs")
- ("lda", po::value<size_t>(&(all->lda)), "Run lda with <int> topics")
+ ("lda", po::value<uint32_t>(&(all->lda)), "Run lda with <int> topics")
("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units")
("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve")
("searn", po::value<size_t>(), "use searn, argument=maximum action id or 0 for LDF")
@@ -265,7 +267,8 @@ vw* parse_args(int argc, char *argv[])
("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial.")
("version","Version information")
("random_seed", po::value<size_t>(&random_seed), "seed random number generator")
- ("noop","do no learning") ;
+ ("noop","do no learning")
+ ("print","print examples");
//po::positional_options_description p;
// Be friendly: if -d was left out, treat positional param as data file
@@ -338,7 +341,7 @@ vw* parse_args(int argc, char *argv[])
if (vm.count("progress")) {
string progress_str = vm["progress"].as<string>();
- all->progress_arg = ::atof(progress_str.c_str());
+ all->progress_arg = (float)::atof(progress_str.c_str());
// --progress interval is dual: either integer or floating-point
if (progress_str.find_first_of(".") == string::npos) {
@@ -415,7 +418,8 @@ vw* parse_args(int argc, char *argv[])
if( all->normalized_updates ) all->reg.stride *= 2;
if(!vm.count("learning_rate") && !vm.count("l") && !(all->adaptive && all->normalized_updates))
- all->eta = 10; //default learning rate to 10 for non default update rule
+ if (all->lda == 0)
+ all->eta = 10; //default learning rate to 10 for non default update rule
//if not using normalized or adaptive, default initial_t to 1 instead of 0
if(!all->adaptive && !all->normalized_updates && !vm.count("initial_t")) {
@@ -488,14 +492,10 @@ vw* parse_args(int argc, char *argv[])
compile_gram(all->skip_strings, all->skips, (char*)"skips", all->quiet);
}
- if (vm.count("affix")) {
- parse_affix_argument(*all, vm["affix"].as<string>());
- }
-
if (vm.count("spelling")) {
vector<string> spelling_ns = vm["spelling"].as< vector<string> >();
for (size_t id=0; id<spelling_ns.size(); id++)
- if (spelling_ns[id][0] == '_') all->spelling_features[' '] = true;
+ if (spelling_ns[id][0] == '_') all->spelling_features[(unsigned char)' '] = true;
else all->spelling_features[(size_t)spelling_ns[id][0]] = true;
}
@@ -625,7 +625,7 @@ vw* parse_args(int argc, char *argv[])
for (size_t i = 0; i < 256; i++)
all->ignore[i] = false;
all->ignore_some = false;
-
+
if (vm.count("ignore"))
{
all->ignore_some = true;
@@ -749,6 +749,12 @@ vw* parse_args(int argc, char *argv[])
if (vm.count("noop"))
all->l = NOOP::setup(*all);
+ if (vm.count("print"))
+ {
+ all->l = PRINT::setup(*all);
+ all->reg.stride = 1;
+ }
+
if (!vm.count("new_mf") && all->rank > 0)
all->l = GDMF::setup(*all);
@@ -853,6 +859,9 @@ vw* parse_args(int argc, char *argv[])
if(vm.count("autolink") || vm_file.count("autolink") )
all->l = ALINK::setup(*all, to_pass_further, vm, vm_file);
+ if (vm.count("lrq") || vm_file.count("lrq"))
+ all->l = LRQ::setup(*all, to_pass_further, vm, vm_file);
+
all->l = Scorer::setup(*all, to_pass_further, vm, vm_file);
if(vm.count("top") || vm_file.count("top") )
@@ -861,9 +870,6 @@ vw* parse_args(int argc, char *argv[])
if (vm.count("binary") || vm_file.count("binary"))
all->l = BINARY::setup(*all, to_pass_further, vm, vm_file);
- if (vm.count("lrq") || vm_file.count("lrq"))
- all->l = LRQ::setup(*all, to_pass_further, vm, vm_file);
-
if(vm.count("oaa") || vm_file.count("oaa") ) {
if (got_mc) { cerr << "error: cannot specify multiple MC learners" << endl; throw exception(); }
@@ -882,6 +888,7 @@ vw* parse_args(int argc, char *argv[])
if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); }
all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file);
+ all->cost_sensitive = all->l;
got_cs = true;
}
@@ -889,6 +896,7 @@ vw* parse_args(int argc, char *argv[])
if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); }
all->l = WAP::setup(*all, to_pass_further, vm, vm_file);
+ all->cost_sensitive = all->l;
got_cs = true;
}
@@ -896,6 +904,7 @@ vw* parse_args(int argc, char *argv[])
if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); }
all->l = CSOAA_AND_WAP_LDF::setup(*all, to_pass_further, vm, vm_file);
+ all->cost_sensitive = all->l;
got_cs = true;
}
@@ -903,6 +912,7 @@ vw* parse_args(int argc, char *argv[])
if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); }
all->l = CSOAA_AND_WAP_LDF::setup(*all, to_pass_further, vm, vm_file);
+ all->cost_sensitive = all->l;
got_cs = true;
}
@@ -913,6 +923,7 @@ vw* parse_args(int argc, char *argv[])
else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cb"]));
all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified
+ all->cost_sensitive = all->l;
got_cs = true;
}
@@ -927,6 +938,7 @@ vw* parse_args(int argc, char *argv[])
else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cbify"]));
all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified
+ all->cost_sensitive = all->l;
got_cs = true;
}
@@ -940,12 +952,27 @@ vw* parse_args(int argc, char *argv[])
all->l = CBIFY::setup(*all, to_pass_further, vm, vm_file);
}
+
+ if (vm_file.count("affix") && vm.count("affix")) {
+ cerr << "should not specify --affix when loading a model trained with affix features (they're turned on by default)" << endl;
+ throw exception();
+ }
+ if (vm_file.count("affix"))
+ parse_affix_argument(*all, vm_file["affix"].as<string>());
+ if (vm.count("affix")) {
+ parse_affix_argument(*all, vm["affix"].as<string>());
+ stringstream ss;
+ ss << " --affix " << vm["affix"].as<string>();
+ all->options_from_file.append(ss.str());
+ }
+
if (vm.count("searn") || vm_file.count("searn") ) {
if (!got_cs && !got_cb) {
if( vm_file.count("searn") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["searn"]));
else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["searn"]));
all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless others have been specified
+ all->cost_sensitive = all->l;
got_cs = true;
}
//all->searnstr = (Searn::searn*)calloc(1, sizeof(Searn::searn));
@@ -1066,8 +1093,8 @@ namespace VW {
char** argv = get_argv_from_string(s,argc);
vw* all = parse_args(argc, argv);
-
- initialize_examples(*all);
+
+ initialize_parser_datastructures(*all);
for(int i = 0; i < argc; i++)
free(argv[i]);
@@ -1094,11 +1121,7 @@ namespace VW {
free(all.options_from_file_argv);
for (size_t i = 0; i < all.final_prediction_sink.size(); i++)
if (all.final_prediction_sink[i] != 1)
-#ifdef _WIN32
- _close(all.final_prediction_sink[i]);
-#else
- close(all.final_prediction_sink[i]);
-#endif
+ io_buf::close_file_or_socket(all.final_prediction_sink[i]);
all.final_prediction_sink.delete_v();
delete all.loss;
delete &all;