From cb8d90352a0007c8b133f535ad0f9111c7a58cee Mon Sep 17 00:00:00 2001 From: Juha Reunanen Date: Wed, 22 Jan 2014 19:02:49 +0200 Subject: Added new param lda_epsilon to replace the hard-coded convergence threshold used in the lda_loop function. --- vowpalwabbit/global_data.cc | 1 + vowpalwabbit/global_data.h | 1 + vowpalwabbit/lda_core.cc | 3 ++- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/vowpalwabbit/global_data.cc b/vowpalwabbit/global_data.cc index 844e6db3..a0d602ee 100644 --- a/vowpalwabbit/global_data.cc +++ b/vowpalwabbit/global_data.cc @@ -238,6 +238,7 @@ vw::vw() lda_alpha = 0.1f; lda_rho = 0.1f; lda_D = 10000.; + lda_epsilon = 0.001; minibatch = 1; span_server = ""; m = 15; diff --git a/vowpalwabbit/global_data.h b/vowpalwabbit/global_data.h index 0b78ef42..40e12718 100644 --- a/vowpalwabbit/global_data.h +++ b/vowpalwabbit/global_data.h @@ -220,6 +220,7 @@ struct vw { float lda_alpha; float lda_rho; float lda_D; + float lda_epsilon; std::string text_regressor_name; std::string inv_hash_regressor_name; diff --git a/vowpalwabbit/lda_core.cc b/vowpalwabbit/lda_core.cc index c7be0d9f..db2281a2 100644 --- a/vowpalwabbit/lda_core.cc +++ b/vowpalwabbit/lda_core.cc @@ -495,7 +495,7 @@ v_array old_gamma; for (size_t k =0; k 0.001); + while (average_diff(all, old_gamma.begin, new_gamma.begin) > all.lda_epsilon); ec->topic_predictions.erase(); ec->topic_predictions.resize(all.lda); @@ -748,6 +748,7 @@ learner* setup(vw&all, std::vector&opts, po::variables_map& vm) ("lda_alpha", po::value(&all.lda_alpha), "Prior on sparsity of per-document topic weights") ("lda_rho", po::value(&all.lda_rho), "Prior on sparsity of topic distributions") ("lda_D", po::value(&all.lda_D), "Number of documents") + ("lda_epsilon", po::value(&all.lda_epsilon), "Loop convergence threshold") ("minibatch", po::value(&all.minibatch), "Minibatch size, for LDA"); po::parsed_options parsed = po::command_line_parser(opts). -- cgit v1.2.3