diff options
author | Jake Hofman <jhofman@gmail.com> | 2013-12-28 02:03:42 +0400 |
---|---|---|
committer | Jake Hofman <jhofman@gmail.com> | 2013-12-28 02:03:42 +0400 |
commit | cedc01406d6aca7fe5c5236b95804208ebbaa341 (patch) | |
tree | 4d9b8b4c491e59bb441f0530955f994afe03003f /vowpalwabbit/parse_args.cc | |
parent | 43d4d21ad3c02c0e3cdc052adaf4baa7136b57d9 (diff) |
disabled mf in favor of gd_mf for time being
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r-- | vowpalwabbit/parse_args.cc | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index fa98503f..a81bebc2 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -27,6 +27,7 @@ license as described in the file LICENSE. #include "bfgs.h" #include "lda_core.h" #include "noop.h" +#include "gd_mf.h" #include "mf.h" #include "vw.h" #include "rand48.h" @@ -354,11 +355,11 @@ vw* parse_args(int argc, char *argv[]) all->reg.stride = 4; //use stride of 4 for default invariant normalized adaptive updates //if we are doing matrix factorization, or user specified anything in sgd,adaptive,invariant,normalized, we turn off default update rules and use whatever user specified - if( !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) ) + if( all->rank > 0 || !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) ) { - all->adaptive = all->training && vm.count("adaptive"); + all->adaptive = all->training && (vm.count("adaptive") && all->rank == 0); all->invariant_updates = all->training && vm.count("invariant"); - all->normalized_updates = all->training && vm.count("normalized"); + all->normalized_updates = all->training && (vm.count("normalized") && all->rank == 0); all->reg.stride = 1; @@ -367,7 +368,7 @@ vw* parse_args(int argc, char *argv[]) if( all->normalized_updates ) all->reg.stride *= 2; - if(!vm.count("learning_rate") && !(all->adaptive && all->normalized_updates)) + if(!vm.count("learning_rate") && !vm.count("l") && !(all->adaptive && all->normalized_updates)) all->eta = 10; //default learning rate to 10 for non default update rule //if not using normalized or adaptive, default initial_t to 1 instead of 0 @@ -602,7 +603,6 @@ vw* parse_args(int argc, char *argv[]) } } - /* // matrix factorization enabled if (all->rank > 0) { // store linear + 2*rank weights per index, round up to power of two @@ -610,7 +610,6 @@ vw* parse_args(int argc, char *argv[]) all->reg.stride = 1 << (int) temp; all->random_weights = true; - if ( vm.count("adaptive") ) { cerr << "adaptive is not implemented for matrix factorization" << endl; @@ -639,7 +638,6 @@ vw* parse_args(int argc, char *argv[]) all->initial_t = 1.f; } } - */ if (vm.count("noconstant")) all->add_constant = false; @@ -687,6 +685,9 @@ vw* parse_args(int argc, char *argv[]) if (vm.count("noop")) all->l = NOOP::setup(*all); + if (all->rank != 0) + all->l = GDMF::setup(*all); + all->loss = getLossFunction(all, loss_function, (float)loss_parameter); if (pow((double)all->eta_decay_rate, (double)all->numpasses) < 0.0001 ) @@ -800,8 +801,8 @@ vw* parse_args(int argc, char *argv[]) if(vm.count("nn") || vm_file.count("nn") ) all->l = NN::setup(*all, to_pass_further, vm, vm_file); - if (all->rank != 0) - all->l = MF::setup(*all, vm); + // if (all->rank != 0) + // all->l = MF::setup(*all); if(vm.count("autolink") || vm_file.count("autolink") ) all->l = ALINK::setup(*all, to_pass_further, vm, vm_file); |