Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/vowpal_wabbit.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJuha Reunanen <juha.reunanen@gmail.com>2014-01-22 21:02:49 +0400
committerJuha Reunanen <juha.reunanen@gmail.com>2014-01-22 21:02:49 +0400
commitcb8d90352a0007c8b133f535ad0f9111c7a58cee (patch)
tree07d3c5c0cf7200dcc9477e4dadb93cd87439b770
parentc803da144613a117bf624e5e9dc4dce8f62671b1 (diff)
Added new param lda_epsilon to replace the hard-coded convergence threshold used in the lda_loop function.
-rw-r--r--vowpalwabbit/global_data.cc1
-rw-r--r--vowpalwabbit/global_data.h1
-rw-r--r--vowpalwabbit/lda_core.cc3
3 files changed, 4 insertions, 1 deletions
diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc
index 844e6db3..a0d602ee 100644
--- a/vowpalwabbit/global_data.cc
+++ b/vowpalwabbit/global_data.cc
@@ -238,6 +238,7 @@ vw::vw()
lda_alpha = 0.1f;
lda_rho = 0.1f;
lda_D = 10000.;
+ lda_epsilon = 0.001;
minibatch = 1;
span_server = "";
m = 15;
diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h
index 0b78ef42..40e12718 100644
--- a/vowpalwabbit/global_data.h
+++ b/vowpalwabbit/global_data.h
@@ -220,6 +220,7 @@ struct vw {
float lda_alpha;
float lda_rho;
float lda_D;
+ float lda_epsilon;
std::string text_regressor_name;
std::string inv_hash_regressor_name;
diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc
index c7be0d9f..db2281a2 100644
--- a/vowpalwabbit/lda_core.cc
+++ b/vowpalwabbit/lda_core.cc
@@ -495,7 +495,7 @@ v_array<float> old_gamma;
for (size_t k =0; k<all.lda; k++)
new_gamma[k] = new_gamma[k]*v[k]+all.lda_alpha;
}
- while (average_diff(all, old_gamma.begin, new_gamma.begin) > 0.001);
+ while (average_diff(all, old_gamma.begin, new_gamma.begin) > all.lda_epsilon);
ec->topic_predictions.erase();
ec->topic_predictions.resize(all.lda);
@@ -748,6 +748,7 @@ learner* setup(vw&all, std::vector<std::string>&opts, po::variables_map& vm)
("lda_alpha", po::value<float>(&all.lda_alpha), "Prior on sparsity of per-document topic weights")
("lda_rho", po::value<float>(&all.lda_rho), "Prior on sparsity of topic distributions")
("lda_D", po::value<float>(&all.lda_D), "Number of documents")
+ ("lda_epsilon", po::value<float>(&all.lda_epsilon), "Loop convergence threshold")
("minibatch", po::value<size_t>(&all.minibatch), "Minibatch size, for LDA");
po::parsed_options parsed = po::command_line_parser(opts).