Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2014-11-12 01:25:57 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2014-11-12 01:25:57 +0300
commit454c695cfb029ede789fe1edd80f87ff332fe664 (patch)
treefe4a03f6cdb79fcc915421f0e4772c74c634447d
parent15b82b43712095136b125b48dedfc63fc3a82836 (diff)
clean up, options, vocab pruning
-rw-r--r--lm/builder/adjust_counts.cc18
-rw-r--r--lm/builder/lmplz_main.cc8
-rw-r--r--lm/builder/pipeline.cc24
-rw-r--r--lm/builder/pipeline.hh1
4 files changed, 23 insertions, 28 deletions
diff --git a/lm/builder/adjust_counts.cc b/lm/builder/adjust_counts.cc
index 6650c9c..47ff76c 100644
--- a/lm/builder/adjust_counts.cc
+++ b/lm/builder/adjust_counts.cc
@@ -223,9 +223,9 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
// Only unigrams. Just collect stats.
for (NGramStream full(positions[0]); full; ++full) {
- if(*full->begin() != kBOS && *full->begin() != kEOS && *full->begin() != kUNK) {
- uint64_t realCount = full->Count();
- if(prune_thresholds_[0] && realCount <= prune_thresholds_[0])
+ // Do not prune <s> </s> <unk>
+ if(*full->begin() > 2) {
+ if(full->Count() <= prune_thresholds_[0])
full->Mark();
if(!prune_words_.empty() && prune_words_[*full->begin()])
@@ -267,17 +267,11 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
uint64_t lower_order = (*lower_valid)->Order();
uint64_t lower_count = lower_counts[lower_order - 1];
- if(lower_order > 1 && prune_thresholds_[lower_order - 1] && lower_count <= prune_thresholds_[lower_order - 1])
+ if(lower_order > 1 && lower_count <= prune_thresholds_[lower_order - 1])
(*lower_valid)->Mark();
- bool special = false;
- if(lower_order == 1) {
- WordIndex w = *(*lower_valid)->begin();
- if(w == kBOS || w == kEOS || w == kUNK)
- special = true;
- }
-
- if(!special && prune_thresholds_[lower_order - 1] && lower_count <= prune_thresholds_[lower_order - 1])
+ // Do not prune unigrams <unk> <s> </s>
+ if(lower_order == 1 && *(*lower_valid)->begin() > 2 && lower_count <= prune_thresholds_[0])
(*lower_valid)->Mark();
if(!prune_words_.empty()) {
diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc
index afa670a..214f1e2 100644
--- a/lm/builder/lmplz_main.cc
+++ b/lm/builder/lmplz_main.cc
@@ -115,6 +115,7 @@ int main(int argc, char *argv[]) {
("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.")
("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Default is to not prune, which is equivalent to --prune 0.")
+ ("limit_vocab_file", po::value<std::string>(&pipeline.prune_vocab_file)->default_value(""), "Read allowed vocabulary separated by whitespace. N-grams that contain vocabulary items not in this list will be pruned. Can be combined with --prune arg")
("discount_fallback", po::value<std::vector<std::string> >(&discount_fallback)->multitoken()->implicit_value(discount_fallback_default, "0.5 1 1.5"), "The closed-form estimate for Kneser-Ney discounts does not work without singletons or doubletons. It can also fail if these values are out of range. This option falls back to user-specified discounts when the closed-form estimate fails. Note that this option is generally a bad idea: you should deduplicate your corpus instead. However, class-based models need custom discounts because they lack singleton unigrams. Provide up to three discounts (for adjusted counts 1, 2, and 3+), which will be applied to all orders where the closed-form estimates fail.");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, options), vm);
@@ -180,6 +181,13 @@ int main(int argc, char *argv[]) {
// parse pruning thresholds. These depend on order, so it is not done as a notifier.
pipeline.prune_thresholds = ParsePruning(pruning, pipeline.order);
+
+ if (!vm["limit_vocab_file"].as<std::string>().empty()) {
+ pipeline.prune_vocab = true;
+ }
+ else {
+ pipeline.prune_vocab = false;
+ }
util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc
index 9284575..6b368b8 100644
--- a/lm/builder/pipeline.cc
+++ b/lm/builder/pipeline.cc
@@ -304,7 +304,6 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
UTIL_TIMER("(%w s) Total wall time elapsed\n");
- config.prune_vocab = true;
Master master(config);
// master's destructor will wait for chains. But they might be deadlocked if
// this thread dies because e.g. it ran out of memory.
@@ -315,30 +314,23 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
uint64_t token_count;
std::string text_file_name;
CountText(text_file, vocab_file.get(), master, token_count, text_file_name);
-
-
- // for compact size
- std::set<std::string> keepSet;
- std::string line;
- std::ifstream keepFile("keep.txt");
- while(std::getline(keepFile, line)) {
- std::cerr << "Adding: " << line << std::endl;
- keepSet.insert(line);
- }
std::vector<bool> prune_words;
- {
- // TODO: create this in corpus_count!!!
+ if(config.prune_vocab) {
+ std::set<std::string> keep_set;
+ std::string vocab_item;
+ std::ifstream keep_file(config.prune_vocab_file.c_str());
+ while(keep_file >> vocab_item)
+ keep_set.insert(vocab_item);
+
VocabReconstitute vocab(vocab_file.get());
prune_words.resize(vocab.Size(), true);
prune_words[kUNK] = false;
prune_words[kBOS] = false;
prune_words[kEOS] = false;
for(size_t i = 3; i < vocab.Size(); ++i)
- if(keepSet.count(vocab.Lookup(i))) {
- std::cerr << "Keeping: " << vocab.Lookup(i) << " " << i << std::endl;
+ if(keep_set.count(vocab.Lookup(i)))
prune_words[i] = false;
- }
}
std::vector<uint64_t> counts;
diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh
index 6eeb50e..15c3981 100644
--- a/lm/builder/pipeline.hh
+++ b/lm/builder/pipeline.hh
@@ -38,6 +38,7 @@ struct PipelineConfig {
// corresponding n-gram order
std::vector<uint64_t> prune_thresholds; //mjd
bool prune_vocab;
+ std::string prune_vocab_file;
// What to do with discount failures.
DiscountConfig discount;