diff options
author | Kenneth Heafield <github@kheafield.com> | 2014-02-18 03:50:21 +0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2014-02-18 03:50:21 +0400 |
commit | 0339a879a86f2a7c5194dedff5ba5910bf6920e6 (patch) | |
tree | 44ca6073c4ebd4ac20bde49458fc000473a5f688 | |
parent | 8142108fb0f810a135753ab9dc54bc18395df592 (diff) |
Add corpus_count fixed vocab option
-rw-r--r-- | lm/builder/corpus_count.cc | 50 | ||||
-rw-r--r-- | lm/builder/corpus_count.hh | 13 |
2 files changed, 56 insertions, 7 deletions
diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc index b99edd0..d3a0c2d 100644 --- a/lm/builder/corpus_count.cc +++ b/lm/builder/corpus_count.cc @@ -39,6 +39,7 @@ struct VocabEntry { const float kProbingMultiplier = 1.5; +// Hand out vocab ids on the fly. class VocabHandout { public: static std::size_t MemUsage(WordIndex initial_guess) { @@ -46,7 +47,7 @@ class VocabHandout { return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier)); } - explicit VocabHandout(int fd, WordIndex initial_guess) : + VocabHandout(int fd, WordIndex initial_guess) : table_backing_(util::CallocOrThrow(MemUsage(initial_guess))), table_(table_backing_.get(), MemUsage(initial_guess)), double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)), @@ -91,6 +92,35 @@ class VocabHandout { util::FakeOFStream word_list_; }; +// Vocab ids are given in a precompiled hash table. +class VocabGiven { + public: + explicit VocabGiven(int fd) { + util::MapRead(util::POPULATE_OR_READ, fd, 0, util::CheckOverflow(util::SizeOrThrow(fd)), table_backing_); + // Leave space for header with size. + table_ = Table(static_cast<char*>(table_backing_.get()) + sizeof(uint64_t), table_backing_.size() - sizeof(uint64_t)); + } + + WordIndex Lookup(const StringPiece &word) const { + Table::ConstIterator it; + if (table_.Find(util::MurmurHashNative(word.data(), word.size()), it)) { + return it->value; + } else { + return 0; // <unk>. + } + } + + WordIndex Size() const { + return *static_cast<const uint64_t*>(table_backing_.get()); + } + + private: + util::scoped_memory table_backing_; + + typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table; + Table table_; +}; + class DedupeHash : public std::unary_function<const WordIndex *, bool> { public: explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} @@ -223,11 +253,12 @@ std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { return VocabHandout::MemUsage(vocab_estimate); } -CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol) - : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_file, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol, bool dynamic_vocab) + : from_(from), vocab_file_(vocab_file), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)), - disallowed_symbol_action_(disallowed_symbol) { + disallowed_symbol_action_(disallowed_symbol), + dynamic_vocab_(dynamic_vocab) { } namespace { @@ -246,9 +277,18 @@ namespace { } // namespace void CorpusCount::Run(const util::stream::ChainPosition &position) { - VocabHandout vocab(vocab_write_, type_count_); token_count_ = 0; type_count_ = 0; + if (dynamic_vocab_) { + VocabHandout vocab(vocab_file_, type_count_); + RunWithVocab(position, vocab); + } else { + VocabGiven vocab(vocab_file_); + RunWithVocab(position, vocab); + } +} + +template <class Voc> void CorpusCount::RunWithVocab(const util::stream::ChainPosition &position, Voc &vocab) { const WordIndex end_sentence = vocab.Lookup("</s>"); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; diff --git a/lm/builder/corpus_count.hh b/lm/builder/corpus_count.hh index 17fc7db..1bb1dbd 100644 --- a/lm/builder/corpus_count.hh +++ b/lm/builder/corpus_count.hh @@ -29,13 +29,20 @@ class CorpusCount { // token_count: out. // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value. - CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol); + // + // If dynamic_vocab is true, then vocab ids are created on the fly and the + // words are written to vocab_file. If dynamic_vocab is false, then + // vocab_file is expected to contain an 8-byte count followed by a probing + // hash table with precomputed vocab ids. + CorpusCount(util::FilePiece &from, int vocab_file, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol, bool dynamic_vocab = true); void Run(const util::stream::ChainPosition &position); private: + template <class Voc> void RunWithVocab(const util::stream::ChainPosition &position, Voc &vocab); + util::FilePiece &from_; - int vocab_write_; + int vocab_file_; uint64_t &token_count_; WordIndex &type_count_; @@ -43,6 +50,8 @@ class CorpusCount { util::scoped_malloc dedupe_mem_; WarningAction disallowed_symbol_action_; + + bool dynamic_vocab_; }; } // namespace builder |