Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJake Hofman <jhofman@gmail>2014-01-08 05:25:55 +0400
committerJake Hofman <jhofman@gmail>2014-01-08 05:25:55 +0400
commitde91f21a827ae6e53e239a26edcc53f632b49c4f (patch)
tree3fa09692c7cf0b702b783f2dfa1d707509fc1d06 /vowpalwabbit/parse_args.cc
parenta73c38db84b44307e8efc4a7c77fecba18ad2932 (diff)
removed precomputed_prediction in favor of base.update
Diffstat (limited to 'vowpalwabbit/parse_args.cc')
-rw-r--r--vowpalwabbit/parse_args.cc19
1 files changed, 10 insertions, 9 deletions
diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc
index e369d975..82809fe7 100644
--- a/vowpalwabbit/parse_args.cc
+++ b/vowpalwabbit/parse_args.cc
@@ -212,6 +212,7 @@ vw* parse_args(int argc, char *argv[])
("cubic", po::value< vector<string> > (),
"Create and use cubic features")
("rank", po::value<uint32_t>(&(all->rank)), "rank for matrix factorization.")
+ ("new_mf", "use new, reduction-based matrix factorization")
;
po::options_description lrq_opt("Low Rank Quadratic options");
@@ -398,12 +399,12 @@ vw* parse_args(int argc, char *argv[])
}
all->reg.stride = 4; //use stride of 4 for default invariant normalized adaptive updates
- //if we are doing matrix factorization, or user specified anything in sgd,adaptive,invariant,normalized, we turn off default update rules and use whatever user specified
- if( all->rank > 0 || !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) )
+ //if the user specified anything in sgd,adaptive,invariant,normalized, we turn off default update rules and use whatever user specified
+ if( (all->rank > 0 && !vm.count("new_mf")) || !all->training || ( ( vm.count("sgd") || vm.count("adaptive") || vm.count("invariant") || vm.count("normalized") ) && !vm.count("exact_adaptive_norm")) )
{
- all->adaptive = all->training && (vm.count("adaptive") && all->rank == 0);
+ all->adaptive = all->training && vm.count("adaptive") && (all->rank == 0 && !vm.count("new_mf"));
all->invariant_updates = all->training && vm.count("invariant");
- all->normalized_updates = all->training && (vm.count("normalized") && all->rank == 0);
+ all->normalized_updates = all->training && vm.count("normalized") && (all->rank == 0 && !vm.count("new_mf"));
all->reg.stride = 1;
@@ -665,8 +666,8 @@ vw* parse_args(int argc, char *argv[])
}
}
- // matrix factorization enabled
- if (all->rank > 0) {
+ // (non-reduction) matrix factorization enabled
+ if (!vm.count("new_mf") && all->rank > 0) {
// store linear + 2*rank weights per index, round up to power of two
float temp = ceilf(logf((float)(all->rank*2+1)) / logf (2.f));
all->reg.stride = 1 << (int) temp;
@@ -747,7 +748,7 @@ vw* parse_args(int argc, char *argv[])
if (vm.count("noop"))
all->l = NOOP::setup(*all);
- if (all->rank != 0)
+ if (!vm.count("new_mf") && all->rank > 0)
all->l = GDMF::setup(*all);
all->loss = getLossFunction(all, loss_function, (float)loss_parameter);
@@ -845,8 +846,8 @@ vw* parse_args(int argc, char *argv[])
if(vm.count("nn") || vm_file.count("nn") )
all->l = NN::setup(*all, to_pass_further, vm, vm_file);
- // if (all->rank != 0)
- // all->l = MF::setup(*all);
+ if (vm.count("new_mf") && all->rank > 0)
+ all->l = MF::setup(*all, vm);
if(vm.count("autolink") || vm_file.count("autolink") )
all->l = ALINK::setup(*all, to_pass_further, vm, vm_file);