diff options
author | John Langford <jl@hunch.net> | 2014-02-26 01:02:22 +0400 |
---|---|---|
committer | John Langford <jl@hunch.net> | 2014-02-26 01:02:22 +0400 |
commit | 3e368e9b2e6e3cf01d9ab8c59127b3d1806f43bd (patch) | |
tree | 7a77db2292e9724506c6f35207e87a7d04bf3088 /vowpalwabbit/parse_args.cc | |
parent | b26fe54fc782b2dcc75ce3d49ac623a3e2193abe (diff) | |
parent | 6a8158a84ccd5cc091131f0cc954c1bd7e1963e7 (diff) |
fixed conflict
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r-- | vowpalwabbit/parse_args.cc | 455 |
1 files changed, 302 insertions, 153 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index 16596b98..2a501f45 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -15,22 +15,28 @@ license as described in the file LICENSE. #include "network.h" #include "global_data.h" #include "nn.h" +#include "cbify.h" #include "oaa.h" +#include "rand48.h" #include "bs.h" #include "topk.h" #include "ect.h" #include "csoaa.h" #include "wap.h" #include "cb.h" +#include "scorer.h" #include "searn.h" #include "bfgs.h" #include "lda_core.h" #include "noop.h" +#include "print.h" #include "gd_mf.h" +#include "mf.h" #include "vw.h" #include "rand48.h" #include "parse_args.h" #include "binary.h" +#include "lrq.h" #include "autolink.h" using namespace std; @@ -56,7 +62,7 @@ bool valid_ns(char c) void parse_affix_argument(vw&all, string str) { if (str.length() == 0) return; - char*cstr = new char[str.length()+1]; + char* cstr = (char*)calloc(str.length()+1, sizeof(char)); strcpy(cstr, str.c_str()); char*p = strtok(cstr, ","); @@ -74,7 +80,7 @@ void parse_affix_argument(vw&all, string str) { if (q[1] != 0) { if (valid_ns(q[1])) ns = (uint16_t)q[1]; - else { + else { cerr << "malformed affix argument (invalid namespace): " << p << endl; throw exception(); } @@ -87,130 +93,200 @@ void parse_affix_argument(vw&all, string str) { uint16_t afx = (len << 1) | (prefix & 0x1); all.affix_features[ns] <<= 4; all.affix_features[ns] |= afx; - + p = strtok(NULL, ","); } - - delete cstr; + + free(cstr); } vw* parse_args(int argc, char *argv[]) { po::options_description desc("VW options"); - + vw* all = new vw(); size_t random_seed = 0; all->program_name = argv[0]; - // Declare the supported options. - desc.add_options() - ("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial.") - ("active_learning", "active learning mode") - ("active_simulation", "active learning simulation mode") - ("active_mellowness", po::value<float>(&(all->active_c0)), "active learning mellowness parameter c_0. Default 8") + + po::options_description in_opt("Input options"); + + in_opt.add_options() + ("data,d", po::value< string >(), "Example Set") + ("ring_size", po::value<size_t>(&(all->p->ring_size)), "size of example ring") + ("examples", po::value<size_t>(&(all->max_examples)), "number of examples to parse") + ("testonly,t", "Ignore label information and just test") + ("daemon", "persistent daemon mode on port 26542") + ("port", po::value<size_t>(),"port to listen on") + ("num_children", po::value<size_t>(&(all->num_children)), "number of children for persistent daemon mode") + ("pid_file", po::value< string >(), "Write pid file in persistent daemon mode") + ("passes", po::value<size_t>(&(all->numpasses)),"Number of Training Passes") + ("cache,c", "Use a cache. The default is <data>.cache") + ("cache_file", po::value< vector<string> >(), "The location(s) of cache_file.") + ("kill_cache,k", "do not reuse existing cache: create a new one always") + ("compressed", "use gzip format whenever possible. If a cache file is being created, this option creates a compressed cache file. A mixture of raw-text & compressed inputs are supported with autodetection.") + ("no_stdin", "do not default to reading from stdin") + ("save_resume", "save extra state so learning can be resumed later with new data") + ; + + po::options_description out_opt("Output options"); + + out_opt.add_options() + ("audit,a", "print weights of features") + ("predictions,p", po::value< string >(), "File to output predictions to") + ("raw_predictions,r", po::value< string >(), "File to output unnormalized predictions to") + ("sendto", po::value< vector<string> >(), "send examples to <host>") + ("quiet", "Don't output disgnostics and progress updates") + ("progress,P", po::value< string >(), "Progress update frequency. int: additive, float: multiplicative") ("binary", "report loss as binary classification on -1,1") - ("bs", po::value<size_t>(), "bootstrap mode with k rounds by online importance resampling") - ("top", po::value<size_t>(), "top k recommendation") - ("bs_type", po::value<string>(), "bootstrap mode - currently 'mean' or 'vote'") - ("autolink", po::value<size_t>(), "create link function with polynomial d") + ("min_prediction", po::value<float>(&(all->sd->min_label)), "Smallest prediction to output") + ("max_prediction", po::value<float>(&(all->sd->max_label)), "Largest prediction to output") + ; + + po::options_description update_opt("Update options"); + + update_opt.add_options() ("sgd", "use regular stochastic gradient descent update.") + ("hessian_on", "use second derivative in line search") + ("bfgs", "use bfgs optimization") + ("mem", po::value<int>(&(all->m)), "memory in bfgs") + ("termination", po::value<float>(&(all->rel_threshold)),"Termination threshold") ("adaptive", "use adaptive, individual learning rates.") ("invariant", "use safe/importance aware updates.") ("normalized", "use per feature normalized updates") ("exact_adaptive_norm", "use current default invariant normalized adaptive update rule") - ("audit,a", "print weights of features") - ("bit_precision,b", po::value<size_t>(), "number of bits in the feature table") - ("bfgs", "use bfgs optimization") - ("cache,c", "Use a cache. The default is <data>.cache") - ("cache_file", po::value< vector<string> >(), "The location(s) of cache_file.") - ("compressed", "use gzip format whenever possible. If a cache file is being created, this option creates a compressed cache file. A mixture of raw-text & compressed inputs are supported with autodetection.") - ("no_stdin", "do not default to reading from stdin") ("conjugate_gradient", "use conjugate gradient based optimization") - ("csoaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> costs") - ("wap", po::value<size_t>(), "Use weighted all-pairs multiclass learning with <k> costs") - ("csoaa_ldf", po::value<string>(), "Use one-against-all multiclass learning with label dependent features. Specify singleline or multiline.") - ("wap_ldf", po::value<string>(), "Use weighted all-pairs multiclass learning with label dependent features. Specify singleline or multiline.") - ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs") ("l1", po::value<float>(&(all->l1_lambda)), "l_1 lambda") ("l2", po::value<float>(&(all->l2_lambda)), "l_2 lambda") - ("data,d", po::value< string >(), "Example Set") - ("daemon", "persistent daemon mode on port 26542") - ("num_children", po::value<size_t>(&(all->num_children)), "number of children for persistent daemon mode") - ("pid_file", po::value< string >(), "Write pid file in persistent daemon mode") + ("learning_rate,l", po::value<float>(&(all->eta)), "Set Learning Rate") + ("loss_function", po::value<string>()->default_value("squared"), "Specify the loss function to be used, uses squared by default. Currently available ones are squared, classic, hinge, logistic and quantile.") + ("quantile_tau", po::value<float>()->default_value(0.5), "Parameter \\tau associated with Quantile loss. Defaults to 0.5") + ("power_t", po::value<float>(&(all->power_t)), "t power value") ("decay_learning_rate", po::value<float>(&(all->eta_decay_rate)), "Set Decay factor for learning_rate between passes") - ("input_feature_regularizer", po::value< string >(&(all->per_feature_regularizer_input)), "Per feature regularization input file") + ("initial_pass_length", po::value<size_t>(&(all->pass_length)), "initial number of examples per pass") + ("initial_t", po::value<double>(&((all->sd->t))), "initial t value") + ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated. If no initial_regressor given, also used for initial weights.") + ; + + po::options_description weight_opt("Weight options"); + + weight_opt.add_options() + ("bit_precision,b", po::value<size_t>(), "number of bits in the feature table") + ("initial_regressor,i", po::value< vector<string> >(), "Initial regressor(s)") ("final_regressor,f", po::value< string >(), "Final regressor") + ("initial_weight", po::value<float>(&(all->initial_weight)), "Set all weights to an initial value of 1.") + ("random_weights", po::value<bool>(&(all->random_weights)), "make initial weights random") ("readable_model", po::value< string >(), "Output human-readable final regressor with numeric features") ("invert_hash", po::value< string >(), "Output human-readable final regressor with feature names") - ("hash", po::value< string > (), "how to hash the features. Available options: strings, all") - ("hessian_on", "use second derivative in line search") + ("save_per_pass", "Save the model after every pass over data") + ("input_feature_regularizer", po::value< string >(&(all->per_feature_regularizer_input)), "Per feature regularization input file") + ("output_feature_regularizer_binary", po::value< string >(&(all->per_feature_regularizer_output)), "Per feature regularization output file") + ("output_feature_regularizer_text", po::value< string >(&(all->per_feature_regularizer_text)), "Per feature regularization output file, in text") + ; + + po::options_description holdout_opt("Holdout options"); + holdout_opt.add_options() ("holdout_off", "no holdout data in multiple passes") ("holdout_period", po::value<uint32_t>(&(all->holdout_period)), "holdout period for test only, default 10") ("holdout_after", po::value<uint32_t>(&(all->holdout_after)), "holdout after n training examples, default off (disables holdout_period)") - ("version","Version information") + ("early_terminate", po::value<size_t>(), "Specify the number of passes tolerated when holdout loss doesn't decrease before early termination, default is 3") + ; + + po::options_description namespace_opt("Feature namespace options"); + namespace_opt.add_options() + ("hash", po::value< string > (), "how to hash the features. Available options: strings, all") ("ignore", po::value< vector<unsigned char> >(), "ignore namespaces beginning with character <arg>") ("keep", po::value< vector<unsigned char> >(), "keep namespaces beginning with character <arg>") - ("kill_cache,k", "do not reuse existing cache: create a new one always") - ("initial_weight", po::value<float>(&(all->initial_weight)), "Set all weights to an initial value of 1.") - ("initial_regressor,i", po::value< vector<string> >(), "Initial regressor(s)") - ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated. If no initial_regressor given, also used for initial weights.") - ("initial_pass_length", po::value<size_t>(&(all->pass_length)), "initial number of examples per pass") - ("initial_t", po::value<double>(&((all->sd->t))), "initial t value") - ("lda", po::value<size_t>(&(all->lda)), "Run lda with <int> topics") - ("span_server", po::value<string>(&(all->span_server)), "Location of server for setting up spanning tree") - ("min_prediction", po::value<float>(&(all->sd->min_label)), "Smallest prediction to output") - ("max_prediction", po::value<float>(&(all->sd->max_label)), "Largest prediction to output") - ("mem", po::value<int>(&(all->m)), "memory in bfgs") - ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units") ("noconstant", "Don't add a constant feature") - ("noop","do no learning") - ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels") - ("ect", po::value<size_t>(), "Use error correcting tournament with <k> labels") - ("output_feature_regularizer_binary", po::value< string >(&(all->per_feature_regularizer_output)), "Per feature regularization output file") - ("output_feature_regularizer_text", po::value< string >(&(all->per_feature_regularizer_text)), "Per feature regularization output file, in text") - ("port", po::value<size_t>(),"port to listen on") - ("power_t", po::value<float>(&(all->power_t)), "t power value") - ("learning_rate,l", po::value<float>(&(all->eta)), "Set Learning Rate") - ("passes", po::value<size_t>(&(all->numpasses)),"Number of Training Passes") - ("termination", po::value<float>(&(all->rel_threshold)),"Termination threshold") - ("predictions,p", po::value< string >(), "File to output predictions to") + ("constant,C", po::value<float>(&(all->initial_constant)), "Set initial value of constant") + ("sort_features", "turn this on to disregard order in which features have been defined. This will lead to smaller cache sizes") + ("ngram", po::value< vector<string> >(), "Generate N grams") + ("skips", po::value< vector<string> >(), "Generate skips in N grams. This in conjunction with the ngram tag can be used to generate generalized n-skip-k-gram.") + ("affix", po::value<string>(), "generate prefixes/suffixes of features; argument '+2a,-3b,+1' means generate 2-char prefixes for namespace a, 3-char suffixes for b and 1 char prefixes for default namespace") + ("spelling", po::value< vector<string> >(), "compute spelling features for a give namespace (use '_' for default namespace)"); + ; + + po::options_description mf_opt("Matrix factorization options"); + mf_opt.add_options() ("quadratic,q", po::value< vector<string> > (), "Create and use quadratic features") ("q:", po::value< string >(), ": corresponds to a wildcard for all printable characters") ("cubic", po::value< vector<string> > (), "Create and use cubic features") - ("quiet", "Don't output diagnostics") ("rank", po::value<uint32_t>(&(all->rank)), "rank for matrix factorization.") - ("random_weights", po::value<bool>(&(all->random_weights)), "make initial weights random") - ("random_seed", po::value<size_t>(&random_seed), "seed random number generator") - ("raw_predictions,r", po::value< string >(), - "File to output unnormalized predictions to") - ("ring_size", po::value<size_t>(&(all->p->ring_size)), "size of example ring") - ("examples", po::value<size_t>(&(all->max_examples)), "number of examples to parse") - ("save_per_pass", "Save the model after every pass over data") - ("early_terminate", po::value<size_t>(), "Specify the number of passes tolerated when holdout loss doesn't decrease before early termination, default is 3") - ("save_resume", "save extra state so learning can be resumed later with new data") - ("sendto", po::value< vector<string> >(), "send examples to <host>") - ("searn", po::value<size_t>(), "use searn, argument=maximum action id or 0 for LDF") - ("testonly,t", "Ignore label information and just test") - ("loss_function", po::value<string>()->default_value("squared"), "Specify the loss function to be used, uses squared by default. Currently available ones are squared, classic, hinge, logistic and quantile.") - ("quantile_tau", po::value<float>()->default_value(0.5), "Parameter \\tau associated with Quantile loss. Defaults to 0.5") + ("new_mf", "use new, reduction-based matrix factorization") + ; + po::options_description lrq_opt("Low Rank Quadratic options"); + lrq_opt.add_options() + ("lrq", po::value<vector<string> > (), "use low rank quadratic features") + ("lrqdropout", "use dropout training for low rank quadratic features") + ; + + po::options_description multiclass_opt("Multiclass options"); + multiclass_opt.add_options() + ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels") + ("ect", po::value<size_t>(), "Use error correcting tournament with <k> labels") + ("csoaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> costs") + ("wap", po::value<size_t>(), "Use weighted all-pairs multiclass learning with <k> costs") + ("csoaa_ldf", po::value<string>(), "Use one-against-all multiclass learning with label dependent features. Specify singleline or multiline.") + ("wap_ldf", po::value<string>(), "Use weighted all-pairs multiclass learning with label dependent features. Specify singleline or multiline.") + ; + + po::options_description active_opt("Active Learning options"); + active_opt.add_options() + ("active_learning", "active learning mode") + ("active_simulation", "active learning simulation mode") + ("active_mellowness", po::value<float>(&(all->active_c0)), "active learning mellowness parameter c_0. Default 8") + ; + + po::options_description cluster_opt("Parallelization options"); + cluster_opt.add_options() + ("span_server", po::value<string>(&(all->span_server)), "Location of server for setting up spanning tree") ("unique_id", po::value<size_t>(&(all->unique_id)),"unique id used for cluster parallel jobs") - ("total", po::value<size_t>(&(all->total)),"total number of nodes used in cluster parallel job") - ("node", po::value<size_t>(&(all->node)),"node number in cluster parallel job") + ("total", po::value<size_t>(&(all->total)),"total number of nodes used in cluster parallel job") + ("node", po::value<size_t>(&(all->node)),"node number in cluster parallel job") + ; - ("sort_features", "turn this on to disregard order in which features have been defined. This will lead to smaller cache sizes") - ("ngram", po::value< vector<string> >(), "Generate N grams") - ("skips", po::value< vector<string> >(), "Generate skips in N grams. This in conjunction with the ngram tag can be used to generate generalized n-skip-k-gram.") - ("affix", po::value<string>(), "generate prefixes/suffixes of features; argument '+2a,-3b,+1' means generate 2-char prefixes for namespace a, 3-char suffixes for b and 1 char prefixes for default namespace") - ("spelling", po::value< vector<string> >(), "compute spelling features for a give namespace (use '_' for default namespace)"); + po::options_description other_opt("Other options"); + other_opt.add_options() + ("bs", po::value<size_t>(), "bootstrap mode with k rounds by online importance resampling") + ("top", po::value<size_t>(), "top k recommendation") + ("bs_type", po::value<string>(), "bootstrap mode - currently 'mean' or 'vote'") + ("autolink", po::value<size_t>(), "create link function with polynomial d") + ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs") + ("lda", po::value<uint32_t>(&(all->lda)), "Run lda with <int> topics") + ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units") + ("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve") + ("searn", po::value<size_t>(), "use searn, argument=maximum action id or 0 for LDF") + ; + + // Declare the supported options. + desc.add_options() + ("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial.") + ("version","Version information") + ("random_seed", po::value<size_t>(&random_seed), "seed random number generator") + ("noop","do no learning") + ("print","print examples"); //po::positional_options_description p; // Be friendly: if -d was left out, treat positional param as data file //p.add("data", -1); + desc.add(in_opt) + .add(out_opt) + .add(update_opt) + .add(weight_opt) + .add(holdout_opt) + .add(namespace_opt) + .add(mf_opt) + .add(lrq_opt) + .add(multiclass_opt) + .add(active_opt) + .add(cluster_opt) + .add(other_opt); + po::variables_map vm = po::variables_map(); po::variables_map vm_file = po::variables_map(); //separate variable map for storing flags in regressor file @@ -225,7 +301,7 @@ vw* parse_args(int argc, char *argv[]) po::store(parsed, vm); po::notify(vm); - + if(all->numpasses > 1) all->holdout_set_off = false; @@ -236,7 +312,7 @@ vw* parse_args(int argc, char *argv[]) { all->holdout_set_off = true; cerr<<"Making holdout_set_off=true since output regularizer specified\n"; - } + } all->data_filename = ""; @@ -257,11 +333,47 @@ vw* parse_args(int argc, char *argv[]) exit(0); } - if (vm.count("quiet")) + if (vm.count("quiet")) { all->quiet = true; - else + // --quiet wins over --progress + } else { all->quiet = false; + if (vm.count("progress")) { + string progress_str = vm["progress"].as<string>(); + all->progress_arg = (float)::atof(progress_str.c_str()); + + // --progress interval is dual: either integer or floating-point + if (progress_str.find_first_of(".") == string::npos) { + // No "." in arg: assume integer -> additive + all->progress_add = true; + if (all->progress_arg < 1) { + cerr << "warning: additive --progress <int>" + << " can't be < 1: forcing to 1\n"; + all->progress_arg = 1; + + } + all->sd->dump_interval = all->progress_arg; + + } else { + // A "." in arg: assume floating-point -> multiplicative + all->progress_add = false; + + if (all->progress_arg <= 1.0) { + cerr << "warning: multiplicative --progress <float>: " + << vm["progress"].as<string>() + << " is <= 1.0: adding 1.0\n"; + all->progress_arg += 1.0; + + } else if (all->progress_arg > 9.0) { + cerr << "warning: multiplicative --progress <float>" + << " is > 9.0: you probably meant to use an integer\n"; + } + all->sd->dump_interval = 1.0; + } + } + } + msrand48(random_seed); if (vm.count("active_simulation")) @@ -291,12 +403,12 @@ vw* parse_args(int argc, char *argv[]) } all->reg.stride = 4; //use stride of 4 for default invariant normalized adaptive updates - //if we are doing matrix factorization, or user specified anything in sgd,adaptive,invariant,normalized, we turn off default update rules and use whatever user specified - if( all->rank > 0 || !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) ) + //if the user specified anything in sgd,adaptive,invariant,normalized, we turn off default update rules and use whatever user specified + if( (all->rank > 0 && !vm.count("new_mf")) || !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) ) { - all->adaptive = all->training && (vm.count("adaptive") && all->rank == 0); + all->adaptive = all->training && vm.count("adaptive") && (all->rank == 0 && !vm.count("new_mf")); all->invariant_updates = all->training && vm.count("invariant"); - all->normalized_updates = all->training && (vm.count("normalized") && all->rank == 0); + all->normalized_updates = all->training && vm.count("normalized") && (all->rank == 0 && !vm.count("new_mf")); all->reg.stride = 1; @@ -306,7 +418,8 @@ vw* parse_args(int argc, char *argv[]) if( all->normalized_updates ) all->reg.stride *= 2; if(!vm.count("learning_rate") && !vm.count("l") && !(all->adaptive && all->normalized_updates)) - all->eta = 10; //default learning rate to 10 for non default update rule + if (all->lda == 0) + all->eta = 10; //default learning rate to 10 for non default update rule //if not using normalized or adaptive, default initial_t to 1 instead of 0 if(!all->adaptive && !all->normalized_updates && !vm.count("initial_t")) { @@ -320,15 +433,33 @@ vw* parse_args(int argc, char *argv[]) all->feature_mask_idx = 1; } else if(all->reg.stride == 2){ - all->reg.stride *= 2;//if either normalized or adaptive, stride->4, mask_idx is still 3 + all->reg.stride *= 2;//if either normalized or adaptive, stride->4, mask_idx is still 3 } } } + if (all->l1_lambda < 0.) { + cerr << "l1_lambda should be nonnegative: resetting from " << all->l1_lambda << " to 0" << endl; + all->l1_lambda = 0.; + } + if (all->l2_lambda < 0.) { + cerr << "l2_lambda should be nonnegative: resetting from " << all->l2_lambda << " to 0" << endl; + all->l2_lambda = 0.; + } + all->reg_mode += (all->l1_lambda > 0.) ? 1 : 0; + all->reg_mode += (all->l2_lambda > 0.) ? 2 : 0; + if (!all->quiet) + { + if (all->reg_mode %2 && !vm.count("bfgs")) + cerr << "using l1 regularization = " << all->l1_lambda << endl; + if (all->reg_mode > 1) + cerr << "using l2 regularization = " << all->l2_lambda << endl; + } + all->l = GD::setup(*all, vm); all->scorer = all->l; - if (vm.count("bfgs") || vm.count("conjugate_gradient")) + if (vm.count("bfgs") || vm.count("conjugate_gradient")) all->l = BFGS::setup(*all, to_pass_further, vm, vm_file); if (vm.count("version") || argc == 1) { @@ -356,7 +487,7 @@ vw* parse_args(int argc, char *argv[]) cout << "You can not skip unless ngram is > 1" << endl; throw exception(); } - + all->skip_strings = vm["skips"].as<vector<string> >(); compile_gram(all->skip_strings, all->skips, (char*)"skips", all->quiet); } @@ -368,10 +499,10 @@ vw* parse_args(int argc, char *argv[]) if (vm.count("spelling")) { vector<string> spelling_ns = vm["spelling"].as< vector<string> >(); for (size_t id=0; id<spelling_ns.size(); id++) - if (spelling_ns[id][0] == '_') all->spelling_features[' '] = true; + if (spelling_ns[id][0] == '_') all->spelling_features[(unsigned char)' '] = true; else all->spelling_features[(size_t)spelling_ns[id][0]] = true; } - + if (vm.count("bit_precision")) { all->default_bits = false; @@ -382,7 +513,7 @@ vw* parse_args(int argc, char *argv[]) throw exception(); } } - + if (vm.count("daemon") || vm.count("pid_file") || (vm.count("port") && !all->active) ) { all->daemon = true; @@ -392,7 +523,7 @@ vw* parse_args(int argc, char *argv[]) if (vm.count("compressed")) set_compressed(all->p); - + if (vm.count("data")) { all->data_filename = vm["data"].as<string>(); if (ends_with(all->data_filename, ".gz")) @@ -408,14 +539,14 @@ vw* parse_args(int argc, char *argv[]) { all->pairs = vm["quadratic"].as< vector<string> >(); vector<string> newpairs; - //string tmp; + //string tmp; char printable_start = '!'; char printable_end = '~'; int valid_ns_size = printable_end - printable_start - 1; //will skip two characters if(!all->quiet) - cerr<<"creating quadratic features for pairs: "; - + cerr<<"creating quadratic features for pairs: "; + for (vector<string>::iterator i = all->pairs.begin(); i != all->pairs.end();i++){ if(!all->quiet){ cerr << *i << " "; @@ -456,7 +587,7 @@ vw* parse_args(int argc, char *argv[]) } else{ newpairs.push_back(string(*i)); - } + } } newpairs.swap(all->pairs); if(!all->quiet) @@ -540,8 +671,8 @@ vw* parse_args(int argc, char *argv[]) } } - // matrix factorization enabled - if (all->rank > 0) { + // (non-reduction) matrix factorization enabled + if (!vm.count("new_mf") && all->rank > 0) { // store linear + 2*rank weights per index, round up to power of two float temp = ceilf(logf((float)(all->rank*2+1)) / logf (2.f)); all->reg.stride = 1 << (int) temp; @@ -582,21 +713,21 @@ vw* parse_args(int argc, char *argv[]) //if (vm.count("nonormalize")) // all->nonormalize = true; - if (vm.count("lda")) + if (vm.count("lda")) all->l = LDA::setup(*all, to_pass_further, vm); - if (!vm.count("lda") && !all->adaptive && !all->normalized_updates) + if (!vm.count("lda") && !all->adaptive && !all->normalized_updates) all->eta *= powf((float)(all->sd->t), all->power_t); - + if (vm.count("readable_model")) all->text_regressor_name = vm["readable_model"].as<string>(); if (vm.count("invert_hash")){ all->inv_hash_regressor_name = vm["invert_hash"].as<string>(); - all->hash_inv = true; + all->hash_inv = true; } - + if (vm.count("save_per_pass")) all->save_per_pass = true; @@ -619,10 +750,16 @@ vw* parse_args(int argc, char *argv[]) if(vm.count("quantile_tau")) loss_parameter = vm["quantile_tau"].as<float>(); - if (vm.count("noop")) + if (vm.count("noop")) all->l = NOOP::setup(*all); - - if (all->rank != 0) + + if (vm.count("print")) + { + all->l = PRINT::setup(*all); + all->reg.stride = 1; + } + + if (!vm.count("new_mf") && all->rank > 0) all->l = GDMF::setup(*all); all->loss = getLossFunction(all, loss_function, (float)loss_parameter); @@ -713,37 +850,27 @@ vw* parse_args(int argc, char *argv[]) io_temp.close_file(); } - if (all->l1_lambda < 0.) { - cerr << "l1_lambda should be nonnegative: resetting from " << all->l1_lambda << " to 0" << endl; - all->l1_lambda = 0.; - } - if (all->l2_lambda < 0.) { - cerr << "l2_lambda should be nonnegative: resetting from " << all->l2_lambda << " to 0" << endl; - all->l2_lambda = 0.; - } - all->reg_mode += (all->l1_lambda > 0.) ? 1 : 0; - all->reg_mode += (all->l2_lambda > 0.) ? 2 : 0; - if (!all->quiet) - { - if (all->reg_mode %2 && !vm.count("bfgs")) - cerr << "using l1 regularization = " << all->l1_lambda << endl; - if (all->reg_mode > 1) - cerr << "using l2 regularization = " << all->l2_lambda << endl; - } - bool got_mc = false; bool got_cs = false; bool got_cb = false; - if(vm.count("nn") || vm_file.count("nn") ) + if(vm.count("nn") || vm_file.count("nn") ) all->l = NN::setup(*all, to_pass_further, vm, vm_file); - if(vm.count("autolink") || vm_file.count("autolink") ) + if (vm.count("new_mf") && all->rank > 0) + all->l = MF::setup(*all, vm); + + if(vm.count("autolink") || vm_file.count("autolink") ) all->l = ALINK::setup(*all, to_pass_further, vm, vm_file); - if(vm.count("top") || vm_file.count("top") ) + if (vm.count("lrq") || vm_file.count("lrq")) + all->l = LRQ::setup(*all, to_pass_further, vm, vm_file); + + all->l = Scorer::setup(*all, to_pass_further, vm, vm_file); + + if(vm.count("top") || vm_file.count("top") ) all->l = TOPK::setup(*all, to_pass_further, vm, vm_file); - + if (vm.count("binary") || vm_file.count("binary")) all->l = BINARY::setup(*all, to_pass_further, vm, vm_file); @@ -753,7 +880,7 @@ vw* parse_args(int argc, char *argv[]) all->l = OAA::setup(*all, to_pass_further, vm, vm_file); got_mc = true; } - + if (vm.count("ect") || vm_file.count("ect") ) { if (got_mc) { cerr << "error: cannot specify multiple MC learners" << endl; throw exception(); } @@ -763,15 +890,17 @@ vw* parse_args(int argc, char *argv[]) if(vm.count("csoaa") || vm_file.count("csoaa") ) { if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); } - + all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); + all->cost_sensitive = all->l; got_cs = true; } if(vm.count("wap") || vm_file.count("wap") ) { if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); } - + all->l = WAP::setup(*all, to_pass_further, vm, vm_file); + all->cost_sensitive = all->l; got_cs = true; } @@ -779,6 +908,7 @@ vw* parse_args(int argc, char *argv[]) if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); } all->l = CSOAA_AND_WAP_LDF::setup(*all, to_pass_further, vm, vm_file); + all->cost_sensitive = all->l; got_cs = true; } @@ -786,6 +916,7 @@ vw* parse_args(int argc, char *argv[]) if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); } all->l = CSOAA_AND_WAP_LDF::setup(*all, to_pass_further, vm, vm_file); + all->cost_sensitive = all->l; got_cs = true; } @@ -796,6 +927,7 @@ vw* parse_args(int argc, char *argv[]) else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cb"])); all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified + all->cost_sensitive = all->l; got_cs = true; } @@ -803,16 +935,37 @@ vw* parse_args(int argc, char *argv[]) got_cb = true; } - all->searnstr = NULL; - if (vm.count("searn") || vm_file.count("searn") ) { + if (vm.count("cbify") || vm_file.count("cbify")) + { + if(!got_cs) { + if( vm_file.count("cbify") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["cbify"])); + else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cbify"])); + + all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified + all->cost_sensitive = all->l; + got_cs = true; + } + + if (!got_cb) { + if( vm_file.count("cbify") ) vm.insert(pair<string,po::variable_value>(string("cb"),vm_file["cbify"])); + else vm.insert(pair<string,po::variable_value>(string("cb"),vm["cbify"])); + all->l = CB::setup(*all, to_pass_further, vm, vm_file); + got_cb = true; + } + + all->l = CBIFY::setup(*all, to_pass_further, vm, vm_file); + } + + if (vm.count("searn") || vm_file.count("searn") ) { if (!got_cs && !got_cb) { if( vm_file.count("searn") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["searn"])); else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["searn"])); - + all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless others have been specified + all->cost_sensitive = all->l; got_cs = true; } - all->searnstr = (Searn::searn*)calloc(1, sizeof(Searn::searn)); + //all->searnstr = (Searn::searn*)calloc(1, sizeof(Searn::searn)); all->l = Searn::setup(*all, to_pass_further, vm, vm_file); } @@ -821,7 +974,7 @@ vw* parse_args(int argc, char *argv[]) throw exception(); } - if(vm.count("bs") || vm_file.count("bs") ) + if(vm.count("bs") || vm_file.count("bs") ) all->l = BS::setup(*all, to_pass_further, vm, vm_file); if (to_pass_further.size() > 0) { @@ -857,9 +1010,9 @@ vw* parse_args(int argc, char *argv[]) parse_source_args(*all, vm, all->quiet,all->numpasses); - // force stride * weights_per_problem to be a power of 2 to avoid 32-bit overflow + // force wpp to be a power of 2 to avoid 32-bit overflow uint32_t i = 0; - size_t params_per_problem = all->l->increment * all->l->weights; + size_t params_per_problem = all->l->increment; while (params_per_problem > (uint32_t)(1 << i)) i++; all->wpp = (1 << i) / all->reg.stride; @@ -880,7 +1033,7 @@ namespace VW { } else { //flag is present, need to replace old value with new value - + //compute position after flag_to_replace pos += flag_to_replace.size(); @@ -908,7 +1061,7 @@ namespace VW { v_array<substring> foo; foo.end_array = foo.begin = foo.end = NULL; tokenize(' ', ss, foo); - + char** argv = (char**)calloc(foo.size(), sizeof(char*)); for (size_t i = 0; i < foo.size(); i++) { @@ -922,13 +1075,13 @@ namespace VW { foo.delete_v(); return argv; } - + vw* initialize(string s) { int argc = 0; s += " --no_stdin"; char** argv = get_argv_from_string(s,argc); - + vw* all = parse_args(argc, argv); initialize_parser_datastructures(*all); @@ -944,11 +1097,11 @@ namespace VW { { finalize_regressor(all, all.final_regressor_name); all.l->finish(); + delete all.l; if (all.reg.weight_vector != NULL) free(all.reg.weight_vector); free_parser(all); finalize_source(all.p); - free(all.p->lp); all.p->parse_name.erase(); all.p->parse_name.delete_v(); free(all.p); @@ -958,11 +1111,7 @@ namespace VW { free(all.options_from_file_argv); for (size_t i = 0; i < all.final_prediction_sink.size(); i++) if (all.final_prediction_sink[i] != 1) -#ifdef _WIN32 - _close(all.final_prediction_sink[i]); -#else - close(all.final_prediction_sink[i]); -#endif + io_buf::close_file_or_socket(all.final_prediction_sink[i]); all.final_prediction_sink.delete_v(); delete all.loss; delete &all; |