Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@hunch.net>2014-02-26 01:02:22 +0400
committerJohn Langford <jl@hunch.net>2014-02-26 01:02:22 +0400
commit3e368e9b2e6e3cf01d9ab8c59127b3d1806f43bd (patch)
tree7a77db2292e9724506c6f35207e87a7d04bf3088 /vowpalwabbit/parse_args.cc
parentb26fe54fc782b2dcc75ce3d49ac623a3e2193abe (diff)
parent6a8158a84ccd5cc091131f0cc954c1bd7e1963e7 (diff)
fixed conflict
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r--vowpalwabbit/parse_args.cc455
1 files changed, 302 insertions, 153 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index 16596b98..2a501f45 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -15,22 +15,28 @@ license as described in the file LICENSE.
#include "network.h"
#include "global_data.h"
#include "nn.h"
+#include "cbify.h"
#include "oaa.h"
+#include "rand48.h"
#include "bs.h"
#include "topk.h"
#include "ect.h"
#include "csoaa.h"
#include "wap.h"
#include "cb.h"
+#include "scorer.h"
#include "searn.h"
#include "bfgs.h"
#include "lda_core.h"
#include "noop.h"
+#include "print.h"
#include "gd_mf.h"
+#include "mf.h"
#include "vw.h"
#include "rand48.h"
#include "parse_args.h"
#include "binary.h"
+#include "lrq.h"
#include "autolink.h"
using namespace std;
@@ -56,7 +62,7 @@ bool valid_ns(char c)
void parse_affix_argument(vw&all, string str) {
if (str.length() == 0) return;
- char*cstr = new char[str.length()+1];
+ char* cstr = (char*)calloc(str.length()+1, sizeof(char));
strcpy(cstr, str.c_str());
char*p = strtok(cstr, ",");
@@ -74,7 +80,7 @@ void parse_affix_argument(vw&all, string str) {
if (q[1] != 0) {
if (valid_ns(q[1]))
ns = (uint16_t)q[1];
- else {
+ else {
cerr << "malformed affix argument (invalid namespace): " << p << endl;
throw exception();
}
@@ -87,130 +93,200 @@ void parse_affix_argument(vw&all, string str) {
uint16_t afx = (len << 1) | (prefix & 0x1);
all.affix_features[ns] <<= 4;
all.affix_features[ns] |= afx;
-
+
p = strtok(NULL, ",");
}
-
- delete cstr;
+
+ free(cstr);
}
vw* parse_args(int argc, char *argv[])
{
po::options_description desc("VW options");
-
+
vw* all = new vw();
size_t random_seed = 0;
all->program_name = argv[0];
- // Declare the supported options.
- desc.add_options()
- ("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial.")
- ("active_learning", "active learning mode")
- ("active_simulation", "active learning simulation mode")
- ("active_mellowness", po::value<float>(&(all->active_c0)), "active learning mellowness parameter c_0. Default 8")
+
+ po::options_description in_opt("Input options");
+
+ in_opt.add_options()
+ ("data,d", po::value< string >(), "Example Set")
+ ("ring_size", po::value<size_t>(&(all->p->ring_size)), "size of example ring")
+ ("examples", po::value<size_t>(&(all->max_examples)), "number of examples to parse")
+ ("testonly,t", "Ignore label information and just test")
+ ("daemon", "persistent daemon mode on port 26542")
+ ("port", po::value<size_t>(),"port to listen on")
+ ("num_children", po::value<size_t>(&(all->num_children)), "number of children for persistent daemon mode")
+ ("pid_file", po::value< string >(), "Write pid file in persistent daemon mode")
+ ("passes", po::value<size_t>(&(all->numpasses)),"Number of Training Passes")
+ ("cache,c", "Use a cache. The default is <data>.cache")
+ ("cache_file", po::value< vector<string> >(), "The location(s) of cache_file.")
+ ("kill_cache,k", "do not reuse existing cache: create a new one always")
+ ("compressed", "use gzip format whenever possible. If a cache file is being created, this option creates a compressed cache file. A mixture of raw-text & compressed inputs are supported with autodetection.")
+ ("no_stdin", "do not default to reading from stdin")
+ ("save_resume", "save extra state so learning can be resumed later with new data")
+ ;
+
+ po::options_description out_opt("Output options");
+
+ out_opt.add_options()
+ ("audit,a", "print weights of features")
+ ("predictions,p", po::value< string >(), "File to output predictions to")
+ ("raw_predictions,r", po::value< string >(), "File to output unnormalized predictions to")
+ ("sendto", po::value< vector<string> >(), "send examples to <host>")
+ ("quiet", "Don't output disgnostics and progress updates")
+ ("progress,P", po::value< string >(), "Progress update frequency. int: additive, float: multiplicative")
("binary", "report loss as binary classification on -1,1")
- ("bs", po::value<size_t>(), "bootstrap mode with k rounds by online importance resampling")
- ("top", po::value<size_t>(), "top k recommendation")
- ("bs_type", po::value<string>(), "bootstrap mode - currently 'mean' or 'vote'")
- ("autolink", po::value<size_t>(), "create link function with polynomial d")
+ ("min_prediction", po::value<float>(&(all->sd->min_label)), "Smallest prediction to output")
+ ("max_prediction", po::value<float>(&(all->sd->max_label)), "Largest prediction to output")
+ ;
+
+ po::options_description update_opt("Update options");
+
+ update_opt.add_options()
("sgd", "use regular stochastic gradient descent update.")
+ ("hessian_on", "use second derivative in line search")
+ ("bfgs", "use bfgs optimization")
+ ("mem", po::value<int>(&(all->m)), "memory in bfgs")
+ ("termination", po::value<float>(&(all->rel_threshold)),"Termination threshold")
("adaptive", "use adaptive, individual learning rates.")
("invariant", "use safe/importance aware updates.")
("normalized", "use per feature normalized updates")
("exact_adaptive_norm", "use current default invariant normalized adaptive update rule")
- ("audit,a", "print weights of features")
- ("bit_precision,b", po::value<size_t>(), "number of bits in the feature table")
- ("bfgs", "use bfgs optimization")
- ("cache,c", "Use a cache. The default is <data>.cache")
- ("cache_file", po::value< vector<string> >(), "The location(s) of cache_file.")
- ("compressed", "use gzip format whenever possible. If a cache file is being created, this option creates a compressed cache file. A mixture of raw-text & compressed inputs are supported with autodetection.")
- ("no_stdin", "do not default to reading from stdin")
("conjugate_gradient", "use conjugate gradient based optimization")
- ("csoaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> costs")
- ("wap", po::value<size_t>(), "Use weighted all-pairs multiclass learning with <k> costs")
- ("csoaa_ldf", po::value<string>(), "Use one-against-all multiclass learning with label dependent features. Specify singleline or multiline.")
- ("wap_ldf", po::value<string>(), "Use weighted all-pairs multiclass learning with label dependent features. Specify singleline or multiline.")
- ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs")
("l1", po::value<float>(&(all->l1_lambda)), "l_1 lambda")
("l2", po::value<float>(&(all->l2_lambda)), "l_2 lambda")
- ("data,d", po::value< string >(), "Example Set")
- ("daemon", "persistent daemon mode on port 26542")
- ("num_children", po::value<size_t>(&(all->num_children)), "number of children for persistent daemon mode")
- ("pid_file", po::value< string >(), "Write pid file in persistent daemon mode")
+ ("learning_rate,l", po::value<float>(&(all->eta)), "Set Learning Rate")
+ ("loss_function", po::value<string>()->default_value("squared"), "Specify the loss function to be used, uses squared by default. Currently available ones are squared, classic, hinge, logistic and quantile.")
+ ("quantile_tau", po::value<float>()->default_value(0.5), "Parameter \\tau associated with Quantile loss. Defaults to 0.5")
+ ("power_t", po::value<float>(&(all->power_t)), "t power value")
("decay_learning_rate", po::value<float>(&(all->eta_decay_rate)),
"Set Decay factor for learning_rate between passes")
- ("input_feature_regularizer", po::value< string >(&(all->per_feature_regularizer_input)), "Per feature regularization input file")
+ ("initial_pass_length", po::value<size_t>(&(all->pass_length)), "initial number of examples per pass")
+ ("initial_t", po::value<double>(&((all->sd->t))), "initial t value")
+ ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated. If no initial_regressor given, also used for initial weights.")
+ ;
+
+ po::options_description weight_opt("Weight options");
+
+ weight_opt.add_options()
+ ("bit_precision,b", po::value<size_t>(), "number of bits in the feature table")
+ ("initial_regressor,i", po::value< vector<string> >(), "Initial regressor(s)")
("final_regressor,f", po::value< string >(), "Final regressor")
+ ("initial_weight", po::value<float>(&(all->initial_weight)), "Set all weights to an initial value of 1.")
+ ("random_weights", po::value<bool>(&(all->random_weights)), "make initial weights random")
("readable_model", po::value< string >(), "Output human-readable final regressor with numeric features")
("invert_hash", po::value< string >(), "Output human-readable final regressor with feature names")
- ("hash", po::value< string > (), "how to hash the features. Available options: strings, all")
- ("hessian_on", "use second derivative in line search")
+ ("save_per_pass", "Save the model after every pass over data")
+ ("input_feature_regularizer", po::value< string >(&(all->per_feature_regularizer_input)), "Per feature regularization input file")
+ ("output_feature_regularizer_binary", po::value< string >(&(all->per_feature_regularizer_output)), "Per feature regularization output file")
+ ("output_feature_regularizer_text", po::value< string >(&(all->per_feature_regularizer_text)), "Per feature regularization output file, in text")
+ ;
+
+ po::options_description holdout_opt("Holdout options");
+ holdout_opt.add_options()
("holdout_off", "no holdout data in multiple passes")
("holdout_period", po::value<uint32_t>(&(all->holdout_period)), "holdout period for test only, default 10")
("holdout_after", po::value<uint32_t>(&(all->holdout_after)), "holdout after n training examples, default off (disables holdout_period)")
- ("version","Version information")
+ ("early_terminate", po::value<size_t>(), "Specify the number of passes tolerated when holdout loss doesn't decrease before early termination, default is 3")
+ ;
+
+ po::options_description namespace_opt("Feature namespace options");
+ namespace_opt.add_options()
+ ("hash", po::value< string > (), "how to hash the features. Available options: strings, all")
("ignore", po::value< vector<unsigned char> >(), "ignore namespaces beginning with character <arg>")
("keep", po::value< vector<unsigned char> >(), "keep namespaces beginning with character <arg>")
- ("kill_cache,k", "do not reuse existing cache: create a new one always")
- ("initial_weight", po::value<float>(&(all->initial_weight)), "Set all weights to an initial value of 1.")
- ("initial_regressor,i", po::value< vector<string> >(), "Initial regressor(s)")
- ("feature_mask", po::value< string >(), "Use existing regressor to determine which parameters may be updated. If no initial_regressor given, also used for initial weights.")
- ("initial_pass_length", po::value<size_t>(&(all->pass_length)), "initial number of examples per pass")
- ("initial_t", po::value<double>(&((all->sd->t))), "initial t value")
- ("lda", po::value<size_t>(&(all->lda)), "Run lda with <int> topics")
- ("span_server", po::value<string>(&(all->span_server)), "Location of server for setting up spanning tree")
- ("min_prediction", po::value<float>(&(all->sd->min_label)), "Smallest prediction to output")
- ("max_prediction", po::value<float>(&(all->sd->max_label)), "Largest prediction to output")
- ("mem", po::value<int>(&(all->m)), "memory in bfgs")
- ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units")
("noconstant", "Don't add a constant feature")
- ("noop","do no learning")
- ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels")
- ("ect", po::value<size_t>(), "Use error correcting tournament with <k> labels")
- ("output_feature_regularizer_binary", po::value< string >(&(all->per_feature_regularizer_output)), "Per feature regularization output file")
- ("output_feature_regularizer_text", po::value< string >(&(all->per_feature_regularizer_text)), "Per feature regularization output file, in text")
- ("port", po::value<size_t>(),"port to listen on")
- ("power_t", po::value<float>(&(all->power_t)), "t power value")
- ("learning_rate,l", po::value<float>(&(all->eta)), "Set Learning Rate")
- ("passes", po::value<size_t>(&(all->numpasses)),"Number of Training Passes")
- ("termination", po::value<float>(&(all->rel_threshold)),"Termination threshold")
- ("predictions,p", po::value< string >(), "File to output predictions to")
+ ("constant,C", po::value<float>(&(all->initial_constant)), "Set initial value of constant")
+ ("sort_features", "turn this on to disregard order in which features have been defined. This will lead to smaller cache sizes")
+ ("ngram", po::value< vector<string> >(), "Generate N grams")
+ ("skips", po::value< vector<string> >(), "Generate skips in N grams. This in conjunction with the ngram tag can be used to generate generalized n-skip-k-gram.")
+ ("affix", po::value<string>(), "generate prefixes/suffixes of features; argument '+2a,-3b,+1' means generate 2-char prefixes for namespace a, 3-char suffixes for b and 1 char prefixes for default namespace")
+ ("spelling", po::value< vector<string> >(), "compute spelling features for a give namespace (use '_' for default namespace)");
+ ;
+
+ po::options_description mf_opt("Matrix factorization options");
+ mf_opt.add_options()
("quadratic,q", po::value< vector<string> > (),
"Create and use quadratic features")
("q:", po::value< string >(), ": corresponds to a wildcard for all printable characters")
("cubic", po::value< vector<string> > (),
"Create and use cubic features")
- ("quiet", "Don't output diagnostics")
("rank", po::value<uint32_t>(&(all->rank)), "rank for matrix factorization.")
- ("random_weights", po::value<bool>(&(all->random_weights)), "make initial weights random")
- ("random_seed", po::value<size_t>(&random_seed), "seed random number generator")
- ("raw_predictions,r", po::value< string >(),
- "File to output unnormalized predictions to")
- ("ring_size", po::value<size_t>(&(all->p->ring_size)), "size of example ring")
- ("examples", po::value<size_t>(&(all->max_examples)), "number of examples to parse")
- ("save_per_pass", "Save the model after every pass over data")
- ("early_terminate", po::value<size_t>(), "Specify the number of passes tolerated when holdout loss doesn't decrease before early termination, default is 3")
- ("save_resume", "save extra state so learning can be resumed later with new data")
- ("sendto", po::value< vector<string> >(), "send examples to <host>")
- ("searn", po::value<size_t>(), "use searn, argument=maximum action id or 0 for LDF")
- ("testonly,t", "Ignore label information and just test")
- ("loss_function", po::value<string>()->default_value("squared"), "Specify the loss function to be used, uses squared by default. Currently available ones are squared, classic, hinge, logistic and quantile.")
- ("quantile_tau", po::value<float>()->default_value(0.5), "Parameter \\tau associated with Quantile loss. Defaults to 0.5")
+ ("new_mf", "use new, reduction-based matrix factorization")
+ ;
+ po::options_description lrq_opt("Low Rank Quadratic options");
+ lrq_opt.add_options()
+ ("lrq", po::value<vector<string> > (), "use low rank quadratic features")
+ ("lrqdropout", "use dropout training for low rank quadratic features")
+ ;
+
+ po::options_description multiclass_opt("Multiclass options");
+ multiclass_opt.add_options()
+ ("oaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> labels")
+ ("ect", po::value<size_t>(), "Use error correcting tournament with <k> labels")
+ ("csoaa", po::value<size_t>(), "Use one-against-all multiclass learning with <k> costs")
+ ("wap", po::value<size_t>(), "Use weighted all-pairs multiclass learning with <k> costs")
+ ("csoaa_ldf", po::value<string>(), "Use one-against-all multiclass learning with label dependent features. Specify singleline or multiline.")
+ ("wap_ldf", po::value<string>(), "Use weighted all-pairs multiclass learning with label dependent features. Specify singleline or multiline.")
+ ;
+
+ po::options_description active_opt("Active Learning options");
+ active_opt.add_options()
+ ("active_learning", "active learning mode")
+ ("active_simulation", "active learning simulation mode")
+ ("active_mellowness", po::value<float>(&(all->active_c0)), "active learning mellowness parameter c_0. Default 8")
+ ;
+
+ po::options_description cluster_opt("Parallelization options");
+ cluster_opt.add_options()
+ ("span_server", po::value<string>(&(all->span_server)), "Location of server for setting up spanning tree")
("unique_id", po::value<size_t>(&(all->unique_id)),"unique id used for cluster parallel jobs")
- ("total", po::value<size_t>(&(all->total)),"total number of nodes used in cluster parallel job")
- ("node", po::value<size_t>(&(all->node)),"node number in cluster parallel job")
+ ("total", po::value<size_t>(&(all->total)),"total number of nodes used in cluster parallel job")
+ ("node", po::value<size_t>(&(all->node)),"node number in cluster parallel job")
+ ;
- ("sort_features", "turn this on to disregard order in which features have been defined. This will lead to smaller cache sizes")
- ("ngram", po::value< vector<string> >(), "Generate N grams")
- ("skips", po::value< vector<string> >(), "Generate skips in N grams. This in conjunction with the ngram tag can be used to generate generalized n-skip-k-gram.")
- ("affix", po::value<string>(), "generate prefixes/suffixes of features; argument '+2a,-3b,+1' means generate 2-char prefixes for namespace a, 3-char suffixes for b and 1 char prefixes for default namespace")
- ("spelling", po::value< vector<string> >(), "compute spelling features for a give namespace (use '_' for default namespace)");
+ po::options_description other_opt("Other options");
+ other_opt.add_options()
+ ("bs", po::value<size_t>(), "bootstrap mode with k rounds by online importance resampling")
+ ("top", po::value<size_t>(), "top k recommendation")
+ ("bs_type", po::value<string>(), "bootstrap mode - currently 'mean' or 'vote'")
+ ("autolink", po::value<size_t>(), "create link function with polynomial d")
+ ("cb", po::value<size_t>(), "Use contextual bandit learning with <k> costs")
+ ("lda", po::value<uint32_t>(&(all->lda)), "Run lda with <int> topics")
+ ("nn", po::value<size_t>(), "Use sigmoidal feedforward network with <k> hidden units")
+ ("cbify", po::value<size_t>(), "Convert multiclass on <k> classes into a contextual bandit problem and solve")
+ ("searn", po::value<size_t>(), "use searn, argument=maximum action id or 0 for LDF")
+ ;
+
+ // Declare the supported options.
+ desc.add_options()
+ ("help,h","Look here: http://hunch.net/~vw/ and click on Tutorial.")
+ ("version","Version information")
+ ("random_seed", po::value<size_t>(&random_seed), "seed random number generator")
+ ("noop","do no learning")
+ ("print","print examples");
//po::positional_options_description p;
// Be friendly: if -d was left out, treat positional param as data file
//p.add("data", -1);
+ desc.add(in_opt)
+ .add(out_opt)
+ .add(update_opt)
+ .add(weight_opt)
+ .add(holdout_opt)
+ .add(namespace_opt)
+ .add(mf_opt)
+ .add(lrq_opt)
+ .add(multiclass_opt)
+ .add(active_opt)
+ .add(cluster_opt)
+ .add(other_opt);
+
po::variables_map vm = po::variables_map();
po::variables_map vm_file = po::variables_map(); //separate variable map for storing flags in regressor file
@@ -225,7 +301,7 @@ vw* parse_args(int argc, char *argv[])
po::store(parsed, vm);
po::notify(vm);
-
+
if(all->numpasses > 1)
all->holdout_set_off = false;
@@ -236,7 +312,7 @@ vw* parse_args(int argc, char *argv[])
{
all->holdout_set_off = true;
cerr<<"Making holdout_set_off=true since output regularizer specified\n";
- }
+ }
all->data_filename = "";
@@ -257,11 +333,47 @@ vw* parse_args(int argc, char *argv[])
exit(0);
}
- if (vm.count("quiet"))
+ if (vm.count("quiet")) {
all->quiet = true;
- else
+ // --quiet wins over --progress
+ } else {
all->quiet = false;
+ if (vm.count("progress")) {
+ string progress_str = vm["progress"].as<string>();
+ all->progress_arg = (float)::atof(progress_str.c_str());
+
+ // --progress interval is dual: either integer or floating-point
+ if (progress_str.find_first_of(".") == string::npos) {
+ // No "." in arg: assume integer -> additive
+ all->progress_add = true;
+ if (all->progress_arg < 1) {
+ cerr << "warning: additive --progress <int>"
+ << " can't be < 1: forcing to 1\n";
+ all->progress_arg = 1;
+
+ }
+ all->sd->dump_interval = all->progress_arg;
+
+ } else {
+ // A "." in arg: assume floating-point -> multiplicative
+ all->progress_add = false;
+
+ if (all->progress_arg <= 1.0) {
+ cerr << "warning: multiplicative --progress <float>: "
+ << vm["progress"].as<string>()
+ << " is <= 1.0: adding 1.0\n";
+ all->progress_arg += 1.0;
+
+ } else if (all->progress_arg > 9.0) {
+ cerr << "warning: multiplicative --progress <float>"
+ << " is > 9.0: you probably meant to use an integer\n";
+ }
+ all->sd->dump_interval = 1.0;
+ }
+ }
+ }
+
msrand48(random_seed);
if (vm.count("active_simulation"))
@@ -291,12 +403,12 @@ vw* parse_args(int argc, char *argv[])
}
all->reg.stride = 4; //use stride of 4 for default invariant normalized adaptive updates
- //if we are doing matrix factorization, or user specified anything in sgd,adaptive,invariant,normalized, we turn off default update rules and use whatever user specified
- if( all->rank > 0 || !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) )
+ //if the user specified anything in sgd,adaptive,invariant,normalized, we turn off default update rules and use whatever user specified
+ if( (all->rank > 0 && !vm.count("new_mf")) || !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) )
{
- all->adaptive = all->training && (vm.count("adaptive") && all->rank == 0);
+ all->adaptive = all->training && vm.count("adaptive") && (all->rank == 0 && !vm.count("new_mf"));
all->invariant_updates = all->training && vm.count("invariant");
- all->normalized_updates = all->training && (vm.count("normalized") && all->rank == 0);
+ all->normalized_updates = all->training && vm.count("normalized") && (all->rank == 0 && !vm.count("new_mf"));
all->reg.stride = 1;
@@ -306,7 +418,8 @@ vw* parse_args(int argc, char *argv[])
if( all->normalized_updates ) all->reg.stride *= 2;
if(!vm.count("learning_rate") && !vm.count("l") && !(all->adaptive && all->normalized_updates))
- all->eta = 10; //default learning rate to 10 for non default update rule
+ if (all->lda == 0)
+ all->eta = 10; //default learning rate to 10 for non default update rule
//if not using normalized or adaptive, default initial_t to 1 instead of 0
if(!all->adaptive && !all->normalized_updates && !vm.count("initial_t")) {
@@ -320,15 +433,33 @@ vw* parse_args(int argc, char *argv[])
all->feature_mask_idx = 1;
}
else if(all->reg.stride == 2){
- all->reg.stride *= 2;//if either normalized or adaptive, stride->4, mask_idx is still 3
+ all->reg.stride *= 2;//if either normalized or adaptive, stride->4, mask_idx is still 3
}
}
}
+ if (all->l1_lambda < 0.) {
+ cerr << "l1_lambda should be nonnegative: resetting from " << all->l1_lambda << " to 0" << endl;
+ all->l1_lambda = 0.;
+ }
+ if (all->l2_lambda < 0.) {
+ cerr << "l2_lambda should be nonnegative: resetting from " << all->l2_lambda << " to 0" << endl;
+ all->l2_lambda = 0.;
+ }
+ all->reg_mode += (all->l1_lambda > 0.) ? 1 : 0;
+ all->reg_mode += (all->l2_lambda > 0.) ? 2 : 0;
+ if (!all->quiet)
+ {
+ if (all->reg_mode %2 && !vm.count("bfgs"))
+ cerr << "using l1 regularization = " << all->l1_lambda << endl;
+ if (all->reg_mode > 1)
+ cerr << "using l2 regularization = " << all->l2_lambda << endl;
+ }
+
all->l = GD::setup(*all, vm);
all->scorer = all->l;
- if (vm.count("bfgs") || vm.count("conjugate_gradient"))
+ if (vm.count("bfgs") || vm.count("conjugate_gradient"))
all->l = BFGS::setup(*all, to_pass_further, vm, vm_file);
if (vm.count("version") || argc == 1) {
@@ -356,7 +487,7 @@ vw* parse_args(int argc, char *argv[])
cout << "You can not skip unless ngram is > 1" << endl;
throw exception();
}
-
+
all->skip_strings = vm["skips"].as<vector<string> >();
compile_gram(all->skip_strings, all->skips, (char*)"skips", all->quiet);
}
@@ -368,10 +499,10 @@ vw* parse_args(int argc, char *argv[])
if (vm.count("spelling")) {
vector<string> spelling_ns = vm["spelling"].as< vector<string> >();
for (size_t id=0; id<spelling_ns.size(); id++)
- if (spelling_ns[id][0] == '_') all->spelling_features[' '] = true;
+ if (spelling_ns[id][0] == '_') all->spelling_features[(unsigned char)' '] = true;
else all->spelling_features[(size_t)spelling_ns[id][0]] = true;
}
-
+
if (vm.count("bit_precision"))
{
all->default_bits = false;
@@ -382,7 +513,7 @@ vw* parse_args(int argc, char *argv[])
throw exception();
}
}
-
+
if (vm.count("daemon") || vm.count("pid_file") || (vm.count("port") && !all->active) ) {
all->daemon = true;
@@ -392,7 +523,7 @@ vw* parse_args(int argc, char *argv[])
if (vm.count("compressed"))
set_compressed(all->p);
-
+
if (vm.count("data")) {
all->data_filename = vm["data"].as<string>();
if (ends_with(all->data_filename, ".gz"))
@@ -408,14 +539,14 @@ vw* parse_args(int argc, char *argv[])
{
all->pairs = vm["quadratic"].as< vector<string> >();
vector<string> newpairs;
- //string tmp;
+ //string tmp;
char printable_start = '!';
char printable_end = '~';
int valid_ns_size = printable_end - printable_start - 1; //will skip two characters
if(!all->quiet)
- cerr<<"creating quadratic features for pairs: ";
-
+ cerr<<"creating quadratic features for pairs: ";
+
for (vector<string>::iterator i = all->pairs.begin(); i != all->pairs.end();i++){
if(!all->quiet){
cerr << *i << " ";
@@ -456,7 +587,7 @@ vw* parse_args(int argc, char *argv[])
}
else{
newpairs.push_back(string(*i));
- }
+ }
}
newpairs.swap(all->pairs);
if(!all->quiet)
@@ -540,8 +671,8 @@ vw* parse_args(int argc, char *argv[])
}
}
- // matrix factorization enabled
- if (all->rank > 0) {
+ // (non-reduction) matrix factorization enabled
+ if (!vm.count("new_mf") && all->rank > 0) {
// store linear + 2*rank weights per index, round up to power of two
float temp = ceilf(logf((float)(all->rank*2+1)) / logf (2.f));
all->reg.stride = 1 << (int) temp;
@@ -582,21 +713,21 @@ vw* parse_args(int argc, char *argv[])
//if (vm.count("nonormalize"))
// all->nonormalize = true;
- if (vm.count("lda"))
+ if (vm.count("lda"))
all->l = LDA::setup(*all, to_pass_further, vm);
- if (!vm.count("lda") && !all->adaptive && !all->normalized_updates)
+ if (!vm.count("lda") && !all->adaptive && !all->normalized_updates)
all->eta *= powf((float)(all->sd->t), all->power_t);
-
+
if (vm.count("readable_model"))
all->text_regressor_name = vm["readable_model"].as<string>();
if (vm.count("invert_hash")){
all->inv_hash_regressor_name = vm["invert_hash"].as<string>();
- all->hash_inv = true;
+ all->hash_inv = true;
}
-
+
if (vm.count("save_per_pass"))
all->save_per_pass = true;
@@ -619,10 +750,16 @@ vw* parse_args(int argc, char *argv[])
if(vm.count("quantile_tau"))
loss_parameter = vm["quantile_tau"].as<float>();
- if (vm.count("noop"))
+ if (vm.count("noop"))
all->l = NOOP::setup(*all);
-
- if (all->rank != 0)
+
+ if (vm.count("print"))
+ {
+ all->l = PRINT::setup(*all);
+ all->reg.stride = 1;
+ }
+
+ if (!vm.count("new_mf") && all->rank > 0)
all->l = GDMF::setup(*all);
all->loss = getLossFunction(all, loss_function, (float)loss_parameter);
@@ -713,37 +850,27 @@ vw* parse_args(int argc, char *argv[])
io_temp.close_file();
}
- if (all->l1_lambda < 0.) {
- cerr << "l1_lambda should be nonnegative: resetting from " << all->l1_lambda << " to 0" << endl;
- all->l1_lambda = 0.;
- }
- if (all->l2_lambda < 0.) {
- cerr << "l2_lambda should be nonnegative: resetting from " << all->l2_lambda << " to 0" << endl;
- all->l2_lambda = 0.;
- }
- all->reg_mode += (all->l1_lambda > 0.) ? 1 : 0;
- all->reg_mode += (all->l2_lambda > 0.) ? 2 : 0;
- if (!all->quiet)
- {
- if (all->reg_mode %2 && !vm.count("bfgs"))
- cerr << "using l1 regularization = " << all->l1_lambda << endl;
- if (all->reg_mode > 1)
- cerr << "using l2 regularization = " << all->l2_lambda << endl;
- }
-
bool got_mc = false;
bool got_cs = false;
bool got_cb = false;
- if(vm.count("nn") || vm_file.count("nn") )
+ if(vm.count("nn") || vm_file.count("nn") )
all->l = NN::setup(*all, to_pass_further, vm, vm_file);
- if(vm.count("autolink") || vm_file.count("autolink") )
+ if (vm.count("new_mf") && all->rank > 0)
+ all->l = MF::setup(*all, vm);
+
+ if(vm.count("autolink") || vm_file.count("autolink") )
all->l = ALINK::setup(*all, to_pass_further, vm, vm_file);
- if(vm.count("top") || vm_file.count("top") )
+ if (vm.count("lrq") || vm_file.count("lrq"))
+ all->l = LRQ::setup(*all, to_pass_further, vm, vm_file);
+
+ all->l = Scorer::setup(*all, to_pass_further, vm, vm_file);
+
+ if(vm.count("top") || vm_file.count("top") )
all->l = TOPK::setup(*all, to_pass_further, vm, vm_file);
-
+
if (vm.count("binary") || vm_file.count("binary"))
all->l = BINARY::setup(*all, to_pass_further, vm, vm_file);
@@ -753,7 +880,7 @@ vw* parse_args(int argc, char *argv[])
all->l = OAA::setup(*all, to_pass_further, vm, vm_file);
got_mc = true;
}
-
+
if (vm.count("ect") || vm_file.count("ect") ) {
if (got_mc) { cerr << "error: cannot specify multiple MC learners" << endl; throw exception(); }
@@ -763,15 +890,17 @@ vw* parse_args(int argc, char *argv[])
if(vm.count("csoaa") || vm_file.count("csoaa") ) {
if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); }
-
+
all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file);
+ all->cost_sensitive = all->l;
got_cs = true;
}
if(vm.count("wap") || vm_file.count("wap") ) {
if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); }
-
+
all->l = WAP::setup(*all, to_pass_further, vm, vm_file);
+ all->cost_sensitive = all->l;
got_cs = true;
}
@@ -779,6 +908,7 @@ vw* parse_args(int argc, char *argv[])
if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); }
all->l = CSOAA_AND_WAP_LDF::setup(*all, to_pass_further, vm, vm_file);
+ all->cost_sensitive = all->l;
got_cs = true;
}
@@ -786,6 +916,7 @@ vw* parse_args(int argc, char *argv[])
if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); }
all->l = CSOAA_AND_WAP_LDF::setup(*all, to_pass_further, vm, vm_file);
+ all->cost_sensitive = all->l;
got_cs = true;
}
@@ -796,6 +927,7 @@ vw* parse_args(int argc, char *argv[])
else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cb"]));
all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified
+ all->cost_sensitive = all->l;
got_cs = true;
}
@@ -803,16 +935,37 @@ vw* parse_args(int argc, char *argv[])
got_cb = true;
}
- all->searnstr = NULL;
- if (vm.count("searn") || vm_file.count("searn") ) {
+ if (vm.count("cbify") || vm_file.count("cbify"))
+ {
+ if(!got_cs) {
+ if( vm_file.count("cbify") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["cbify"]));
+ else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["cbify"]));
+
+ all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless wap is specified
+ all->cost_sensitive = all->l;
+ got_cs = true;
+ }
+
+ if (!got_cb) {
+ if( vm_file.count("cbify") ) vm.insert(pair<string,po::variable_value>(string("cb"),vm_file["cbify"]));
+ else vm.insert(pair<string,po::variable_value>(string("cb"),vm["cbify"]));
+ all->l = CB::setup(*all, to_pass_further, vm, vm_file);
+ got_cb = true;
+ }
+
+ all->l = CBIFY::setup(*all, to_pass_further, vm, vm_file);
+ }
+
+ if (vm.count("searn") || vm_file.count("searn") ) {
if (!got_cs && !got_cb) {
if( vm_file.count("searn") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["searn"]));
else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["searn"]));
-
+
all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless others have been specified
+ all->cost_sensitive = all->l;
got_cs = true;
}
- all->searnstr = (Searn::searn*)calloc(1, sizeof(Searn::searn));
+ //all->searnstr = (Searn::searn*)calloc(1, sizeof(Searn::searn));
all->l = Searn::setup(*all, to_pass_further, vm, vm_file);
}
@@ -821,7 +974,7 @@ vw* parse_args(int argc, char *argv[])
throw exception();
}
- if(vm.count("bs") || vm_file.count("bs") )
+ if(vm.count("bs") || vm_file.count("bs") )
all->l = BS::setup(*all, to_pass_further, vm, vm_file);
if (to_pass_further.size() > 0) {
@@ -857,9 +1010,9 @@ vw* parse_args(int argc, char *argv[])
parse_source_args(*all, vm, all->quiet,all->numpasses);
- // force stride * weights_per_problem to be a power of 2 to avoid 32-bit overflow
+ // force wpp to be a power of 2 to avoid 32-bit overflow
uint32_t i = 0;
- size_t params_per_problem = all->l->increment * all->l->weights;
+ size_t params_per_problem = all->l->increment;
while (params_per_problem > (uint32_t)(1 << i))
i++;
all->wpp = (1 << i) / all->reg.stride;
@@ -880,7 +1033,7 @@ namespace VW {
}
else {
//flag is present, need to replace old value with new value
-
+
//compute position after flag_to_replace
pos += flag_to_replace.size();
@@ -908,7 +1061,7 @@ namespace VW {
v_array<substring> foo;
foo.end_array = foo.begin = foo.end = NULL;
tokenize(' ', ss, foo);
-
+
char** argv = (char**)calloc(foo.size(), sizeof(char*));
for (size_t i = 0; i < foo.size(); i++)
{
@@ -922,13 +1075,13 @@ namespace VW {
foo.delete_v();
return argv;
}
-
+
vw* initialize(string s)
{
int argc = 0;
s += " --no_stdin";
char** argv = get_argv_from_string(s,argc);
-
+
vw* all = parse_args(argc, argv);
initialize_parser_datastructures(*all);
@@ -944,11 +1097,11 @@ namespace VW {
{
finalize_regressor(all, all.final_regressor_name);
all.l->finish();
+ delete all.l;
if (all.reg.weight_vector != NULL)
free(all.reg.weight_vector);
free_parser(all);
finalize_source(all.p);
- free(all.p->lp);
all.p->parse_name.erase();
all.p->parse_name.delete_v();
free(all.p);
@@ -958,11 +1111,7 @@ namespace VW {
free(all.options_from_file_argv);
for (size_t i = 0; i < all.final_prediction_sink.size(); i++)
if (all.final_prediction_sink[i] != 1)
-#ifdef _WIN32
- _close(all.final_prediction_sink[i]);
-#else
- close(all.final_prediction_sink[i]);
-#endif
+ io_buf::close_file_or_socket(all.final_prediction_sink[i]);
all.final_prediction_sink.delete_v();
delete all.loss;
delete &all;