diff options
author | ariel faigon <github.2009@yendor.com> | 2013-12-26 23:17:07 +0400 |
---|---|---|
committer | ariel faigon <github.2009@yendor.com> | 2013-12-26 23:17:07 +0400 |
commit | 9e0d9e19e6e8cf5a8e2392f723d8497984be8112 (patch) | |
tree | c77e24a7a72c9115d2d8a655d10547ba0ccd6ac5 /vowpalwabbit/parse_args.cc | |
parent | 1097df5bffa4f9cf809c5cfc2f7637eb88275b0b (diff) |
Clean up to reduce coupling: move initial_constant stuff into gd.cc
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r-- | vowpalwabbit/parse_args.cc | 88 |
1 files changed, 44 insertions, 44 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index aad2eae6..f386734c 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -75,7 +75,7 @@ void parse_affix_argument(vw&all, string str) { if (q[1] != 0) { if (valid_ns(q[1])) ns = (uint16_t)q[1]; - else { + else { cerr << "malformed affix argument (invalid namespace): " << p << endl; throw exception(); } @@ -88,17 +88,17 @@ void parse_affix_argument(vw&all, string str) { uint16_t afx = (len << 1) | (prefix & 0x1); all.affix_features[ns] <<= 4; all.affix_features[ns] |= afx; - + p = strtok(NULL, ","); } - + delete cstr; } vw* parse_args(int argc, char *argv[]) { po::options_description desc("VW options"); - + vw* all = new vw(); size_t random_seed = 0; @@ -135,18 +135,18 @@ vw* parse_args(int argc, char *argv[]) ("quiet", "Don't output diagnostics") ("binary", "report loss as binary classification on -1,1") ("min_prediction", po::value<float>(&(all->sd->min_label)), "Smallest prediction to output") - ("max_prediction", po::value<float>(&(all->sd->max_label)), "Largest prediction to output") + ("max_prediction", po::value<float>(&(all->sd->max_label)), "Largest prediction to output") ; po::options_description update_opt("Update options"); - + update_opt.add_options() ("sgd", "use regular stochastic gradient descent update.") ("hessian_on", "use second derivative in line search") ("bfgs", "use bfgs optimization") ("mem", po::value<int>(&(all->m)), "memory in bfgs") ("termination", po::value<float>(&(all->rel_threshold)),"Termination threshold") - ("adaptive", "use adaptive, individual learning rates.") + ("adaptive", "use adaptive, individual learning rates.") ("invariant", "use safe/importance aware updates.") ("normalized", "use per feature normalized updates") ("exact_adaptive_norm", "use current default invariant normalized adaptive update rule") @@ -233,7 +233,7 @@ vw* parse_args(int argc, char *argv[]) cluster_opt.add_options() ("span_server", po::value<string>(&(all->span_server)), "Location of server for setting up spanning tree") ("unique_id", po::value<size_t>(&(all->unique_id)),"unique id used for cluster parallel jobs") - ("total", po::value<size_t>(&(all->total)),"total number of nodes used in cluster parallel job") + ("total", po::value<size_t>(&(all->total)),"total number of nodes used in cluster parallel job") ("node", po::value<size_t>(&(all->node)),"node number in cluster parallel job") ; @@ -287,7 +287,7 @@ vw* parse_args(int argc, char *argv[]) po::store(parsed, vm); po::notify(vm); - + if(all->numpasses > 1) all->holdout_set_off = false; @@ -298,7 +298,7 @@ vw* parse_args(int argc, char *argv[]) { all->holdout_set_off = true; cerr<<"Making holdout_set_off=true since output regularizer specified\n"; - } + } all->data_filename = ""; @@ -382,7 +382,7 @@ vw* parse_args(int argc, char *argv[]) all->feature_mask_idx = 1; } else if(all->reg.stride == 2){ - all->reg.stride *= 2;//if either normalized or adaptive, stride->4, mask_idx is still 3 + all->reg.stride *= 2;//if either normalized or adaptive, stride->4, mask_idx is still 3 } } } @@ -390,7 +390,7 @@ vw* parse_args(int argc, char *argv[]) all->l = GD::setup(*all, vm); all->scorer = all->l; - if (vm.count("bfgs") || vm.count("conjugate_gradient")) + if (vm.count("bfgs") || vm.count("conjugate_gradient")) all->l = BFGS::setup(*all, to_pass_further, vm, vm_file); if (vm.count("version") || argc == 1) { @@ -418,7 +418,7 @@ vw* parse_args(int argc, char *argv[]) cout << "You can not skip unless ngram is > 1" << endl; throw exception(); } - + all->skip_strings = vm["skips"].as<vector<string> >(); compile_gram(all->skip_strings, all->skips, (char*)"skips", all->quiet); } @@ -433,7 +433,7 @@ vw* parse_args(int argc, char *argv[]) if (spelling_ns[id][0] == '_') all->spelling_features[' '] = true; else all->spelling_features[(size_t)spelling_ns[id][0]] = true; } - + if (vm.count("bit_precision")) { all->default_bits = false; @@ -444,7 +444,7 @@ vw* parse_args(int argc, char *argv[]) throw exception(); } } - + if (vm.count("daemon") || vm.count("pid_file") || (vm.count("port") && !all->active) ) { all->daemon = true; @@ -454,7 +454,7 @@ vw* parse_args(int argc, char *argv[]) if (vm.count("compressed")) set_compressed(all->p); - + if (vm.count("data")) { all->data_filename = vm["data"].as<string>(); if (ends_with(all->data_filename, ".gz")) @@ -470,14 +470,14 @@ vw* parse_args(int argc, char *argv[]) { all->pairs = vm["quadratic"].as< vector<string> >(); vector<string> newpairs; - //string tmp; + //string tmp; char printable_start = '!'; char printable_end = '~'; int valid_ns_size = printable_end - printable_start - 1; //will skip two characters if(!all->quiet) - cerr<<"creating quadratic features for pairs: "; - + cerr<<"creating quadratic features for pairs: "; + for (vector<string>::iterator i = all->pairs.begin(); i != all->pairs.end();i++){ if(!all->quiet){ cerr << *i << " "; @@ -518,7 +518,7 @@ vw* parse_args(int argc, char *argv[]) } else{ newpairs.push_back(string(*i)); - } + } } newpairs.swap(all->pairs); if(!all->quiet) @@ -644,21 +644,21 @@ vw* parse_args(int argc, char *argv[]) //if (vm.count("nonormalize")) // all->nonormalize = true; - if (vm.count("lda")) + if (vm.count("lda")) all->l = LDA::setup(*all, to_pass_further, vm); - if (!vm.count("lda") && !all->adaptive && !all->normalized_updates) + if (!vm.count("lda") && !all->adaptive && !all->normalized_updates) all->eta *= powf((float)(all->sd->t), all->power_t); - + if (vm.count("readable_model")) all->text_regressor_name = vm["readable_model"].as<string>(); if (vm.count("invert_hash")){ all->inv_hash_regressor_name = vm["invert_hash"].as<string>(); - all->hash_inv = true; + all->hash_inv = true; } - + if (vm.count("save_per_pass")) all->save_per_pass = true; @@ -681,10 +681,10 @@ vw* parse_args(int argc, char *argv[]) if(vm.count("quantile_tau")) loss_parameter = vm["quantile_tau"].as<float>(); - if (vm.count("noop")) + if (vm.count("noop")) all->l = NOOP::setup(*all); - - if (all->rank != 0) + + if (all->rank != 0) all->l = GDMF::setup(*all); all->loss = getLossFunction(all, loss_function, (float)loss_parameter); @@ -797,15 +797,15 @@ vw* parse_args(int argc, char *argv[]) bool got_cs = false; bool got_cb = false; - if(vm.count("nn") || vm_file.count("nn") ) + if(vm.count("nn") || vm_file.count("nn") ) all->l = NN::setup(*all, to_pass_further, vm, vm_file); - if(vm.count("autolink") || vm_file.count("autolink") ) + if(vm.count("autolink") || vm_file.count("autolink") ) all->l = ALINK::setup(*all, to_pass_further, vm, vm_file); - if(vm.count("top") || vm_file.count("top") ) + if(vm.count("top") || vm_file.count("top") ) all->l = TOPK::setup(*all, to_pass_further, vm, vm_file); - + if (vm.count("binary") || vm_file.count("binary")) all->l = BINARY::setup(*all, to_pass_further, vm, vm_file); @@ -815,7 +815,7 @@ vw* parse_args(int argc, char *argv[]) all->l = OAA::setup(*all, to_pass_further, vm, vm_file); got_mc = true; } - + if (vm.count("ect") || vm_file.count("ect") ) { if (got_mc) { cerr << "error: cannot specify multiple MC learners" << endl; throw exception(); } @@ -825,14 +825,14 @@ vw* parse_args(int argc, char *argv[]) if(vm.count("csoaa") || vm_file.count("csoaa") ) { if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); } - + all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); got_cs = true; } if(vm.count("wap") || vm_file.count("wap") ) { if (got_cs) { cerr << "error: cannot specify multiple CS learners" << endl; throw exception(); } - + all->l = WAP::setup(*all, to_pass_further, vm, vm_file); got_cs = true; } @@ -881,16 +881,16 @@ vw* parse_args(int argc, char *argv[]) all->l = CB::setup(*all, to_pass_further, vm, vm_file); got_cb = true; } - + all->l = CBIFY::setup(*all, to_pass_further, vm, vm_file); } all->searnstr = NULL; - if (vm.count("searn") || vm_file.count("searn") ) { + if (vm.count("searn") || vm_file.count("searn") ) { if (!got_cs && !got_cb) { if( vm_file.count("searn") ) vm.insert(pair<string,po::variable_value>(string("csoaa"),vm_file["searn"])); else vm.insert(pair<string,po::variable_value>(string("csoaa"),vm["searn"])); - + all->l = CSOAA::setup(*all, to_pass_further, vm, vm_file); // default to CSOAA unless others have been specified got_cs = true; } @@ -903,7 +903,7 @@ vw* parse_args(int argc, char *argv[]) throw exception(); } - if(vm.count("bs") || vm_file.count("bs") ) + if(vm.count("bs") || vm_file.count("bs") ) all->l = BS::setup(*all, to_pass_further, vm, vm_file); if (to_pass_further.size() > 0) { @@ -962,7 +962,7 @@ namespace VW { } else { //flag is present, need to replace old value with new value - + //compute position after flag_to_replace pos += flag_to_replace.size(); @@ -990,7 +990,7 @@ namespace VW { v_array<substring> foo; foo.end_array = foo.begin = foo.end = NULL; tokenize(' ', ss, foo); - + char** argv = (char**)calloc(foo.size(), sizeof(char*)); for (size_t i = 0; i < foo.size(); i++) { @@ -1004,15 +1004,15 @@ namespace VW { foo.delete_v(); return argv; } - + vw* initialize(string s) { int argc = 0; s += " --no_stdin"; char** argv = get_argv_from_string(s,argc); - + vw* all = parse_args(argc, argv); - + initialize_examples(*all); for(int i = 0; i < argc; i++) |