#include "lm/builder/corpus_count.hh" #include "lm/builder/ngram.hh" #include "lm/lm_exception.hh" #include "lm/vocab.hh" #include "lm/word_index.hh" #include "util/fake_ofstream.hh" #include "util/file.hh" #include "util/file_piece.hh" #include "util/murmur_hash.hh" #include "util/probing_hash_table.hh" #include "util/scoped.hh" #include "util/stream/chain.hh" #include "util/stream/timer.hh" #include "util/tokenize_piece.hh" #include #include #include #include namespace lm { namespace builder { namespace { #pragma pack(push) #pragma pack(4) struct VocabEntry { typedef uint64_t Key; uint64_t GetKey() const { return key; } void SetKey(uint64_t to) { key = to; } uint64_t key; lm::WordIndex value; }; #pragma pack(pop) class DedupeHash : public std::unary_function { public: explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} std::size_t operator()(const WordIndex *start) const { return util::MurmurHashNative(start, size_); } private: const std::size_t size_; }; class DedupeEquals : public std::binary_function { public: explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {} bool operator()(const WordIndex *first, const WordIndex *second) const { return !memcmp(first, second, size_); } private: const std::size_t size_; }; struct DedupeEntry { typedef WordIndex *Key; Key GetKey() const { return key; } void SetKey(WordIndex *to) { key = to; } Key key; static DedupeEntry Construct(WordIndex *at) { DedupeEntry ret; ret.key = at; return ret; } }; // TODO: don't have this here, should be with probing hash table defaults? const float kProbingMultiplier = 1.5; typedef util::ProbingHashTable Dedupe; class Writer { public: Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size) : block_(position), gram_(block_->Get(), order), dedupe_invalid_(order, std::numeric_limits::max()), dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), buffer_(new WordIndex[order - 1]), block_size_(position.GetChain().BlockSize()) { dedupe_.Clear(); assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); if (order == 1) { // Add special words. AdjustCounts is responsible if order != 1. AddUnigramWord(kUNK); AddUnigramWord(kBOS); } } ~Writer() { block_->SetValidSize(reinterpret_cast(gram_.begin()) - static_cast(block_->Get())); (++block_).Poison(); } // Write context with a bunch of void StartSentence() { for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) { *i = kBOS; } } void Append(WordIndex word) { *(gram_.end() - 1) = word; Dedupe::MutableIterator at; bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at); if (found) { // Already present. NGram already(at->key, gram_.Order()); ++(already.Count()); // Shift left by one. memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1)); return; } // Complete the write. gram_.Count() = 1; // Prepare the next n-gram. if (reinterpret_cast(gram_.begin()) + gram_.TotalSize() != static_cast(block_->Get()) + block_size_) { NGram last(gram_); gram_.NextInMemory(); std::copy(last.begin() + 1, last.end(), gram_.begin()); return; } // Block end. Need to store the context in a temporary buffer. std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); dedupe_.Clear(); block_->SetValidSize(block_size_); gram_.ReBase((++block_)->Get()); std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); } private: void AddUnigramWord(WordIndex index) { *gram_.begin() = index; gram_.Count() = 0; gram_.NextInMemory(); if (gram_.Base() == static_cast(block_->Get()) + block_size_) { block_->SetValidSize(block_size_); gram_.ReBase((++block_)->Get()); } } util::stream::Link block_; NGram gram_; // This is the memory behind the invalid value in dedupe_. std::vector dedupe_invalid_; // Hash table combiner implementation. Dedupe dedupe_; // Small buffer to hold existing ngrams when shifting across a block boundary. boost::scoped_array buffer_; const std::size_t block_size_; }; } // namespace float CorpusCount::DedupeMultiplier(std::size_t order) { return kProbingMultiplier * static_cast(sizeof(DedupeEntry)) / static_cast(NGram::TotalSize(order)); } std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { return ngram::GrowableVocab::MemUsage(vocab_estimate); } CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::vector &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), prune_words_(prune_words), prune_vocab_filename_(prune_vocab_filename), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)), disallowed_symbol_action_(disallowed_symbol) { } namespace { void ComplainDisallowed(StringPiece word, WarningAction &action) { switch (action) { case SILENT: return; case COMPLAIN: std::cerr << "Warning: " << word << " appears in the input. All instances of , , and will be interpreted as whitespace." << std::endl; action = SILENT; return; case THROW_UP: UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing in the future. Pass --skip_symbols to convert these symbols to whitespace."); } } } // namespace void CorpusCount::Run(const util::stream::ChainPosition &position) { ngram::GrowableVocab vocab(type_count_, vocab_write_); token_count_ = 0; type_count_ = 0; const WordIndex end_sentence = vocab.FindOrInsert(""); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; bool delimiters[256]; util::BoolCharacter::Build("\0\t\n\r ", delimiters); try { while(true) { StringPiece line(from_.ReadLine()); writer.StartSentence(); for (util::TokenIter w(line, delimiters); w; ++w) { WordIndex word = vocab.FindOrInsert(*w); if (word <= 2) { ComplainDisallowed(*w, disallowed_symbol_action_); continue; } writer.Append(word); ++count; } writer.Append(end_sentence); } } catch (const util::EndOfFileException &e) {} token_count_ = count; type_count_ = vocab.Size(); // Create list of unigrams that are supposed to be pruned if (!prune_vocab_filename_.empty()) { try { util::FilePiece prune_vocab_file(prune_vocab_filename_.c_str()); prune_words_.resize(vocab.Size(), true); try { while (true) { StringPiece line(prune_vocab_file.ReadLine()); for (util::TokenIter w(line, delimiters); w; ++w) prune_words_[vocab.Index(*w)] = false; } } catch (const util::EndOfFileException &e) {} // Never prune , , prune_words_[kUNK] = false; prune_words_[kBOS] = false; prune_words_[kEOS] = false; } catch (const util::Exception &e) { std::cerr << e.what() << std::endl; abort(); } } } } // namespace builder } // namespace lm