diff options
Diffstat (limited to 'lm')
-rw-r--r-- | lm/bhiksha.cc | 15 | ||||
-rw-r--r-- | lm/bhiksha.hh | 9 | ||||
-rw-r--r-- | lm/binary_format.cc | 248 | ||||
-rw-r--r-- | lm/binary_format.hh | 112 | ||||
-rw-r--r-- | lm/builder/corpus_count.cc | 29 | ||||
-rw-r--r-- | lm/builder/interpolate.cc | 8 | ||||
-rw-r--r-- | lm/config.cc | 4 | ||||
-rw-r--r-- | lm/facade.hh | 6 | ||||
-rw-r--r-- | lm/filter/arpa_io.hh | 4 | ||||
-rw-r--r-- | lm/filter/count_io.hh | 23 | ||||
-rw-r--r-- | lm/filter/filter_main.cc | 167 | ||||
-rw-r--r-- | lm/filter/format.hh | 2 | ||||
-rw-r--r-- | lm/filter/vocab.cc | 6 | ||||
-rw-r--r-- | lm/model.cc | 84 | ||||
-rw-r--r-- | lm/model.hh | 12 | ||||
-rw-r--r-- | lm/model_test.cc | 8 | ||||
-rw-r--r-- | lm/quantize.cc | 12 | ||||
-rw-r--r-- | lm/quantize.hh | 5 | ||||
-rw-r--r-- | lm/search_hashed.cc | 33 | ||||
-rw-r--r-- | lm/search_hashed.hh | 12 | ||||
-rw-r--r-- | lm/search_trie.cc | 17 | ||||
-rw-r--r-- | lm/search_trie.hh | 18 | ||||
-rw-r--r-- | lm/trie.hh | 9 | ||||
-rw-r--r-- | lm/trie_sort.cc | 4 | ||||
-rw-r--r-- | lm/virtual_interface.hh | 6 | ||||
-rw-r--r-- | lm/vocab.cc | 28 | ||||
-rw-r--r-- | lm/vocab.hh | 12 |
27 files changed, 462 insertions, 431 deletions
diff --git a/lm/bhiksha.cc b/lm/bhiksha.cc index 088ea98d4..c8a18dfda 100644 --- a/lm/bhiksha.cc +++ b/lm/bhiksha.cc @@ -1,4 +1,6 @@ #include "lm/bhiksha.hh" + +#include "lm/binary_format.hh" #include "lm/config.hh" #include "util/file.hh" #include "util/exception.hh" @@ -15,11 +17,11 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_ const uint8_t kArrayBhikshaVersion = 0; // TODO: put this in binary file header instead when I change the binary file format again. -void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { - uint8_t version; - uint8_t configured_bits; - util::ReadOrThrow(fd, &version, 1); - util::ReadOrThrow(fd, &configured_bits, 1); +void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { + uint8_t buffer[2]; + file.ReadForConfig(buffer, 2, offset); + uint8_t version = buffer[0]; + uint8_t configured_bits = buffer[1]; if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); config.pointer_bhiksha_bits = configured_bits; } @@ -87,9 +89,6 @@ void ArrayBhiksha::FinishedLoading(const Config &config) { *(head_write++) = config.pointer_bhiksha_bits; } -void ArrayBhiksha::LoadedBinary() { -} - } // namespace trie } // namespace ngram } // namespace lm diff --git a/lm/bhiksha.hh b/lm/bhiksha.hh index 8ff88654d..350571a6e 100644 --- a/lm/bhiksha.hh +++ b/lm/bhiksha.hh @@ -24,6 +24,7 @@ namespace lm { namespace ngram { struct Config; +class BinaryFormat; namespace trie { @@ -31,7 +32,7 @@ class DontBhiksha { public: static const ModelType kModelTypeAdd = static_cast<ModelType>(0); - static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} + static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {} static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } @@ -53,8 +54,6 @@ class DontBhiksha { void FinishedLoading(const Config &/*config*/) {} - void LoadedBinary() {} - uint8_t InlineBits() const { return next_.bits; } private: @@ -65,7 +64,7 @@ class ArrayBhiksha { public: static const ModelType kModelTypeAdd = kArrayAdd; - static void UpdateConfigFromBinary(int fd, Config &config); + static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); @@ -93,8 +92,6 @@ class ArrayBhiksha { void FinishedLoading(const Config &config); - void LoadedBinary(); - uint8_t InlineBits() const { return next_inline_.bits; } private: diff --git a/lm/binary_format.cc b/lm/binary_format.cc index bef51eb82..9c744b138 100644 --- a/lm/binary_format.cc +++ b/lm/binary_format.cc @@ -14,6 +14,9 @@ namespace lm { namespace ngram { + +const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; + namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; @@ -58,8 +61,6 @@ struct Sanity { } }; -const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; - std::size_t TotalHeaderSize(unsigned char order) { return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } @@ -81,83 +82,6 @@ void WriteHeader(void *to, const Parameters ¶ms) { } // namespace -uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { - if (config.write_mmap) { - std::size_t total = TotalHeaderSize(order) + memory_size; - backing.file.reset(util::CreateOrThrow(config.write_mmap)); - if (config.write_method == Config::WRITE_MMAP) { - backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); - } else { - util::ResizeOrThrow(backing.file.get(), 0); - util::MapAnonymous(total, backing.vocab); - } - strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); - return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order); - } else { - util::MapAnonymous(memory_size, backing.vocab); - return reinterpret_cast<uint8_t*>(backing.vocab.get()); - } -} - -uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { - std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; - if (config.write_mmap) { - // Grow the file to accomodate the search, using zeros. - try { - util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); - } catch (util::ErrnoException &e) { - e << " for file " << config.write_mmap; - throw e; - } - - if (config.write_method == Config::WRITE_AFTER) { - util::MapAnonymous(memory_size, backing.search); - return reinterpret_cast<uint8_t*>(backing.search.get()); - } - // mmap it now. - // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. - std::size_t page_size = util::SizePage(); - std::size_t alignment_cruft = adjusted_vocab % page_size; - backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); - return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft; - } else { - util::MapAnonymous(memory_size, backing.search); - return reinterpret_cast<uint8_t*>(backing.search.get()); - } -} - -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { - if (!config.write_mmap) return; - switch (config.write_method) { - case Config::WRITE_MMAP: - util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); - util::SyncOrThrow(backing.search.get(), backing.search.size()); - break; - case Config::WRITE_AFTER: - util::SeekOrThrow(backing.file.get(), 0); - util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size()); - util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad); - util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size()); - util::FSyncOrThrow(backing.file.get()); - break; - } - // header and vocab share the same mmap. The header is written here because we know the counts. - Parameters params = Parameters(); - params.counts = counts; - params.fixed.order = counts.size(); - params.fixed.probing_multiplier = config.probing_multiplier; - params.fixed.model_type = model_type; - params.fixed.has_vocabulary = config.include_vocab; - params.fixed.search_version = search_version; - WriteHeader(backing.vocab.get(), params); - if (config.write_method == Config::WRITE_AFTER) { - util::SeekOrThrow(backing.file.get(), 0); - util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size())); - } -} - -namespace detail { - bool IsBinaryFormat(int fd) { const uint64_t size = util::SizeFile(fd); if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false; @@ -209,44 +133,164 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version); } -void SeekPastHeader(int fd, const Parameters ¶ms) { - util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); +const std::size_t kInvalidSize = static_cast<std::size_t>(-1); + +BinaryFormat::BinaryFormat(const Config &config) + : write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method), + header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {} + +void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms) { + file_.reset(fd); + write_mmap_ = NULL; // Ignore write requests; this is already in binary format. + ReadHeader(fd, params); + MatchCheck(model_type, search_version, params); + header_size_ = TotalHeaderSize(params.counts.size()); +} + +void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const { + assert(header_size_ != kInvalidSize); + util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_); } -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { - const uint64_t file_size = util::SizeFile(backing.file.get()); +void *BinaryFormat::LoadBinary(std::size_t size) { + assert(header_size_ != kInvalidSize); + const uint64_t file_size = util::SizeFile(file_.get()); // The header is smaller than a page, so we have to map the whole header as well. - std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); - if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) - UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); + uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size); + UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); - util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search); + util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_); - if (config.enumerate_vocab && !params.fixed.has_vocabulary) - UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); + vocab_string_offset_ = total_map; + return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_; +} + +void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) { + vocab_size_ = memory_size; + if (!write_mmap_) { + header_size_ = 0; + util::MapAnonymous(memory_size, memory_vocab_); + return reinterpret_cast<uint8_t*>(memory_vocab_.get()); + } + header_size_ = TotalHeaderSize(order); + std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size)); + file_.reset(util::CreateOrThrow(write_mmap_)); + // some gccs complain about uninitialized variables even though all enum values are covered. + void *vocab_base = NULL; + switch (write_method_) { + case Config::WRITE_MMAP: + mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); + vocab_base = mapping_.get(); + break; + case Config::WRITE_AFTER: + util::ResizeOrThrow(file_.get(), 0); + util::MapAnonymous(total, memory_vocab_); + vocab_base = memory_vocab_.get(); + break; + } + strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_); + return reinterpret_cast<uint8_t*>(vocab_base) + header_size_; +} - // Seek to vocabulary words - util::SeekOrThrow(backing.file.get(), total_map); - return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size()); +void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) { + assert(vocab_size_ != kInvalidSize); + vocab_pad_ = vocab_pad; + std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size; + vocab_string_offset_ = new_size; + if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) { + util::MapAnonymous(memory_size, memory_search_); + assert(header_size_ == 0 || write_mmap_); + vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_; + return reinterpret_cast<uint8_t*>(memory_search_.get()); + } + + assert(write_method_ == Config::WRITE_MMAP); + // Also known as total size without vocab words. + // Grow the file to accomodate the search, using zeros. + // According to man mmap, behavior is undefined when the file is resized + // underneath a mmap that is not a multiple of the page size. So to be + // safe, we'll unmap it and map it again. + mapping_.reset(); + util::ResizeOrThrow(file_.get(), new_size); + void *ret; + MapFile(vocab_base, ret); + return ret; } -void ComplainAboutARPA(const Config &config, ModelType model_type) { - if (config.write_mmap || !config.messages) return; - if (config.arpa_complain == Config::ALL) { - *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; - } else if (config.arpa_complain == Config::EXPENSIVE && - (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { - *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; +void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) { + // Checking Config's include_vocab is the responsibility of the caller. + assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize); + if (!write_mmap_) { + // Unchanged base. + vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()); + search_base = reinterpret_cast<uint8_t*>(memory_search_.get()); + return; + } + if (write_method_ == Config::WRITE_MMAP) { + mapping_.reset(); + } + util::SeekOrThrow(file_.get(), VocabStringReadingOffset()); + util::WriteOrThrow(file_.get(), &buffer[0], buffer.size()); + if (write_method_ == Config::WRITE_MMAP) { + MapFile(vocab_base, search_base); + } else { + vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_; + search_base = reinterpret_cast<uint8_t*>(memory_search_.get()); + } +} + +void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) { + if (!write_mmap_) return; + switch (write_method_) { + case Config::WRITE_MMAP: + util::SyncOrThrow(mapping_.get(), mapping_.size()); + break; + case Config::WRITE_AFTER: + util::SeekOrThrow(file_.get(), 0); + util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size()); + util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_); + util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size()); + util::FSyncOrThrow(file_.get()); + break; + } + // header and vocab share the same mmap. + Parameters params = Parameters(); + memset(¶ms, 0, sizeof(Parameters)); + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + params.fixed.search_version = search_version; + switch (write_method_) { + case Config::WRITE_MMAP: + WriteHeader(mapping_.get(), params); + util::SyncOrThrow(mapping_.get(), mapping_.size()); + break; + case Config::WRITE_AFTER: + { + std::vector<uint8_t> buffer(TotalHeaderSize(counts.size())); + WriteHeader(&buffer[0], params); + util::SeekOrThrow(file_.get(), 0); + util::WriteOrThrow(file_.get(), &buffer[0], buffer.size()); + } + break; } } -} // namespace detail +void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) { + mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED); + vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_; + search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_; +} bool RecognizeBinary(const char *file, ModelType &recognized) { util::scoped_fd fd(util::OpenReadOrThrow(file)); - if (!detail::IsBinaryFormat(fd.get())) return false; + if (!IsBinaryFormat(fd.get())) { + return false; + } Parameters params; - detail::ReadHeader(fd.get(), params); + ReadHeader(fd.get(), params); recognized = params.fixed.model_type; return true; } diff --git a/lm/binary_format.hh b/lm/binary_format.hh index bf699d5f4..f33f88d75 100644 --- a/lm/binary_format.hh +++ b/lm/binary_format.hh @@ -17,6 +17,8 @@ namespace lm { namespace ngram { +extern const char *kModelNames[6]; + /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in * this header designed for use by decoder authors. @@ -42,67 +44,63 @@ struct Parameters { std::vector<uint64_t> counts; }; -struct Backing { - // File behind memory, if any. - util::scoped_fd file; - // Vocabulary lookup table. Not to be confused with the vocab words themselves. - util::scoped_memory vocab; - // Raw block of memory backing the language model data structures - util::scoped_memory search; -}; - -// Create just enough of a binary file to write vocabulary to it. -uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); -// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. -uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing); - -// Write header to binary file. This is done last to prevent incomplete files -// from loading. -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing); +class BinaryFormat { + public: + explicit BinaryFormat(const Config &config); + + // Reading a binary file: + // Takes ownership of fd + void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms); + // Used to read parts of the file to update the config object before figuring out full size. + void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const; + // Actually load the binary file and return a pointer to the beginning of the search area. + void *LoadBinary(std::size_t size); + + uint64_t VocabStringReadingOffset() const { + assert(vocab_string_offset_ != kInvalidOffset); + return vocab_string_offset_; + } -namespace detail { + // Writing a binary file or initializing in RAM from ARPA: + // Size for vocabulary. + void *SetupJustVocab(std::size_t memory_size, uint8_t order); + // Warning: can change the vocaulary base pointer. + void *GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base); + // Warning: can change vocabulary and search base addresses. + void WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base); + // Write the header at the beginning of the file. + void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts); + + private: + void MapFile(void *&vocab_base, void *&search_base); + + // Copied from configuration. + const Config::WriteMethod write_method_; + const char *write_mmap_; + util::LoadMethod load_method_; + + // File behind memory, if any. + util::scoped_fd file_; + + // If there is a file involved, a single mapping. + util::scoped_memory mapping_; + + // If the data is only in memory, separately allocate each because the trie + // knows vocab's size before it knows search's size (because SRILM might + // have pruned). + util::scoped_memory memory_vocab_, memory_search_; + + // Memory ranges. Note that these may not be contiguous and may not all + // exist. + std::size_t header_size_, vocab_size_, vocab_pad_; + // aka end of search. + uint64_t vocab_string_offset_; + + static const uint64_t kInvalidOffset = (uint64_t)-1; +}; bool IsBinaryFormat(int fd); -void ReadHeader(int fd, Parameters ¶ms); - -void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms); - -void SeekPastHeader(int fd, const Parameters ¶ms); - -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing); - -void ComplainAboutARPA(const Config &config, ModelType model_type); - -} // namespace detail - -template <class To> void LoadLM(const char *file, const Config &config, To &to) { - Backing &backing = to.MutableBacking(); - backing.file.reset(util::OpenReadOrThrow(file)); - - try { - if (detail::IsBinaryFormat(backing.file.get())) { - Parameters params; - detail::ReadHeader(backing.file.get(), params); - detail::MatchCheck(To::kModelType, To::kVersion, params); - // Replace the run-time configured probing_multiplier with the one in the file. - Config new_config(config); - new_config.probing_multiplier = params.fixed.probing_multiplier; - detail::SeekPastHeader(backing.file.get(), params); - To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); - uint64_t memory_size = To::Size(params.counts, new_config); - uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); - to.InitializeFromBinary(start, params, new_config, backing.file.get()); - } else { - detail::ComplainAboutARPA(config, To::kModelType); - to.InitializeFromARPA(file, config); - } - } catch (util::Exception &e) { - e << " File: " << file; - throw; - } -} - } // namespace ngram } // namespace lm #endif // LM_BINARY_FORMAT__ diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc index 6ad91dde7..ccc06efca 100644 --- a/lm/builder/corpus_count.cc +++ b/lm/builder/corpus_count.cc @@ -87,7 +87,7 @@ class VocabHandout { Table table_; std::size_t double_cutoff_; - + util::FakeOFStream word_list_; }; @@ -98,7 +98,7 @@ class DedupeHash : public std::unary_function<const WordIndex *, bool> { std::size_t operator()(const WordIndex *start) const { return util::MurmurHashNative(start, size_); } - + private: const std::size_t size_; }; @@ -106,11 +106,11 @@ class DedupeHash : public std::unary_function<const WordIndex *, bool> { class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> { public: explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {} - + bool operator()(const WordIndex *first, const WordIndex *second) const { return !memcmp(first, second, size_); - } - + } + private: const std::size_t size_; }; @@ -131,7 +131,7 @@ typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe; class Writer { public: - Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size) + Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size) : block_(position), gram_(block_->Get(), order), dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()), dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), @@ -140,7 +140,7 @@ class Writer { dedupe_.Clear(); assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); if (order == 1) { - // Add special words. AdjustCounts is responsible if order != 1. + // Add special words. AdjustCounts is responsible if order != 1. AddUnigramWord(kUNK); AddUnigramWord(kBOS); } @@ -170,16 +170,16 @@ class Writer { memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1)); return; } - // Complete the write. + // Complete the write. gram_.Count() = 1; - // Prepare the next n-gram. + // Prepare the next n-gram. if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) { NGram last(gram_); gram_.NextInMemory(); std::copy(last.begin() + 1, last.end(), gram_.begin()); return; } - // Block end. Need to store the context in a temporary buffer. + // Block end. Need to store the context in a temporary buffer. std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); dedupe_.Clear(); block_->SetValidSize(block_size_); @@ -207,7 +207,7 @@ class Writer { // Hash table combiner implementation. Dedupe dedupe_; - // Small buffer to hold existing ngrams when shifting across a block boundary. + // Small buffer to hold existing ngrams when shifting across a block boundary. boost::scoped_array<WordIndex> buffer_; const std::size_t block_size_; @@ -223,7 +223,7 @@ std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { return VocabHandout::MemUsage(vocab_estimate); } -CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { @@ -240,7 +240,10 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) { uint64_t count = 0; bool delimiters[256]; memset(delimiters, 0, sizeof(delimiters)); - delimiters['\0'] = delimiters['\t'] = delimiters['\n'] = delimiters['\r'] = delimiters[' '] = true; + const char kDelimiterSet[] = "\0\t\n\r "; + for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) { + delimiters[static_cast<unsigned char>(*i)] = true; + } try { while(true) { StringPiece line(from_.ReadLine()); diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc index 52e69f02e..500268069 100644 --- a/lm/builder/interpolate.cc +++ b/lm/builder/interpolate.cc @@ -33,12 +33,12 @@ class Callback { pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; probs_[order_minus_1 + 1] = pay.complete.prob; pay.complete.prob = log10(pay.complete.prob); - // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling. - if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS && backoffs_[order_minus_1].Get()) { // check valid pointer at tht end + // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling. + if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get())); ++backoffs_[order_minus_1]; } else { - // Not a context. + // Not a context. pay.complete.backoff = 0.0; } } @@ -52,7 +52,7 @@ class Callback { }; } // namespace -Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) +Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {} // perform order-wise interpolation diff --git a/lm/config.cc b/lm/config.cc index dc3365319..9520c41c8 100644 --- a/lm/config.cc +++ b/lm/config.cc @@ -11,11 +11,7 @@ Config::Config() : enumerate_vocab(NULL), unknown_missing(COMPLAIN), sentence_marker_missing(THROW_UP), -#if defined(_WIN32) || defined(_WIN64) - positive_log_probability(SILENT), -#else positive_log_probability(THROW_UP), -#endif unknown_missing_logprob(-100.0), probing_multiplier(1.5), building_memory(1073741824ULL), // 1 GB diff --git a/lm/facade.hh b/lm/facade.hh index 760e839e0..de1551f12 100644 --- a/lm/facade.hh +++ b/lm/facade.hh @@ -17,14 +17,14 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ typedef VocabularyT Vocabulary; /* Translate from void* to State */ - FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const { + FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const { return static_cast<const Child*>(this)->FullScore( *reinterpret_cast<const State*>(in_state), new_word, *reinterpret_cast<State*>(out_state)); } - FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const { + FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const { return static_cast<const Child*>(this)->FullScoreForgotState( context_rbegin, context_rend, @@ -37,7 +37,7 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob; } - float Score(const void *in_state, const WordIndex new_word, void *out_state) const { + float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const { return static_cast<const Child*>(this)->Score( *reinterpret_cast<const State*>(in_state), new_word, diff --git a/lm/filter/arpa_io.hh b/lm/filter/arpa_io.hh index 08e658666..602b5b31b 100644 --- a/lm/filter/arpa_io.hh +++ b/lm/filter/arpa_io.hh @@ -14,10 +14,6 @@ #include <string> #include <vector> -#if !defined __MINGW32__ -#include <err.h> -#endif - #include <string.h> #include <stdint.h> diff --git a/lm/filter/count_io.hh b/lm/filter/count_io.hh index 740b8d50e..d992026ff 100644 --- a/lm/filter/count_io.hh +++ b/lm/filter/count_io.hh @@ -5,27 +5,18 @@ #include <iostream> #include <string> -#if !defined __MINGW32__ -#include <err.h> -#endif - +#include "util/fake_ofstream.hh" +#include "util/file.hh" #include "util/file_piece.hh" namespace lm { class CountOutput : boost::noncopyable { public: - explicit CountOutput(const char *name) : file_(name, std::ios::out) {} + explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {} void AddNGram(const StringPiece &line) { - if (!(file_ << line << '\n')) { -#if defined __MINGW32__ - std::cerr<<"Writing counts file failed"<<std::endl; - exit(3); -#else - err(3, "Writing counts file failed"); -#endif - } + file_ << line << '\n'; } template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { @@ -37,12 +28,12 @@ class CountOutput : boost::noncopyable { } private: - std::fstream file_; + util::FakeOFStream file_; }; class CountBatch { public: - explicit CountBatch(std::streamsize initial_read) + explicit CountBatch(std::streamsize initial_read) : initial_read_(initial_read) { buffer_.reserve(initial_read); } @@ -75,7 +66,7 @@ class CountBatch { private: std::streamsize initial_read_; - // This could have been a std::string but that's less happy with raw writes. + // This could have been a std::string but that's less happy with raw writes. std::vector<char> buffer_; }; diff --git a/lm/filter/filter_main.cc b/lm/filter/filter_main.cc index f89ac4df3..82fdc1ef7 100644 --- a/lm/filter/filter_main.cc +++ b/lm/filter/filter_main.cc @@ -6,6 +6,7 @@ #endif #include "lm/filter/vocab.hh" #include "lm/filter/wrapper.hh" +#include "util/exception.hh" #include "util/file_piece.hh" #include <boost/ptr_container/ptr_vector.hpp> @@ -57,7 +58,7 @@ typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION, MODE_UNSET} Fil typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format; struct Config { - Config() : + Config() : #ifndef NTHREAD batch_size(25000), threads(boost::thread::hardware_concurrency()), @@ -157,102 +158,96 @@ template <class Format> void DispatchFilterModes(const Config &config, std::istr } // namespace lm int main(int argc, char *argv[]) { - if (argc < 4) { - lm::DisplayHelp(argv[0]); - return 1; - } + try { + if (argc < 4) { + lm::DisplayHelp(argv[0]); + return 1; + } - // I used to have boost::program_options, but some users didn't want to compile boost. - lm::Config config; - config.mode = lm::MODE_UNSET; - for (int i = 1; i < argc - 2; ++i) { - const char *str = argv[i]; - if (!std::strcmp(str, "copy")) { - config.mode = lm::MODE_COPY; - } else if (!std::strcmp(str, "single")) { - config.mode = lm::MODE_SINGLE; - } else if (!std::strcmp(str, "multiple")) { - config.mode = lm::MODE_MULTIPLE; - } else if (!std::strcmp(str, "union")) { - config.mode = lm::MODE_UNION; - } else if (!std::strcmp(str, "phrase")) { - config.phrase = true; - } else if (!std::strcmp(str, "context")) { - config.context = true; - } else if (!std::strcmp(str, "arpa")) { - config.format = lm::FORMAT_ARPA; - } else if (!std::strcmp(str, "raw")) { - config.format = lm::FORMAT_COUNT; + // I used to have boost::program_options, but some users didn't want to compile boost. + lm::Config config; + config.mode = lm::MODE_UNSET; + for (int i = 1; i < argc - 2; ++i) { + const char *str = argv[i]; + if (!std::strcmp(str, "copy")) { + config.mode = lm::MODE_COPY; + } else if (!std::strcmp(str, "single")) { + config.mode = lm::MODE_SINGLE; + } else if (!std::strcmp(str, "multiple")) { + config.mode = lm::MODE_MULTIPLE; + } else if (!std::strcmp(str, "union")) { + config.mode = lm::MODE_UNION; + } else if (!std::strcmp(str, "phrase")) { + config.phrase = true; + } else if (!std::strcmp(str, "context")) { + config.context = true; + } else if (!std::strcmp(str, "arpa")) { + config.format = lm::FORMAT_ARPA; + } else if (!std::strcmp(str, "raw")) { + config.format = lm::FORMAT_COUNT; #ifndef NTHREAD - } else if (!std::strncmp(str, "threads:", 8)) { - config.threads = boost::lexical_cast<size_t>(str + 8); - if (!config.threads) { - std::cerr << "Specify at least one thread." << std::endl; + } else if (!std::strncmp(str, "threads:", 8)) { + config.threads = boost::lexical_cast<size_t>(str + 8); + if (!config.threads) { + std::cerr << "Specify at least one thread." << std::endl; + return 1; + } + } else if (!std::strncmp(str, "batch_size:", 11)) { + config.batch_size = boost::lexical_cast<size_t>(str + 11); + if (config.batch_size < 5000) { + std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; + if (!config.batch_size) return 1; + } +#endif + } else { + lm::DisplayHelp(argv[0]); return 1; } - } else if (!std::strncmp(str, "batch_size:", 11)) { - config.batch_size = boost::lexical_cast<size_t>(str + 11); - if (config.batch_size < 5000) { - std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; - if (!config.batch_size) return 1; - } -#endif - } else { + } + + if (config.mode == lm::MODE_UNSET) { lm::DisplayHelp(argv[0]); return 1; } - } - if (config.mode == lm::MODE_UNSET) { - lm::DisplayHelp(argv[0]); - return 1; - } - - if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { - std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; - return 1; - } + if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { + std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; + return 1; + } - bool cmd_is_model = true; - const char *cmd_input = argv[argc - 2]; - if (!strncmp(cmd_input, "vocab:", 6)) { - cmd_is_model = false; - cmd_input += 6; - } else if (!strncmp(cmd_input, "model:", 6)) { - cmd_input += 6; - } else if (strchr(cmd_input, ':')) { -#if defined __MINGW32__ - std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl; - exit(1); -#else - errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); -#endif // defined - } else { - std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; - } - std::ifstream cmd_file; - std::istream *vocab; - if (cmd_is_model) { - vocab = &std::cin; - } else { - cmd_file.open(cmd_input, std::ios::in); - if (!cmd_file) { -#if defined __MINGW32__ - std::cerr << "Could not open input file " << cmd_input << std::endl; - exit(2); -#else - err(2, "Could not open input file %s", cmd_input); -#endif // defined + bool cmd_is_model = true; + const char *cmd_input = argv[argc - 2]; + if (!strncmp(cmd_input, "vocab:", 6)) { + cmd_is_model = false; + cmd_input += 6; + } else if (!strncmp(cmd_input, "model:", 6)) { + cmd_input += 6; + } else if (strchr(cmd_input, ':')) { + std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl; + return 1; + } else { + std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; + } + std::ifstream cmd_file; + std::istream *vocab; + if (cmd_is_model) { + vocab = &std::cin; + } else { + cmd_file.open(cmd_input, std::ios::in); + UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input); + vocab = &cmd_file; } - vocab = &cmd_file; - } - util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); + util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); - if (config.format == lm::FORMAT_ARPA) { - lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); - } else if (config.format == lm::FORMAT_COUNT) { - lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + if (config.format == lm::FORMAT_ARPA) { + lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); + } else if (config.format == lm::FORMAT_COUNT) { + lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + } + return 0; + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; } - return 0; } diff --git a/lm/filter/format.hh b/lm/filter/format.hh index 7f945b0d6..7d8c28dbc 100644 --- a/lm/filter/format.hh +++ b/lm/filter/format.hh @@ -1,5 +1,5 @@ #ifndef LM_FILTER_FORMAT_H__ -#define LM_FITLER_FORMAT_H__ +#define LM_FILTER_FORMAT_H__ #include "lm/filter/arpa_io.hh" #include "lm/filter/count_io.hh" diff --git a/lm/filter/vocab.cc b/lm/filter/vocab.cc index 7ed5d92fb..011ab5992 100644 --- a/lm/filter/vocab.cc +++ b/lm/filter/vocab.cc @@ -5,10 +5,6 @@ #include <ctype.h> -#if !defined __MINGW32__ -#include <err.h> -#endif - namespace lm { namespace vocab { @@ -34,7 +30,7 @@ bool IsLineEnd(std::istream &in) { }// namespace // Read space separated words in enter separated lines. These lines can be -// very long, so don't read an entire line at a time. +// very long, so don't read an entire line at a time. unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out) { in.exceptions(std::istream::badbit); unsigned int sentence = 0; diff --git a/lm/model.cc b/lm/model.cc index a26654a6f..a5a16bf8e 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -34,23 +34,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size); } -template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { - LoadLM(file, config, *this); - - // g++ prints warnings unless these are fully initialized. - State begin_sentence = State(); - begin_sentence.length = 1; - begin_sentence.words[0] = vocab_.BeginSentence(); - typename Search::Node ignored_node; - bool ignored_independent_left; - uint64_t ignored_extend_left; - begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); - State null_context = State(); - null_context.length = 0; - P::Init(begin_sentence, null_context, vocab_, search_.Order()); +namespace { +void ComplainAboutARPA(const Config &config, ModelType model_type) { + if (config.write_mmap || !config.messages) return; + if (config.arpa_complain == Config::ALL) { + *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; + } else if (config.arpa_complain == Config::EXPENSIVE && + (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { + *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; + } } -namespace { void CheckCounts(const std::vector<uint64_t> &counts) { UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE); if (sizeof(uint64_t) > sizeof(std::size_t)) { @@ -59,18 +53,45 @@ void CheckCounts(const std::vector<uint64_t> &counts) { } } } + } // namespace -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { - CheckCounts(params.counts); - SetupMemory(start, params.counts, config); - vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab); - search_.LoadedBinary(); +template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) { + util::scoped_fd fd(util::OpenReadOrThrow(file)); + if (IsBinaryFormat(fd.get())) { + Parameters parameters; + int fd_shallow = fd.release(); + backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters); + CheckCounts(parameters.counts); + + Config new_config(init_config); + new_config.probing_multiplier = parameters.fixed.probing_multiplier; + Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config); + UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); + + SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config); + vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset()); + } else { + ComplainAboutARPA(init_config, kModelType); + InitializeFromARPA(fd.release(), file, init_config); + } + + // g++ prints warnings unless these are fully initialized. + State begin_sentence = State(); + begin_sentence.length = 1; + begin_sentence.words[0] = vocab_.BeginSentence(); + typename Search::Node ignored_node; + bool ignored_independent_left; + uint64_t ignored_extend_left; + begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); + State null_context = State(); + null_context.length = 0; + P::Init(begin_sentence, null_context, vocab_, search_.Order()); } -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { - // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. - util::FilePiece f(backing_.file.release(), file, config.ProgressMessages()); +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) { + // Backing file is the ARPA. + util::FilePiece f(fd, file, config.ProgressMessages()); try { std::vector<uint64_t> counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. @@ -81,13 +102,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. - vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); + vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config); - if (config.write_mmap) { + if (config.write_mmap && config.include_vocab) { WriteWordsWrapper wrap(config.enumerate_vocab); vocab_.ConfigureEnumerate(&wrap, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config)); + void *vocab_rebase, *search_rebase; + backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase); + // Due to writing at the end of file, mmap may have relocated data. So remap. + vocab_.Relocate(vocab_rebase); + search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config); } else { vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -99,18 +124,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.UnknownUnigram().backoff = 0.0; search_.UnknownUnigram().prob = config.unknown_missing_logprob; } - FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_); + backing_.FinishFile(config, kModelType, kVersion, counts); } catch (util::Exception &e) { e << " Byte: " << f.Offset(); throw; } } -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { - util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); - Search::UpdateConfigFromBinary(fd, counts, config); -} - template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) { diff --git a/lm/model.hh b/lm/model.hh index c9c17c4b3..e75da93bf 100644 --- a/lm/model.hh +++ b/lm/model.hh @@ -104,10 +104,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod } private: - friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); - - static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); - FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const; // Score bigrams and above. Do not include backoff. @@ -116,15 +112,11 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod // Appears after Size in the cc file. void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config); - void InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd); - - void InitializeFromARPA(const char *file, const Config &config); + void InitializeFromARPA(int fd, const char *file, const Config &config); float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const; - Backing &MutableBacking() { return backing_; } - - Backing backing_; + BinaryFormat backing_; VocabularyT vocab_; diff --git a/lm/model_test.cc b/lm/model_test.cc index eb1590942..7005b05ea 100644 --- a/lm/model_test.cc +++ b/lm/model_test.cc @@ -360,10 +360,11 @@ BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { LoadingTest<QuantArrayTrieModel>(); } -template <class ModelT> void BinaryTest() { +template <class ModelT> void BinaryTest(Config::WriteMethod write_method) { Config config; config.write_mmap = "test.binary"; config.messages = NULL; + config.write_method = write_method; ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; @@ -406,6 +407,11 @@ template <class ModelT> void BinaryTest() { unlink("test_nounk.binary"); } +template <class ModelT> void BinaryTest() { + BinaryTest<ModelT>(Config::WRITE_MMAP); + BinaryTest<ModelT>(Config::WRITE_AFTER); +} + BOOST_AUTO_TEST_CASE(write_and_read_probing) { BinaryTest<ProbingModel>(); } diff --git a/lm/quantize.cc b/lm/quantize.cc index b58c3f3f6..273ea3989 100644 --- a/lm/quantize.cc +++ b/lm/quantize.cc @@ -38,13 +38,13 @@ const char kSeparatelyQuantizeVersion = 2; } // namespace -void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) { - char version; - util::ReadOrThrow(fd, &version, 1); - util::ReadOrThrow(fd, &config.prob_bits, 1); - util::ReadOrThrow(fd, &config.backoff_bits, 1); +void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { + unsigned char buffer[3]; + file.ReadForConfig(buffer, 3, offset); + char version = buffer[0]; + config.prob_bits = buffer[1]; + config.backoff_bits = buffer[2]; if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); - util::AdvanceOrThrow(fd, -3); } void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) { diff --git a/lm/quantize.hh b/lm/quantize.hh index 8ce2378a7..9d3a2f439 100644 --- a/lm/quantize.hh +++ b/lm/quantize.hh @@ -18,12 +18,13 @@ namespace lm { namespace ngram { struct Config; +class BinaryFormat; /* Store values directly and don't quantize. */ class DontQuantize { public: static const ModelType kModelTypeAdd = static_cast<ModelType>(0); - static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} + static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {} static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } static uint8_t LongestBits(const Config &/*config*/) { return 31; } @@ -136,7 +137,7 @@ class SeparatelyQuantize { public: static const ModelType kModelTypeAdd = kQuantAdd; - static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); + static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); static uint64_t Size(uint8_t order, const Config &config) { uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float); diff --git a/lm/search_hashed.cc b/lm/search_hashed.cc index 62275d277..354a56b46 100644 --- a/lm/search_hashed.cc +++ b/lm/search_hashed.cc @@ -204,9 +204,10 @@ template <class Build, class Activate, class Store> void ReadNGrams( namespace detail { template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { - std::size_t allocated = Unigram::Size(counts[0]); - unigram_ = Unigram(start, counts[0], allocated); - start += allocated; + unigram_ = Unigram(start, counts[0]); + start += Unigram::Size(counts[0]); + std::size_t allocated; + middle_.clear(); for (unsigned int n = 2; n < counts.size(); ++n) { allocated = Middle::Size(counts[n - 1], config.probing_multiplier); middle_.push_back(Middle(start, allocated)); @@ -218,9 +219,21 @@ template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, return start; } -template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing) { - // TODO: fix sorted. - SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config); +/*template <class Value> void HashedSearch<Value>::Relocate(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { + unigram_ = Unigram(start, counts[0]); + start += Unigram::Size(counts[0]); + for (unsigned int n = 2; n < counts.size(); ++n) { + middle[n-2].Relocate(start); + start += Middle::Size(counts[n - 1], config.probing_multiplier) + } + longest_.Relocate(start); +}*/ + +template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing) { + void *vocab_rebase; + void *search_base = backing.GrowForSearch(Size(counts, config), vocab.UnkCountChangePadding(), vocab_rebase); + vocab.Relocate(vocab_rebase); + SetupMemory(reinterpret_cast<uint8_t*>(search_base), counts, config); PositiveProbWarn warn(config.positive_log_probability); Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn); @@ -277,14 +290,6 @@ template <class Value> template <class Build> void HashedSearch<Value>::ApplyBui ReadEnd(f); } -template <class Value> void HashedSearch<Value>::LoadedBinary() { - unigram_.LoadedBinary(); - for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) { - i->LoadedBinary(); - } - longest_.LoadedBinary(); -} - template class HashedSearch<BackoffValue>; template class HashedSearch<RestValue>; diff --git a/lm/search_hashed.hh b/lm/search_hashed.hh index 9d067bc2e..8193262b0 100644 --- a/lm/search_hashed.hh +++ b/lm/search_hashed.hh @@ -18,7 +18,7 @@ namespace util { class FilePiece; } namespace lm { namespace ngram { -struct Backing; +class BinaryFormat; class ProbingVocabulary; namespace detail { @@ -72,7 +72,7 @@ template <class Value> class HashedSearch { static const unsigned int kVersion = 0; // TODO: move probing_multiplier here with next binary file format update. - static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} + static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {} static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { uint64_t ret = Unigram::Size(counts[0]); @@ -84,9 +84,7 @@ template <class Value> class HashedSearch { uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); - void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing); - - void LoadedBinary(); + void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing); unsigned char Order() const { return middle_.size() + 2; @@ -148,7 +146,7 @@ template <class Value> class HashedSearch { public: Unigram() {} - Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : + Unigram(void *start, uint64_t count) : unigram_(static_cast<typename Value::Weights*>(start)) #ifdef DEBUG , count_(count) @@ -168,8 +166,6 @@ template <class Value> class HashedSearch { typename Value::Weights &Unknown() { return unigram_[0]; } - void LoadedBinary() {} - // For building. typename Value::Weights *Raw() { return unigram_; } diff --git a/lm/search_trie.cc b/lm/search_trie.cc index 27605e548..4a88194e8 100644 --- a/lm/search_trie.cc +++ b/lm/search_trie.cc @@ -459,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c } // namespace -template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) { RecordReader inputs[KENLM_MAX_ORDER - 1]; RecordReader contexts[KENLM_MAX_ORDER - 1]; @@ -488,7 +488,10 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); - out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); + void *vocab_relocate; + void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate); + vocab.Relocate(vocab_relocate); + out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config); for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -571,15 +574,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { - unigram_.LoadedBinary(); - for (Middle *i = middle_begin_; i != middle_end_; ++i) { - i->LoadedBinary(); - } - longest_.LoadedBinary(); -} - -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) { std::string temporary_prefix; if (config.temporary_directory_prefix) { temporary_prefix = config.temporary_directory_prefix; diff --git a/lm/search_trie.hh b/lm/search_trie.hh index 763fd1a72..299262a5d 100644 --- a/lm/search_trie.hh +++ b/lm/search_trie.hh @@ -17,13 +17,13 @@ namespace lm { namespace ngram { -struct Backing; +class BinaryFormat; class SortedVocabulary; namespace trie { template <class Quant, class Bhiksha> class TrieSearch; class SortedFiles; -template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing); template <class Quant, class Bhiksha> class TrieSearch { public: @@ -39,11 +39,11 @@ template <class Quant, class Bhiksha> class TrieSearch { static const unsigned int kVersion = 1; - static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { - Quant::UpdateConfigFromBinary(fd, counts, config); - util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); + static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) { + Quant::UpdateConfigFromBinary(file, offset, config); // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2. - if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config); + if (counts.size() > 2) + Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config); } static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { @@ -60,9 +60,7 @@ template <class Quant, class Bhiksha> class TrieSearch { uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); - void LoadedBinary(); - - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing); unsigned char Order() const { return middle_end_ - middle_begin_ + 2; @@ -103,7 +101,7 @@ template <class Quant, class Bhiksha> class TrieSearch { } private: - friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); + friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing); // Middles are managed manually so we can delay construction and they don't have to be copyable. void FreeMiddles() { diff --git a/lm/trie.hh b/lm/trie.hh index 9ea3c5466..d858ab5e4 100644 --- a/lm/trie.hh +++ b/lm/trie.hh @@ -62,8 +62,6 @@ class Unigram { return unigram_; } - void LoadedBinary() {} - UnigramPointer Find(WordIndex word, NodeRange &next) const { UnigramValue *val = unigram_ + word; next.begin = val->next; @@ -108,8 +106,6 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked { void FinishedLoading(uint64_t next_end, const Config &config); - void LoadedBinary() { bhiksha_.LoadedBinary(); } - util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const; util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) { @@ -138,14 +134,9 @@ class BitPackedLongest : public BitPacked { BaseInit(base, max_vocab, quant_bits); } - void LoadedBinary() {} - util::BitAddress Insert(WordIndex word); util::BitAddress Find(WordIndex word, const NodeRange &node) const; - - private: - uint8_t quant_bits_; }; } // namespace trie diff --git a/lm/trie_sort.cc b/lm/trie_sort.cc index dc542bb32..126d43aba 100644 --- a/lm/trie_sort.cc +++ b/lm/trie_sort.cc @@ -50,6 +50,10 @@ class PartialViewProxy { const void *Data() const { return inner_.Data(); } void *Data() { return inner_.Data(); } + friend void swap(PartialViewProxy first, PartialViewProxy second) { + std::swap_ranges(reinterpret_cast<char*>(first.Data()), reinterpret_cast<char*>(first.Data()) + first.attention_size_, reinterpret_cast<char*>(second.Data())); + } + private: friend class util::ProxyIterator<PartialViewProxy>; diff --git a/lm/virtual_interface.hh b/lm/virtual_interface.hh index ff4a388e7..7a3e23796 100644 --- a/lm/virtual_interface.hh +++ b/lm/virtual_interface.hh @@ -125,13 +125,13 @@ class Model { void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); } // Requires in_state != out_state - virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + virtual float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; // Requires in_state != out_state - virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + virtual FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; // Prefer to use FullScore. The context words should be provided in reverse order. - virtual FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0; + virtual FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0; unsigned char Order() const { return order_; } diff --git a/lm/vocab.cc b/lm/vocab.cc index fd7f96dc4..7f0878f40 100644 --- a/lm/vocab.cc +++ b/lm/vocab.cc @@ -32,7 +32,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5); // Sadly some LMs have <UNK>. const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5); -void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) { +void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) { + util::SeekOrThrow(fd, offset); // Check that we're at the right place by reading <unk> which is always first. char check_unk[6]; util::ReadOrThrow(fd, check_unk, 6); @@ -80,11 +81,6 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { buffer_.push_back(0); } -void WriteWordsWrapper::Write(int fd, uint64_t start) { - util::SeekOrThrow(fd, start); - util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); -} - SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) { @@ -100,6 +96,12 @@ void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size saw_unk_ = false; } +void SortedVocabulary::Relocate(void *new_start) { + std::size_t delta = end_ - begin_; + begin_ = reinterpret_cast<uint64_t*>(new_start) + 1; + end_ = begin_ + delta; +} + void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) { enumerate_ = to; if (enumerate_) { @@ -147,11 +149,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { bound_ = end_ - begin_ + 1; } -void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { +void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) { end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); SetSpecial(Index("<s>"), Index("</s>"), 0); bound_ = end_ - begin_ + 1; - if (have_words) ReadWords(fd, to, bound_); + if (have_words) ReadWords(fd, to, bound_, offset); } namespace { @@ -179,6 +181,11 @@ void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::siz saw_unk_ = false; } +void ProbingVocabulary::Relocate(void *new_start) { + header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start); + lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader))); +} + void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) { enumerate_ = to; if (enumerate_) { @@ -206,12 +213,11 @@ void ProbingVocabulary::InternalFinishedLoading() { SetSpecial(Index("<s>"), Index("</s>"), 0); } -void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { +void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) { UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); - lookup_.LoadedBinary(); bound_ = header_->bound; SetSpecial(Index("<s>"), Index("</s>"), 0); - if (have_words) ReadWords(fd, to, bound_); + if (have_words) ReadWords(fd, to, bound_, offset); } void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { diff --git a/lm/vocab.hh b/lm/vocab.hh index 226ae4385..074b74d86 100644 --- a/lm/vocab.hh +++ b/lm/vocab.hh @@ -36,7 +36,7 @@ class WriteWordsWrapper : public EnumerateVocab { void Add(WordIndex index, const StringPiece &str); - void Write(int fd, uint64_t start); + const std::string &Buffer() const { return buffer_; } private: EnumerateVocab *inner_; @@ -71,6 +71,8 @@ class SortedVocabulary : public base::Vocabulary { // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + void Relocate(void *new_start); + void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); WordIndex Insert(const StringPiece &str); @@ -83,15 +85,13 @@ class SortedVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } - void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: uint64_t *begin_, *end_; WordIndex bound_; - WordIndex highest_value_; - bool saw_unk_; EnumerateVocab *enumerate_; @@ -140,6 +140,8 @@ class ProbingVocabulary : public base::Vocabulary { // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + void Relocate(void *new_start); + void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); WordIndex Insert(const StringPiece &str); @@ -152,7 +154,7 @@ class ProbingVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } - void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: void InternalFinishedLoading(); |