Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lm
diff options
context:
space:
mode:
Diffstat (limited to 'lm')
-rw-r--r--lm/bhiksha.cc15
-rw-r--r--lm/bhiksha.hh9
-rw-r--r--lm/binary_format.cc248
-rw-r--r--lm/binary_format.hh112
-rw-r--r--lm/builder/corpus_count.cc29
-rw-r--r--lm/builder/interpolate.cc8
-rw-r--r--lm/config.cc4
-rw-r--r--lm/facade.hh6
-rw-r--r--lm/filter/arpa_io.hh4
-rw-r--r--lm/filter/count_io.hh23
-rw-r--r--lm/filter/filter_main.cc167
-rw-r--r--lm/filter/format.hh2
-rw-r--r--lm/filter/vocab.cc6
-rw-r--r--lm/model.cc84
-rw-r--r--lm/model.hh12
-rw-r--r--lm/model_test.cc8
-rw-r--r--lm/quantize.cc12
-rw-r--r--lm/quantize.hh5
-rw-r--r--lm/search_hashed.cc33
-rw-r--r--lm/search_hashed.hh12
-rw-r--r--lm/search_trie.cc17
-rw-r--r--lm/search_trie.hh18
-rw-r--r--lm/trie.hh9
-rw-r--r--lm/trie_sort.cc4
-rw-r--r--lm/virtual_interface.hh6
-rw-r--r--lm/vocab.cc28
-rw-r--r--lm/vocab.hh12
27 files changed, 462 insertions, 431 deletions
diff --git a/lm/bhiksha.cc b/lm/bhiksha.cc
index 088ea98d4..c8a18dfda 100644
--- a/lm/bhiksha.cc
+++ b/lm/bhiksha.cc
@@ -1,4 +1,6 @@
#include "lm/bhiksha.hh"
+
+#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "util/file.hh"
#include "util/exception.hh"
@@ -15,11 +17,11 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_
const uint8_t kArrayBhikshaVersion = 0;
// TODO: put this in binary file header instead when I change the binary file format again.
-void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) {
- uint8_t version;
- uint8_t configured_bits;
- util::ReadOrThrow(fd, &version, 1);
- util::ReadOrThrow(fd, &configured_bits, 1);
+void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) {
+ uint8_t buffer[2];
+ file.ReadForConfig(buffer, 2, offset);
+ uint8_t version = buffer[0];
+ uint8_t configured_bits = buffer[1];
if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion);
config.pointer_bhiksha_bits = configured_bits;
}
@@ -87,9 +89,6 @@ void ArrayBhiksha::FinishedLoading(const Config &config) {
*(head_write++) = config.pointer_bhiksha_bits;
}
-void ArrayBhiksha::LoadedBinary() {
-}
-
} // namespace trie
} // namespace ngram
} // namespace lm
diff --git a/lm/bhiksha.hh b/lm/bhiksha.hh
index 8ff88654d..350571a6e 100644
--- a/lm/bhiksha.hh
+++ b/lm/bhiksha.hh
@@ -24,6 +24,7 @@
namespace lm {
namespace ngram {
struct Config;
+class BinaryFormat;
namespace trie {
@@ -31,7 +32,7 @@ class DontBhiksha {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
- static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {}
static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
@@ -53,8 +54,6 @@ class DontBhiksha {
void FinishedLoading(const Config &/*config*/) {}
- void LoadedBinary() {}
-
uint8_t InlineBits() const { return next_.bits; }
private:
@@ -65,7 +64,7 @@ class ArrayBhiksha {
public:
static const ModelType kModelTypeAdd = kArrayAdd;
- static void UpdateConfigFromBinary(int fd, Config &config);
+ static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
@@ -93,8 +92,6 @@ class ArrayBhiksha {
void FinishedLoading(const Config &config);
- void LoadedBinary();
-
uint8_t InlineBits() const { return next_inline_.bits; }
private:
diff --git a/lm/binary_format.cc b/lm/binary_format.cc
index bef51eb82..9c744b138 100644
--- a/lm/binary_format.cc
+++ b/lm/binary_format.cc
@@ -14,6 +14,9 @@
namespace lm {
namespace ngram {
+
+const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
+
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
@@ -58,8 +61,6 @@ struct Sanity {
}
};
-const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
-
std::size_t TotalHeaderSize(unsigned char order) {
return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
}
@@ -81,83 +82,6 @@ void WriteHeader(void *to, const Parameters &params) {
} // namespace
-uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
- if (config.write_mmap) {
- std::size_t total = TotalHeaderSize(order) + memory_size;
- backing.file.reset(util::CreateOrThrow(config.write_mmap));
- if (config.write_method == Config::WRITE_MMAP) {
- backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
- } else {
- util::ResizeOrThrow(backing.file.get(), 0);
- util::MapAnonymous(total, backing.vocab);
- }
- strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
- return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
- } else {
- util::MapAnonymous(memory_size, backing.vocab);
- return reinterpret_cast<uint8_t*>(backing.vocab.get());
- }
-}
-
-uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
- std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
- if (config.write_mmap) {
- // Grow the file to accomodate the search, using zeros.
- try {
- util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
- } catch (util::ErrnoException &e) {
- e << " for file " << config.write_mmap;
- throw e;
- }
-
- if (config.write_method == Config::WRITE_AFTER) {
- util::MapAnonymous(memory_size, backing.search);
- return reinterpret_cast<uint8_t*>(backing.search.get());
- }
- // mmap it now.
- // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
- std::size_t page_size = util::SizePage();
- std::size_t alignment_cruft = adjusted_vocab % page_size;
- backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
- return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
- } else {
- util::MapAnonymous(memory_size, backing.search);
- return reinterpret_cast<uint8_t*>(backing.search.get());
- }
-}
-
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
- if (!config.write_mmap) return;
- switch (config.write_method) {
- case Config::WRITE_MMAP:
- util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
- util::SyncOrThrow(backing.search.get(), backing.search.size());
- break;
- case Config::WRITE_AFTER:
- util::SeekOrThrow(backing.file.get(), 0);
- util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size());
- util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
- util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
- util::FSyncOrThrow(backing.file.get());
- break;
- }
- // header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params = Parameters();
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
- params.fixed.search_version = search_version;
- WriteHeader(backing.vocab.get(), params);
- if (config.write_method == Config::WRITE_AFTER) {
- util::SeekOrThrow(backing.file.get(), 0);
- util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size()));
- }
-}
-
-namespace detail {
-
bool IsBinaryFormat(int fd) {
const uint64_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
@@ -209,44 +133,164 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version);
}
-void SeekPastHeader(int fd, const Parameters &params) {
- util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
+const std::size_t kInvalidSize = static_cast<std::size_t>(-1);
+
+BinaryFormat::BinaryFormat(const Config &config)
+ : write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method),
+ header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {}
+
+void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params) {
+ file_.reset(fd);
+ write_mmap_ = NULL; // Ignore write requests; this is already in binary format.
+ ReadHeader(fd, params);
+ MatchCheck(model_type, search_version, params);
+ header_size_ = TotalHeaderSize(params.counts.size());
+}
+
+void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const {
+ assert(header_size_ != kInvalidSize);
+ util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_);
}
-uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
- const uint64_t file_size = util::SizeFile(backing.file.get());
+void *BinaryFormat::LoadBinary(std::size_t size) {
+ assert(header_size_ != kInvalidSize);
+ const uint64_t file_size = util::SizeFile(file_.get());
// The header is smaller than a page, so we have to map the whole header as well.
- std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
- if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
- UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
+ uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size);
+ UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
- util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search);
+ util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_);
- if (config.enumerate_vocab && !params.fixed.has_vocabulary)
- UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+ vocab_string_offset_ = total_map;
+ return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
+}
+
+void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
+ vocab_size_ = memory_size;
+ if (!write_mmap_) {
+ header_size_ = 0;
+ util::MapAnonymous(memory_size, memory_vocab_);
+ return reinterpret_cast<uint8_t*>(memory_vocab_.get());
+ }
+ header_size_ = TotalHeaderSize(order);
+ std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size));
+ file_.reset(util::CreateOrThrow(write_mmap_));
+ // some gccs complain about uninitialized variables even though all enum values are covered.
+ void *vocab_base = NULL;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
+ vocab_base = mapping_.get();
+ break;
+ case Config::WRITE_AFTER:
+ util::ResizeOrThrow(file_.get(), 0);
+ util::MapAnonymous(total, memory_vocab_);
+ vocab_base = memory_vocab_.get();
+ break;
+ }
+ strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_);
+ return reinterpret_cast<uint8_t*>(vocab_base) + header_size_;
+}
- // Seek to vocabulary words
- util::SeekOrThrow(backing.file.get(), total_map);
- return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
+void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) {
+ assert(vocab_size_ != kInvalidSize);
+ vocab_pad_ = vocab_pad;
+ std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size;
+ vocab_string_offset_ = new_size;
+ if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) {
+ util::MapAnonymous(memory_size, memory_search_);
+ assert(header_size_ == 0 || write_mmap_);
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ return reinterpret_cast<uint8_t*>(memory_search_.get());
+ }
+
+ assert(write_method_ == Config::WRITE_MMAP);
+ // Also known as total size without vocab words.
+ // Grow the file to accomodate the search, using zeros.
+ // According to man mmap, behavior is undefined when the file is resized
+ // underneath a mmap that is not a multiple of the page size. So to be
+ // safe, we'll unmap it and map it again.
+ mapping_.reset();
+ util::ResizeOrThrow(file_.get(), new_size);
+ void *ret;
+ MapFile(vocab_base, ret);
+ return ret;
}
-void ComplainAboutARPA(const Config &config, ModelType model_type) {
- if (config.write_mmap || !config.messages) return;
- if (config.arpa_complain == Config::ALL) {
- *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
- } else if (config.arpa_complain == Config::EXPENSIVE &&
- (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
- *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) {
+ // Checking Config's include_vocab is the responsibility of the caller.
+ assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize);
+ if (!write_mmap_) {
+ // Unchanged base.
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get());
+ search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
+ return;
+ }
+ if (write_method_ == Config::WRITE_MMAP) {
+ mapping_.reset();
+ }
+ util::SeekOrThrow(file_.get(), VocabStringReadingOffset());
+ util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
+ if (write_method_ == Config::WRITE_MMAP) {
+ MapFile(vocab_base, search_base);
+ } else {
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
+ }
+}
+
+void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) {
+ if (!write_mmap_) return;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ util::SyncOrThrow(mapping_.get(), mapping_.size());
+ break;
+ case Config::WRITE_AFTER:
+ util::SeekOrThrow(file_.get(), 0);
+ util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size());
+ util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_);
+ util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size());
+ util::FSyncOrThrow(file_.get());
+ break;
+ }
+ // header and vocab share the same mmap.
+ Parameters params = Parameters();
+ memset(&params, 0, sizeof(Parameters));
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ params.fixed.search_version = search_version;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ WriteHeader(mapping_.get(), params);
+ util::SyncOrThrow(mapping_.get(), mapping_.size());
+ break;
+ case Config::WRITE_AFTER:
+ {
+ std::vector<uint8_t> buffer(TotalHeaderSize(counts.size()));
+ WriteHeader(&buffer[0], params);
+ util::SeekOrThrow(file_.get(), 0);
+ util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
+ }
+ break;
}
}
-} // namespace detail
+void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) {
+ mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED);
+ vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
+ search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_;
+}
bool RecognizeBinary(const char *file, ModelType &recognized) {
util::scoped_fd fd(util::OpenReadOrThrow(file));
- if (!detail::IsBinaryFormat(fd.get())) return false;
+ if (!IsBinaryFormat(fd.get())) {
+ return false;
+ }
Parameters params;
- detail::ReadHeader(fd.get(), params);
+ ReadHeader(fd.get(), params);
recognized = params.fixed.model_type;
return true;
}
diff --git a/lm/binary_format.hh b/lm/binary_format.hh
index bf699d5f4..f33f88d75 100644
--- a/lm/binary_format.hh
+++ b/lm/binary_format.hh
@@ -17,6 +17,8 @@
namespace lm {
namespace ngram {
+extern const char *kModelNames[6];
+
/*Inspect a file to determine if it is a binary lm. If not, return false.
* If so, return true and set recognized to the type. This is the only API in
* this header designed for use by decoder authors.
@@ -42,67 +44,63 @@ struct Parameters {
std::vector<uint64_t> counts;
};
-struct Backing {
- // File behind memory, if any.
- util::scoped_fd file;
- // Vocabulary lookup table. Not to be confused with the vocab words themselves.
- util::scoped_memory vocab;
- // Raw block of memory backing the language model data structures
- util::scoped_memory search;
-};
-
-// Create just enough of a binary file to write vocabulary to it.
-uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
-// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
-uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing);
-
-// Write header to binary file. This is done last to prevent incomplete files
-// from loading.
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing);
+class BinaryFormat {
+ public:
+ explicit BinaryFormat(const Config &config);
+
+ // Reading a binary file:
+ // Takes ownership of fd
+ void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params);
+ // Used to read parts of the file to update the config object before figuring out full size.
+ void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const;
+ // Actually load the binary file and return a pointer to the beginning of the search area.
+ void *LoadBinary(std::size_t size);
+
+ uint64_t VocabStringReadingOffset() const {
+ assert(vocab_string_offset_ != kInvalidOffset);
+ return vocab_string_offset_;
+ }
-namespace detail {
+ // Writing a binary file or initializing in RAM from ARPA:
+ // Size for vocabulary.
+ void *SetupJustVocab(std::size_t memory_size, uint8_t order);
+ // Warning: can change the vocaulary base pointer.
+ void *GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base);
+ // Warning: can change vocabulary and search base addresses.
+ void WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base);
+ // Write the header at the beginning of the file.
+ void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts);
+
+ private:
+ void MapFile(void *&vocab_base, void *&search_base);
+
+ // Copied from configuration.
+ const Config::WriteMethod write_method_;
+ const char *write_mmap_;
+ util::LoadMethod load_method_;
+
+ // File behind memory, if any.
+ util::scoped_fd file_;
+
+ // If there is a file involved, a single mapping.
+ util::scoped_memory mapping_;
+
+ // If the data is only in memory, separately allocate each because the trie
+ // knows vocab's size before it knows search's size (because SRILM might
+ // have pruned).
+ util::scoped_memory memory_vocab_, memory_search_;
+
+ // Memory ranges. Note that these may not be contiguous and may not all
+ // exist.
+ std::size_t header_size_, vocab_size_, vocab_pad_;
+ // aka end of search.
+ uint64_t vocab_string_offset_;
+
+ static const uint64_t kInvalidOffset = (uint64_t)-1;
+};
bool IsBinaryFormat(int fd);
-void ReadHeader(int fd, Parameters &params);
-
-void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters &params);
-
-void SeekPastHeader(int fd, const Parameters &params);
-
-uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing);
-
-void ComplainAboutARPA(const Config &config, ModelType model_type);
-
-} // namespace detail
-
-template <class To> void LoadLM(const char *file, const Config &config, To &to) {
- Backing &backing = to.MutableBacking();
- backing.file.reset(util::OpenReadOrThrow(file));
-
- try {
- if (detail::IsBinaryFormat(backing.file.get())) {
- Parameters params;
- detail::ReadHeader(backing.file.get(), params);
- detail::MatchCheck(To::kModelType, To::kVersion, params);
- // Replace the run-time configured probing_multiplier with the one in the file.
- Config new_config(config);
- new_config.probing_multiplier = params.fixed.probing_multiplier;
- detail::SeekPastHeader(backing.file.get(), params);
- To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config);
- uint64_t memory_size = To::Size(params.counts, new_config);
- uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);
- to.InitializeFromBinary(start, params, new_config, backing.file.get());
- } else {
- detail::ComplainAboutARPA(config, To::kModelType);
- to.InitializeFromARPA(file, config);
- }
- } catch (util::Exception &e) {
- e << " File: " << file;
- throw;
- }
-}
-
} // namespace ngram
} // namespace lm
#endif // LM_BINARY_FORMAT__
diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc
index 6ad91dde7..ccc06efca 100644
--- a/lm/builder/corpus_count.cc
+++ b/lm/builder/corpus_count.cc
@@ -87,7 +87,7 @@ class VocabHandout {
Table table_;
std::size_t double_cutoff_;
-
+
util::FakeOFStream word_list_;
};
@@ -98,7 +98,7 @@ class DedupeHash : public std::unary_function<const WordIndex *, bool> {
std::size_t operator()(const WordIndex *start) const {
return util::MurmurHashNative(start, size_);
}
-
+
private:
const std::size_t size_;
};
@@ -106,11 +106,11 @@ class DedupeHash : public std::unary_function<const WordIndex *, bool> {
class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> {
public:
explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {}
-
+
bool operator()(const WordIndex *first, const WordIndex *second) const {
return !memcmp(first, second, size_);
- }
-
+ }
+
private:
const std::size_t size_;
};
@@ -131,7 +131,7 @@ typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
class Writer {
public:
- Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
+ Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
: block_(position), gram_(block_->Get(), order),
dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()),
dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
@@ -140,7 +140,7 @@ class Writer {
dedupe_.Clear();
assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
if (order == 1) {
- // Add special words. AdjustCounts is responsible if order != 1.
+ // Add special words. AdjustCounts is responsible if order != 1.
AddUnigramWord(kUNK);
AddUnigramWord(kBOS);
}
@@ -170,16 +170,16 @@ class Writer {
memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
return;
}
- // Complete the write.
+ // Complete the write.
gram_.Count() = 1;
- // Prepare the next n-gram.
+ // Prepare the next n-gram.
if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
NGram last(gram_);
gram_.NextInMemory();
std::copy(last.begin() + 1, last.end(), gram_.begin());
return;
}
- // Block end. Need to store the context in a temporary buffer.
+ // Block end. Need to store the context in a temporary buffer.
std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
dedupe_.Clear();
block_->SetValidSize(block_size_);
@@ -207,7 +207,7 @@ class Writer {
// Hash table combiner implementation.
Dedupe dedupe_;
- // Small buffer to hold existing ngrams when shifting across a block boundary.
+ // Small buffer to hold existing ngrams when shifting across a block boundary.
boost::scoped_array<WordIndex> buffer_;
const std::size_t block_size_;
@@ -223,7 +223,7 @@ std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
return VocabHandout::MemUsage(vocab_estimate);
}
-CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
+CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
@@ -240,7 +240,10 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
uint64_t count = 0;
bool delimiters[256];
memset(delimiters, 0, sizeof(delimiters));
- delimiters['\0'] = delimiters['\t'] = delimiters['\n'] = delimiters['\r'] = delimiters[' '] = true;
+ const char kDelimiterSet[] = "\0\t\n\r ";
+ for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) {
+ delimiters[static_cast<unsigned char>(*i)] = true;
+ }
try {
while(true) {
StringPiece line(from_.ReadLine());
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc
index 52e69f02e..500268069 100644
--- a/lm/builder/interpolate.cc
+++ b/lm/builder/interpolate.cc
@@ -33,12 +33,12 @@ class Callback {
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
probs_[order_minus_1 + 1] = pay.complete.prob;
pay.complete.prob = log10(pay.complete.prob);
- // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling.
- if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS && backoffs_[order_minus_1].Get()) { // check valid pointer at tht end
+ // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling.
+ if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
++backoffs_[order_minus_1];
} else {
- // Not a context.
+ // Not a context.
pay.complete.backoff = 0.0;
}
}
@@ -52,7 +52,7 @@ class Callback {
};
} // namespace
-Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)
+Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)
: uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {}
// perform order-wise interpolation
diff --git a/lm/config.cc b/lm/config.cc
index dc3365319..9520c41c8 100644
--- a/lm/config.cc
+++ b/lm/config.cc
@@ -11,11 +11,7 @@ Config::Config() :
enumerate_vocab(NULL),
unknown_missing(COMPLAIN),
sentence_marker_missing(THROW_UP),
-#if defined(_WIN32) || defined(_WIN64)
- positive_log_probability(SILENT),
-#else
positive_log_probability(THROW_UP),
-#endif
unknown_missing_logprob(-100.0),
probing_multiplier(1.5),
building_memory(1073741824ULL), // 1 GB
diff --git a/lm/facade.hh b/lm/facade.hh
index 760e839e0..de1551f12 100644
--- a/lm/facade.hh
+++ b/lm/facade.hh
@@ -17,14 +17,14 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ
typedef VocabularyT Vocabulary;
/* Translate from void* to State */
- FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const {
+ FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->FullScore(
*reinterpret_cast<const State*>(in_state),
new_word,
*reinterpret_cast<State*>(out_state));
}
- FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const {
+ FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->FullScoreForgotState(
context_rbegin,
context_rend,
@@ -37,7 +37,7 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ
return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob;
}
- float Score(const void *in_state, const WordIndex new_word, void *out_state) const {
+ float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->Score(
*reinterpret_cast<const State*>(in_state),
new_word,
diff --git a/lm/filter/arpa_io.hh b/lm/filter/arpa_io.hh
index 08e658666..602b5b31b 100644
--- a/lm/filter/arpa_io.hh
+++ b/lm/filter/arpa_io.hh
@@ -14,10 +14,6 @@
#include <string>
#include <vector>
-#if !defined __MINGW32__
-#include <err.h>
-#endif
-
#include <string.h>
#include <stdint.h>
diff --git a/lm/filter/count_io.hh b/lm/filter/count_io.hh
index 740b8d50e..d992026ff 100644
--- a/lm/filter/count_io.hh
+++ b/lm/filter/count_io.hh
@@ -5,27 +5,18 @@
#include <iostream>
#include <string>
-#if !defined __MINGW32__
-#include <err.h>
-#endif
-
+#include "util/fake_ofstream.hh"
+#include "util/file.hh"
#include "util/file_piece.hh"
namespace lm {
class CountOutput : boost::noncopyable {
public:
- explicit CountOutput(const char *name) : file_(name, std::ios::out) {}
+ explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {}
void AddNGram(const StringPiece &line) {
- if (!(file_ << line << '\n')) {
-#if defined __MINGW32__
- std::cerr<<"Writing counts file failed"<<std::endl;
- exit(3);
-#else
- err(3, "Writing counts file failed");
-#endif
- }
+ file_ << line << '\n';
}
template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
@@ -37,12 +28,12 @@ class CountOutput : boost::noncopyable {
}
private:
- std::fstream file_;
+ util::FakeOFStream file_;
};
class CountBatch {
public:
- explicit CountBatch(std::streamsize initial_read)
+ explicit CountBatch(std::streamsize initial_read)
: initial_read_(initial_read) {
buffer_.reserve(initial_read);
}
@@ -75,7 +66,7 @@ class CountBatch {
private:
std::streamsize initial_read_;
- // This could have been a std::string but that's less happy with raw writes.
+ // This could have been a std::string but that's less happy with raw writes.
std::vector<char> buffer_;
};
diff --git a/lm/filter/filter_main.cc b/lm/filter/filter_main.cc
index f89ac4df3..82fdc1ef7 100644
--- a/lm/filter/filter_main.cc
+++ b/lm/filter/filter_main.cc
@@ -6,6 +6,7 @@
#endif
#include "lm/filter/vocab.hh"
#include "lm/filter/wrapper.hh"
+#include "util/exception.hh"
#include "util/file_piece.hh"
#include <boost/ptr_container/ptr_vector.hpp>
@@ -57,7 +58,7 @@ typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION, MODE_UNSET} Fil
typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format;
struct Config {
- Config() :
+ Config() :
#ifndef NTHREAD
batch_size(25000),
threads(boost::thread::hardware_concurrency()),
@@ -157,102 +158,96 @@ template <class Format> void DispatchFilterModes(const Config &config, std::istr
} // namespace lm
int main(int argc, char *argv[]) {
- if (argc < 4) {
- lm::DisplayHelp(argv[0]);
- return 1;
- }
+ try {
+ if (argc < 4) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
- // I used to have boost::program_options, but some users didn't want to compile boost.
- lm::Config config;
- config.mode = lm::MODE_UNSET;
- for (int i = 1; i < argc - 2; ++i) {
- const char *str = argv[i];
- if (!std::strcmp(str, "copy")) {
- config.mode = lm::MODE_COPY;
- } else if (!std::strcmp(str, "single")) {
- config.mode = lm::MODE_SINGLE;
- } else if (!std::strcmp(str, "multiple")) {
- config.mode = lm::MODE_MULTIPLE;
- } else if (!std::strcmp(str, "union")) {
- config.mode = lm::MODE_UNION;
- } else if (!std::strcmp(str, "phrase")) {
- config.phrase = true;
- } else if (!std::strcmp(str, "context")) {
- config.context = true;
- } else if (!std::strcmp(str, "arpa")) {
- config.format = lm::FORMAT_ARPA;
- } else if (!std::strcmp(str, "raw")) {
- config.format = lm::FORMAT_COUNT;
+ // I used to have boost::program_options, but some users didn't want to compile boost.
+ lm::Config config;
+ config.mode = lm::MODE_UNSET;
+ for (int i = 1; i < argc - 2; ++i) {
+ const char *str = argv[i];
+ if (!std::strcmp(str, "copy")) {
+ config.mode = lm::MODE_COPY;
+ } else if (!std::strcmp(str, "single")) {
+ config.mode = lm::MODE_SINGLE;
+ } else if (!std::strcmp(str, "multiple")) {
+ config.mode = lm::MODE_MULTIPLE;
+ } else if (!std::strcmp(str, "union")) {
+ config.mode = lm::MODE_UNION;
+ } else if (!std::strcmp(str, "phrase")) {
+ config.phrase = true;
+ } else if (!std::strcmp(str, "context")) {
+ config.context = true;
+ } else if (!std::strcmp(str, "arpa")) {
+ config.format = lm::FORMAT_ARPA;
+ } else if (!std::strcmp(str, "raw")) {
+ config.format = lm::FORMAT_COUNT;
#ifndef NTHREAD
- } else if (!std::strncmp(str, "threads:", 8)) {
- config.threads = boost::lexical_cast<size_t>(str + 8);
- if (!config.threads) {
- std::cerr << "Specify at least one thread." << std::endl;
+ } else if (!std::strncmp(str, "threads:", 8)) {
+ config.threads = boost::lexical_cast<size_t>(str + 8);
+ if (!config.threads) {
+ std::cerr << "Specify at least one thread." << std::endl;
+ return 1;
+ }
+ } else if (!std::strncmp(str, "batch_size:", 11)) {
+ config.batch_size = boost::lexical_cast<size_t>(str + 11);
+ if (config.batch_size < 5000) {
+ std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
+ if (!config.batch_size) return 1;
+ }
+#endif
+ } else {
+ lm::DisplayHelp(argv[0]);
return 1;
}
- } else if (!std::strncmp(str, "batch_size:", 11)) {
- config.batch_size = boost::lexical_cast<size_t>(str + 11);
- if (config.batch_size < 5000) {
- std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
- if (!config.batch_size) return 1;
- }
-#endif
- } else {
+ }
+
+ if (config.mode == lm::MODE_UNSET) {
lm::DisplayHelp(argv[0]);
return 1;
}
- }
- if (config.mode == lm::MODE_UNSET) {
- lm::DisplayHelp(argv[0]);
- return 1;
- }
-
- if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
- std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
- return 1;
- }
+ if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
+ std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
+ return 1;
+ }
- bool cmd_is_model = true;
- const char *cmd_input = argv[argc - 2];
- if (!strncmp(cmd_input, "vocab:", 6)) {
- cmd_is_model = false;
- cmd_input += 6;
- } else if (!strncmp(cmd_input, "model:", 6)) {
- cmd_input += 6;
- } else if (strchr(cmd_input, ':')) {
-#if defined __MINGW32__
- std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl;
- exit(1);
-#else
- errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input);
-#endif // defined
- } else {
- std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
- }
- std::ifstream cmd_file;
- std::istream *vocab;
- if (cmd_is_model) {
- vocab = &std::cin;
- } else {
- cmd_file.open(cmd_input, std::ios::in);
- if (!cmd_file) {
-#if defined __MINGW32__
- std::cerr << "Could not open input file " << cmd_input << std::endl;
- exit(2);
-#else
- err(2, "Could not open input file %s", cmd_input);
-#endif // defined
+ bool cmd_is_model = true;
+ const char *cmd_input = argv[argc - 2];
+ if (!strncmp(cmd_input, "vocab:", 6)) {
+ cmd_is_model = false;
+ cmd_input += 6;
+ } else if (!strncmp(cmd_input, "model:", 6)) {
+ cmd_input += 6;
+ } else if (strchr(cmd_input, ':')) {
+ std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl;
+ return 1;
+ } else {
+ std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
+ }
+ std::ifstream cmd_file;
+ std::istream *vocab;
+ if (cmd_is_model) {
+ vocab = &std::cin;
+ } else {
+ cmd_file.open(cmd_input, std::ios::in);
+ UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input);
+ vocab = &cmd_file;
}
- vocab = &cmd_file;
- }
- util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
+ util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
- if (config.format == lm::FORMAT_ARPA) {
- lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
- } else if (config.format == lm::FORMAT_COUNT) {
- lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ if (config.format == lm::FORMAT_ARPA) {
+ lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
+ } else if (config.format == lm::FORMAT_COUNT) {
+ lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ }
+ return 0;
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
}
- return 0;
}
diff --git a/lm/filter/format.hh b/lm/filter/format.hh
index 7f945b0d6..7d8c28dbc 100644
--- a/lm/filter/format.hh
+++ b/lm/filter/format.hh
@@ -1,5 +1,5 @@
#ifndef LM_FILTER_FORMAT_H__
-#define LM_FITLER_FORMAT_H__
+#define LM_FILTER_FORMAT_H__
#include "lm/filter/arpa_io.hh"
#include "lm/filter/count_io.hh"
diff --git a/lm/filter/vocab.cc b/lm/filter/vocab.cc
index 7ed5d92fb..011ab5992 100644
--- a/lm/filter/vocab.cc
+++ b/lm/filter/vocab.cc
@@ -5,10 +5,6 @@
#include <ctype.h>
-#if !defined __MINGW32__
-#include <err.h>
-#endif
-
namespace lm {
namespace vocab {
@@ -34,7 +30,7 @@ bool IsLineEnd(std::istream &in) {
}// namespace
// Read space separated words in enter separated lines. These lines can be
-// very long, so don't read an entire line at a time.
+// very long, so don't read an entire line at a time.
unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out) {
in.exceptions(std::istream::badbit);
unsigned int sentence = 0;
diff --git a/lm/model.cc b/lm/model.cc
index a26654a6f..a5a16bf8e 100644
--- a/lm/model.cc
+++ b/lm/model.cc
@@ -34,23 +34,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
}
-template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
- LoadLM(file, config, *this);
-
- // g++ prints warnings unless these are fully initialized.
- State begin_sentence = State();
- begin_sentence.length = 1;
- begin_sentence.words[0] = vocab_.BeginSentence();
- typename Search::Node ignored_node;
- bool ignored_independent_left;
- uint64_t ignored_extend_left;
- begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
- State null_context = State();
- null_context.length = 0;
- P::Init(begin_sentence, null_context, vocab_, search_.Order());
+namespace {
+void ComplainAboutARPA(const Config &config, ModelType model_type) {
+ if (config.write_mmap || !config.messages) return;
+ if (config.arpa_complain == Config::ALL) {
+ *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
+ } else if (config.arpa_complain == Config::EXPENSIVE &&
+ (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
+ *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+ }
}
-namespace {
void CheckCounts(const std::vector<uint64_t> &counts) {
UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
if (sizeof(uint64_t) > sizeof(std::size_t)) {
@@ -59,18 +53,45 @@ void CheckCounts(const std::vector<uint64_t> &counts) {
}
}
}
+
} // namespace
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
- CheckCounts(params.counts);
- SetupMemory(start, params.counts, config);
- vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
- search_.LoadedBinary();
+template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) {
+ util::scoped_fd fd(util::OpenReadOrThrow(file));
+ if (IsBinaryFormat(fd.get())) {
+ Parameters parameters;
+ int fd_shallow = fd.release();
+ backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters);
+ CheckCounts(parameters.counts);
+
+ Config new_config(init_config);
+ new_config.probing_multiplier = parameters.fixed.probing_multiplier;
+ Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config);
+ UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+
+ SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config);
+ vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset());
+ } else {
+ ComplainAboutARPA(init_config, kModelType);
+ InitializeFromARPA(fd.release(), file, init_config);
+ }
+
+ // g++ prints warnings unless these are fully initialized.
+ State begin_sentence = State();
+ begin_sentence.length = 1;
+ begin_sentence.words[0] = vocab_.BeginSentence();
+ typename Search::Node ignored_node;
+ bool ignored_independent_left;
+ uint64_t ignored_extend_left;
+ begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
+ State null_context = State();
+ null_context.length = 0;
+ P::Init(begin_sentence, null_context, vocab_, search_.Order());
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
- // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
- util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) {
+ // Backing file is the ARPA.
+ util::FilePiece f(fd, file, config.ProgressMessages());
try {
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
@@ -81,13 +102,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
- vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
+ vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config);
- if (config.write_mmap) {
+ if (config.write_mmap && config.include_vocab) {
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));
+ void *vocab_rebase, *search_rebase;
+ backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase);
+ // Due to writing at the end of file, mmap may have relocated data. So remap.
+ vocab_.Relocate(vocab_rebase);
+ search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config);
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
@@ -99,18 +124,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.UnknownUnigram().backoff = 0.0;
search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
- FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
+ backing_.FinishFile(config, kModelType, kVersion, counts);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
}
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
- util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
- Search::UpdateConfigFromBinary(fd, counts, config);
-}
-
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {
diff --git a/lm/model.hh b/lm/model.hh
index c9c17c4b3..e75da93bf 100644
--- a/lm/model.hh
+++ b/lm/model.hh
@@ -104,10 +104,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
}
private:
- friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
-
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
-
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
// Score bigrams and above. Do not include backoff.
@@ -116,15 +112,11 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
- void InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd);
-
- void InitializeFromARPA(const char *file, const Config &config);
+ void InitializeFromARPA(int fd, const char *file, const Config &config);
float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const;
- Backing &MutableBacking() { return backing_; }
-
- Backing backing_;
+ BinaryFormat backing_;
VocabularyT vocab_;
diff --git a/lm/model_test.cc b/lm/model_test.cc
index eb1590942..7005b05ea 100644
--- a/lm/model_test.cc
+++ b/lm/model_test.cc
@@ -360,10 +360,11 @@ BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
LoadingTest<QuantArrayTrieModel>();
}
-template <class ModelT> void BinaryTest() {
+template <class ModelT> void BinaryTest(Config::WriteMethod write_method) {
Config config;
config.write_mmap = "test.binary";
config.messages = NULL;
+ config.write_method = write_method;
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
@@ -406,6 +407,11 @@ template <class ModelT> void BinaryTest() {
unlink("test_nounk.binary");
}
+template <class ModelT> void BinaryTest() {
+ BinaryTest<ModelT>(Config::WRITE_MMAP);
+ BinaryTest<ModelT>(Config::WRITE_AFTER);
+}
+
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
BinaryTest<ProbingModel>();
}
diff --git a/lm/quantize.cc b/lm/quantize.cc
index b58c3f3f6..273ea3989 100644
--- a/lm/quantize.cc
+++ b/lm/quantize.cc
@@ -38,13 +38,13 @@ const char kSeparatelyQuantizeVersion = 2;
} // namespace
-void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) {
- char version;
- util::ReadOrThrow(fd, &version, 1);
- util::ReadOrThrow(fd, &config.prob_bits, 1);
- util::ReadOrThrow(fd, &config.backoff_bits, 1);
+void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) {
+ unsigned char buffer[3];
+ file.ReadForConfig(buffer, 3, offset);
+ char version = buffer[0];
+ config.prob_bits = buffer[1];
+ config.backoff_bits = buffer[2];
if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion);
- util::AdvanceOrThrow(fd, -3);
}
void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) {
diff --git a/lm/quantize.hh b/lm/quantize.hh
index 8ce2378a7..9d3a2f439 100644
--- a/lm/quantize.hh
+++ b/lm/quantize.hh
@@ -18,12 +18,13 @@ namespace lm {
namespace ngram {
struct Config;
+class BinaryFormat;
/* Store values directly and don't quantize. */
class DontQuantize {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
- static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {}
static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
@@ -136,7 +137,7 @@ class SeparatelyQuantize {
public:
static const ModelType kModelTypeAdd = kQuantAdd;
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
+ static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
static uint64_t Size(uint8_t order, const Config &config) {
uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
diff --git a/lm/search_hashed.cc b/lm/search_hashed.cc
index 62275d277..354a56b46 100644
--- a/lm/search_hashed.cc
+++ b/lm/search_hashed.cc
@@ -204,9 +204,10 @@ template <class Build, class Activate, class Store> void ReadNGrams(
namespace detail {
template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
- std::size_t allocated = Unigram::Size(counts[0]);
- unigram_ = Unigram(start, counts[0], allocated);
- start += allocated;
+ unigram_ = Unigram(start, counts[0]);
+ start += Unigram::Size(counts[0]);
+ std::size_t allocated;
+ middle_.clear();
for (unsigned int n = 2; n < counts.size(); ++n) {
allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
middle_.push_back(Middle(start, allocated));
@@ -218,9 +219,21 @@ template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start,
return start;
}
-template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing) {
- // TODO: fix sorted.
- SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config);
+/*template <class Value> void HashedSearch<Value>::Relocate(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
+ unigram_ = Unigram(start, counts[0]);
+ start += Unigram::Size(counts[0]);
+ for (unsigned int n = 2; n < counts.size(); ++n) {
+ middle[n-2].Relocate(start);
+ start += Middle::Size(counts[n - 1], config.probing_multiplier)
+ }
+ longest_.Relocate(start);
+}*/
+
+template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing) {
+ void *vocab_rebase;
+ void *search_base = backing.GrowForSearch(Size(counts, config), vocab.UnkCountChangePadding(), vocab_rebase);
+ vocab.Relocate(vocab_rebase);
+ SetupMemory(reinterpret_cast<uint8_t*>(search_base), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn);
@@ -277,14 +290,6 @@ template <class Value> template <class Build> void HashedSearch<Value>::ApplyBui
ReadEnd(f);
}
-template <class Value> void HashedSearch<Value>::LoadedBinary() {
- unigram_.LoadedBinary();
- for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) {
- i->LoadedBinary();
- }
- longest_.LoadedBinary();
-}
-
template class HashedSearch<BackoffValue>;
template class HashedSearch<RestValue>;
diff --git a/lm/search_hashed.hh b/lm/search_hashed.hh
index 9d067bc2e..8193262b0 100644
--- a/lm/search_hashed.hh
+++ b/lm/search_hashed.hh
@@ -18,7 +18,7 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
-struct Backing;
+class BinaryFormat;
class ProbingVocabulary;
namespace detail {
@@ -72,7 +72,7 @@ template <class Value> class HashedSearch {
static const unsigned int kVersion = 0;
// TODO: move probing_multiplier here with next binary file format update.
- static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
uint64_t ret = Unigram::Size(counts[0]);
@@ -84,9 +84,7 @@ template <class Value> class HashedSearch {
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing);
-
- void LoadedBinary();
+ void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing);
unsigned char Order() const {
return middle_.size() + 2;
@@ -148,7 +146,7 @@ template <class Value> class HashedSearch {
public:
Unigram() {}
- Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
+ Unigram(void *start, uint64_t count) :
unigram_(static_cast<typename Value::Weights*>(start))
#ifdef DEBUG
, count_(count)
@@ -168,8 +166,6 @@ template <class Value> class HashedSearch {
typename Value::Weights &Unknown() { return unigram_[0]; }
- void LoadedBinary() {}
-
// For building.
typename Value::Weights *Raw() { return unigram_; }
diff --git a/lm/search_trie.cc b/lm/search_trie.cc
index 27605e548..4a88194e8 100644
--- a/lm/search_trie.cc
+++ b/lm/search_trie.cc
@@ -459,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c
} // namespace
-template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) {
RecordReader inputs[KENLM_MAX_ORDER - 1];
RecordReader contexts[KENLM_MAX_ORDER - 1];
@@ -488,7 +488,10 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
- out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
+ void *vocab_relocate;
+ void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate);
+ vocab.Relocate(vocab_relocate);
+ out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config);
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
@@ -571,15 +574,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
-template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
- unigram_.LoadedBinary();
- for (Middle *i = middle_begin_; i != middle_end_; ++i) {
- i->LoadedBinary();
- }
- longest_.LoadedBinary();
-}
-
-template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) {
std::string temporary_prefix;
if (config.temporary_directory_prefix) {
temporary_prefix = config.temporary_directory_prefix;
diff --git a/lm/search_trie.hh b/lm/search_trie.hh
index 763fd1a72..299262a5d 100644
--- a/lm/search_trie.hh
+++ b/lm/search_trie.hh
@@ -17,13 +17,13 @@
namespace lm {
namespace ngram {
-struct Backing;
+class BinaryFormat;
class SortedVocabulary;
namespace trie {
template <class Quant, class Bhiksha> class TrieSearch;
class SortedFiles;
-template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
template <class Quant, class Bhiksha> class TrieSearch {
public:
@@ -39,11 +39,11 @@ template <class Quant, class Bhiksha> class TrieSearch {
static const unsigned int kVersion = 1;
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
- Quant::UpdateConfigFromBinary(fd, counts, config);
- util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
+ static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) {
+ Quant::UpdateConfigFromBinary(file, offset, config);
// Currently the unigram pointers are not compresssed, so there will only be a header for order > 2.
- if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config);
+ if (counts.size() > 2)
+ Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config);
}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
@@ -60,9 +60,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- void LoadedBinary();
-
- void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
+ void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing);
unsigned char Order() const {
return middle_end_ - middle_begin_ + 2;
@@ -103,7 +101,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
}
private:
- friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
+ friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
diff --git a/lm/trie.hh b/lm/trie.hh
index 9ea3c5466..d858ab5e4 100644
--- a/lm/trie.hh
+++ b/lm/trie.hh
@@ -62,8 +62,6 @@ class Unigram {
return unigram_;
}
- void LoadedBinary() {}
-
UnigramPointer Find(WordIndex word, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
next.begin = val->next;
@@ -108,8 +106,6 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked {
void FinishedLoading(uint64_t next_end, const Config &config);
- void LoadedBinary() { bhiksha_.LoadedBinary(); }
-
util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
@@ -138,14 +134,9 @@ class BitPackedLongest : public BitPacked {
BaseInit(base, max_vocab, quant_bits);
}
- void LoadedBinary() {}
-
util::BitAddress Insert(WordIndex word);
util::BitAddress Find(WordIndex word, const NodeRange &node) const;
-
- private:
- uint8_t quant_bits_;
};
} // namespace trie
diff --git a/lm/trie_sort.cc b/lm/trie_sort.cc
index dc542bb32..126d43aba 100644
--- a/lm/trie_sort.cc
+++ b/lm/trie_sort.cc
@@ -50,6 +50,10 @@ class PartialViewProxy {
const void *Data() const { return inner_.Data(); }
void *Data() { return inner_.Data(); }
+ friend void swap(PartialViewProxy first, PartialViewProxy second) {
+ std::swap_ranges(reinterpret_cast<char*>(first.Data()), reinterpret_cast<char*>(first.Data()) + first.attention_size_, reinterpret_cast<char*>(second.Data()));
+ }
+
private:
friend class util::ProxyIterator<PartialViewProxy>;
diff --git a/lm/virtual_interface.hh b/lm/virtual_interface.hh
index ff4a388e7..7a3e23796 100644
--- a/lm/virtual_interface.hh
+++ b/lm/virtual_interface.hh
@@ -125,13 +125,13 @@ class Model {
void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }
// Requires in_state != out_state
- virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+ virtual float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
// Requires in_state != out_state
- virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+ virtual FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
// Prefer to use FullScore. The context words should be provided in reverse order.
- virtual FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0;
+ virtual FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0;
unsigned char Order() const { return order_; }
diff --git a/lm/vocab.cc b/lm/vocab.cc
index fd7f96dc4..7f0878f40 100644
--- a/lm/vocab.cc
+++ b/lm/vocab.cc
@@ -32,7 +32,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
-void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) {
+void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) {
+ util::SeekOrThrow(fd, offset);
// Check that we're at the right place by reading <unk> which is always first.
char check_unk[6];
util::ReadOrThrow(fd, check_unk, 6);
@@ -80,11 +81,6 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
-void WriteWordsWrapper::Write(int fd, uint64_t start) {
- util::SeekOrThrow(fd, start);
- util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
-}
-
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
@@ -100,6 +96,12 @@ void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size
saw_unk_ = false;
}
+void SortedVocabulary::Relocate(void *new_start) {
+ std::size_t delta = end_ - begin_;
+ begin_ = reinterpret_cast<uint64_t*>(new_start) + 1;
+ end_ = begin_ + delta;
+}
+
void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
enumerate_ = to;
if (enumerate_) {
@@ -147,11 +149,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
bound_ = end_ - begin_ + 1;
}
-void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
+void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
- if (have_words) ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_, offset);
}
namespace {
@@ -179,6 +181,11 @@ void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::siz
saw_unk_ = false;
}
+void ProbingVocabulary::Relocate(void *new_start) {
+ header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start);
+ lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)));
+}
+
void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {
enumerate_ = to;
if (enumerate_) {
@@ -206,12 +213,11 @@ void ProbingVocabulary::InternalFinishedLoading() {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
-void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
+void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
- lookup_.LoadedBinary();
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
- if (have_words) ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_, offset);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
diff --git a/lm/vocab.hh b/lm/vocab.hh
index 226ae4385..074b74d86 100644
--- a/lm/vocab.hh
+++ b/lm/vocab.hh
@@ -36,7 +36,7 @@ class WriteWordsWrapper : public EnumerateVocab {
void Add(WordIndex index, const StringPiece &str);
- void Write(int fd, uint64_t start);
+ const std::string &Buffer() const { return buffer_; }
private:
EnumerateVocab *inner_;
@@ -71,6 +71,8 @@ class SortedVocabulary : public base::Vocabulary {
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+ void Relocate(void *new_start);
+
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
@@ -83,15 +85,13 @@ class SortedVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
uint64_t *begin_, *end_;
WordIndex bound_;
- WordIndex highest_value_;
-
bool saw_unk_;
EnumerateVocab *enumerate_;
@@ -140,6 +140,8 @@ class ProbingVocabulary : public base::Vocabulary {
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+ void Relocate(void *new_start);
+
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
@@ -152,7 +154,7 @@ class ProbingVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
void InternalFinishedLoading();