Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorariel faigon <github.2009@yendor.com>2013-12-26 23:17:07 +0400
committerariel faigon <github.2009@yendor.com>2013-12-26 23:17:07 +0400
commit9e0d9e19e6e8cf5a8e2392f723d8497984be8112 (patch)
treec77e24a7a72c9115d2d8a655d10547ba0ccd6ac5 /vowpalwabbit/parse_args.cc
parent1097df5bffa4f9cf809c5cfc2f7637eb88275b0b (diff)
Clean up to reduce coupling: move initial_constant stuff into gd.cc
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r--vowpalwabbit/parse_args.cc88
1 files changed, 44 insertions, 44 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index aad2eae6..f386734c 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -75,7 +75,7 @@ void parse_affix_argument(vw&all, string str) {
if (q[1] != 0) {
if (valid_ns(q[1]))
ns = (uint16_t)q[1];
- else {
+ else {
cerr << "malformed affix argument (invalid namespace): " << p << endl;
throw exception();
}
@@ -88,17 +88,17 @@ void parse_affix_argument(vw&all, string str) {
uint16_t afx = (len << 1) | (prefix & 0x1);
all.affix_features[ns] <<= 4;
all.affix_features[ns] |= afx;
-
+
p = strtok(NULL, ",");
}
-
+
delete cstr;
}
vw* parse_args(int argc, char *argv[])
{
po::options_description desc("VW options");
-
+
vw* all = new vw();
size_t random_seed = 0;
@@ -135,18 +135,18 @@ vw* parse_args(int argc, char *argv[])
("quiet", "Don't output diagnostics")
("binary", "report loss as binary classification on -1,1")
("min_prediction", po::value<float>(&(all->sd->min_label)), "Smallest prediction to output")
- ("max_prediction", po::value<float>(&(all->sd->max_label)), "Largest prediction to output")
+ ("max_prediction", po::value<float>(&(all->sd->max_label)), "Largest prediction to output")
;
po::options_description update_opt("Update options");
-
+
update_opt.add_options()
("sgd", "use regular stochastic gradient descent update.")
("hessian_on", "use second derivative in line search")
("bfgs", "use bfgs optimization")
("mem", po::value<int>(&(all->m)), "memory in bfgs")
("termination", po::value<float>(&(all->rel_threshold)),"Termination threshold")
- ("adaptive", "use adaptive, individual learning rates.")
+ ("adaptive", "use adaptive, individual learning rates.")
("invariant", "use safe/importance aware updates.")
("normalized", "use per feature normalized updates")
("exact_adaptive_norm", "use current default invariant normalized adaptive update rule")
@@ -233,7 +233,7 @@ vw* parse_args(int argc, char *argv[])
cluster_opt.add_options()
("span_server", po::value<string>(&(all->span_server)), "Location of server for setting up spanning tree")
("unique_id", po::value<size_t>(&(all->unique_id)),"unique id used for cluster parallel jobs")
- ("total", po::value<size_t>(&(all->total)),"total number of nodes used in cluster parallel job")
+ ("total", po::value<size_t>(&(all->total)),"total number of nodes used in cluster parallel job")
("node", po::value<size_t>(&(all->node)),"node number in cluster parallel job")
;
@@ -287,7 +287,7 @@ vw* parse_args(int argc, char *argv[])
po::store(parsed, vm);
po::notify(vm);
-
+
if(all->numpasses > 1)
all->holdout_set_off = false;
@@ -298,7 +298,7 @@ vw* parse_args(int argc, char *argv[])
{
all->holdout_set_off = true;
cerr<<"Making holdout_set_off=true since output regularizer specified\n";
- }
+ }
all->data_filename = "";
@@ -382,7 +382,7 @@ vw* parse_args(int argc, char *argv[])
all->feature_mask_idx = 1;
}
else if(all->reg.stride == 2){
- all->reg.stride *= 2;//if either normalized or adaptive, stride->4, mask_idx is still 3
+ all->reg.stride *= 2;//if either normalized or adaptive, stride->4, mask_idx is still 3
}
}
}
@@ -390,7 +390,7 @@ vw* parse_args(int argc, char *argv[])
all->l = GD::setup(*all, vm);
all->scorer = all->l;
- if (vm.count("bfgs") || vm.count("conjugate_gradient"))
+ if (vm.count("bfgs") || vm.count("conjugate_gradient"))
all->l = BFGS::setup(*all, to_pass_further, vm, vm_file);
if (vm.count("version") || argc == 1) {
@@ -418,7 +418,7 @@ vw* parse_args(int argc, char *argv[])
cout << "You can not skip unless ngram is > 1" << endl;
throw exception();
}
-
+
all->skip_strings = vm["skips"].as<vector<string> >();
compile_gram(all->skip_strings, all->skips, (char*)"skips", all->quiet);
}
@@ -433,7 +433,7 @@ vw* parse_args(int argc, char *argv[])
if (spelling_ns[id][0] == '_') all->spelling_features[' '] = true;
else all->spelling_features[(size_t)spelling_ns[id][0]] = true;
}
-
+
if (vm.count("bit_precision"))
{
all->default_bits = false;
@@ -444,7 +444,7 @@ vw* parse_args(int argc, char *argv[])
throw exception();
}
}
-
+
if (vm.count("daemon") || vm.count("pid_file") || (vm.count("port") && !all->active) ) {
all->daemon = true;
@@ -454,7 +454,7 @@ vw* parse_args(int argc, char *argv[])
if (vm.count("compressed"))
set_compressed(all->p);
-
+
if (vm.count("data")) {
all->data_filename = vm["data"].as<string>();
if (ends_with(all->data_filename, ".gz"))
@@ -470,14 +470,14 @@ vw* parse_args(int argc, char *argv[])
{
all->pairs = vm["quadratic"].as< vector<string> >();
vector<string> newpairs;
- //string tmp;
+ //string tmp;
char printable_start = '!';
char printable_end = '~';
int valid_ns_size = printable_end - printable_start - 1; //will skip two characters
if(!all->quiet)
- cerr<<"creating quadratic features for pairs: ";
-
+ cerr<<"creating quadratic features for pairs: ";
+
for (vector<string>::iterator i = all->pairs.begin(); i != all->pairs.end();i++){
if(!all->quiet){
cerr << *i << " ";
@@ -518,7 +518,7 @@ vw* parse_args(int argc, char *argv[])
}
else{
newpairs.push_back(string(*i));
- }
+ }
}
newpairs.swap(all->pairs);
if(!all->quiet)
@@ -644,21 +644,21 @@ vw* parse_args(int argc, char *argv[])
//if (vm.count("nonormalize"))
// all->nonormalize = true;
- if (vm.count("lda"))
+ if (vm.count("lda"))
all->l = LDA::setup(*all, to_pass_further, vm);
- if (!vm.count("lda") && !all->adaptive && !all->normalized_updates)
+ if (!vm.count("lda") && !all->adaptive && !all->normalized_updates)
all->eta *= powf((float)(all->sd->t), all->power_t);
-
+
if (vm.count("readable_model"))
all->text_regressor_name = vm["readable_model"].as<string>();
if (vm.count("invert_hash")){
all->inv_hash_regressor_name = vm["invert_hash"].as<string>();
- all->hash_inv = true;
+ all->hash_inv = true;
}
-
+
if (vm.count("save_per_pass"))
all->save_per_pass = true;
@@ -681,10 +681,10 @@ vw* parse_args(int argc, char *argv[])
if(vm.count("quantile_tau"))
loss_parameter = vm["quantile_tau"].as<float>();
- if (vm.count("noop"))
+ if (vm.count("noop"))
all->l = NOOP::setup(*all);
-
- if (all->rank != 0)
+
+ if (all->rank != 0)
all->l = GDMF::setup(*all);
all->loss = getLossFunction(all, loss_function, (float)loss_parameter);
@@ -797,15 +797,15 @@ vw* parse_args(int argc, char *argv[])
bool got_cs = false;
bool got_cb = false;
- if(vm.count("nn") || vm_file.count("nn") )
+ if(vm.count("nn") || vm_file.count("nn") )
all->l = NN::setup(*all, to_pass_further, vm, vm_file);
- if(vm.count("autolink") || vm_file.count("autolink") )
+ if(vm.count("autolink") || vm_file.count("autolink") )
all->l = ALINK::setup(*all, to_pass_further, vm, vm_file);
- if(vm.count("top") || vm_file.count("top") )
+ if(vm.count("top") || vm_file.count("top") )
all->l = TOPK::setup(*all, to_pass_further, vm, vm_file);
-
+
if (vm.count("binary") || vm_file.count("binary"))
all->l = BINARY::setup(*all, to_pass_further, vm, vm_file);
@@ -815,7 +815,7 @@ vw* parse_args(int argc, char *argv[])
all->l = OAA::setup(*all, to_pass_further, vm, vm_file);
got_mc = true;
}
-
+
if (vm.count("ect") || vm_file.count("ect") ) {
if (got_mc) { cerr << "error: cannot specify multiple MC learners" << endl; throw exception(); }
@@ -825,14 +825,14 @@ vw* parse_args(int argc, char *argv[])
if(vm.count("csoaa") || vm_file.count("csoaa") ) {
if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); }
-
+
all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file);
got_cs = true;
}
if(vm.count("wap") || vm_file.count("wap") ) {
if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); }
-
+
all->l = WAP::setup(*all, to_pass_further, vm, vm_file);
got_cs = true;
}
@@ -881,16 +881,16 @@ vw* parse_args(int argc, char *argv[])
all->l = CB::setup(*all, to_pass_further, vm, vm_file);
got_cb = true;
}
-
+
all->l = CBIFY::setup(*all, to_pass_further, vm, vm_file);
}
all->searnstr = NULL;
- if (vm.count("searn") || vm_file.count("searn") ) {
+ if (vm.count("searn") || vm_file.count("searn") ) {
if (!got_cs && !got_cb) {
if( vm_file.count("searn") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["searn"]));
else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["searn"]));
-
+
all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless others have been specified
got_cs = true;
}
@@ -903,7 +903,7 @@ vw* parse_args(int argc, char *argv[])
throw exception();
}
- if(vm.count("bs") || vm_file.count("bs") )
+ if(vm.count("bs") || vm_file.count("bs") )
all->l = BS::setup(*all, to_pass_further, vm, vm_file);
if (to_pass_further.size() > 0) {
@@ -962,7 +962,7 @@ namespace VW {
}
else {
//flag is present, need to replace old value with new value
-
+
//compute position after flag_to_replace
pos += flag_to_replace.size();
@@ -990,7 +990,7 @@ namespace VW {
v_array<substring> foo;
foo.end_array = foo.begin = foo.end = NULL;
tokenize(' ', ss, foo);
-
+
char** argv = (char**)calloc(foo.size(), sizeof(char*));
for (size_t i = 0; i < foo.size(); i++)
{
@@ -1004,15 +1004,15 @@ namespace VW {
foo.delete_v();
return argv;
}
-
+
vw* initialize(string s)
{
int argc = 0;
s += " --no_stdin";
char** argv = get_argv_from_string(s,argc);
-
+
vw* all = parse_args(argc, argv);
-
+
initialize_examples(*all);
for(int i = 0; i < argc; i++)