/* Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD (revised) license as described in the file LICENSE. */ #include #include #include #include "cache.h" #include "io_buf.h" #include "parse_regressor.h" #include "parser.h" #include "parse_args.h" #include "sender.h" #include "network.h" #include "global_data.h" #include "nn.h" #include "cbify.h" #include "oaa.h" #include "rand48.h" #include "bs.h" #include "topk.h" #include "ect.h" #include "csoaa.h" #include "wap.h" #include "cb.h" #include "cb_algs.h" #include "scorer.h" #include "searn.h" #include "bfgs.h" #include "lda_core.h" #include "noop.h" #include "print.h" #include "gd_mf.h" #include "mf.h" #include "vw.h" #include "rand48.h" #include "parse_args.h" #include "binary.h" #include "lrq.h" #include "autolink.h" #include "memory.h" using namespace std; // // Does string end with a certain substring? // bool ends_with(string const &fullString, string const &ending) { if (fullString.length() > ending.length()) { return (fullString.compare(fullString.length() - ending.length(), ending.length(), ending) == 0); } else { return false; } } bool valid_ns(char c) { if (c=='|'||c==':') return false; return true; } void parse_affix_argument(vw&all, string str) { if (str.length() == 0) return; char* cstr = (char*)calloc_or_die(str.length()+1, sizeof(char)); strcpy(cstr, str.c_str()); char*p = strtok(cstr, ","); while (p != 0) { char*q = p; uint16_t prefix = 1; if (q[0] == '+') { q++; } else if (q[0] == '-') { prefix = 0; q++; } if ((q[0] < '1') || (q[0] > '7')) { cerr << "malformed affix argument (length must be 1..7): " << p << endl; throw exception(); } uint16_t len = (uint16_t)(q[0] - '0'); uint16_t ns = (uint16_t)' '; // default namespace if (q[1] != 0) { if (valid_ns(q[1])) ns = (uint16_t)q[1]; else { cerr << "malformed affix argument (invalid namespace): " << p << endl; throw exception(); } if (q[2] != 0) { cerr << "malformed affix argument (too long): " << p << endl; throw exception(); } } uint16_t afx = (len << 1) | (prefix & 0x1); all.affix_features[ns] <<= 4; all.affix_features[ns] |= afx; p = strtok(NULL, ","); } free(cstr); } void parse_diagnostics(vw& all, po::variables_map& vm, po::options_description& desc, int argc) { // Begin diagnostic options if (vm.count("help") || argc == 1) { /* upon direct query for help -- spit it out to stdout */ cout << "\n" << desc << "\n"; exit(0); } if (vm.count("version")) { /* upon direct query for version -- spit it out to stdout */ cout << version.to_string() << "\n"; exit(0); } if (vm.count("quiet")) { all.quiet = true; // --quiet wins over --progress } else { all.quiet = false; if (vm.count("progress")) { string progress_str = vm["progress"].as(); all.progress_arg = (float)::atof(progress_str.c_str()); // --progress interval is dual: either integer or floating-point if (progress_str.find_first_of(".") == string::npos) { // No "." in arg: assume integer -> additive all.progress_add = true; if (all.progress_arg < 1) { cerr << "warning: additive --progress " << " can't be < 1: forcing to 1\n"; all.progress_arg = 1; } all.sd->dump_interval = all.progress_arg; } else { // A "." in arg: assume floating-point -> multiplicative all.progress_add = false; if (all.progress_arg <= 1.0) { cerr << "warning: multiplicative --progress : " << vm["progress"].as() << " is <= 1.0: adding 1.0\n"; all.progress_arg += 1.0; } else if (all.progress_arg > 9.0) { cerr << "warning: multiplicative --progress " << " is > 9.0: you probably meant to use an integer\n"; } all.sd->dump_interval = 1.0; } } } if (vm.count("audit")){ all.audit = true; } } void parse_source(vw& all, po::variables_map& vm) { //begin input source if (vm.count("no_stdin")) all.stdin_off = true; if ( (vm.count("total") || vm.count("node") || vm.count("unique_id")) && !(vm.count("total") && vm.count("node") && vm.count("unique_id")) ) { cout << "you must specificy unique_id, total, and node if you specify any" << endl; throw exception(); } if (vm.count("daemon") || vm.count("pid_file") || (vm.count("port") && !all.active) ) { all.daemon = true; // allow each child to process up to 1e5 connections all.numpasses = (size_t) 1e5; } if (vm.count("compressed")) set_compressed(all.p); if (vm.count("data")) { all.data_filename = vm["data"].as(); if (ends_with(all.data_filename, ".gz")) set_compressed(all.p); } else all.data_filename = ""; } void parse_feature_tweaks(vw& all, po::variables_map& vm, po::variables_map& vm_file) { //feature manipulation string hash_function("strings"); if(vm.count("hash")) hash_function = vm["hash"].as(); all.p->hasher = getHasher(hash_function); if (vm.count("spelling")) { vector spelling_ns = vm["spelling"].as< vector >(); for (size_t id=0; id()); if (vm.count("affix")) { parse_affix_argument(all, vm["affix"].as()); stringstream ss; ss << " --affix " << vm["affix"].as(); all.options_from_file.append(ss.str()); } if(vm.count("ngram")){ if(vm.count("sort_features")) { cerr << "ngram is incompatible with sort_features. " << endl; throw exception(); } all.ngram_strings = vm["ngram"].as< vector >(); compile_gram(all.ngram_strings, all.ngram, (char*)"grams", all.quiet); } if(vm.count("skips")) { if(!vm.count("ngram")) { cout << "You can not skip unless ngram is > 1" << endl; throw exception(); } all.skip_strings = vm["skips"].as >(); compile_gram(all.skip_strings, all.skips, (char*)"skips", all.quiet); } if (vm.count("bit_precision")) { all.default_bits = false; all.num_bits = (uint32_t)vm["bit_precision"].as< size_t>(); if (all.num_bits > min(32, sizeof(size_t)*8 - 3)) { cout << "Only " << min(32, sizeof(size_t)*8 - 3) << " or fewer bits allowed. If this is a serious limit, speak up." << endl; throw exception(); } } if (vm.count("quadratic")) { all.pairs = vm["quadratic"].as< vector >(); vector newpairs; //string tmp; char printable_start = '!'; char printable_end = '~'; int valid_ns_size = printable_end - printable_start - 1; //will skip two characters if(!all.quiet) cerr<<"creating quadratic features for pairs: "; for (vector::iterator i = all.pairs.begin(); i != all.pairs.end();i++){ if(!all.quiet){ cerr << *i << " "; if (i->length() > 2) cerr << endl << "warning, ignoring characters after the 2nd.\n"; if (i->length() < 2) { cerr << endl << "error, quadratic features must involve two sets.\n"; throw exception(); } } //-q x: if((*i)[0]!=':'&&(*i)[1]==':'){ newpairs.reserve(newpairs.size() + valid_ns_size); for (char j=printable_start; j<=printable_end; j++){ if(valid_ns(j)) newpairs.push_back(string(1,(*i)[0])+j); } } //-q :x else if((*i)[0]==':'&&(*i)[1]!=':'){ newpairs.reserve(newpairs.size() + valid_ns_size); for (char j=printable_start; j<=printable_end; j++){ if(valid_ns(j)){ stringstream ss; ss << j << (*i)[1]; newpairs.push_back(ss.str()); } } } //-q :: else if((*i)[0]==':'&&(*i)[1]==':'){ cout << "in pair creation" << endl; newpairs.reserve(newpairs.size() + valid_ns_size*valid_ns_size); for (char j=printable_start; j<=printable_end; j++){ if(valid_ns(j)){ for (char k=printable_start; k<=printable_end; k++){ if(valid_ns(k)){ stringstream ss; ss << j << k; newpairs.push_back(ss.str()); } } } } } else{ newpairs.push_back(string(*i)); } } newpairs.swap(all.pairs); if(!all.quiet) cerr< >(); if (!all.quiet) { cerr << "creating cubic features for triples: "; for (vector::iterator i = all.triples.begin(); i != all.triples.end();i++) { cerr << *i << " "; if (i->length() > 3) cerr << endl << "warning, ignoring characters after the 3rd.\n"; if (i->length() < 3) { cerr << endl << "error, cubic features must involve three sets.\n"; throw exception(); } } cerr << endl; } } for (size_t i = 0; i < 256; i++) all.ignore[i] = false; all.ignore_some = false; if (vm.count("ignore")) { all.ignore_some = true; vector ignore = vm["ignore"].as< vector >(); for (vector::iterator i = ignore.begin(); i != ignore.end();i++) { all.ignore[*i] = true; } if (!all.quiet) { cerr << "ignoring namespaces beginning with: "; for (vector::iterator i = ignore.begin(); i != ignore.end();i++) cerr << *i << " "; cerr << endl; } } if (vm.count("keep")) { for (size_t i = 0; i < 256; i++) all.ignore[i] = true; all.ignore_some = true; vector keep = vm["keep"].as< vector >(); for (vector::iterator i = keep.begin(); i != keep.end();i++) { all.ignore[*i] = false; } if (!all.quiet) { cerr << "using namespaces beginning with: "; for (vector::iterator i = keep.begin(); i != keep.end();i++) cerr << *i << " "; cerr << endl; } } if (vm.count("noconstant")) all.add_constant = false; } void parse_example_tweaks(vw& all, po::variables_map& vm) { if (vm.count("testonly") || all.eta == 0.) { if (!all.quiet) cerr << "only testing" << endl; all.training = false; if (all.lda > 0) all.eta = 0; } else all.training = true; if(all.numpasses > 1) all.holdout_set_off = false; if(vm.count("holdout_off")) all.holdout_set_off = true; if(!all.holdout_set_off && (vm.count("output_feature_regularizer_binary") || vm.count("output_feature_regularizer_text"))) { all.holdout_set_off = true; cerr<<"Making holdout_set_off=true since output regularizer specified\n"; } if(vm.count("sort_features")) all.p->sort_features = true; if (vm.count("min_prediction")) all.sd->min_label = vm["min_prediction"].as(); if (vm.count("max_prediction")) all.sd->max_label = vm["max_prediction"].as(); if (vm.count("min_prediction") || vm.count("max_prediction") || vm.count("testonly")) all.set_minmax = noop_mm; string loss_function; if(vm.count("loss_function")) loss_function = vm["loss_function"].as(); else loss_function = "squaredloss"; float loss_parameter = 0.0; if(vm.count("quantile_tau")) loss_parameter = vm["quantile_tau"].as(); all.loss = getLossFunction(&all, loss_function, (float)loss_parameter); if (all.l1_lambda < 0.) { cerr << "l1_lambda should be nonnegative: resetting from " << all.l1_lambda << " to 0" << endl; all.l1_lambda = 0.; } if (all.l2_lambda < 0.) { cerr << "l2_lambda should be nonnegative: resetting from " << all.l2_lambda << " to 0" << endl; all.l2_lambda = 0.; } all.reg_mode += (all.l1_lambda > 0.) ? 1 : 0; all.reg_mode += (all.l2_lambda > 0.) ? 2 : 0; if (!all.quiet) { if (all.reg_mode %2 && !vm.count("bfgs")) cerr << "using l1 regularization = " << all.l1_lambda << endl; if (all.reg_mode > 1) cerr << "using l2 regularization = " << all.l2_lambda << endl; } } void parse_output_preds(vw& all, po::variables_map& vm, po::variables_map& vm_file) { if (vm.count("predictions")) { if (!all.quiet) cerr << "predictions = " << vm["predictions"].as< string >() << endl; if (strcmp(vm["predictions"].as< string >().c_str(), "stdout") == 0) { all.final_prediction_sink.push_back((size_t) 1);//stdout } else { const char* fstr = (vm["predictions"].as< string >().c_str()); int f; #ifdef _WIN32 _sopen_s(&f, fstr, _O_CREAT|_O_WRONLY|_O_BINARY|_O_TRUNC, _SH_DENYWR, _S_IREAD|_S_IWRITE); #else f = open(fstr, O_CREAT|O_WRONLY|O_LARGEFILE|O_TRUNC,0666); #endif if (f < 0) cerr << "Error opening the predictions file: " << fstr << endl; all.final_prediction_sink.push_back((size_t) f); } } if (vm.count("raw_predictions")) { if (!all.quiet) { cerr << "raw predictions = " << vm["raw_predictions"].as< string >() << endl; if (vm.count("binary") || vm_file.count("binary")) cerr << "Warning: --raw has no defined value when --binary specified, expect no output" << endl; } if (strcmp(vm["raw_predictions"].as< string >().c_str(), "stdout") == 0) all.raw_prediction = 1;//stdout else { const char* t = vm["raw_predictions"].as< string >().c_str(); int f; #ifdef _WIN32 _sopen_s(&f, t, _O_CREAT|_O_WRONLY|_O_BINARY|_O_TRUNC, _SH_DENYWR, _S_IREAD|_S_IWRITE); #else f = open(t, O_CREAT|O_WRONLY|O_LARGEFILE|O_TRUNC,0666); #endif all.raw_prediction = f; } } } void parse_output_model(vw& all, po::variables_map& vm) { if (vm.count("final_regressor")) { all.final_regressor_name = vm["final_regressor"].as(); if (!all.quiet) cerr << "final_regressor = " << vm["final_regressor"].as() << endl; } else all.final_regressor_name = ""; if (vm.count("readable_model")) all.text_regressor_name = vm["readable_model"].as(); if (vm.count("invert_hash")){ all.inv_hash_regressor_name = vm["invert_hash"].as(); all.hash_inv = true; } if (vm.count("save_per_pass")) all.save_per_pass = true; if (vm.count("save_resume")) all.save_resume = true; } void parse_base_algorithm(vw& all, vector& to_pass_further, po::variables_map& vm) { //base learning algorithm. if (vm.count("bfgs") || vm.count("conjugate_gradient")) all.l = BFGS::setup(all, to_pass_further, vm); else if (vm.count("lda")) all.l = LDA::setup(all, to_pass_further, vm); else if (vm.count("noop")) all.l = NOOP::setup(all); else if (vm.count("print")) all.l = PRINT::setup(all); else if (!vm.count("new_mf") && all.rank > 0) all.l = GDMF::setup(all, vm); else if (vm.count("sendto")) all.l = SENDER::setup(all, vm, all.pairs); else { all.l = GD::setup(all, vm); all.scorer = all.l; } } void load_input_model(vw& all, po::variables_map& vm, io_buf& io_temp) { // Need to see if we have to load feature mask first or second. // -i and -mask are from same file, load -i file first so mask can use it if (vm.count("feature_mask") && vm.count("initial_regressor") && vm["feature_mask"].as() == vm["initial_regressor"].as< vector >()[0]) { // load rest of regressor all.l->save_load(io_temp, true, false); io_temp.close_file(); // set the mask, which will reuse -i file we just loaded parse_mask_regressor_args(all, vm); } else { // load mask first parse_mask_regressor_args(all, vm); // load rest of regressor all.l->save_load(io_temp, true, false); io_temp.close_file(); } } void parse_scorer_reductions(vw& all, vector& to_pass_further, po::variables_map& vm, po::variables_map vm_file) { if(vm.count("nn") || vm_file.count("nn") ) all.l = NN::setup(all, to_pass_further, vm, vm_file); if (vm.count("new_mf") && all.rank > 0) all.l = MF::setup(all, vm); if(vm.count("autolink") || vm_file.count("autolink") ) all.l = ALINK::setup(all, to_pass_further, vm, vm_file); if (vm.count("lrq") || vm_file.count("lrq")) all.l = LRQ::setup(all, to_pass_further, vm, vm_file); all.l = Scorer::setup(all, to_pass_further, vm, vm_file); } LEARNER::learner* exclusive_setup(vw& all, vector& to_pass_further, po::variables_map& vm, po::variables_map vm_file, bool& score_consumer, LEARNER::learner* (*setup)(vw&, vector&, po::variables_map&, po::variables_map&)) { if (score_consumer) { cerr << "error: cannot specify multiple direct score consumers" << endl; throw exception(); } score_consumer = true; return setup(all, to_pass_further, vm, vm_file); } void parse_score_users(vw& all, vector& to_pass_further, po::variables_map& vm, po::variables_map vm_file, bool& got_cs) { bool score_consumer = false; if(vm.count("top") || vm_file.count("top") ) all.l = exclusive_setup(all, to_pass_further, vm, vm_file, score_consumer, TOPK::setup); if (vm.count("binary") || vm_file.count("binary")) all.l = exclusive_setup(all, to_pass_further, vm, vm_file, score_consumer, BINARY::setup); if (vm.count("oaa") || vm_file.count("oaa") ) all.l = exclusive_setup(all, to_pass_further, vm, vm_file, score_consumer, OAA::setup); if (vm.count("ect") || vm_file.count("ect") ) all.l = exclusive_setup(all, to_pass_further, vm, vm_file, score_consumer, ECT::setup); if(vm.count("csoaa") || vm_file.count("csoaa") ) { all.l = exclusive_setup(all, to_pass_further, vm, vm_file, score_consumer, CSOAA::setup); all.cost_sensitive = all.l; got_cs = true; } if(vm.count("wap") || vm_file.count("wap") ) { all.l = exclusive_setup(all, to_pass_further, vm, vm_file, score_consumer, WAP::setup); all.cost_sensitive = all.l; got_cs = true; } if(vm.count("csoaa_ldf") || vm_file.count("csoaa_ldf")) { all.l = exclusive_setup(all, to_pass_further, vm, vm_file, score_consumer, CSOAA_AND_WAP_LDF::setup); all.cost_sensitive = all.l; got_cs = true; } if(vm.count("wap_ldf") || vm_file.count("wap_ldf") ) { all.l = exclusive_setup(all, to_pass_further, vm, vm_file, score_consumer, CSOAA_AND_WAP_LDF::setup); all.cost_sensitive = all.l; got_cs = true; } } void parse_cb(vw& all, vector& to_pass_further, po::variables_map& vm, po::variables_map vm_file, bool& got_cs, bool& got_cb) { if( vm.count("cb") || vm_file.count("cb") ) { if(!got_cs) { if( vm_file.count("cb") ) vm.insert(pair(string("csoaa"),vm_file["cb"])); else vm.insert(pair(string("csoaa"),vm["cb"])); all.l = CSOAA::setup(all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified all.cost_sensitive = all.l; got_cs = true; } all.l = CB_ALGS::setup(all, to_pass_further, vm, vm_file); got_cb = true; } if (vm.count("cbify") || vm_file.count("cbify")) { if(!got_cs) { if( vm_file.count("cbify") ) vm.insert(pair(string("csoaa"),vm_file["cbify"])); else vm.insert(pair(string("csoaa"),vm["cbify"])); all.l = CSOAA::setup(all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified all.cost_sensitive = all.l; got_cs = true; } if (!got_cb) { if( vm_file.count("cbify") ) vm.insert(pair(string("cb"),vm_file["cbify"])); else vm.insert(pair(string("cb"),vm["cbify"])); all.l = CB_ALGS::setup(all, to_pass_further, vm, vm_file); got_cb = true; } all.l = CBIFY::setup(all, to_pass_further, vm, vm_file); } } void parse_search(vw& all, vector& to_pass_further, po::variables_map& vm, po::variables_map vm_file, bool& got_cs, bool& got_cb) { if (vm.count("search") || vm_file.count("search") ) { if (!got_cs && !got_cb) { if( vm_file.count("search") ) vm.insert(pair(string("csoaa"),vm_file["search"])); else vm.insert(pair(string("csoaa"),vm["search"])); all.l = CSOAA::setup(all, to_pass_further, vm, vm_file); // default to CSOAA unless others have been specified all.cost_sensitive = all.l; got_cs = true; } //all.searnstr = (Searn::searn*)calloc_or_die(1, sizeof(Searn::searn)); all.l = Searn::setup(all, to_pass_further, vm, vm_file); } } vw* parse_args(int argc, char *argv[]) { po::options_description desc("VW options"); vw* all = new vw(); size_t random_seed = 0; all->program_name = argv[0]; po::options_description in_opt("Input options"); in_opt.add_options() ("data,d", po::value< string >(), "Example Set") ("ring_size", po::value(&(all->p->ring_size)), "size of example ring") ("examples", po::value(&(all->max_examples)), "number of examples to parse") ("testonly,t", "Ignore label information and just test") ("daemon", "persistent daemon mode on port 26542") ("port", po::value(),"port to listen on; use 0 to pick unused port") ("num_children", po::value(&(all->num_children)), "number of children for persistent daemon mode") ("pid_file", po::value< string >(), "Write pid file in persistent daemon mode") ("port_file", po::value< string >(), "Write port used in persistent daemon mode") ("passes", po::value(&(all->numpasses)),"Number of Training Passes") ("cache,c", "Use a cache. The default is .cache") ("cache_file", po::value< vector >(), "The location(s) of cache_file.") ("kill_cache,k", "do not reuse existing cache: create a new one always") ("compressed", "use gzip format whenever possible. If a cache file is being created, this option creates a compressed cache file. A mixture of raw-text & compressed inputs are supported with autodetection.") ("no_stdin", "do not default to reading from stdin") ("save_resume", "save extra state so learning can be resumed later with new data") ; po::options_description out_opt("Output options"); out_opt.add_options() ("audit,a", "print weights of features") ("predictions,p", po::value< string >(), "File to output predictions to") ("raw_predictions,r", po::value< string >(), "File to output unnormalized predictions to") ("sendto", po::value< vector >(), "send examples to ") ("quiet", "Don't output disgnostics and progress updates") ("progress,P", po::value< string >(), "Progress update frequency. int: additive, float: multiplicative") ("binary", "report loss as binary classification on -1,1") ("min_prediction", po::value(&(all->sd->min_label)), "Smallest prediction to output") ("max_prediction", po::value(&(all->sd->max_label)), "Largest prediction to output") ; po::options_description update_opt("Update options"); update_opt.add_options() ("sgd", "use regular stochastic gradient descent update.") ("hessian_on", "use second derivative in line search") ("bfgs", "use bfgs optimization") ("mem", po::value(&(all->m)), "memory in bfgs") ("termination", po::value(&(all->rel_threshold)),"Termination threshold") ("adaptive", "use adaptive, individual learning rates.") ("invariant", "use safe/importance aware updates.") ("normalized", "use per feature normalized updates") ("exact_adaptive_norm", "use current default invariant normalized adaptive update rule") ("conjugate_gradient", "use conjugate gradient based optimization") ("l1", po::value(&(all->l1_lambda)), "l_1 lambda") ("l2", po::value(&(all->l2_lambda)), "l_2 lambda") ("learning_rate,l", po::value(&(all->eta)), "Set Learning Rate") ("loss_function", po::value()->default_value("squared"), "Specify the loss function to be used, uses squared by default. Currently available ones are squared, classic, hinge, logistic and quantile.") ("quantile_tau", po::value()->default_value(0.5), "Parameter \\tau associated with Quantile loss. Defaults to 0.5") ("power_t", po::value(&(all->power_t)), "t power value") ("decay_learning_rate", po::value(&(all->eta_decay_rate)), "Set Decay factor for learning_rate between passes") ("initial_pass_length", po::value(&(all->pass_length)), "initial number of examples per pass") ("initial_t", po::value(&((all->sd->t))), "initial t value") ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated. If no initial_regressor given, also used for initial weights.") ; po::options_description weight_opt("Weight options"); weight_opt.add_options() ("bit_precision,b", po::value(), "number of bits in the feature table") ("initial_regressor,i", po::value< vector >(), "Initial regressor(s)") ("final_regressor,f", po::value< string >(), "Final regressor") ("initial_weight", po::value(&(all->initial_weight)), "Set all weights to an initial value of 1.") ("random_weights", po::value(&(all->random_weights)), "make initial weights random") ("readable_model", po::value< string >(), "Output human-readable final regressor with numeric features") ("invert_hash", po::value< string >(), "Output human-readable final regressor with feature names") ("save_per_pass", "Save the model after every pass over data") ("input_feature_regularizer", po::value< string >(&(all->per_feature_regularizer_input)), "Per feature regularization input file") ("output_feature_regularizer_binary", po::value< string >(&(all->per_feature_regularizer_output)), "Per feature regularization output file") ("output_feature_regularizer_text", po::value< string >(&(all->per_feature_regularizer_text)), "Per feature regularization output file, in text") ; po::options_description holdout_opt("Holdout options"); holdout_opt.add_options() ("holdout_off", "no holdout data in multiple passes") ("holdout_period", po::value(&(all->holdout_period)), "holdout period for test only, default 10") ("holdout_after", po::value(&(all->holdout_after)), "holdout after n training examples, default off (disables holdout_period)") ("early_terminate", po::value(), "Specify the number of passes tolerated when holdout loss doesn't decrease before early termination, default is 3") ; po::options_description namespace_opt("Feature namespace options"); namespace_opt.add_options() ("hash", po::value< string > (), "how to hash the features. Available options: strings, all") ("ignore", po::value< vector >(), "ignore namespaces beginning with character ") ("keep", po::value< vector >(), "keep namespaces beginning with character ") ("noconstant", "Don't add a constant feature") ("constant,C", po::value(&(all->initial_constant)), "Set initial value of constant") ("sort_features", "turn this on to disregard order in which features have been defined. This will lead to smaller cache sizes") ("ngram", po::value< vector >(), "Generate N grams. To generate N grams for a single namespace 'foo', arg should be fN.") ("skips", po::value< vector >(), "Generate skips in N grams. This in conjunction with the ngram tag can be used to generate generalized n-skip-k-gram. To generate n-skips for a single namespace 'foo', arg should be fn.") ("affix", po::value(), "generate prefixes/suffixes of features; argument '+2a,-3b,+1' means generate 2-char prefixes for namespace a, 3-char suffixes for b and 1 char prefixes for default namespace") ("spelling", po::value< vector >(), "compute spelling features for a give namespace (use '_' for default namespace)"); ; po::options_description mf_opt("Matrix factorization options"); mf_opt.add_options() ("quadratic,q", po::value< vector > (), "Create and use quadratic features") ("q:", po::value< string >(), ": corresponds to a wildcard for all printable characters") ("cubic", po::value< vector > (), "Create and use cubic features") ("rank", po::value(&(all->rank)), "rank for matrix factorization.") ("new_mf", "use new, reduction-based matrix factorization") ; po::options_description lrq_opt("Low Rank Quadratic options"); lrq_opt.add_options() ("lrq", po::value > (), "use low rank quadratic features") ("lrqdropout", "use dropout training for low rank quadratic features") ; po::options_description multiclass_opt("Multiclass options"); multiclass_opt.add_options() ("oaa", po::value(), "Use one-against-all multiclass learning with labels") ("ect", po::value(), "Use error correcting tournament with labels") ("csoaa", po::value(), "Use one-against-all multiclass learning with costs") ("wap", po::value(), "Use weighted all-pairs multiclass learning with costs") ("csoaa_ldf", po::value(), "Use one-against-all multiclass learning with label dependent features. Specify singleline or multiline.") ("wap_ldf", po::value(), "Use weighted all-pairs multiclass learning with label dependent features. Specify singleline or multiline.") ; po::options_description active_opt("Active Learning options"); active_opt.add_options() ("active_learning", "active learning mode") ("active_simulation", "active learning simulation mode") ("active_mellowness", po::value(&(all->active_c0)), "active learning mellowness parameter c_0. Default 8") ; po::options_description cluster_opt("Parallelization options"); cluster_opt.add_options() ("span_server", po::value(&(all->span_server)), "Location of server for setting up spanning tree") ("unique_id", po::value(&(all->unique_id)),"unique id used for cluster parallel jobs") ("total", po::value(&(all->total)),"total number of nodes used in cluster parallel job") ("node", po::value(&(all->node)),"node number in cluster parallel job") ; po::options_description other_opt("Other options"); other_opt.add_options() ("bs", po::value(), "bootstrap mode with k rounds by online importance resampling") ("top", po::value(), "top k recommendation") ("bs_type", po::value(), "bootstrap mode - currently 'mean' or 'vote'") ("autolink", po::value(), "create link function with polynomial d") ("cb", po::value(), "Use contextual bandit learning with costs") ("lda", po::value(&(all->lda)), "Run lda with topics") ("nn", po::value(), "Use sigmoidal feedforward network with hidden units") ("cbify", po::value(), "Convert multiclass on classes into a contextual bandit problem and solve") ("search", po::value(), "use search-based structured prediction, argument=maximum action id or 0 for LDF") ; // Declare the supported options. desc.add_options() ("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial.") ("version","Version information") ("random_seed", po::value(&random_seed), "seed random number generator") ("noop","do no learning") ("print","print examples"); //po::positional_options_description p; // Be friendly: if -d was left out, treat positional param as data file //p.add("data", -1); desc.add(in_opt) .add(out_opt) .add(update_opt) .add(weight_opt) .add(holdout_opt) .add(namespace_opt) .add(mf_opt) .add(lrq_opt) .add(multiclass_opt) .add(active_opt) .add(cluster_opt) .add(other_opt); po::variables_map vm = po::variables_map(); po::variables_map vm_file = po::variables_map(); //separate variable map for storing flags in regressor file po::parsed_options parsed = po::command_line_parser(argc, argv). style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing). options(desc).allow_unregistered().run(); // got rid of ".positional(p)" because it doesn't work well with unrecognized options vector to_pass_further = po::collect_unrecognized(parsed.options, po::include_positional); string last_unrec_arg = (to_pass_further.size() > 0) ? string(to_pass_further[to_pass_further.size()-1]) // we want to write this down in case it's a data argument ala the positional option we got rid of : ""; po::store(parsed, vm); po::notify(vm); msrand48(random_seed); parse_diagnostics(*all, vm, desc, argc); if (vm.count("active_simulation")) all->active_simulation = true; if (vm.count("active_learning") && !all->active_simulation) all->active = true; parse_source(*all, vm); all->sd->weighted_unlabeled_examples = all->sd->t; all->initial_t = (float)all->sd->t; if(all->initial_t > 0)//for the normalized update: if initial_t is bigger than 1 we interpret this as if we had seen (all->initial_t) previous fake datapoints all with norm 1 all->normalized_sum_norm_x = all->initial_t; //Input regressor header io_buf io_temp; parse_regressor_args(*all, vm, io_temp); all->options_from_file_argv = VW::get_argv_from_string(all->options_from_file,all->options_from_file_argc); po::parsed_options parsed_file = po::command_line_parser(all->options_from_file_argc, all->options_from_file_argv). style(po::command_line_style::default_style ^ po::command_line_style::allow_guessing). options(desc).allow_unregistered().run(); po::store(parsed_file, vm_file); po::notify(vm_file); parse_feature_tweaks(*all, vm, vm_file); //feature tweaks parse_example_tweaks(*all, vm); //example manipulation parse_base_algorithm(*all, to_pass_further, vm); if (!all->quiet) { cerr << "Num weight bits = " << all->num_bits << endl; cerr << "learning rate = " << all->eta << endl; cerr << "initial_t = " << all->sd->t << endl; cerr << "power_t = " << all->power_t << endl; if (all->numpasses > 1) cerr << "decay_learning_rate = " << all->eta_decay_rate << endl; if (all->rank > 0) cerr << "rank = " << all->rank << endl; } parse_output_model(*all, vm); parse_output_preds(*all, vm, vm_file); load_input_model(*all, vm, io_temp); parse_scorer_reductions(*all, to_pass_further, vm, vm_file); bool got_cs = false; parse_score_users(*all, to_pass_further, vm, vm_file, got_cs); bool got_cb = false; parse_cb(*all, to_pass_further, vm, vm_file, got_cs, got_cb); parse_search(*all, to_pass_further, vm, vm_file, got_cs, got_cb); if(vm.count("bs") || vm_file.count("bs") ) all->l = BS::setup(*all, to_pass_further, vm, vm_file); if (to_pass_further.size() > 0) { bool is_actually_okay = false; // special case to try to emulate the missing -d if ((to_pass_further.size() == 1) && (to_pass_further[to_pass_further.size()-1] == last_unrec_arg)) { int f = io_buf().open_file(last_unrec_arg.c_str(), all->stdin_off, io_buf::READ); if (f != -1) { #ifdef _WIN32 _close(f); #else close(f); #endif all->data_filename = last_unrec_arg; if (ends_with(last_unrec_arg, ".gz")) set_compressed(all->p); is_actually_okay = true; } } if (!is_actually_okay) { cerr << "unrecognized options:"; for (size_t i=0; iquiet,all->numpasses); // force wpp to be a power of 2 to avoid 32-bit overflow uint32_t i = 0; size_t params_per_problem = all->l->increment; while (params_per_problem > (uint32_t)(1 << i)) i++; all->wpp = (1 << i) >> all->reg.stride_shift; return all; } namespace VW { void cmd_string_replace_value( string& cmd, string flag_to_replace, string new_value ) { flag_to_replace.append(" "); //add a space to make sure we obtain the right flag in case 2 flags start with the same set of characters size_t pos = cmd.find(flag_to_replace); if( pos == string::npos ) { //flag currently not present in command string, so just append it to command string cmd.append(" "); cmd.append(flag_to_replace); cmd.append(new_value); } else { //flag is present, need to replace old value with new value //compute position after flag_to_replace pos += flag_to_replace.size(); //now pos is position where value starts //find position of next space size_t pos_after_value = cmd.find(" ",pos); if(pos_after_value == string::npos) { //we reach the end of the string, so replace the all characters after pos by new_value cmd.replace(pos,cmd.size()-pos,new_value); } else { //replace characters between pos and pos_after_value by new_value cmd.replace(pos,pos_after_value-pos,new_value); } } } char** get_argv_from_string(string s, int& argc) { char* c = (char*)calloc_or_die(s.length()+3, sizeof(char)); c[0] = 'b'; c[1] = ' '; strcpy(c+2, s.c_str()); substring ss = {c, c+s.length()+2}; v_array foo; foo.end_array = foo.begin = foo.end = NULL; tokenize(' ', ss, foo); char** argv = (char**)calloc_or_die(foo.size(), sizeof(char*)); for (size_t i = 0; i < foo.size(); i++) { *(foo[i].end) = '\0'; argv[i] = (char*)calloc_or_die(foo[i].end-foo[i].begin+1, sizeof(char)); sprintf(argv[i],"%s",foo[i].begin); } argc = (int)foo.size(); free(c); foo.delete_v(); return argv; } vw* initialize(string s) { int argc = 0; s += " --no_stdin"; char** argv = get_argv_from_string(s,argc); vw* all = parse_args(argc, argv); initialize_parser_datastructures(*all); for(int i = 0; i < argc; i++) free(argv[i]); free (argv); return all; } void finish(vw& all) { finalize_regressor(all, all.final_regressor_name); all.l->finish(); delete all.l; if (all.reg.weight_vector != NULL) free(all.reg.weight_vector); free_parser(all); finalize_source(all.p); all.p->parse_name.erase(); all.p->parse_name.delete_v(); free(all.p); free(all.sd); for (int i = 0; i < all.options_from_file_argc; i++) free(all.options_from_file_argv[i]); free(all.options_from_file_argv); for (size_t i = 0; i < all.final_prediction_sink.size(); i++) if (all.final_prediction_sink[i] != 1) io_buf::close_file_or_socket(all.final_prediction_sink[i]); all.final_prediction_sink.delete_v(); delete all.loss; delete &all; } }