diff options
author | Hal Daume III <me@hal3.name> | 2014-03-01 21:20:01 +0400 |
---|---|---|
committer | Hal Daume III <me@hal3.name> | 2014-03-01 21:20:01 +0400 |
commit | f3b36073f693621d7858f660fa88a05f2571b8e4 (patch) | |
tree | 4f4a19b14ff113fa4a29e0b2799bb8311d09b571 /vowpalwabbit/parse_args.cc | |
parent | ef827509126c45e01f72d506a499f19dff945205 (diff) | |
parent | 3e368e9b2e6e3cf01d9ab8c59127b3d1806f43bd (diff) |
fixed neighbor feature auditing, and affix feature storing in file
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r-- | vowpalwabbit/parse_args.cc | 63 |
1 files changed, 43 insertions, 20 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index c941928b..5ebde865 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -17,6 +17,7 @@ license as described in the file LICENSE. #include "nn.h" #include "cbify.h" #include "oaa.h" +#include "rand48.h" #include "bs.h" #include "topk.h" #include "ect.h" @@ -28,6 +29,7 @@ license as described in the file LICENSE. #include "bfgs.h" #include "lda_core.h" #include "noop.h" +#include "print.h" #include "gd_mf.h" #include "mf.h" #include "vw.h" @@ -254,7 +256,7 @@ vw* parse_args(int argc, char *argv[]) ("bs_type", po::value<string>(), "bootstrap mode - currently 'mean' or 'vote'") ("autolink", po::value<size_t>(), "create link function with polynomial d") ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs") - ("lda", po::value<size_t>(&(all->lda)), "Run lda with <int> topics") + ("lda", po::value<uint32_t>(&(all->lda)), "Run lda with <int> topics") ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units") ("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve") ("searn", po::value<size_t>(), "use searn, argument=maximum action id or 0 for LDF") @@ -265,7 +267,8 @@ vw* parse_args(int argc, char *argv[]) ("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial.") ("version","Version information") ("random_seed", po::value<size_t>(&random_seed), "seed random number generator") - ("noop","do no learning") ; + ("noop","do no learning") + ("print","print examples"); //po::positional_options_description p; // Be friendly: if -d was left out, treat positional param as data file @@ -338,7 +341,7 @@ vw* parse_args(int argc, char *argv[]) if (vm.count("progress")) { string progress_str = vm["progress"].as<string>(); - all->progress_arg = ::atof(progress_str.c_str()); + all->progress_arg = (float)::atof(progress_str.c_str()); // --progress interval is dual: either integer or floating-point if (progress_str.find_first_of(".") == string::npos) { @@ -415,7 +418,8 @@ vw* parse_args(int argc, char *argv[]) if( all->normalized_updates ) all->reg.stride *= 2; if(!vm.count("learning_rate") && !vm.count("l") && !(all->adaptive && all->normalized_updates)) - all->eta = 10; //default learning rate to 10 for non default update rule + if (all->lda == 0) + all->eta = 10; //default learning rate to 10 for non default update rule //if not using normalized or adaptive, default initial_t to 1 instead of 0 if(!all->adaptive && !all->normalized_updates && !vm.count("initial_t")) { @@ -488,14 +492,10 @@ vw* parse_args(int argc, char *argv[]) compile_gram(all->skip_strings, all->skips, (char*)"skips", all->quiet); } - if (vm.count("affix")) { - parse_affix_argument(*all, vm["affix"].as<string>()); - } - if (vm.count("spelling")) { vector<string> spelling_ns = vm["spelling"].as< vector<string> >(); for (size_t id=0; id<spelling_ns.size(); id++) - if (spelling_ns[id][0] == '_') all->spelling_features[' '] = true; + if (spelling_ns[id][0] == '_') all->spelling_features[(unsigned char)' '] = true; else all->spelling_features[(size_t)spelling_ns[id][0]] = true; } @@ -625,7 +625,7 @@ vw* parse_args(int argc, char *argv[]) for (size_t i = 0; i < 256; i++) all->ignore[i] = false; all->ignore_some = false; - + if (vm.count("ignore")) { all->ignore_some = true; @@ -749,6 +749,12 @@ vw* parse_args(int argc, char *argv[]) if (vm.count("noop")) all->l = NOOP::setup(*all); + if (vm.count("print")) + { + all->l = PRINT::setup(*all); + all->reg.stride = 1; + } + if (!vm.count("new_mf") && all->rank > 0) all->l = GDMF::setup(*all); @@ -853,6 +859,9 @@ vw* parse_args(int argc, char *argv[]) if(vm.count("autolink") || vm_file.count("autolink") ) all->l = ALINK::setup(*all, to_pass_further, vm, vm_file); + if (vm.count("lrq") || vm_file.count("lrq")) + all->l = LRQ::setup(*all, to_pass_further, vm, vm_file); + all->l = Scorer::setup(*all, to_pass_further, vm, vm_file); if(vm.count("top") || vm_file.count("top") ) @@ -861,9 +870,6 @@ vw* parse_args(int argc, char *argv[]) if (vm.count("binary") || vm_file.count("binary")) all->l = BINARY::setup(*all, to_pass_further, vm, vm_file); - if (vm.count("lrq") || vm_file.count("lrq")) - all->l = LRQ::setup(*all, to_pass_further, vm, vm_file); - if(vm.count("oaa") || vm_file.count("oaa") ) { if (got_mc) { cerr << "error: cannot specify multiple MC learners" << endl; throw exception(); } @@ -882,6 +888,7 @@ vw* parse_args(int argc, char *argv[]) if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); } all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); + all->cost_sensitive = all->l; got_cs = true; } @@ -889,6 +896,7 @@ vw* parse_args(int argc, char *argv[]) if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); } all->l = WAP::setup(*all, to_pass_further, vm, vm_file); + all->cost_sensitive = all->l; got_cs = true; } @@ -896,6 +904,7 @@ vw* parse_args(int argc, char *argv[]) if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); } all->l = CSOAA_AND_WAP_LDF::setup(*all, to_pass_further, vm, vm_file); + all->cost_sensitive = all->l; got_cs = true; } @@ -903,6 +912,7 @@ vw* parse_args(int argc, char *argv[]) if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); } all->l = CSOAA_AND_WAP_LDF::setup(*all, to_pass_further, vm, vm_file); + all->cost_sensitive = all->l; got_cs = true; } @@ -913,6 +923,7 @@ vw* parse_args(int argc, char *argv[]) else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cb"])); all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified + all->cost_sensitive = all->l; got_cs = true; } @@ -927,6 +938,7 @@ vw* parse_args(int argc, char *argv[]) else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cbify"])); all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified + all->cost_sensitive = all->l; got_cs = true; } @@ -940,12 +952,27 @@ vw* parse_args(int argc, char *argv[]) all->l = CBIFY::setup(*all, to_pass_further, vm, vm_file); } + + if (vm_file.count("affix") && vm.count("affix")) { + cerr << "should not specify --affix when loading a model trained with affix features (they're turned on by default)" << endl; + throw exception(); + } + if (vm_file.count("affix")) + parse_affix_argument(*all, vm_file["affix"].as<string>()); + if (vm.count("affix")) { + parse_affix_argument(*all, vm["affix"].as<string>()); + stringstream ss; + ss << " --affix " << vm["affix"].as<string>(); + all->options_from_file.append(ss.str()); + } + if (vm.count("searn") || vm_file.count("searn") ) { if (!got_cs && !got_cb) { if( vm_file.count("searn") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["searn"])); else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["searn"])); all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless others have been specified + all->cost_sensitive = all->l; got_cs = true; } //all->searnstr = (Searn::searn*)calloc(1, sizeof(Searn::searn)); @@ -1066,8 +1093,8 @@ namespace VW { char** argv = get_argv_from_string(s,argc); vw* all = parse_args(argc, argv); - - initialize_examples(*all); + + initialize_parser_datastructures(*all); for(int i = 0; i < argc; i++) free(argv[i]); @@ -1094,11 +1121,7 @@ namespace VW { free(all.options_from_file_argv); for (size_t i = 0; i < all.final_prediction_sink.size(); i++) if (all.final_prediction_sink[i] != 1) -#ifdef _WIN32 - _close(all.final_prediction_sink[i]); -#else - close(all.final_prediction_sink[i]); -#endif + io_buf::close_file_or_socket(all.final_prediction_sink[i]); all.final_prediction_sink.delete_v(); delete all.loss; delete &all; |