diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2014-11-12 01:25:57 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2014-11-12 01:25:57 +0300 |
commit | 454c695cfb029ede789fe1edd80f87ff332fe664 (patch) | |
tree | fe4a03f6cdb79fcc915421f0e4772c74c634447d | |
parent | 15b82b43712095136b125b48dedfc63fc3a82836 (diff) |
clean up, options, vocab pruning
-rw-r--r-- | lm/builder/adjust_counts.cc | 18 | ||||
-rw-r--r-- | lm/builder/lmplz_main.cc | 8 | ||||
-rw-r--r-- | lm/builder/pipeline.cc | 24 | ||||
-rw-r--r-- | lm/builder/pipeline.hh | 1 |
4 files changed, 23 insertions, 28 deletions
diff --git a/lm/builder/adjust_counts.cc b/lm/builder/adjust_counts.cc index 6650c9c..47ff76c 100644 --- a/lm/builder/adjust_counts.cc +++ b/lm/builder/adjust_counts.cc @@ -223,9 +223,9 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) { // Only unigrams. Just collect stats. for (NGramStream full(positions[0]); full; ++full) { - if(*full->begin() != kBOS && *full->begin() != kEOS && *full->begin() != kUNK) { - uint64_t realCount = full->Count(); - if(prune_thresholds_[0] && realCount <= prune_thresholds_[0]) + // Do not prune <s> </s> <unk> + if(*full->begin() > 2) { + if(full->Count() <= prune_thresholds_[0]) full->Mark(); if(!prune_words_.empty() && prune_words_[*full->begin()]) @@ -267,17 +267,11 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) { uint64_t lower_order = (*lower_valid)->Order(); uint64_t lower_count = lower_counts[lower_order - 1]; - if(lower_order > 1 && prune_thresholds_[lower_order - 1] && lower_count <= prune_thresholds_[lower_order - 1]) + if(lower_order > 1 && lower_count <= prune_thresholds_[lower_order - 1]) (*lower_valid)->Mark(); - bool special = false; - if(lower_order == 1) { - WordIndex w = *(*lower_valid)->begin(); - if(w == kBOS || w == kEOS || w == kUNK) - special = true; - } - - if(!special && prune_thresholds_[lower_order - 1] && lower_count <= prune_thresholds_[lower_order - 1]) + // Do not prune unigrams <unk> <s> </s> + if(lower_order == 1 && *(*lower_valid)->begin() > 2 && lower_count <= prune_thresholds_[0]) (*lower_valid)->Mark(); if(!prune_words_.empty()) { diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc index afa670a..214f1e2 100644 --- a/lm/builder/lmplz_main.cc +++ b/lm/builder/lmplz_main.cc @@ -115,6 +115,7 @@ int main(int argc, char *argv[]) { ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout") ("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.") ("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Default is to not prune, which is equivalent to --prune 0.") + ("limit_vocab_file", po::value<std::string>(&pipeline.prune_vocab_file)->default_value(""), "Read allowed vocabulary separated by whitespace. N-grams that contain vocabulary items not in this list will be pruned. Can be combined with --prune arg") ("discount_fallback", po::value<std::vector<std::string> >(&discount_fallback)->multitoken()->implicit_value(discount_fallback_default, "0.5 1 1.5"), "The closed-form estimate for Kneser-Ney discounts does not work without singletons or doubletons. It can also fail if these values are out of range. This option falls back to user-specified discounts when the closed-form estimate fails. Note that this option is generally a bad idea: you should deduplicate your corpus instead. However, class-based models need custom discounts because they lack singleton unigrams. Provide up to three discounts (for adjusted counts 1, 2, and 3+), which will be applied to all orders where the closed-form estimates fail."); po::variables_map vm; po::store(po::parse_command_line(argc, argv, options), vm); @@ -180,6 +181,13 @@ int main(int argc, char *argv[]) { // parse pruning thresholds. These depend on order, so it is not done as a notifier. pipeline.prune_thresholds = ParsePruning(pruning, pipeline.order); + + if (!vm["limit_vocab_file"].as<std::string>().empty()) { + pipeline.prune_vocab = true; + } + else { + pipeline.prune_vocab = false; + } util::NormalizeTempPrefix(pipeline.sort.temp_prefix); diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc index 9284575..6b368b8 100644 --- a/lm/builder/pipeline.cc +++ b/lm/builder/pipeline.cc @@ -304,7 +304,6 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) { UTIL_TIMER("(%w s) Total wall time elapsed\n"); - config.prune_vocab = true; Master master(config); // master's destructor will wait for chains. But they might be deadlocked if // this thread dies because e.g. it ran out of memory. @@ -315,30 +314,23 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) { uint64_t token_count; std::string text_file_name; CountText(text_file, vocab_file.get(), master, token_count, text_file_name); - - - // for compact size - std::set<std::string> keepSet; - std::string line; - std::ifstream keepFile("keep.txt"); - while(std::getline(keepFile, line)) { - std::cerr << "Adding: " << line << std::endl; - keepSet.insert(line); - } std::vector<bool> prune_words; - { - // TODO: create this in corpus_count!!! + if(config.prune_vocab) { + std::set<std::string> keep_set; + std::string vocab_item; + std::ifstream keep_file(config.prune_vocab_file.c_str()); + while(keep_file >> vocab_item) + keep_set.insert(vocab_item); + VocabReconstitute vocab(vocab_file.get()); prune_words.resize(vocab.Size(), true); prune_words[kUNK] = false; prune_words[kBOS] = false; prune_words[kEOS] = false; for(size_t i = 3; i < vocab.Size(); ++i) - if(keepSet.count(vocab.Lookup(i))) { - std::cerr << "Keeping: " << vocab.Lookup(i) << " " << i << std::endl; + if(keep_set.count(vocab.Lookup(i))) prune_words[i] = false; - } } std::vector<uint64_t> counts; diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh index 6eeb50e..15c3981 100644 --- a/lm/builder/pipeline.hh +++ b/lm/builder/pipeline.hh @@ -38,6 +38,7 @@ struct PipelineConfig { // corresponding n-gram order std::vector<uint64_t> prune_thresholds; //mjd bool prune_vocab; + std::string prune_vocab_file; // What to do with discount failures. DiscountConfig discount; |