Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Langford <jl@jl-desktop.(none)>2010-12-03 19:19:09 +0300
committerJohn Langford <jl@jl-desktop.(none)>2010-12-03 19:19:09 +0300
commitc13ae76b0e85fa4987b31862ba47a1f17f17ae03 (patch)
tree9941fb3af44da8f4c99fe4667a675014cd23688a /parse_args.cc
parent5d424bddd267c28e5bb15ece8781214080d0d182 (diff)
synced lda
Diffstat (limited to 'parse_args.cc')
-rw-r--r--parse_args.cc24
1 files changed, 23 insertions, 1 deletions
diff --git a/parse_args.cc b/parse_args.cc
index 06964955..a3af35f6 100644
--- a/parse_args.cc
+++ b/parse_args.cc
@@ -53,6 +53,10 @@ po::variables_map parse_args(int argc, char *argv[], boost::program_options::opt
("initial_weight", po::value<float>(&global.initial_weight)->default_value(0.), "Set all weights to an initial value of 1.")
("initial_regressor,i", po::value< vector<string> >(), "Initial regressor(s)")
("initial_t", po::value<float>(&(par->t))->default_value(1.), "initial t value")
+ ("lda", po::value<size_t>(&global.lda), "Run lda with <int> topics")
+ ("lda_alpha", po::value<float>(&global.lda_alpha)->default_value(0.1), "Prior on sparsity of per-document topic weights")
+ ("lda_rho", po::value<float>(&global.lda_rho)->default_value(0.1), "Prior on sparsity of topic distributions")
+ ("lda_D", po::value<float>(&global.lda_D)->default_value(10000.), "Number of documents")
("min_prediction", po::value<double>(&global.min_label), "Smallest prediction to output")
("max_prediction", po::value<double>(&global.max_label), "Largest prediction to output")
("multisource", po::value<size_t>(), "multiple sources for daemon input")
@@ -68,6 +72,7 @@ po::variables_map parse_args(int argc, char *argv[], boost::program_options::opt
("quadratic,q", po::value< vector<string> > (),
"Create and use quadratic features")
("quiet", "Don't output diagnostics")
+ ("random_weights", po::value<bool>(&global.random_weights), "make initial weights random")
("raw_predictions,r", po::value< string >(),
"File to output unnormalized predictions to")
("sendto", po::value< vector<string> >(), "send example to <hosts>")
@@ -103,6 +108,8 @@ po::variables_map parse_args(int argc, char *argv[], boost::program_options::opt
global.print = print_result;
global.min_label = 0.;
global.max_label = 1.;
+ global.lda =0;
+ global.random_weights = false;
global.adaptive = false;
global.audit = false;
@@ -232,6 +239,21 @@ po::variables_map parse_args(int argc, char *argv[], boost::program_options::opt
}
}
+ if (vm.count("lda"))
+ {
+ float temp = ceilf(logf((float)(global.lda+1)) / logf (2.f));
+ global.stride = powf(2,temp);
+ global.random_weights = true;
+ }
+
+ if (vm.count("lda") && global.eta > 1.)
+ {
+ cerr << "your learning rate is too high, setting it to 1" << endl;
+ global.eta = min(global.eta,1.f);
+ }
+ if (!vm.count("lda"))
+ global.eta *= pow(par->t, vars.power_t);
+
parse_regressor_args(vm, r, final_regressor_name, global.quiet);
if (vm.count("active_c0"))
@@ -256,7 +278,7 @@ po::variables_map parse_args(int argc, char *argv[], boost::program_options::opt
r.loss = getLossFunction(loss_function, loss_parameter);
global.loss = r.loss;
- global.eta *= pow(par->t, vars.power_t);
+// global.eta *= pow(par->t, vars.power_t);
if (global.eta_decay_rate != default_decay && global.numpasses == 1)
cerr << "Warning: decay_learning_rate has no effect when there is only one pass" << endl;