Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJake Hofman <jhofman@gmail.com>2013-12-28 02:03:42 +0400
committerJake Hofman <jhofman@gmail.com>2013-12-28 02:03:42 +0400
commitcedc01406d6aca7fe5c5236b95804208ebbaa341 (patch)
tree4d9b8b4c491e59bb441f0530955f994afe03003f /vowpalwabbit/parse_args.cc
parent43d4d21ad3c02c0e3cdc052adaf4baa7136b57d9 (diff)
disabled mf in favor of gd_mf for time being
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r--vowpalwabbit/parse_args.cc19
1 files changed, 10 insertions, 9 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index fa98503f..a81bebc2 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -27,6 +27,7 @@ license as described in the file LICENSE.
#include "bfgs.h"
#include "lda_core.h"
#include "noop.h"
+#include "gd_mf.h"
#include "mf.h"
#include "vw.h"
#include "rand48.h"
@@ -354,11 +355,11 @@ vw* parse_args(int argc, char *argv[])
all->reg.stride = 4; //use stride of 4 for default invariant normalized adaptive updates
//if we are doing matrix factorization, or user specified anything in sgd,adaptive,invariant,normalized, we turn off default update rules and use whatever user specified
- if( !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) )
+ if( all->rank > 0 || !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) )
{
- all->adaptive = all->training && vm.count("adaptive");
+ all->adaptive = all->training && (vm.count("adaptive") && all->rank == 0);
all->invariant_updates = all->training && vm.count("invariant");
- all->normalized_updates = all->training && vm.count("normalized");
+ all->normalized_updates = all->training && (vm.count("normalized") && all->rank == 0);
all->reg.stride = 1;
@@ -367,7 +368,7 @@ vw* parse_args(int argc, char *argv[])
if( all->normalized_updates ) all->reg.stride *= 2;
- if(!vm.count("learning_rate") && !(all->adaptive && all->normalized_updates))
+ if(!vm.count("learning_rate") && !vm.count("l") && !(all->adaptive && all->normalized_updates))
all->eta = 10; //default learning rate to 10 for non default update rule
//if not using normalized or adaptive, default initial_t to 1 instead of 0
@@ -602,7 +603,6 @@ vw* parse_args(int argc, char *argv[])
}
}
- /*
// matrix factorization enabled
if (all->rank > 0) {
// store linear + 2*rank weights per index, round up to power of two
@@ -610,7 +610,6 @@ vw* parse_args(int argc, char *argv[])
all->reg.stride = 1 << (int) temp;
all->random_weights = true;
-
if ( vm.count("adaptive") )
{
cerr << "adaptive is not implemented for matrix factorization" << endl;
@@ -639,7 +638,6 @@ vw* parse_args(int argc, char *argv[])
all->initial_t = 1.f;
}
}
- */
if (vm.count("noconstant"))
all->add_constant = false;
@@ -687,6 +685,9 @@ vw* parse_args(int argc, char *argv[])
if (vm.count("noop"))
all->l = NOOP::setup(*all);
+ if (all->rank != 0)
+ all->l = GDMF::setup(*all);
+
all->loss = getLossFunction(all, loss_function, (float)loss_parameter);
if (pow((double)all->eta_decay_rate, (double)all->numpasses) < 0.0001 )
@@ -800,8 +801,8 @@ vw* parse_args(int argc, char *argv[])
if(vm.count("nn") || vm_file.count("nn") )
all->l = NN::setup(*all, to_pass_further, vm, vm_file);
- if (all->rank != 0)
- all->l = MF::setup(*all, vm);
+ // if (all->rank != 0)
+ // all->l = MF::setup(*all);
if(vm.count("autolink") || vm_file.count("autolink") )
all->l = ALINK::setup(*all, to_pass_further, vm, vm_file);