diff options
author | Jake Hofman <jhofman@gmail> | 2014-01-08 05:25:55 +0400 |
---|---|---|
committer | Jake Hofman <jhofman@gmail> | 2014-01-08 05:25:55 +0400 |
commit | de91f21a827ae6e53e239a26edcc53f632b49c4f (patch) | |
tree | 3fa09692c7cf0b702b783f2dfa1d707509fc1d06 /vowpalwabbit/parse_args.cc | |
parent | a73c38db84b44307e8efc4a7c77fecba18ad2932 (diff) |
removed precomputed_prediction in favor of base.update
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r-- | vowpalwabbit/parse_args.cc | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index e369d975..82809fe7 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -212,6 +212,7 @@ vw* parse_args(int argc, char *argv[]) ("cubic", po::value< vector<string> > (), "Create and use cubic features") ("rank", po::value<uint32_t>(&(all->rank)), "rank for matrix factorization.") + ("new_mf", "use new, reduction-based matrix factorization") ; po::options_description lrq_opt("Low Rank Quadratic options"); @@ -398,12 +399,12 @@ vw* parse_args(int argc, char *argv[]) } all->reg.stride = 4; //use stride of 4 for default invariant normalized adaptive updates - //if we are doing matrix factorization, or user specified anything in sgd,adaptive,invariant,normalized, we turn off default update rules and use whatever user specified - if( all->rank > 0 || !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) ) + //if the user specified anything in sgd,adaptive,invariant,normalized, we turn off default update rules and use whatever user specified + if( (all->rank > 0 && !vm.count("new_mf")) || !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) ) { - all->adaptive = all->training && (vm.count("adaptive") && all->rank == 0); + all->adaptive = all->training && vm.count("adaptive") && (all->rank == 0 && !vm.count("new_mf")); all->invariant_updates = all->training && vm.count("invariant"); - all->normalized_updates = all->training && (vm.count("normalized") && all->rank == 0); + all->normalized_updates = all->training && vm.count("normalized") && (all->rank == 0 && !vm.count("new_mf")); all->reg.stride = 1; @@ -665,8 +666,8 @@ vw* parse_args(int argc, char *argv[]) } } - // matrix factorization enabled - if (all->rank > 0) { + // (non-reduction) matrix factorization enabled + if (!vm.count("new_mf") && all->rank > 0) { // store linear + 2*rank weights per index, round up to power of two float temp = ceilf(logf((float)(all->rank*2+1)) / logf (2.f)); all->reg.stride = 1 << (int) temp; @@ -747,7 +748,7 @@ vw* parse_args(int argc, char *argv[]) if (vm.count("noop")) all->l = NOOP::setup(*all); - if (all->rank != 0) + if (!vm.count("new_mf") && all->rank > 0) all->l = GDMF::setup(*all); all->loss = getLossFunction(all, loss_function, (float)loss_parameter); @@ -845,8 +846,8 @@ vw* parse_args(int argc, char *argv[]) if(vm.count("nn") || vm_file.count("nn") ) all->l = NN::setup(*all, to_pass_further, vm, vm_file); - // if (all->rank != 0) - // all->l = MF::setup(*all); + if (vm.count("new_mf") && all->rank > 0) + all->l = MF::setup(*all, vm); if(vm.count("autolink") || vm_file.count("autolink") ) all->l = ALINK::setup(*all, to_pass_further, vm, vm_file); |