From 794867c555e82edf3bd12ffb3faa35fb24d6e0a1 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 24 Jun 2013 16:05:47 +0100 Subject: KenLM 6b4a1c7940a36026de1d96693ccb6ec0f16de8dc --- lm/builder/lmplz_main.cc | 31 ++++++++++++++++++++++--------- lm/builder/ngram.hh | 2 +- lm/model.cc | 21 +++++++++++++++++++++ lm/model.hh | 5 +++++ lm/search_hashed.cc | 29 ++++++++++++++--------------- lm/search_hashed.hh | 19 +++++++------------ lm/virtual_interface.hh | 3 +++ 7 files changed, 73 insertions(+), 37 deletions(-) (limited to 'lm') diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc index 1e086dcce..2e3002d12 100644 --- a/lm/builder/lmplz_main.cc +++ b/lm/builder/lmplz_main.cc @@ -33,6 +33,8 @@ int main(int argc, char *argv[]) { po::options_description options("Language model building options"); lm::builder::PipelineConfig pipeline; + std::string text, arpa; + options.add_options() ("order,o", po::value(&pipeline.order) #if BOOST_VERSION >= 104200 @@ -47,18 +49,21 @@ int main(int argc, char *argv[]) { ("vocab_estimate", po::value(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table") ("block_count", po::value(&pipeline.block_count)->default_value(2), "Block count (per order)") ("vocab_file", po::value(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") - ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); + ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.") + ("text", po::value(&text), "Read text from a file instead of stdin") + ("arpa", po::value(&arpa), "Write ARPA to a file instead of stdout"); if (argc == 1) { std::cerr << "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" "Please cite:\n" - "@inproceedings{kenlm,\n" - "author = {Kenneth Heafield},\n" - "title = {{KenLM}: Faster and Smaller Language Model Queries},\n" - "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" - "month = {July}, year={2011},\n" - "address = {Edinburgh, UK},\n" - "publisher = {Association for Computational Linguistics},\n" + "@inproceedings{Heafield-estimate,\n" + " author = {Kenneth Heafield and Ivan Pouzyrevsky and Jonathan H. Clark and Philipp Koehn},\n" + " title = {Scalable Modified {Kneser-Ney} Language Model Estimation},\n" + " year = {2013},\n" + " month = {8},\n" + " booktitle = {Proceedings of the 51st Annual Meeting of the Association for Computational Linguistics},\n" + " address = {Sofia, Bulgaria},\n" + " url = {http://kheafield.com/professional/edinburgh/estimate\\_paper.pdf},\n" "}\n\n" "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n" "the model (-o) is the only mandatory option. As this is an on-disk program,\n" @@ -91,9 +96,17 @@ int main(int argc, char *argv[]) { initial.adder_out.block_count = 2; pipeline.read_backoffs = initial.adder_out; + util::scoped_fd in(0), out(1); + if (vm.count("text")) { + in.reset(util::OpenReadOrThrow(text.c_str())); + } + if (vm.count("arpa")) { + out.reset(util::CreateOrThrow(arpa.c_str())); + } + // Read from stdin try { - lm::builder::Pipeline(pipeline, 0, 1); + lm::builder::Pipeline(pipeline, in.release(), out.release()); } catch (const util::MallocException &e) { std::cerr << e.what() << std::endl; std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as() << std::endl; diff --git a/lm/builder/ngram.hh b/lm/builder/ngram.hh index 2984ed0b6..f5681516a 100644 --- a/lm/builder/ngram.hh +++ b/lm/builder/ngram.hh @@ -53,7 +53,7 @@ class NGram { Payload &Value() { return *reinterpret_cast(end_); } uint64_t &Count() { return Value().count; } - const uint64_t Count() const { return Value().count; } + uint64_t Count() const { return Value().count; } std::size_t Order() const { return end_ - begin_; } diff --git a/lm/model.cc b/lm/model.cc index a40fd2fb0..a26654a6f 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -304,5 +304,26 @@ template class GenericModel, SortedVocabulary>; } // namespace detail + +base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) { + RecognizeBinary(file_name, model_type); + switch (model_type) { + case PROBING: + return new ProbingModel(file_name, config); + case REST_PROBING: + return new RestProbingModel(file_name, config); + case TRIE: + return new TrieModel(file_name, config); + case QUANT_TRIE: + return new QuantTrieModel(file_name, config); + case ARRAY_TRIE: + return new ArrayTrieModel(file_name, config); + case QUANT_ARRAY_TRIE: + return new QuantArrayTrieModel(file_name, config); + default: + UTIL_THROW(FormatLoadException, "Confused by model type " << model_type); + } +} + } // namespace ngram } // namespace lm diff --git a/lm/model.hh b/lm/model.hh index 13ff864e1..60f55110b 100644 --- a/lm/model.hh +++ b/lm/model.hh @@ -153,6 +153,11 @@ LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel class ActivateUnigram { Weights *modify_; }; -// Find the lower order entry, inserting blanks along the way as necessary. +// Find the lower order entry, inserting blanks along the way as necessary. template void FindLower( const std::vector &keys, typename Value::Weights &unigram, @@ -64,7 +64,7 @@ template void FindLower( typename Value::ProbingEntry entry; // Backoff will always be 0.0. We'll get the probability and rest in another pass. entry.value.backoff = kNoExtensionBackoff; - // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. + // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. for (int lower = keys.size() - 2; ; --lower) { if (lower == -1) { between.push_back(&unigram); @@ -77,11 +77,11 @@ template void FindLower( } } -// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here. +// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here. template void AdjustLower( const Added &added, const Build &build, - std::vector &between, + std::vector &between, const unsigned int n, const std::vector &vocab_ids, typename Build::Value::Weights *unigrams, @@ -93,14 +93,14 @@ template void AdjustLower( } typedef util::ProbingHashTable Middle; float prob = -fabs(between.back()->prob); - // Order of the n-gram on which probabilities are based. + // Order of the n-gram on which probabilities are based. unsigned char basis = n - between.size(); assert(basis != 0); typename Build::Value::Weights **change = &between.back(); // Skip the basis. --change; if (basis == 1) { - // Hallucinate a bigram based on a unigram's backoff and a unigram probability. + // Hallucinate a bigram based on a unigram's backoff and a unigram probability. float &backoff = unigrams[vocab_ids[1]].backoff; SetExtension(backoff); prob += backoff; @@ -128,14 +128,14 @@ template void AdjustLower( typename std::vector::const_iterator i(between.begin()); build.MarkExtends(**i, added); const typename Value::Weights *longer = *i; - // Everything has probability but is not marked as extending. + // Everything has probability but is not marked as extending. for (++i; i != between.end(); ++i) { build.MarkExtends(**i, *longer); longer = *i; } } -// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds. +// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds. template void MarkLower( const std::vector &keys, const Build &build, @@ -144,15 +144,15 @@ template void MarkLower( int start_order, const typename Build::Value::Weights &longer) { if (start_order == 0) return; - typename util::ProbingHashTable::MutableIterator iter; - // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code. + // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code. for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) { if (even_lower == -1) { build.MarkExtends(unigram, longer); return; } - middle[even_lower].UnsafeMutableFind(keys[even_lower], iter); - if (!build.MarkExtends(iter->value, longer)) return; + if (!build.MarkExtends( + middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value, + longer)) return; } } @@ -168,7 +168,6 @@ template void ReadNGrams( Store &store, PositiveProbWarn &warn) { typedef typename Build::Value Value; - typedef util::ProbingHashTable Middle; assert(n >= 2); ReadNGramHeader(f, n); @@ -186,7 +185,7 @@ template void ReadNGrams( for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } - // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. + // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. util::SetSign(entry.value.prob); entry.key = keys[n-2]; @@ -203,7 +202,7 @@ template void ReadNGrams( } // namespace namespace detail { - + template uint8_t *HashedSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { std::size_t allocated = Unigram::Size(counts[0]); unigram_ = Unigram(start, counts[0], allocated); diff --git a/lm/search_hashed.hh b/lm/search_hashed.hh index 005957967..9d067bc2e 100644 --- a/lm/search_hashed.hh +++ b/lm/search_hashed.hh @@ -71,7 +71,7 @@ template class HashedSearch { static const bool kDifferentRest = Value::kDifferentRest; static const unsigned int kVersion = 0; - // TODO: move probing_multiplier here with next binary file format update. + // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} static uint64_t Size(const std::vector &counts, const Config &config) { @@ -102,14 +102,9 @@ template class HashedSearch { return ret; } -#pragma GCC diagnostic ignored "-Wuninitialized" MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { node = extend_pointer; - typename Middle::ConstIterator found; - bool got = middle_[extend_length - 2].Find(extend_pointer, found); - assert(got); - (void)got; - return MiddlePointer(found->value); + return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value); } MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const { @@ -126,14 +121,14 @@ template class HashedSearch { } LongestPointer LookupLongest(WordIndex word, const Node &node) const { - // Sign bit is always on because longest n-grams do not extend left. + // Sign bit is always on because longest n-grams do not extend left. typename Longest::ConstIterator found; if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer(); return LongestPointer(found->value.prob); } - // Generate a node without necessarily checking that it actually exists. - // Optionally return false if it's know to not exist. + // Generate a node without necessarily checking that it actually exists. + // Optionally return false if it's know to not exist. bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { assert(begin != end); node = static_cast(*begin); @@ -144,7 +139,7 @@ template class HashedSearch { } private: - // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. + // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. void DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); template void ApplyBuild(util::FilePiece &f, const std::vector &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); @@ -153,7 +148,7 @@ template class HashedSearch { public: Unigram() {} - Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : + Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : unigram_(static_cast(start)) #ifdef DEBUG , count_(count) diff --git a/lm/virtual_interface.hh b/lm/virtual_interface.hh index 6a5a0196f..17f064b2c 100644 --- a/lm/virtual_interface.hh +++ b/lm/virtual_interface.hh @@ -6,6 +6,7 @@ #include "util/string_piece.hh" #include +#include namespace lm { namespace base { @@ -119,7 +120,9 @@ class Model { size_t StateSize() const { return state_size_; } const void *BeginSentenceMemory() const { return begin_sentence_memory_; } + void BeginSentenceWrite(void *to) const { memcpy(to, begin_sentence_memory_, StateSize()); } const void *NullContextMemory() const { return null_context_memory_; } + void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); } // Requires in_state != out_state virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; -- cgit v1.2.3