diff options
author | Kenneth Heafield <github@kheafield.com> | 2014-04-08 03:17:11 +0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2014-04-08 03:17:11 +0400 |
commit | 8d41ee98530e2941e9bb50f9a62e09afdb35f3bf (patch) | |
tree | 7f001935207178a286acb9c66230f010ca527c08 | |
parent | 7e4d1bb7893021f21ae263e7d474dabc469e5d5d (diff) | |
parent | 38d40aa509af2329eda2aff65191f4606598f516 (diff) |
Merge branch 'master' into pruning2
Conflicts:
lm/builder/interpolate.cc
lm/builder/interpolate.hh
lm/builder/lmplz_main.cc
lm/builder/pipeline.cc
lm/builder/pipeline.hh
128 files changed, 2183 insertions, 1342 deletions
@@ -19,5 +19,5 @@ util/read_compressed_test util/sorted_uniform_test previous.sh jam-files/bjam -jam-files/engine/bin.linuxx86/ +jam-files/engine/bin.*/ jam-files/engine/bootstrap/ @@ -52,7 +52,7 @@ include $(TOP)/jam-files/sanity.jam ; boost 103600 ; project : requirements $(requirements) <include>. ; -project : default-build <threading>multi <warnings>on <variant>release ; +project : default-build <threading>multi <warnings>on <variant>release <link>static ; external-lib z ; @@ -61,7 +61,7 @@ build-project util ; lib kenlm : lm//kenlm ; -install-bin-libs lm//programs kenlm ; +install-bin-libs lm//programs util//programs kenlm ; install-headers headers : [ glob-tree *.hh : dist include ] : . ; alias install : prefix-bin prefix-lib prefix-include ; explicit headers ; @@ -2,6 +2,7 @@ Most of the code here is licensed under the LGPL. There are exceptions that have their own licenses, listed below. See comments in those files for more details. +util/getopt.* is getopt for Windows util/murmur_hash.cc util/string_piece.hh and util/string_piece.cc util/double-conversion/LICENSE covers util/double-conversion except Jamfile diff --git a/compile_query_only.sh b/compile_query_only.sh index 7c6d3f6..7a82f49 100755 --- a/compile_query_only.sh +++ b/compile_query_only.sh @@ -30,5 +30,5 @@ mkdir -p bin if [ "$(uname)" != Darwin ]; then CXXFLAGS="$CXXFLAGS -lrt" fi -$CXX lm/build_binary_main.cc $objects -o bin/build_binary $CXXFLAGS -$CXX lm/query_main.cc $objects -o bin/query $CXXFLAGS +$CXX lm/build_binary_main.cc $objects -o bin/build_binary $CXXFLAGS $LDFLAGS +$CXX lm/query_main.cc $objects -o bin/query $CXXFLAGS $LDFLAGS diff --git a/jam-files/sanity.jam b/jam-files/sanity.jam index 7f9c45d..bc07945 100644 --- a/jam-files/sanity.jam +++ b/jam-files/sanity.jam @@ -140,10 +140,9 @@ rule boost-lib ( name macro : deps * ) { } if $(boost-auto-shared) = "<link>shared" { - alias boost_$(name) : inner_boost_$(name) : <link>shared ; - requirements += <define>BOOST_$(macro) ; + alias boost_$(name) : inner_boost_$(name) : <link>shared : : <define>BOOST_$(macro) ; } else { - alias boost_$(name) : inner_boost_$(name) : <link>static ; + alias boost_$(name) : inner_boost_$(name) : : : <link>shared:<define>BOOST_$(macro) ; } } @@ -171,9 +170,10 @@ rule boost ( min-version ) { boost-auto-shared = [ auto-shared "boost_program_options"$(boost-lib-version) : $(L-boost-search) ] ; #See tools/build/v2/contrib/boost.jam in a boost distribution for a table of macros to define. + boost-lib exception EXCEPTION_DYN_LINK ; boost-lib system SYSTEM_DYN_LINK ; - boost-lib thread THREAD_DYN_DLL : boost_system ; - boost-lib program_options PROGRAM_OPTIONS_DYN_LINK ; + boost-lib thread THREAD_DYN_DLL : boost_system boost_exception ; + boost-lib program_options PROGRAM_OPTIONS_DYN_LINK : boost_exception ; boost-lib unit_test_framework TEST_DYN_LINK ; boost-lib iostreams IOSTREAMS_DYN_LINK ; boost-lib filesystem FILE_SYSTEM_DYN_LINK ; diff --git a/lm/bhiksha.cc b/lm/bhiksha.cc index 088ea98..c8a18df 100644 --- a/lm/bhiksha.cc +++ b/lm/bhiksha.cc @@ -1,4 +1,6 @@ #include "lm/bhiksha.hh" + +#include "lm/binary_format.hh" #include "lm/config.hh" #include "util/file.hh" #include "util/exception.hh" @@ -15,11 +17,11 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_ const uint8_t kArrayBhikshaVersion = 0; // TODO: put this in binary file header instead when I change the binary file format again. -void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { - uint8_t version; - uint8_t configured_bits; - util::ReadOrThrow(fd, &version, 1); - util::ReadOrThrow(fd, &configured_bits, 1); +void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { + uint8_t buffer[2]; + file.ReadForConfig(buffer, 2, offset); + uint8_t version = buffer[0]; + uint8_t configured_bits = buffer[1]; if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); config.pointer_bhiksha_bits = configured_bits; } @@ -87,9 +89,6 @@ void ArrayBhiksha::FinishedLoading(const Config &config) { *(head_write++) = config.pointer_bhiksha_bits; } -void ArrayBhiksha::LoadedBinary() { -} - } // namespace trie } // namespace ngram } // namespace lm diff --git a/lm/bhiksha.hh b/lm/bhiksha.hh index 8ff8865..db71766 100644 --- a/lm/bhiksha.hh +++ b/lm/bhiksha.hh @@ -10,8 +10,8 @@ * Currently only used for next pointers. */ -#ifndef LM_BHIKSHA__ -#define LM_BHIKSHA__ +#ifndef LM_BHIKSHA_H +#define LM_BHIKSHA_H #include <stdint.h> #include <assert.h> @@ -24,6 +24,7 @@ namespace lm { namespace ngram { struct Config; +class BinaryFormat; namespace trie { @@ -31,7 +32,7 @@ class DontBhiksha { public: static const ModelType kModelTypeAdd = static_cast<ModelType>(0); - static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} + static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {} static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } @@ -53,8 +54,6 @@ class DontBhiksha { void FinishedLoading(const Config &/*config*/) {} - void LoadedBinary() {} - uint8_t InlineBits() const { return next_.bits; } private: @@ -65,7 +64,7 @@ class ArrayBhiksha { public: static const ModelType kModelTypeAdd = kArrayAdd; - static void UpdateConfigFromBinary(int fd, Config &config); + static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); @@ -93,8 +92,6 @@ class ArrayBhiksha { void FinishedLoading(const Config &config); - void LoadedBinary(); - uint8_t InlineBits() const { return next_inline_.bits; } private: @@ -112,4 +109,4 @@ class ArrayBhiksha { } // namespace ngram } // namespace lm -#endif // LM_BHIKSHA__ +#endif // LM_BHIKSHA_H diff --git a/lm/binary_format.cc b/lm/binary_format.cc index bef51eb..9c744b1 100644 --- a/lm/binary_format.cc +++ b/lm/binary_format.cc @@ -14,6 +14,9 @@ namespace lm { namespace ngram { + +const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; + namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; @@ -58,8 +61,6 @@ struct Sanity { } }; -const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; - std::size_t TotalHeaderSize(unsigned char order) { return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } @@ -81,83 +82,6 @@ void WriteHeader(void *to, const Parameters ¶ms) { } // namespace -uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { - if (config.write_mmap) { - std::size_t total = TotalHeaderSize(order) + memory_size; - backing.file.reset(util::CreateOrThrow(config.write_mmap)); - if (config.write_method == Config::WRITE_MMAP) { - backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); - } else { - util::ResizeOrThrow(backing.file.get(), 0); - util::MapAnonymous(total, backing.vocab); - } - strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); - return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order); - } else { - util::MapAnonymous(memory_size, backing.vocab); - return reinterpret_cast<uint8_t*>(backing.vocab.get()); - } -} - -uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { - std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; - if (config.write_mmap) { - // Grow the file to accomodate the search, using zeros. - try { - util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); - } catch (util::ErrnoException &e) { - e << " for file " << config.write_mmap; - throw e; - } - - if (config.write_method == Config::WRITE_AFTER) { - util::MapAnonymous(memory_size, backing.search); - return reinterpret_cast<uint8_t*>(backing.search.get()); - } - // mmap it now. - // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. - std::size_t page_size = util::SizePage(); - std::size_t alignment_cruft = adjusted_vocab % page_size; - backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); - return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft; - } else { - util::MapAnonymous(memory_size, backing.search); - return reinterpret_cast<uint8_t*>(backing.search.get()); - } -} - -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { - if (!config.write_mmap) return; - switch (config.write_method) { - case Config::WRITE_MMAP: - util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); - util::SyncOrThrow(backing.search.get(), backing.search.size()); - break; - case Config::WRITE_AFTER: - util::SeekOrThrow(backing.file.get(), 0); - util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size()); - util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad); - util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size()); - util::FSyncOrThrow(backing.file.get()); - break; - } - // header and vocab share the same mmap. The header is written here because we know the counts. - Parameters params = Parameters(); - params.counts = counts; - params.fixed.order = counts.size(); - params.fixed.probing_multiplier = config.probing_multiplier; - params.fixed.model_type = model_type; - params.fixed.has_vocabulary = config.include_vocab; - params.fixed.search_version = search_version; - WriteHeader(backing.vocab.get(), params); - if (config.write_method == Config::WRITE_AFTER) { - util::SeekOrThrow(backing.file.get(), 0); - util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size())); - } -} - -namespace detail { - bool IsBinaryFormat(int fd) { const uint64_t size = util::SizeFile(fd); if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false; @@ -209,44 +133,164 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version); } -void SeekPastHeader(int fd, const Parameters ¶ms) { - util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); +const std::size_t kInvalidSize = static_cast<std::size_t>(-1); + +BinaryFormat::BinaryFormat(const Config &config) + : write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method), + header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {} + +void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms) { + file_.reset(fd); + write_mmap_ = NULL; // Ignore write requests; this is already in binary format. + ReadHeader(fd, params); + MatchCheck(model_type, search_version, params); + header_size_ = TotalHeaderSize(params.counts.size()); +} + +void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const { + assert(header_size_ != kInvalidSize); + util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_); } -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { - const uint64_t file_size = util::SizeFile(backing.file.get()); +void *BinaryFormat::LoadBinary(std::size_t size) { + assert(header_size_ != kInvalidSize); + const uint64_t file_size = util::SizeFile(file_.get()); // The header is smaller than a page, so we have to map the whole header as well. - std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); - if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) - UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); + uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size); + UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); - util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search); + util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_); - if (config.enumerate_vocab && !params.fixed.has_vocabulary) - UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); + vocab_string_offset_ = total_map; + return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_; +} + +void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) { + vocab_size_ = memory_size; + if (!write_mmap_) { + header_size_ = 0; + util::MapAnonymous(memory_size, memory_vocab_); + return reinterpret_cast<uint8_t*>(memory_vocab_.get()); + } + header_size_ = TotalHeaderSize(order); + std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size)); + file_.reset(util::CreateOrThrow(write_mmap_)); + // some gccs complain about uninitialized variables even though all enum values are covered. + void *vocab_base = NULL; + switch (write_method_) { + case Config::WRITE_MMAP: + mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); + vocab_base = mapping_.get(); + break; + case Config::WRITE_AFTER: + util::ResizeOrThrow(file_.get(), 0); + util::MapAnonymous(total, memory_vocab_); + vocab_base = memory_vocab_.get(); + break; + } + strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_); + return reinterpret_cast<uint8_t*>(vocab_base) + header_size_; +} - // Seek to vocabulary words - util::SeekOrThrow(backing.file.get(), total_map); - return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size()); +void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) { + assert(vocab_size_ != kInvalidSize); + vocab_pad_ = vocab_pad; + std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size; + vocab_string_offset_ = new_size; + if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) { + util::MapAnonymous(memory_size, memory_search_); + assert(header_size_ == 0 || write_mmap_); + vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_; + return reinterpret_cast<uint8_t*>(memory_search_.get()); + } + + assert(write_method_ == Config::WRITE_MMAP); + // Also known as total size without vocab words. + // Grow the file to accomodate the search, using zeros. + // According to man mmap, behavior is undefined when the file is resized + // underneath a mmap that is not a multiple of the page size. So to be + // safe, we'll unmap it and map it again. + mapping_.reset(); + util::ResizeOrThrow(file_.get(), new_size); + void *ret; + MapFile(vocab_base, ret); + return ret; } -void ComplainAboutARPA(const Config &config, ModelType model_type) { - if (config.write_mmap || !config.messages) return; - if (config.arpa_complain == Config::ALL) { - *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; - } else if (config.arpa_complain == Config::EXPENSIVE && - (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { - *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; +void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) { + // Checking Config's include_vocab is the responsibility of the caller. + assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize); + if (!write_mmap_) { + // Unchanged base. + vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()); + search_base = reinterpret_cast<uint8_t*>(memory_search_.get()); + return; + } + if (write_method_ == Config::WRITE_MMAP) { + mapping_.reset(); + } + util::SeekOrThrow(file_.get(), VocabStringReadingOffset()); + util::WriteOrThrow(file_.get(), &buffer[0], buffer.size()); + if (write_method_ == Config::WRITE_MMAP) { + MapFile(vocab_base, search_base); + } else { + vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_; + search_base = reinterpret_cast<uint8_t*>(memory_search_.get()); + } +} + +void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) { + if (!write_mmap_) return; + switch (write_method_) { + case Config::WRITE_MMAP: + util::SyncOrThrow(mapping_.get(), mapping_.size()); + break; + case Config::WRITE_AFTER: + util::SeekOrThrow(file_.get(), 0); + util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size()); + util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_); + util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size()); + util::FSyncOrThrow(file_.get()); + break; + } + // header and vocab share the same mmap. + Parameters params = Parameters(); + memset(¶ms, 0, sizeof(Parameters)); + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + params.fixed.search_version = search_version; + switch (write_method_) { + case Config::WRITE_MMAP: + WriteHeader(mapping_.get(), params); + util::SyncOrThrow(mapping_.get(), mapping_.size()); + break; + case Config::WRITE_AFTER: + { + std::vector<uint8_t> buffer(TotalHeaderSize(counts.size())); + WriteHeader(&buffer[0], params); + util::SeekOrThrow(file_.get(), 0); + util::WriteOrThrow(file_.get(), &buffer[0], buffer.size()); + } + break; } } -} // namespace detail +void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) { + mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED); + vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_; + search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_; +} bool RecognizeBinary(const char *file, ModelType &recognized) { util::scoped_fd fd(util::OpenReadOrThrow(file)); - if (!detail::IsBinaryFormat(fd.get())) return false; + if (!IsBinaryFormat(fd.get())) { + return false; + } Parameters params; - detail::ReadHeader(fd.get(), params); + ReadHeader(fd.get(), params); recognized = params.fixed.model_type; return true; } diff --git a/lm/binary_format.hh b/lm/binary_format.hh index bf699d5..136d6b1 100644 --- a/lm/binary_format.hh +++ b/lm/binary_format.hh @@ -1,5 +1,5 @@ -#ifndef LM_BINARY_FORMAT__ -#define LM_BINARY_FORMAT__ +#ifndef LM_BINARY_FORMAT_H +#define LM_BINARY_FORMAT_H #include "lm/config.hh" #include "lm/model_type.hh" @@ -17,6 +17,8 @@ namespace lm { namespace ngram { +extern const char *kModelNames[6]; + /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in * this header designed for use by decoder authors. @@ -42,67 +44,63 @@ struct Parameters { std::vector<uint64_t> counts; }; -struct Backing { - // File behind memory, if any. - util::scoped_fd file; - // Vocabulary lookup table. Not to be confused with the vocab words themselves. - util::scoped_memory vocab; - // Raw block of memory backing the language model data structures - util::scoped_memory search; -}; - -// Create just enough of a binary file to write vocabulary to it. -uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); -// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. -uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing); - -// Write header to binary file. This is done last to prevent incomplete files -// from loading. -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing); +class BinaryFormat { + public: + explicit BinaryFormat(const Config &config); + + // Reading a binary file: + // Takes ownership of fd + void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms); + // Used to read parts of the file to update the config object before figuring out full size. + void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const; + // Actually load the binary file and return a pointer to the beginning of the search area. + void *LoadBinary(std::size_t size); + + uint64_t VocabStringReadingOffset() const { + assert(vocab_string_offset_ != kInvalidOffset); + return vocab_string_offset_; + } -namespace detail { + // Writing a binary file or initializing in RAM from ARPA: + // Size for vocabulary. + void *SetupJustVocab(std::size_t memory_size, uint8_t order); + // Warning: can change the vocaulary base pointer. + void *GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base); + // Warning: can change vocabulary and search base addresses. + void WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base); + // Write the header at the beginning of the file. + void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts); + + private: + void MapFile(void *&vocab_base, void *&search_base); + + // Copied from configuration. + const Config::WriteMethod write_method_; + const char *write_mmap_; + util::LoadMethod load_method_; + + // File behind memory, if any. + util::scoped_fd file_; + + // If there is a file involved, a single mapping. + util::scoped_memory mapping_; + + // If the data is only in memory, separately allocate each because the trie + // knows vocab's size before it knows search's size (because SRILM might + // have pruned). + util::scoped_memory memory_vocab_, memory_search_; + + // Memory ranges. Note that these may not be contiguous and may not all + // exist. + std::size_t header_size_, vocab_size_, vocab_pad_; + // aka end of search. + uint64_t vocab_string_offset_; + + static const uint64_t kInvalidOffset = (uint64_t)-1; +}; bool IsBinaryFormat(int fd); -void ReadHeader(int fd, Parameters ¶ms); - -void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms); - -void SeekPastHeader(int fd, const Parameters ¶ms); - -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing); - -void ComplainAboutARPA(const Config &config, ModelType model_type); - -} // namespace detail - -template <class To> void LoadLM(const char *file, const Config &config, To &to) { - Backing &backing = to.MutableBacking(); - backing.file.reset(util::OpenReadOrThrow(file)); - - try { - if (detail::IsBinaryFormat(backing.file.get())) { - Parameters params; - detail::ReadHeader(backing.file.get(), params); - detail::MatchCheck(To::kModelType, To::kVersion, params); - // Replace the run-time configured probing_multiplier with the one in the file. - Config new_config(config); - new_config.probing_multiplier = params.fixed.probing_multiplier; - detail::SeekPastHeader(backing.file.get(), params); - To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); - uint64_t memory_size = To::Size(params.counts, new_config); - uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); - to.InitializeFromBinary(start, params, new_config, backing.file.get()); - } else { - detail::ComplainAboutARPA(config, To::kModelType); - to.InitializeFromARPA(file, config); - } - } catch (util::Exception &e) { - e << " File: " << file; - throw; - } -} - } // namespace ngram } // namespace lm -#endif // LM_BINARY_FORMAT__ +#endif // LM_BINARY_FORMAT_H diff --git a/lm/blank.hh b/lm/blank.hh index 4da8120..94a71ad 100644 --- a/lm/blank.hh +++ b/lm/blank.hh @@ -1,5 +1,5 @@ -#ifndef LM_BLANK__ -#define LM_BLANK__ +#ifndef LM_BLANK_H +#define LM_BLANK_H #include <limits> @@ -40,4 +40,4 @@ inline bool HasExtension(const float &backoff) { } // namespace ngram } // namespace lm -#endif // LM_BLANK__ +#endif // LM_BLANK_H diff --git a/lm/build_binary_main.cc b/lm/build_binary_main.cc index 425a123..15b421e 100644 --- a/lm/build_binary_main.cc +++ b/lm/build_binary_main.cc @@ -52,6 +52,7 @@ void Usage(const char *name, const char *default_mem) { "-a compresses pointers using an array of offsets. The parameter is the\n" " maximum number of bits encoded by the array. Memory is minimized subject\n" " to the maximum, so pick 255 to minimize memory.\n\n" +"-h print this help message.\n\n" "Get a memory estimate by passing an ARPA file without an output file name.\n"; exit(1); } @@ -104,12 +105,15 @@ int main(int argc, char *argv[]) { const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G"; + if (argc == 2 && !strcmp(argv[1], "--help")) + Usage(argv[0], default_mem); + try { bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false; lm::ngram::Config config; config.building_memory = util::ParseSize(default_mem); int opt; - while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) { + while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:h")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -161,6 +165,7 @@ int main(int argc, char *argv[]) { ParseFileList(optarg, config.rest_lower_files); config.rest_function = Config::REST_LOWER; break; + case 'h': // help default: Usage(argv[0], default_mem); } diff --git a/lm/builder/adjust_counts.hh b/lm/builder/adjust_counts.hh index ea8a1e2..00b5d43 100644 --- a/lm/builder/adjust_counts.hh +++ b/lm/builder/adjust_counts.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_ADJUST_COUNTS__ -#define LM_BUILDER_ADJUST_COUNTS__ +#ifndef LM_BUILDER_ADJUST_COUNTS_H +#define LM_BUILDER_ADJUST_COUNTS_H #include "lm/builder/discount.hh" #include "util/exception.hh" @@ -43,5 +43,5 @@ class AdjustCounts { } // namespace builder } // namespace lm -#endif // LM_BUILDER_ADJUST_COUNTS__ +#endif // LM_BUILDER_ADJUST_COUNTS_H diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc index aea93ad..b99edd0 100644 --- a/lm/builder/corpus_count.cc +++ b/lm/builder/corpus_count.cc @@ -223,29 +223,47 @@ std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { return VocabHandout::MemUsage(vocab_estimate); } -CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), - dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { + dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)), + disallowed_symbol_action_(disallowed_symbol) { } -void CorpusCount::Run(const util::stream::ChainPosition &position) { - UTIL_TIMER("(%w s) Counted n-grams\n"); +namespace { + void ComplainDisallowed(StringPiece word, WarningAction &action) { + switch (action) { + case SILENT: + return; + case COMPLAIN: + std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl; + action = SILENT; + return; + case THROW_UP: + UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace."); + } + } +} // namespace +void CorpusCount::Run(const util::stream::ChainPosition &position) { VocabHandout vocab(vocab_write_, type_count_); token_count_ = 0; type_count_ = 0; const WordIndex end_sentence = vocab.Lookup("</s>"); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; - StringPiece delimiters("\0\t\r ", 4); + bool delimiters[256]; + util::BoolCharacter::Build("\0\t\n\r ", delimiters); try { while(true) { StringPiece line(from_.ReadLine()); writer.StartSentence(); - for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) { + for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) { WordIndex word = vocab.Lookup(*w); - UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future."); + if (word <= 2) { + ComplainDisallowed(*w, disallowed_symbol_action_); + continue; + } writer.Append(word); ++count; } diff --git a/lm/builder/corpus_count.hh b/lm/builder/corpus_count.hh index aa0ed8e..da4ff9f 100644 --- a/lm/builder/corpus_count.hh +++ b/lm/builder/corpus_count.hh @@ -1,6 +1,7 @@ -#ifndef LM_BUILDER_CORPUS_COUNT__ -#define LM_BUILDER_CORPUS_COUNT__ +#ifndef LM_BUILDER_CORPUS_COUNT_H +#define LM_BUILDER_CORPUS_COUNT_H +#include "lm/lm_exception.hh" #include "lm/word_index.hh" #include "util/scoped.hh" @@ -28,7 +29,7 @@ class CorpusCount { // token_count: out. // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value. - CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block); + CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol); void Run(const util::stream::ChainPosition &position); @@ -40,8 +41,10 @@ class CorpusCount { std::size_t dedupe_mem_size_; util::scoped_malloc dedupe_mem_; + + WarningAction disallowed_symbol_action_; }; } // namespace builder } // namespace lm -#endif // LM_BUILDER_CORPUS_COUNT__ +#endif // LM_BUILDER_CORPUS_COUNT_H diff --git a/lm/builder/corpus_count_test.cc b/lm/builder/corpus_count_test.cc index 6d325ef..26cb634 100644 --- a/lm/builder/corpus_count_test.cc +++ b/lm/builder/corpus_count_test.cc @@ -45,7 +45,7 @@ BOOST_AUTO_TEST_CASE(Short) { NGramStream stream; uint64_t token_count; WordIndex type_count = 10; - CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize()); + CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize(), SILENT); chain >> boost::ref(counter) >> stream >> util::stream::kRecycle; const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"}; diff --git a/lm/builder/discount.hh b/lm/builder/discount.hh index 4d0aa4f..e2f4084 100644 --- a/lm/builder/discount.hh +++ b/lm/builder/discount.hh @@ -1,5 +1,5 @@ -#ifndef BUILDER_DISCOUNT__ -#define BUILDER_DISCOUNT__ +#ifndef LM_BUILDER_DISCOUNT_H +#define LM_BUILDER_DISCOUNT_H #include <algorithm> @@ -23,4 +23,4 @@ struct Discount { } // namespace builder } // namespace lm -#endif // BUILDER_DISCOUNT__ +#endif // LM_BUILDER_DISCOUNT_H diff --git a/lm/builder/header_info.hh b/lm/builder/header_info.hh index ccca145..16f3f60 100644 --- a/lm/builder/header_info.hh +++ b/lm/builder/header_info.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_HEADER_INFO__ -#define LM_BUILDER_HEADER_INFO__ +#ifndef LM_BUILDER_HEADER_INFO_H +#define LM_BUILDER_HEADER_INFO_H #include <string> #include <stdint.h> diff --git a/lm/builder/initial_probabilities.hh b/lm/builder/initial_probabilities.hh index 0729a6a..f2a2975 100644 --- a/lm/builder/initial_probabilities.hh +++ b/lm/builder/initial_probabilities.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_INITIAL_PROBABILITIES__ -#define LM_BUILDER_INITIAL_PROBABILITIES__ +#ifndef LM_BUILDER_INITIAL_PROBABILITIES_H +#define LM_BUILDER_INITIAL_PROBABILITIES_H #include "lm/builder/discount.hh" #include "util/stream/config.hh" @@ -31,4 +31,4 @@ void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::v } // namespace builder } // namespace lm -#endif // LM_BUILDER_INITIAL_PROBABILITIES__ +#endif // LM_BUILDER_INITIAL_PROBABILITIES_H diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc index cbf4bb4..4f8ffed 100644 --- a/lm/builder/interpolate.cc +++ b/lm/builder/interpolate.cc @@ -5,6 +5,7 @@ #include "lm/builder/multi_stream.hh" #include "lm/builder/sort.hh" #include "lm/lm_exception.hh" +#include "util/fixed_array.hh" #include "util/murmur_hash.hh" #include <assert.h> @@ -74,15 +75,17 @@ class Callback { void Exit(unsigned, const NGram &) const {} private: - FixedArray<util::stream::Stream> backoffs_; + util::FixedArray<util::stream::Stream> backoffs_; std::vector<float> probs_; const std::vector<uint64_t>& prune_thresholds_; }; } // namespace -Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds) - : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs), prune_thresholds_(prune_thresholds) {} +Interpolate::Interpolate(uint64_t vocab_size, const ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds) + : uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>. + backoffs_(backoffs), + prune_thresholds_(prune_thresholds) {} // perform order-wise interpolation void Interpolate::Run(const ChainPositions &positions) { diff --git a/lm/builder/interpolate.hh b/lm/builder/interpolate.hh index cc372b4..d266191 100644 --- a/lm/builder/interpolate.hh +++ b/lm/builder/interpolate.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_INTERPOLATE__ -#define LM_BUILDER_INTERPOLATE__ +#ifndef LM_BUILDER_INTERPOLATE_H +#define LM_BUILDER_INTERPOLATE_H #include <stdint.h> @@ -14,8 +14,9 @@ namespace lm { namespace builder { */ class Interpolate { public: - explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs, - const std::vector<uint64_t> &prune_thresholds_); + // Normally vocab_size is the unigram count-1 (since p(<s>) = 0) but might + // be larger when the user specifies a consistent vocabulary size. + explicit Interpolate(uint64_t vocab_size, const ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds); void Run(const ChainPositions &positions); @@ -26,4 +27,4 @@ class Interpolate { }; }} // namespaces -#endif // LM_BUILDER_INTERPOLATE__ +#endif // LM_BUILDER_INTERPOLATE_H diff --git a/lm/builder/joint_order.hh b/lm/builder/joint_order.hh index b562014..b9c22a0 100644 --- a/lm/builder/joint_order.hh +++ b/lm/builder/joint_order.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_JOINT_ORDER__ -#define LM_BUILDER_JOINT_ORDER__ +#ifndef LM_BUILDER_JOINT_ORDER_H +#define LM_BUILDER_JOINT_ORDER_H #include "lm/builder/multi_stream.hh" #include "lm/lm_exception.hh" @@ -40,4 +40,4 @@ template <class Callback, class Compare> void JointOrder(const ChainPositions &p }} // namespaces -#endif // LM_BUILDER_JOINT_ORDER__ +#endif // LM_BUILDER_JOINT_ORDER_H diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc index d09028c..e09f9df 100644 --- a/lm/builder/lmplz_main.cc +++ b/lm/builder/lmplz_main.cc @@ -1,4 +1,5 @@ #include "lm/builder/pipeline.hh" +#include "lm/lm_exception.hh" #include "util/file.hh" #include "util/file_piece.hh" #include "util/usage.hh" @@ -37,24 +38,30 @@ int main(int argc, char *argv[]) { std::string text, arpa; options.add_options() + ("help,h", po::bool_switch(), "Show this help message") ("order,o", po::value<std::size_t>(&pipeline.order) #if BOOST_VERSION >= 104200 ->required() #endif , "Order of the model") ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") + ("skip_symbols", po::bool_switch(), "Treat <s>, </s>, and <unk> as whitespace instead of throwing an exception") ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") - ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table") ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)") - ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") + ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table") + ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write a file containing the unique vocabulary strings delimited by null bytes") + ("vocab_pad", po::value<std::size_t>(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with <unk> to reach this size. Requires --interpolate_unigrams") ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.") ("text", po::value<std::string>(&text), "Read text from a file instead of stdin") ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout") ("prune_thresholds,P", po::value<std::vector<uint64_t> >(&pipeline.prune_thresholds), "Prune n-grams of count equal to or lower than threshold. 0 means no pruning"); - if (argc == 1) { + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, options), vm); + + if (argc == 1 || vm["help"].as<bool>()) { std::cerr << "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" "Please cite:\n" @@ -72,12 +79,17 @@ int main(int argc, char *argv[]) { "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n" "Memory sizes are specified like GNU sort: a number followed by a unit character.\n" "Valid units are \% for percentage of memory (supported platforms only) and (in\n" - "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n"; + "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n"; + uint64_t mem = util::GuessPhysicalMemory(); + if (mem) { + std::cerr << "This machine has " << mem << " bytes of memory.\n\n"; + } else { + std::cerr << "Unable to determine the amount of memory on this machine.\n\n"; + } std::cerr << options << std::endl; return 1; } - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, options), vm); + po::notify(vm); //std::cerr << "vector: " << pipeline.counts_threshold.size() << std::endl; @@ -120,6 +132,17 @@ int main(int argc, char *argv[]) { } #endif + if (pipeline.vocab_size_for_unk && !pipeline.initial_probs.interpolate_unigrams) { + std::cerr << "--vocab_pad requires --interpolate_unigrams" << std::endl; + return 1; + } + + if (vm["skip_symbols"].as<bool>()) { + pipeline.disallowed_symbol_action = lm::COMPLAIN; + } else { + pipeline.disallowed_symbol_action = lm::THROW_UP; + } + util::NormalizeTempPrefix(pipeline.sort.temp_prefix); lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; diff --git a/lm/builder/multi_stream.hh b/lm/builder/multi_stream.hh index 707a98c..1a8eb8b 100644 --- a/lm/builder/multi_stream.hh +++ b/lm/builder/multi_stream.hh @@ -1,7 +1,8 @@ -#ifndef LM_BUILDER_MULTI_STREAM__ -#define LM_BUILDER_MULTI_STREAM__ +#ifndef LM_BUILDER_MULTI_STREAM_H +#define LM_BUILDER_MULTI_STREAM_H #include "lm/builder/ngram_stream.hh" +#include "util/fixed_array.hh" #include "util/scoped.hh" #include "util/stream/chain.hh" @@ -13,72 +14,9 @@ namespace lm { namespace builder { -template <class T> class FixedArray { - public: - explicit FixedArray(std::size_t count) { - Init(count); - } - - FixedArray() : newed_end_(NULL) {} - - void Init(std::size_t count) { - assert(!block_.get()); - block_.reset(malloc(sizeof(T) * count)); - if (!block_.get()) throw std::bad_alloc(); - newed_end_ = begin(); - } - - FixedArray(const FixedArray &from) { - std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get()); - Init(size); - for (std::size_t i = 0; i < size; ++i) { - new(end()) T(from[i]); - Constructed(); - } - } - - ~FixedArray() { clear(); } - - T *begin() { return static_cast<T*>(block_.get()); } - const T *begin() const { return static_cast<const T*>(block_.get()); } - // Always call Constructed after successful completion of new. - T *end() { return newed_end_; } - const T *end() const { return newed_end_; } - - T &back() { return *(end() - 1); } - const T &back() const { return *(end() - 1); } - - std::size_t size() const { return end() - begin(); } - bool empty() const { return begin() == end(); } - - T &operator[](std::size_t i) { return begin()[i]; } - const T &operator[](std::size_t i) const { return begin()[i]; } - - template <class C> void push_back(const C &c) { - new (end()) T(c); - Constructed(); - } - - void clear() { - for (T *i = begin(); i != end(); ++i) - i->~T(); - newed_end_ = begin(); - } - - protected: - void Constructed() { - ++newed_end_; - } - - private: - util::scoped_malloc block_; - - T *newed_end_; -}; - class Chains; -class ChainPositions : public FixedArray<util::stream::ChainPosition> { +class ChainPositions : public util::FixedArray<util::stream::ChainPosition> { public: ChainPositions() {} @@ -89,14 +27,14 @@ class ChainPositions : public FixedArray<util::stream::ChainPosition> { } }; -class Chains : public FixedArray<util::stream::Chain> { +class Chains : public util::FixedArray<util::stream::Chain> { private: template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun { typedef Chains type; }; public: - explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {} + explicit Chains(std::size_t limit) : util::FixedArray<util::stream::Chain>(limit) {} template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) { threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); @@ -129,7 +67,7 @@ class Chains : public FixedArray<util::stream::Chain> { }; inline void ChainPositions::Init(Chains &chains) { - FixedArray<util::stream::ChainPosition>::Init(chains.size()); + util::FixedArray<util::stream::ChainPosition>::Init(chains.size()); for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) { new (end()) util::stream::ChainPosition(i->Add()); Constructed(); } @@ -140,13 +78,13 @@ inline Chains &operator>>(Chains &chains, ChainPositions &positions) { return chains; } -class NGramStreams : public FixedArray<NGramStream> { +class NGramStreams : public util::FixedArray<NGramStream> { public: NGramStreams() {} // This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning). void InitWithDummy(const ChainPositions &positions) { - FixedArray<NGramStream>::Init(positions.size() + 1); + util::FixedArray<NGramStream>::Init(positions.size() + 1); new (end()) NGramStream(); Constructed(); for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) { push_back(*i); @@ -155,7 +93,7 @@ class NGramStreams : public FixedArray<NGramStream> { // Limit restricts to positions[0,limit) void Init(const ChainPositions &positions, std::size_t limit) { - FixedArray<NGramStream>::Init(limit); + util::FixedArray<NGramStream>::Init(limit); for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) { push_back(*i); } @@ -177,4 +115,4 @@ inline Chains &operator>>(Chains &chains, NGramStreams &streams) { } }} // namespaces -#endif // LM_BUILDER_MULTI_STREAM__ +#endif // LM_BUILDER_MULTI_STREAM_H diff --git a/lm/builder/ngram.hh b/lm/builder/ngram.hh index 756eaa6..0472bcb 100644 --- a/lm/builder/ngram.hh +++ b/lm/builder/ngram.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_NGRAM__ -#define LM_BUILDER_NGRAM__ +#ifndef LM_BUILDER_NGRAM_H +#define LM_BUILDER_NGRAM_H #include "lm/weights.hh" #include "lm/word_index.hh" @@ -106,4 +106,4 @@ const WordIndex kEOS = 2; } // namespace builder } // namespace lm -#endif // LM_BUILDER_NGRAM__ +#endif // LM_BUILDER_NGRAM_H diff --git a/lm/builder/ngram_stream.hh b/lm/builder/ngram_stream.hh index 3c99466..d7bf23a 100644 --- a/lm/builder/ngram_stream.hh +++ b/lm/builder/ngram_stream.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_NGRAM_STREAM__ -#define LM_BUILDER_NGRAM_STREAM__ +#ifndef LM_BUILDER_NGRAM_STREAM_H +#define LM_BUILDER_NGRAM_STREAM_H #include "lm/builder/ngram.hh" #include "util/stream/chain.hh" @@ -52,4 +52,4 @@ inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream & } }} // namespaces -#endif // LM_BUILDER_NGRAM_STREAM__ +#endif // LM_BUILDER_NGRAM_STREAM_H diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc index f5548f7..cede3c7 100644 --- a/lm/builder/pipeline.cc +++ b/lm/builder/pipeline.cc @@ -204,7 +204,7 @@ class Master { Chains chains_; // Often only unigrams, but sometimes all orders. - FixedArray<util::stream::FileBuffer> files_; + util::FixedArray<util::stream::FileBuffer> files_; }; void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) { @@ -225,17 +225,18 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m WordIndex type_count = config.vocab_estimate; util::FilePiece text(text_file, NULL, &std::cerr); text_file_name = text.FileName(); - CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); + CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action); chain >> boost::ref(counter); util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); chain.Wait(true); + std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl; std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl; master.InitForAdjust(sorter, type_count); } void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, - FixedArray<util::stream::FileBuffer> &gammas, std::vector<uint64_t> &prune_thresholds) { + util::FixedArray<util::stream::FileBuffer> &gammas, std::vector<uint64_t> &prune_thresholds) { const PipelineConfig &config = master.Config(); Chains second(config.order); @@ -261,7 +262,7 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector master.SetupSorts(primary); } -void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { +void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas) { std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl; const PipelineConfig &config = master.Config(); master.MaximumLazyInput(counts, primary); @@ -279,7 +280,7 @@ void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &maste gamma_chains.push_back(read_backoffs); gamma_chains.back() >> gammas[i].Source(); } - master >> Interpolate(counts[0], ChainPositions(gamma_chains), config.prune_thresholds); + master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), ChainPositions(gamma_chains), config.prune_thresholds); gamma_chains >> util::stream::kRecycle; master.BufferFinal(counts); } @@ -316,7 +317,7 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) { master >> AdjustCounts(counts, counts_pruned, discounts, config.prune_thresholds); { - FixedArray<util::stream::FileBuffer> gammas; + util::FixedArray<util::stream::FileBuffer> gammas; Sorts<SuffixOrder> primary; InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds); InterpolateProbabilities(counts_pruned, master, primary, gammas); diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh index a937169..4395622 100644 --- a/lm/builder/pipeline.hh +++ b/lm/builder/pipeline.hh @@ -1,8 +1,9 @@ -#ifndef LM_BUILDER_PIPELINE__ -#define LM_BUILDER_PIPELINE__ +#ifndef LM_BUILDER_PIPELINE_H +#define LM_BUILDER_PIPELINE_H #include "lm/builder/initial_probabilities.hh" #include "lm/builder/header_info.hh" +#include "lm/lm_exception.hh" #include "lm/word_index.hh" #include "util/stream/config.hh" #include "util/file_piece.hh" @@ -34,6 +35,24 @@ struct PipelineConfig { // corresponding n-gram order std::vector<uint64_t> prune_thresholds; //mjd + /* Computing the perplexity of LMs with different vocabularies is hard. For + * example, the lowest perplexity is attained by a unigram model that + * predicts p(<unk>) = 1 and has no other vocabulary. Also, linearly + * interpolated models will sum to more than 1 because <unk> is duplicated + * (SRI just pretends p(<unk>) = 0 for these purposes, which makes it sum to + * 1 but comes with its own problems). This option will make the vocabulary + * a particular size by replicating <unk> multiple times for purposes of + * computing vocabulary size. It has no effect if the actual vocabulary is + * larger. This parameter serves the same purpose as IRSTLM's "dub". + */ + uint64_t vocab_size_for_unk; + + /* What to do the first time <s>, </s>, or <unk> appears in the input. If + * this is anything but THROW_UP, then the symbol will always be treated as + * whitespace. + */ + WarningAction disallowed_symbol_action; + const std::string &TempPrefix() const { return sort.temp_prefix; } std::size_t TotalMemory() const { return sort.total_memory; } }; @@ -42,4 +61,4 @@ struct PipelineConfig { void Pipeline(PipelineConfig config, int text_file, int out_arpa); }} // namespaces -#endif // LM_BUILDER_PIPELINE__ +#endif // LM_BUILDER_PIPELINE_H diff --git a/lm/builder/print.hh b/lm/builder/print.hh index adbbb94..397ca95 100644 --- a/lm/builder/print.hh +++ b/lm/builder/print.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_PRINT__ -#define LM_BUILDER_PRINT__ +#ifndef LM_BUILDER_PRINT_H +#define LM_BUILDER_PRINT_H #include "lm/builder/ngram.hh" #include "lm/builder/multi_stream.hh" @@ -100,4 +100,4 @@ class PrintARPA { }; }} // namespaces -#endif // LM_BUILDER_PRINT__ +#endif // LM_BUILDER_PRINT_H diff --git a/lm/builder/sort.hh b/lm/builder/sort.hh index 9989389..c7f2ff8 100644 --- a/lm/builder/sort.hh +++ b/lm/builder/sort.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_SORT__ -#define LM_BUILDER_SORT__ +#ifndef LM_BUILDER_SORT_H +#define LM_BUILDER_SORT_H #include "lm/builder/multi_stream.hh" #include "lm/builder/ngram.hh" @@ -85,10 +85,10 @@ struct AddCombiner { // The combiner is only used on a single chain, so I didn't bother to allow // that template. -template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > { +template <class Compare> class Sorts : public util::FixedArray<util::stream::Sort<Compare> > { private: typedef util::stream::Sort<Compare> S; - typedef FixedArray<S> P; + typedef util::FixedArray<S> P; public: void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) { @@ -100,4 +100,4 @@ template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Comp } // namespace builder } // namespace lm -#endif // LM_BUILDER_SORT__ +#endif // LM_BUILDER_SORT_H diff --git a/lm/config.hh b/lm/config.hh index 0de7b7c..dab2812 100644 --- a/lm/config.hh +++ b/lm/config.hh @@ -1,5 +1,5 @@ -#ifndef LM_CONFIG__ -#define LM_CONFIG__ +#ifndef LM_CONFIG_H +#define LM_CONFIG_H #include "lm/lm_exception.hh" #include "util/mmap.hh" @@ -120,4 +120,4 @@ struct Config { } /* namespace ngram */ } /* namespace lm */ -#endif // LM_CONFIG__ +#endif // LM_CONFIG_H diff --git a/lm/enumerate_vocab.hh b/lm/enumerate_vocab.hh index 2726362..f5ce789 100644 --- a/lm/enumerate_vocab.hh +++ b/lm/enumerate_vocab.hh @@ -1,5 +1,5 @@ -#ifndef LM_ENUMERATE_VOCAB__ -#define LM_ENUMERATE_VOCAB__ +#ifndef LM_ENUMERATE_VOCAB_H +#define LM_ENUMERATE_VOCAB_H #include "lm/word_index.hh" #include "util/string_piece.hh" @@ -24,5 +24,5 @@ class EnumerateVocab { } // namespace lm -#endif // LM_ENUMERATE_VOCAB__ +#endif // LM_ENUMERATE_VOCAB_H diff --git a/lm/facade.hh b/lm/facade.hh index 760e839..8e12b62 100644 --- a/lm/facade.hh +++ b/lm/facade.hh @@ -1,5 +1,5 @@ -#ifndef LM_FACADE__ -#define LM_FACADE__ +#ifndef LM_FACADE_H +#define LM_FACADE_H #include "lm/virtual_interface.hh" #include "util/string_piece.hh" @@ -17,14 +17,14 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ typedef VocabularyT Vocabulary; /* Translate from void* to State */ - FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const { + FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const { return static_cast<const Child*>(this)->FullScore( *reinterpret_cast<const State*>(in_state), new_word, *reinterpret_cast<State*>(out_state)); } - FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const { + FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const { return static_cast<const Child*>(this)->FullScoreForgotState( context_rbegin, context_rend, @@ -37,7 +37,7 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob; } - float Score(const void *in_state, const WordIndex new_word, void *out_state) const { + float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const { return static_cast<const Child*>(this)->Score( *reinterpret_cast<const State*>(in_state), new_word, @@ -70,4 +70,4 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ } // mamespace base } // namespace lm -#endif // LM_FACADE__ +#endif // LM_FACADE_H diff --git a/lm/filter/arpa_io.hh b/lm/filter/arpa_io.hh index 5b31620..99c97b1 100644 --- a/lm/filter/arpa_io.hh +++ b/lm/filter/arpa_io.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_ARPA_IO__ -#define LM_FILTER_ARPA_IO__ +#ifndef LM_FILTER_ARPA_IO_H +#define LM_FILTER_ARPA_IO_H /* Input and output for ARPA format language model files. */ #include "lm/read_arpa.hh" @@ -14,7 +14,6 @@ #include <string> #include <vector> -#include <err.h> #include <string.h> #include <stdint.h> @@ -112,4 +111,4 @@ template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) { } // namespace lm -#endif // LM_FILTER_ARPA_IO__ +#endif // LM_FILTER_ARPA_IO_H diff --git a/lm/filter/count_io.hh b/lm/filter/count_io.hh index 97c0fa2..de894ba 100644 --- a/lm/filter/count_io.hh +++ b/lm/filter/count_io.hh @@ -1,24 +1,22 @@ -#ifndef LM_FILTER_COUNT_IO__ -#define LM_FILTER_COUNT_IO__ +#ifndef LM_FILTER_COUNT_IO_H +#define LM_FILTER_COUNT_IO_H #include <fstream> #include <iostream> #include <string> -#include <err.h> - +#include "util/fake_ofstream.hh" +#include "util/file.hh" #include "util/file_piece.hh" namespace lm { class CountOutput : boost::noncopyable { public: - explicit CountOutput(const char *name) : file_(name, std::ios::out) {} + explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {} void AddNGram(const StringPiece &line) { - if (!(file_ << line << '\n')) { - err(3, "Writing counts file failed"); - } + file_ << line << '\n'; } template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { @@ -30,7 +28,7 @@ class CountOutput : boost::noncopyable { } private: - std::fstream file_; + util::FakeOFStream file_; }; class CountBatch { @@ -88,4 +86,4 @@ template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) { } // namespace lm -#endif // LM_FILTER_COUNT_IO__ +#endif // LM_FILTER_COUNT_IO_H diff --git a/lm/filter/filter_main.cc b/lm/filter/filter_main.cc index 1736bc4..82fdc1e 100644 --- a/lm/filter/filter_main.cc +++ b/lm/filter/filter_main.cc @@ -6,6 +6,7 @@ #endif #include "lm/filter/vocab.hh" #include "lm/filter/wrapper.hh" +#include "util/exception.hh" #include "util/file_piece.hh" #include <boost/ptr_container/ptr_vector.hpp> @@ -157,92 +158,96 @@ template <class Format> void DispatchFilterModes(const Config &config, std::istr } // namespace lm int main(int argc, char *argv[]) { - if (argc < 4) { - lm::DisplayHelp(argv[0]); - return 1; - } + try { + if (argc < 4) { + lm::DisplayHelp(argv[0]); + return 1; + } - // I used to have boost::program_options, but some users didn't want to compile boost. - lm::Config config; - config.mode = lm::MODE_UNSET; - for (int i = 1; i < argc - 2; ++i) { - const char *str = argv[i]; - if (!std::strcmp(str, "copy")) { - config.mode = lm::MODE_COPY; - } else if (!std::strcmp(str, "single")) { - config.mode = lm::MODE_SINGLE; - } else if (!std::strcmp(str, "multiple")) { - config.mode = lm::MODE_MULTIPLE; - } else if (!std::strcmp(str, "union")) { - config.mode = lm::MODE_UNION; - } else if (!std::strcmp(str, "phrase")) { - config.phrase = true; - } else if (!std::strcmp(str, "context")) { - config.context = true; - } else if (!std::strcmp(str, "arpa")) { - config.format = lm::FORMAT_ARPA; - } else if (!std::strcmp(str, "raw")) { - config.format = lm::FORMAT_COUNT; + // I used to have boost::program_options, but some users didn't want to compile boost. + lm::Config config; + config.mode = lm::MODE_UNSET; + for (int i = 1; i < argc - 2; ++i) { + const char *str = argv[i]; + if (!std::strcmp(str, "copy")) { + config.mode = lm::MODE_COPY; + } else if (!std::strcmp(str, "single")) { + config.mode = lm::MODE_SINGLE; + } else if (!std::strcmp(str, "multiple")) { + config.mode = lm::MODE_MULTIPLE; + } else if (!std::strcmp(str, "union")) { + config.mode = lm::MODE_UNION; + } else if (!std::strcmp(str, "phrase")) { + config.phrase = true; + } else if (!std::strcmp(str, "context")) { + config.context = true; + } else if (!std::strcmp(str, "arpa")) { + config.format = lm::FORMAT_ARPA; + } else if (!std::strcmp(str, "raw")) { + config.format = lm::FORMAT_COUNT; #ifndef NTHREAD - } else if (!std::strncmp(str, "threads:", 8)) { - config.threads = boost::lexical_cast<size_t>(str + 8); - if (!config.threads) { - std::cerr << "Specify at least one thread." << std::endl; + } else if (!std::strncmp(str, "threads:", 8)) { + config.threads = boost::lexical_cast<size_t>(str + 8); + if (!config.threads) { + std::cerr << "Specify at least one thread." << std::endl; + return 1; + } + } else if (!std::strncmp(str, "batch_size:", 11)) { + config.batch_size = boost::lexical_cast<size_t>(str + 11); + if (config.batch_size < 5000) { + std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; + if (!config.batch_size) return 1; + } +#endif + } else { + lm::DisplayHelp(argv[0]); return 1; } - } else if (!std::strncmp(str, "batch_size:", 11)) { - config.batch_size = boost::lexical_cast<size_t>(str + 11); - if (config.batch_size < 5000) { - std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; - if (!config.batch_size) return 1; - } -#endif - } else { + } + + if (config.mode == lm::MODE_UNSET) { lm::DisplayHelp(argv[0]); return 1; } - } - - if (config.mode == lm::MODE_UNSET) { - lm::DisplayHelp(argv[0]); - return 1; - } - if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { - std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; - return 1; - } + if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { + std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; + return 1; + } - bool cmd_is_model = true; - const char *cmd_input = argv[argc - 2]; - if (!strncmp(cmd_input, "vocab:", 6)) { - cmd_is_model = false; - cmd_input += 6; - } else if (!strncmp(cmd_input, "model:", 6)) { - cmd_input += 6; - } else if (strchr(cmd_input, ':')) { - errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); - } else { - std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; - } - std::ifstream cmd_file; - std::istream *vocab; - if (cmd_is_model) { - vocab = &std::cin; - } else { - cmd_file.open(cmd_input, std::ios::in); - if (!cmd_file) { - err(2, "Could not open input file %s", cmd_input); + bool cmd_is_model = true; + const char *cmd_input = argv[argc - 2]; + if (!strncmp(cmd_input, "vocab:", 6)) { + cmd_is_model = false; + cmd_input += 6; + } else if (!strncmp(cmd_input, "model:", 6)) { + cmd_input += 6; + } else if (strchr(cmd_input, ':')) { + std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl; + return 1; + } else { + std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; + } + std::ifstream cmd_file; + std::istream *vocab; + if (cmd_is_model) { + vocab = &std::cin; + } else { + cmd_file.open(cmd_input, std::ios::in); + UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input); + vocab = &cmd_file; } - vocab = &cmd_file; - } - util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); + util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); - if (config.format == lm::FORMAT_ARPA) { - lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); - } else if (config.format == lm::FORMAT_COUNT) { - lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + if (config.format == lm::FORMAT_ARPA) { + lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); + } else if (config.format == lm::FORMAT_COUNT) { + lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + } + return 0; + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; } - return 0; } diff --git a/lm/filter/format.hh b/lm/filter/format.hh index 7f945b0..5a2e2db 100644 --- a/lm/filter/format.hh +++ b/lm/filter/format.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_FORMAT_H__ -#define LM_FITLER_FORMAT_H__ +#ifndef LM_FILTER_FORMAT_H +#define LM_FILTER_FORMAT_H #include "lm/filter/arpa_io.hh" #include "lm/filter/count_io.hh" @@ -247,4 +247,4 @@ class MultipleOutputBuffer { } // namespace lm -#endif // LM_FILTER_FORMAT_H__ +#endif // LM_FILTER_FORMAT_H diff --git a/lm/filter/phrase.hh b/lm/filter/phrase.hh index e8e8583..e5898c9 100644 --- a/lm/filter/phrase.hh +++ b/lm/filter/phrase.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_PHRASE_H__ -#define LM_FILTER_PHRASE_H__ +#ifndef LM_FILTER_PHRASE_H +#define LM_FILTER_PHRASE_H #include "util/murmur_hash.hh" #include "util/string_piece.hh" @@ -165,4 +165,4 @@ class Multiple : public detail::ConditionCommon { } // namespace phrase } // namespace lm -#endif // LM_FILTER_PHRASE_H__ +#endif // LM_FILTER_PHRASE_H diff --git a/lm/filter/thread.hh b/lm/filter/thread.hh index e785b26..6a6523f 100644 --- a/lm/filter/thread.hh +++ b/lm/filter/thread.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_THREAD_H__ -#define LM_FILTER_THREAD_H__ +#ifndef LM_FILTER_THREAD_H +#define LM_FILTER_THREAD_H #include "util/thread_pool.hh" @@ -164,4 +164,4 @@ template <class Filter, class OutputBuffer, class RealOutput> class Controller : } // namespace lm -#endif // LM_FILTER_THREAD_H__ +#endif // LM_FILTER_THREAD_H diff --git a/lm/filter/vocab.cc b/lm/filter/vocab.cc index 7ee4e84..011ab59 100644 --- a/lm/filter/vocab.cc +++ b/lm/filter/vocab.cc @@ -4,7 +4,6 @@ #include <iostream> #include <ctype.h> -#include <err.h> namespace lm { namespace vocab { diff --git a/lm/filter/vocab.hh b/lm/filter/vocab.hh index 7f0fada..2ee6e1f 100644 --- a/lm/filter/vocab.hh +++ b/lm/filter/vocab.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_VOCAB_H__ -#define LM_FILTER_VOCAB_H__ +#ifndef LM_FILTER_VOCAB_H +#define LM_FILTER_VOCAB_H // Vocabulary-based filters for language models. @@ -130,4 +130,4 @@ class Multiple { } // namespace vocab } // namespace lm -#endif // LM_FILTER_VOCAB_H__ +#endif // LM_FILTER_VOCAB_H diff --git a/lm/filter/wrapper.hh b/lm/filter/wrapper.hh index 90b07a0..822c5c2 100644 --- a/lm/filter/wrapper.hh +++ b/lm/filter/wrapper.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_WRAPPER_H__ -#define LM_FILTER_WRAPPER_H__ +#ifndef LM_FILTER_WRAPPER_H +#define LM_FILTER_WRAPPER_H #include "util/string_piece.hh" @@ -39,20 +39,18 @@ template <class FilterT> class ContextFilter { explicit ContextFilter(Filter &backend) : backend_(backend) {} template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { - pieces_.clear(); - // TODO: this copy could be avoided by a lookahead iterator. - std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_)); - backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output); + // Find beginning of string or last space. + const char *last_space; + for (last_space = ngram.data() + ngram.size() - 1; last_space > ngram.data() && *last_space != ' '; --last_space) {} + backend_.AddNGram(StringPiece(ngram.data(), last_space - ngram.data()), line, output); } void Flush() const {} private: - std::vector<StringPiece> pieces_; - Filter backend_; }; } // namespace lm -#endif // LM_FILTER_WRAPPER_H__ +#endif // LM_FILTER_WRAPPER_H @@ -35,8 +35,8 @@ * phrase, even if hypotheses are generated left-to-right. */ -#ifndef LM_LEFT__ -#define LM_LEFT__ +#ifndef LM_LEFT_H +#define LM_LEFT_H #include "lm/max_order.hh" #include "lm/state.hh" @@ -213,4 +213,4 @@ template <class M> class RuleScore { } // namespace ngram } // namespace lm -#endif // LM_LEFT__ +#endif // LM_LEFT_H diff --git a/lm/lm_exception.hh b/lm/lm_exception.hh index f607ced..8bb6108 100644 --- a/lm/lm_exception.hh +++ b/lm/lm_exception.hh @@ -1,5 +1,5 @@ -#ifndef LM_LM_EXCEPTION__ -#define LM_LM_EXCEPTION__ +#ifndef LM_LM_EXCEPTION_H +#define LM_LM_EXCEPTION_H // Named to avoid conflict with util/exception.hh. diff --git a/lm/max_order.hh b/lm/max_order.hh index 3eb97cc..f7344cd 100644 --- a/lm/max_order.hh +++ b/lm/max_order.hh @@ -1,9 +1,13 @@ -/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM. +#ifndef LM_MAX_ORDER_H +#define LM_MAX_ORDER_H +/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER_H, THEN CHANGE THE BUILD SYSTEM. * If not, this is the default maximum order. * Having this limit means that State can be * (kMaxOrder - 1) * sizeof(float) bytes instead of * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead */ #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER_H, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif + +#endif // LM_MAX_ORDER_H diff --git a/lm/model.cc b/lm/model.cc index a26654a..a5a16bf 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -34,23 +34,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size); } -template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { - LoadLM(file, config, *this); - - // g++ prints warnings unless these are fully initialized. - State begin_sentence = State(); - begin_sentence.length = 1; - begin_sentence.words[0] = vocab_.BeginSentence(); - typename Search::Node ignored_node; - bool ignored_independent_left; - uint64_t ignored_extend_left; - begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); - State null_context = State(); - null_context.length = 0; - P::Init(begin_sentence, null_context, vocab_, search_.Order()); +namespace { +void ComplainAboutARPA(const Config &config, ModelType model_type) { + if (config.write_mmap || !config.messages) return; + if (config.arpa_complain == Config::ALL) { + *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; + } else if (config.arpa_complain == Config::EXPENSIVE && + (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { + *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; + } } -namespace { void CheckCounts(const std::vector<uint64_t> &counts) { UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE); if (sizeof(uint64_t) > sizeof(std::size_t)) { @@ -59,18 +53,45 @@ void CheckCounts(const std::vector<uint64_t> &counts) { } } } + } // namespace -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { - CheckCounts(params.counts); - SetupMemory(start, params.counts, config); - vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab); - search_.LoadedBinary(); +template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) { + util::scoped_fd fd(util::OpenReadOrThrow(file)); + if (IsBinaryFormat(fd.get())) { + Parameters parameters; + int fd_shallow = fd.release(); + backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters); + CheckCounts(parameters.counts); + + Config new_config(init_config); + new_config.probing_multiplier = parameters.fixed.probing_multiplier; + Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config); + UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); + + SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config); + vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset()); + } else { + ComplainAboutARPA(init_config, kModelType); + InitializeFromARPA(fd.release(), file, init_config); + } + + // g++ prints warnings unless these are fully initialized. + State begin_sentence = State(); + begin_sentence.length = 1; + begin_sentence.words[0] = vocab_.BeginSentence(); + typename Search::Node ignored_node; + bool ignored_independent_left; + uint64_t ignored_extend_left; + begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); + State null_context = State(); + null_context.length = 0; + P::Init(begin_sentence, null_context, vocab_, search_.Order()); } -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { - // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. - util::FilePiece f(backing_.file.release(), file, config.ProgressMessages()); +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) { + // Backing file is the ARPA. + util::FilePiece f(fd, file, config.ProgressMessages()); try { std::vector<uint64_t> counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. @@ -81,13 +102,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. - vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); + vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config); - if (config.write_mmap) { + if (config.write_mmap && config.include_vocab) { WriteWordsWrapper wrap(config.enumerate_vocab); vocab_.ConfigureEnumerate(&wrap, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config)); + void *vocab_rebase, *search_rebase; + backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase); + // Due to writing at the end of file, mmap may have relocated data. So remap. + vocab_.Relocate(vocab_rebase); + search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config); } else { vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -99,18 +124,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.UnknownUnigram().backoff = 0.0; search_.UnknownUnigram().prob = config.unknown_missing_logprob; } - FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_); + backing_.FinishFile(config, kModelType, kVersion, counts); } catch (util::Exception &e) { e << " Byte: " << f.Offset(); throw; } } -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { - util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); - Search::UpdateConfigFromBinary(fd, counts, config); -} - template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) { diff --git a/lm/model.hh b/lm/model.hh index c9c17c4..6925a56 100644 --- a/lm/model.hh +++ b/lm/model.hh @@ -1,5 +1,5 @@ -#ifndef LM_MODEL__ -#define LM_MODEL__ +#ifndef LM_MODEL_H +#define LM_MODEL_H #include "lm/bhiksha.hh" #include "lm/binary_format.hh" @@ -104,10 +104,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod } private: - friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); - - static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); - FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const; // Score bigrams and above. Do not include backoff. @@ -116,15 +112,11 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod // Appears after Size in the cc file. void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config); - void InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd); - - void InitializeFromARPA(const char *file, const Config &config); + void InitializeFromARPA(int fd, const char *file, const Config &config); float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const; - Backing &MutableBacking() { return backing_; } - - Backing backing_; + BinaryFormat backing_; VocabularyT vocab_; @@ -161,4 +153,4 @@ base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), } // namespace ngram } // namespace lm -#endif // LM_MODEL__ +#endif // LM_MODEL_H diff --git a/lm/model_test.cc b/lm/model_test.cc index eb15909..7005b05 100644 --- a/lm/model_test.cc +++ b/lm/model_test.cc @@ -360,10 +360,11 @@ BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { LoadingTest<QuantArrayTrieModel>(); } -template <class ModelT> void BinaryTest() { +template <class ModelT> void BinaryTest(Config::WriteMethod write_method) { Config config; config.write_mmap = "test.binary"; config.messages = NULL; + config.write_method = write_method; ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; @@ -406,6 +407,11 @@ template <class ModelT> void BinaryTest() { unlink("test_nounk.binary"); } +template <class ModelT> void BinaryTest() { + BinaryTest<ModelT>(Config::WRITE_MMAP); + BinaryTest<ModelT>(Config::WRITE_AFTER); +} + BOOST_AUTO_TEST_CASE(write_and_read_probing) { BinaryTest<ProbingModel>(); } diff --git a/lm/model_type.hh b/lm/model_type.hh index 8b35c79..fbe1117 100644 --- a/lm/model_type.hh +++ b/lm/model_type.hh @@ -1,5 +1,5 @@ -#ifndef LM_MODEL_TYPE__ -#define LM_MODEL_TYPE__ +#ifndef LM_MODEL_TYPE_H +#define LM_MODEL_TYPE_H namespace lm { namespace ngram { @@ -20,4 +20,4 @@ const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE - TRIE); } // namespace ngram } // namespace lm -#endif // LM_MODEL_TYPE__ +#endif // LM_MODEL_TYPE_H diff --git a/lm/ngram_query.hh b/lm/ngram_query.hh index ec2590f..454e856 100644 --- a/lm/ngram_query.hh +++ b/lm/ngram_query.hh @@ -1,8 +1,9 @@ -#ifndef LM_NGRAM_QUERY__ -#define LM_NGRAM_QUERY__ +#ifndef LM_NGRAM_QUERY_H +#define LM_NGRAM_QUERY_H #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "util/file_piece.hh" #include "util/usage.hh" #include <cstdlib> @@ -16,42 +17,41 @@ namespace lm { namespace ngram { -template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { +template <class Model> void Query(const Model &model, bool sentence_context) { typename Model::State state, out; lm::FullScoreReturn ret; - std::string word; + StringPiece word; + + util::FilePiece in(0); + std::ostream &out_stream = std::cout; double corpus_total = 0.0; + double corpus_total_oov_only = 0.0; uint64_t corpus_oov = 0; uint64_t corpus_tokens = 0; - while (in_stream) { + while (true) { state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); float total = 0.0; - bool got = false; uint64_t oov = 0; - while (in_stream >> word) { - got = true; + + while (in.ReadWordSameLine(word)) { lm::WordIndex vocab = model.GetVocabulary().Index(word); - if (vocab == 0) ++oov; ret = model.FullScore(state, vocab, out); + if (vocab == model.GetVocabulary().NotFound()) { + ++oov; + corpus_total_oov_only += ret.prob; + } total += ret.prob; out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; ++corpus_tokens; state = out; - char c; - while (true) { - c = in_stream.get(); - if (!in_stream) break; - if (c == '\n') break; - if (!isspace(c)) { - in_stream.unget(); - break; - } - } - if (c == '\n') break; } - if (!got && !in_stream) break; + // If people don't have a newline after their last query, this won't add a </s>. + // Sue me. + try { + UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused."); + } catch (const util::EndOfFileException &e) { break; } if (sentence_context) { ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; @@ -62,18 +62,22 @@ template <class Model> void Query(const Model &model, bool sentence_context, std corpus_total += total; corpus_oov += oov; } - out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl; + out_stream << + "Perplexity including OOVs:\t" << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << "\n" + "Perplexity excluding OOVs:\t" << pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast<double>(corpus_tokens - corpus_oov))) << "\n" + "OOVs:\t" << corpus_oov << "\n" + ; } -template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { +template <class M> void Query(const char *file, bool sentence_context) { Config config; M model(file, config); - Query(model, sentence_context, in_stream, out_stream); + Query(model, sentence_context); } } // namespace ngram } // namespace lm -#endif // LM_NGRAM_QUERY__ +#endif // LM_NGRAM_QUERY_H diff --git a/lm/partial.hh b/lm/partial.hh index 1dede35..d8adc69 100644 --- a/lm/partial.hh +++ b/lm/partial.hh @@ -1,5 +1,5 @@ -#ifndef LM_PARTIAL__ -#define LM_PARTIAL__ +#ifndef LM_PARTIAL_H +#define LM_PARTIAL_H #include "lm/return.hh" #include "lm/state.hh" @@ -164,4 +164,4 @@ template <class Model> float Subsume(const Model &model, Left &first_left, const } // namespace ngram } // namespace lm -#endif // LM_PARTIAL__ +#endif // LM_PARTIAL_H diff --git a/lm/quantize.cc b/lm/quantize.cc index b58c3f3..273ea39 100644 --- a/lm/quantize.cc +++ b/lm/quantize.cc @@ -38,13 +38,13 @@ const char kSeparatelyQuantizeVersion = 2; } // namespace -void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) { - char version; - util::ReadOrThrow(fd, &version, 1); - util::ReadOrThrow(fd, &config.prob_bits, 1); - util::ReadOrThrow(fd, &config.backoff_bits, 1); +void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { + unsigned char buffer[3]; + file.ReadForConfig(buffer, 3, offset); + char version = buffer[0]; + config.prob_bits = buffer[1]; + config.backoff_bits = buffer[2]; if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); - util::AdvanceOrThrow(fd, -3); } void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) { diff --git a/lm/quantize.hh b/lm/quantize.hh index 8ce2378..84a3087 100644 --- a/lm/quantize.hh +++ b/lm/quantize.hh @@ -1,5 +1,5 @@ -#ifndef LM_QUANTIZE_H__ -#define LM_QUANTIZE_H__ +#ifndef LM_QUANTIZE_H +#define LM_QUANTIZE_H #include "lm/blank.hh" #include "lm/config.hh" @@ -18,12 +18,13 @@ namespace lm { namespace ngram { struct Config; +class BinaryFormat; /* Store values directly and don't quantize. */ class DontQuantize { public: static const ModelType kModelTypeAdd = static_cast<ModelType>(0); - static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} + static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {} static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } static uint8_t LongestBits(const Config &/*config*/) { return 31; } @@ -136,7 +137,7 @@ class SeparatelyQuantize { public: static const ModelType kModelTypeAdd = kQuantAdd; - static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); + static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); static uint64_t Size(uint8_t order, const Config &config) { uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float); @@ -229,4 +230,4 @@ class SeparatelyQuantize { } // namespace ngram } // namespace lm -#endif // LM_QUANTIZE_H__ +#endif // LM_QUANTIZE_H diff --git a/lm/query_main.cc b/lm/query_main.cc index b9db7b0..cd661f7 100644 --- a/lm/query_main.cc +++ b/lm/query_main.cc @@ -4,48 +4,62 @@ #include "lm/wrappers/nplm.hh" #endif +#include <stdlib.h> + +void Usage(const char *name) { + std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl; + std::cerr << "Usage: " << name << " [-n] lm_file" << std::endl; + std::cerr << "Input is wrapped in <s> and </s> unless -n is passed." << std::endl; + exit(1); +} + int main(int argc, char *argv[]) { - if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) { - std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl; - std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl; - std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl; - return 1; + bool sentence_context = true; + const char *file = NULL; + for (char **arg = argv + 1; arg != argv + argc; ++arg) { + if (!strcmp(*arg, "-n")) { + sentence_context = false; + } else if (!strcmp(*arg, "-h") || !strcmp(*arg, "--help") || file) { + Usage(argv[0]); + } else { + file = *arg; + } } + if (!file) Usage(argv[0]); try { - bool sentence_context = (argc == 2); using namespace lm::ngram; ModelType model_type; - if (RecognizeBinary(argv[1], model_type)) { + if (RecognizeBinary(file, model_type)) { switch(model_type) { case PROBING: - Query<lm::ngram::ProbingModel>(argv[1], sentence_context, std::cin, std::cout); + Query<lm::ngram::ProbingModel>(file, sentence_context); break; case REST_PROBING: - Query<lm::ngram::RestProbingModel>(argv[1], sentence_context, std::cin, std::cout); + Query<lm::ngram::RestProbingModel>(file, sentence_context); break; case TRIE: - Query<TrieModel>(argv[1], sentence_context, std::cin, std::cout); + Query<TrieModel>(file, sentence_context); break; case QUANT_TRIE: - Query<QuantTrieModel>(argv[1], sentence_context, std::cin, std::cout); + Query<QuantTrieModel>(file, sentence_context); break; case ARRAY_TRIE: - Query<ArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout); + Query<ArrayTrieModel>(file, sentence_context); break; case QUANT_ARRAY_TRIE: - Query<QuantArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout); + Query<QuantArrayTrieModel>(file, sentence_context); break; default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; abort(); } #ifdef WITH_NPLM - } else if (lm::np::Model::Recognize(argv[1])) { - lm::np::Model model(argv[1]); - Query(model, sentence_context, std::cin, std::cout); + } else if (lm::np::Model::Recognize(file)) { + lm::np::Model model(file); + Query(model, sentence_context); #endif } else { - Query<ProbingModel>(argv[1], sentence_context, std::cin, std::cout); + Query<ProbingModel>(file, sentence_context); } std::cerr << "Total time including destruction:\n"; util::PrintUsage(std::cerr); diff --git a/lm/read_arpa.cc b/lm/read_arpa.cc index 5ccba71..fb8bbfa 100644 --- a/lm/read_arpa.cc +++ b/lm/read_arpa.cc @@ -150,7 +150,7 @@ void PositiveProbWarn::Warn(float prob) { case THROW_UP: UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error"); case COMPLAIN: - std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl; + std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapped to 0 log probability." << std::endl; action_ = SILENT; break; case SILENT: diff --git a/lm/read_arpa.hh b/lm/read_arpa.hh index 234d130..16e1fc1 100644 --- a/lm/read_arpa.hh +++ b/lm/read_arpa.hh @@ -1,5 +1,5 @@ -#ifndef LM_READ_ARPA__ -#define LM_READ_ARPA__ +#ifndef LM_READ_ARPA_H +#define LM_READ_ARPA_H #include "lm/lm_exception.hh" #include "lm/word_index.hh" @@ -87,4 +87,4 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns } // namespace lm -#endif // LM_READ_ARPA__ +#endif // LM_READ_ARPA_H diff --git a/lm/return.hh b/lm/return.hh index 622320c..982ffd6 100644 --- a/lm/return.hh +++ b/lm/return.hh @@ -1,5 +1,5 @@ -#ifndef LM_RETURN__ -#define LM_RETURN__ +#ifndef LM_RETURN_H +#define LM_RETURN_H #include <stdint.h> @@ -39,4 +39,4 @@ struct FullScoreReturn { }; } // namespace lm -#endif // LM_RETURN__ +#endif // LM_RETURN_H diff --git a/lm/search_hashed.cc b/lm/search_hashed.cc index 62275d2..354a56b 100644 --- a/lm/search_hashed.cc +++ b/lm/search_hashed.cc @@ -204,9 +204,10 @@ template <class Build, class Activate, class Store> void ReadNGrams( namespace detail { template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { - std::size_t allocated = Unigram::Size(counts[0]); - unigram_ = Unigram(start, counts[0], allocated); - start += allocated; + unigram_ = Unigram(start, counts[0]); + start += Unigram::Size(counts[0]); + std::size_t allocated; + middle_.clear(); for (unsigned int n = 2; n < counts.size(); ++n) { allocated = Middle::Size(counts[n - 1], config.probing_multiplier); middle_.push_back(Middle(start, allocated)); @@ -218,9 +219,21 @@ template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, return start; } -template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing) { - // TODO: fix sorted. - SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config); +/*template <class Value> void HashedSearch<Value>::Relocate(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { + unigram_ = Unigram(start, counts[0]); + start += Unigram::Size(counts[0]); + for (unsigned int n = 2; n < counts.size(); ++n) { + middle[n-2].Relocate(start); + start += Middle::Size(counts[n - 1], config.probing_multiplier) + } + longest_.Relocate(start); +}*/ + +template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing) { + void *vocab_rebase; + void *search_base = backing.GrowForSearch(Size(counts, config), vocab.UnkCountChangePadding(), vocab_rebase); + vocab.Relocate(vocab_rebase); + SetupMemory(reinterpret_cast<uint8_t*>(search_base), counts, config); PositiveProbWarn warn(config.positive_log_probability); Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn); @@ -277,14 +290,6 @@ template <class Value> template <class Build> void HashedSearch<Value>::ApplyBui ReadEnd(f); } -template <class Value> void HashedSearch<Value>::LoadedBinary() { - unigram_.LoadedBinary(); - for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) { - i->LoadedBinary(); - } - longest_.LoadedBinary(); -} - template class HashedSearch<BackoffValue>; template class HashedSearch<RestValue>; diff --git a/lm/search_hashed.hh b/lm/search_hashed.hh index 9d067bc..9dc8445 100644 --- a/lm/search_hashed.hh +++ b/lm/search_hashed.hh @@ -1,5 +1,5 @@ -#ifndef LM_SEARCH_HASHED__ -#define LM_SEARCH_HASHED__ +#ifndef LM_SEARCH_HASHED_H +#define LM_SEARCH_HASHED_H #include "lm/model_type.hh" #include "lm/config.hh" @@ -18,7 +18,7 @@ namespace util { class FilePiece; } namespace lm { namespace ngram { -struct Backing; +class BinaryFormat; class ProbingVocabulary; namespace detail { @@ -72,7 +72,7 @@ template <class Value> class HashedSearch { static const unsigned int kVersion = 0; // TODO: move probing_multiplier here with next binary file format update. - static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} + static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {} static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { uint64_t ret = Unigram::Size(counts[0]); @@ -84,9 +84,7 @@ template <class Value> class HashedSearch { uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); - void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing); - - void LoadedBinary(); + void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing); unsigned char Order() const { return middle_.size() + 2; @@ -148,7 +146,7 @@ template <class Value> class HashedSearch { public: Unigram() {} - Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : + Unigram(void *start, uint64_t count) : unigram_(static_cast<typename Value::Weights*>(start)) #ifdef DEBUG , count_(count) @@ -168,8 +166,6 @@ template <class Value> class HashedSearch { typename Value::Weights &Unknown() { return unigram_[0]; } - void LoadedBinary() {} - // For building. typename Value::Weights *Raw() { return unigram_; } @@ -193,4 +189,4 @@ template <class Value> class HashedSearch { } // namespace ngram } // namespace lm -#endif // LM_SEARCH_HASHED__ +#endif // LM_SEARCH_HASHED_H diff --git a/lm/search_trie.cc b/lm/search_trie.cc index 1b0d9b2..4a88194 100644 --- a/lm/search_trie.cc +++ b/lm/search_trie.cc @@ -253,11 +253,6 @@ class FindBlanks { ++counts_.back(); } - // Unigrams wrote one past. - void Cleanup() { - --counts_[0]; - } - const std::vector<uint64_t> &Counts() const { return counts_; } @@ -310,8 +305,6 @@ template <class Quant, class Bhiksha> class WriteEntries { typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob); } - void Cleanup() {} - private: RecordReader *contexts_; const Quant &quant_; @@ -385,14 +378,14 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con util::ErsatzProgress progress(unigram_count + 1, progress_out, message); WordIndex unigram = 0; std::priority_queue<Gram> grams; - grams.push(Gram(&unigram, 1)); + if (unigram_count) grams.push(Gram(&unigram, 1)); for (unsigned char i = 2; i <= total_order; ++i) { if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i)); } BlankManager<Doing> blank(total_order, doing); - while (true) { + while (!grams.empty()) { Gram top = grams.top(); grams.pop(); unsigned char order = top.end - top.begin; @@ -400,8 +393,7 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con blank.Visit(&unigram, 1, doing.UnigramProb(unigram)); doing.Unigram(unigram); progress.Set(unigram); - if (++unigram == unigram_count + 1) break; - grams.push(top); + if (++unigram < unigram_count) grams.push(top); } else { if (order == total_order) { blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob); @@ -414,8 +406,6 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con if (++reader) grams.push(top); } } - assert(grams.empty()); - doing.Cleanup(); } void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) { @@ -469,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c } // namespace -template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) { RecordReader inputs[KENLM_MAX_ORDER - 1]; RecordReader contexts[KENLM_MAX_ORDER - 1]; @@ -498,7 +488,10 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); - out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); + void *vocab_relocate; + void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate); + vocab.Relocate(vocab_relocate); + out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config); for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -524,6 +517,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve { WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); + // Write the last unigram entry, which is the end pointer for the bigrams. + writer.Unigram(counts[0]); } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -579,15 +574,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { - unigram_.LoadedBinary(); - for (Middle *i = middle_begin_; i != middle_end_; ++i) { - i->LoadedBinary(); - } - longest_.LoadedBinary(); -} - -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) { std::string temporary_prefix; if (config.temporary_directory_prefix) { temporary_prefix = config.temporary_directory_prefix; diff --git a/lm/search_trie.hh b/lm/search_trie.hh index 763fd1a..d8838d2 100644 --- a/lm/search_trie.hh +++ b/lm/search_trie.hh @@ -1,5 +1,5 @@ -#ifndef LM_SEARCH_TRIE__ -#define LM_SEARCH_TRIE__ +#ifndef LM_SEARCH_TRIE_H +#define LM_SEARCH_TRIE_H #include "lm/config.hh" #include "lm/model_type.hh" @@ -17,13 +17,13 @@ namespace lm { namespace ngram { -struct Backing; +class BinaryFormat; class SortedVocabulary; namespace trie { template <class Quant, class Bhiksha> class TrieSearch; class SortedFiles; -template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing); template <class Quant, class Bhiksha> class TrieSearch { public: @@ -39,11 +39,11 @@ template <class Quant, class Bhiksha> class TrieSearch { static const unsigned int kVersion = 1; - static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { - Quant::UpdateConfigFromBinary(fd, counts, config); - util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); + static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) { + Quant::UpdateConfigFromBinary(file, offset, config); // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2. - if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config); + if (counts.size() > 2) + Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config); } static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { @@ -60,9 +60,7 @@ template <class Quant, class Bhiksha> class TrieSearch { uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); - void LoadedBinary(); - - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing); unsigned char Order() const { return middle_end_ - middle_begin_ + 2; @@ -103,7 +101,7 @@ template <class Quant, class Bhiksha> class TrieSearch { } private: - friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); + friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing); // Middles are managed manually so we can delay construction and they don't have to be copyable. void FreeMiddles() { @@ -129,4 +127,4 @@ template <class Quant, class Bhiksha> class TrieSearch { } // namespace ngram } // namespace lm -#endif // LM_SEARCH_TRIE__ +#endif // LM_SEARCH_TRIE_H diff --git a/lm/sizes.hh b/lm/sizes.hh index 85abade..eb7e99d 100644 --- a/lm/sizes.hh +++ b/lm/sizes.hh @@ -1,5 +1,5 @@ -#ifndef LM_SIZES__ -#define LM_SIZES__ +#ifndef LM_SIZES_H +#define LM_SIZES_H #include <vector> @@ -14,4 +14,4 @@ void ShowSizes(const std::vector<uint64_t> &counts); void ShowSizes(const char *file, const lm::ngram::Config &config); }} // namespaces -#endif // LM_SIZES__ +#endif // LM_SIZES_H diff --git a/lm/state.hh b/lm/state.hh index d8e6c13..f6c51d6 100644 --- a/lm/state.hh +++ b/lm/state.hh @@ -1,5 +1,5 @@ -#ifndef LM_STATE__ -#define LM_STATE__ +#ifndef LM_STATE_H +#define LM_STATE_H #include "lm/max_order.hh" #include "lm/word_index.hh" @@ -91,7 +91,7 @@ inline uint64_t hash_value(const Left &left) { } struct ChartState { - bool operator==(const ChartState &other) { + bool operator==(const ChartState &other) const { return (right == other.right) && (left == other.left); } @@ -102,7 +102,7 @@ struct ChartState { } bool operator<(const ChartState &other) const { - return Compare(other) == -1; + return Compare(other) < 0; } void ZeroRemaining() { @@ -122,4 +122,4 @@ inline uint64_t hash_value(const ChartState &state) { } // namespace ngram } // namespace lm -#endif // LM_STATE__ +#endif // LM_STATE_H @@ -1,5 +1,5 @@ -#ifndef LM_TRIE__ -#define LM_TRIE__ +#ifndef LM_TRIE_H +#define LM_TRIE_H #include "lm/weights.hh" #include "lm/word_index.hh" @@ -62,8 +62,6 @@ class Unigram { return unigram_; } - void LoadedBinary() {} - UnigramPointer Find(WordIndex word, NodeRange &next) const { UnigramValue *val = unigram_ + word; next.begin = val->next; @@ -108,8 +106,6 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked { void FinishedLoading(uint64_t next_end, const Config &config); - void LoadedBinary() { bhiksha_.LoadedBinary(); } - util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const; util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) { @@ -138,18 +134,13 @@ class BitPackedLongest : public BitPacked { BaseInit(base, max_vocab, quant_bits); } - void LoadedBinary() {} - util::BitAddress Insert(WordIndex word); util::BitAddress Find(WordIndex word, const NodeRange &node) const; - - private: - uint8_t quant_bits_; }; } // namespace trie } // namespace ngram } // namespace lm -#endif // LM_TRIE__ +#endif // LM_TRIE_H diff --git a/lm/trie_sort.cc b/lm/trie_sort.cc index dc542bb..126d43a 100644 --- a/lm/trie_sort.cc +++ b/lm/trie_sort.cc @@ -50,6 +50,10 @@ class PartialViewProxy { const void *Data() const { return inner_.Data(); } void *Data() { return inner_.Data(); } + friend void swap(PartialViewProxy first, PartialViewProxy second) { + std::swap_ranges(reinterpret_cast<char*>(first.Data()), reinterpret_cast<char*>(first.Data()) + first.attention_size_, reinterpret_cast<char*>(second.Data())); + } + private: friend class util::ProxyIterator<PartialViewProxy>; diff --git a/lm/trie_sort.hh b/lm/trie_sort.hh index 1afd956..e5406d9 100644 --- a/lm/trie_sort.hh +++ b/lm/trie_sort.hh @@ -1,7 +1,7 @@ // Step of trie builder: create sorted files. -#ifndef LM_TRIE_SORT__ -#define LM_TRIE_SORT__ +#ifndef LM_TRIE_SORT_H +#define LM_TRIE_SORT_H #include "lm/max_order.hh" #include "lm/word_index.hh" @@ -111,4 +111,4 @@ class SortedFiles { } // namespace ngram } // namespace lm -#endif // LM_TRIE_SORT__ +#endif // LM_TRIE_SORT_H diff --git a/lm/value.hh b/lm/value.hh index ba71671..36e8708 100644 --- a/lm/value.hh +++ b/lm/value.hh @@ -1,5 +1,5 @@ -#ifndef LM_VALUE__ -#define LM_VALUE__ +#ifndef LM_VALUE_H +#define LM_VALUE_H #include "lm/model_type.hh" #include "lm/value_build.hh" @@ -154,4 +154,4 @@ struct RestValue { } // namespace ngram } // namespace lm -#endif // LM_VALUE__ +#endif // LM_VALUE_H diff --git a/lm/value_build.hh b/lm/value_build.hh index 461e6a5..6fd26ef 100644 --- a/lm/value_build.hh +++ b/lm/value_build.hh @@ -1,5 +1,5 @@ -#ifndef LM_VALUE_BUILD__ -#define LM_VALUE_BUILD__ +#ifndef LM_VALUE_BUILD_H +#define LM_VALUE_BUILD_H #include "lm/weights.hh" #include "lm/word_index.hh" @@ -94,4 +94,4 @@ template <class Model> class LowerRestBuild { } // namespace ngram } // namespace lm -#endif // LM_VALUE_BUILD__ +#endif // LM_VALUE_BUILD_H diff --git a/lm/virtual_interface.hh b/lm/virtual_interface.hh index ff4a388..2a2690e 100644 --- a/lm/virtual_interface.hh +++ b/lm/virtual_interface.hh @@ -1,5 +1,5 @@ -#ifndef LM_VIRTUAL_INTERFACE__ -#define LM_VIRTUAL_INTERFACE__ +#ifndef LM_VIRTUAL_INTERFACE_H +#define LM_VIRTUAL_INTERFACE_H #include "lm/return.hh" #include "lm/word_index.hh" @@ -125,13 +125,13 @@ class Model { void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); } // Requires in_state != out_state - virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + virtual float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; // Requires in_state != out_state - virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + virtual FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; // Prefer to use FullScore. The context words should be provided in reverse order. - virtual FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0; + virtual FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0; unsigned char Order() const { return order_; } @@ -157,4 +157,4 @@ class Model { } // mamespace base } // namespace lm -#endif // LM_VIRTUAL_INTERFACE__ +#endif // LM_VIRTUAL_INTERFACE_H diff --git a/lm/vocab.cc b/lm/vocab.cc index fd7f96d..7f0878f 100644 --- a/lm/vocab.cc +++ b/lm/vocab.cc @@ -32,7 +32,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5); // Sadly some LMs have <UNK>. const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5); -void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) { +void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) { + util::SeekOrThrow(fd, offset); // Check that we're at the right place by reading <unk> which is always first. char check_unk[6]; util::ReadOrThrow(fd, check_unk, 6); @@ -80,11 +81,6 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { buffer_.push_back(0); } -void WriteWordsWrapper::Write(int fd, uint64_t start) { - util::SeekOrThrow(fd, start); - util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); -} - SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) { @@ -100,6 +96,12 @@ void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size saw_unk_ = false; } +void SortedVocabulary::Relocate(void *new_start) { + std::size_t delta = end_ - begin_; + begin_ = reinterpret_cast<uint64_t*>(new_start) + 1; + end_ = begin_ + delta; +} + void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) { enumerate_ = to; if (enumerate_) { @@ -147,11 +149,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { bound_ = end_ - begin_ + 1; } -void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { +void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) { end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); SetSpecial(Index("<s>"), Index("</s>"), 0); bound_ = end_ - begin_ + 1; - if (have_words) ReadWords(fd, to, bound_); + if (have_words) ReadWords(fd, to, bound_, offset); } namespace { @@ -179,6 +181,11 @@ void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::siz saw_unk_ = false; } +void ProbingVocabulary::Relocate(void *new_start) { + header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start); + lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader))); +} + void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) { enumerate_ = to; if (enumerate_) { @@ -206,12 +213,11 @@ void ProbingVocabulary::InternalFinishedLoading() { SetSpecial(Index("<s>"), Index("</s>"), 0); } -void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { +void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) { UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); - lookup_.LoadedBinary(); bound_ = header_->bound; SetSpecial(Index("<s>"), Index("</s>"), 0); - if (have_words) ReadWords(fd, to, bound_); + if (have_words) ReadWords(fd, to, bound_, offset); } void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { diff --git a/lm/vocab.hh b/lm/vocab.hh index 226ae43..dcd298a 100644 --- a/lm/vocab.hh +++ b/lm/vocab.hh @@ -1,5 +1,5 @@ -#ifndef LM_VOCAB__ -#define LM_VOCAB__ +#ifndef LM_VOCAB_H +#define LM_VOCAB_H #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" @@ -36,7 +36,7 @@ class WriteWordsWrapper : public EnumerateVocab { void Add(WordIndex index, const StringPiece &str); - void Write(int fd, uint64_t start); + const std::string &Buffer() const { return buffer_; } private: EnumerateVocab *inner_; @@ -71,6 +71,8 @@ class SortedVocabulary : public base::Vocabulary { // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + void Relocate(void *new_start); + void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); WordIndex Insert(const StringPiece &str); @@ -83,15 +85,13 @@ class SortedVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } - void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: uint64_t *begin_, *end_; WordIndex bound_; - WordIndex highest_value_; - bool saw_unk_; EnumerateVocab *enumerate_; @@ -140,6 +140,8 @@ class ProbingVocabulary : public base::Vocabulary { // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + void Relocate(void *new_start); + void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); WordIndex Insert(const StringPiece &str); @@ -152,7 +154,7 @@ class ProbingVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } - void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); + void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: void InternalFinishedLoading(); @@ -182,4 +184,4 @@ template <class Vocab> void CheckSpecials(const Config &config, const Vocab &voc } // namespace ngram } // namespace lm -#endif // LM_VOCAB__ +#endif // LM_VOCAB_H diff --git a/lm/weights.hh b/lm/weights.hh index bd5d803..da1963d 100644 --- a/lm/weights.hh +++ b/lm/weights.hh @@ -1,5 +1,5 @@ -#ifndef LM_WEIGHTS__ -#define LM_WEIGHTS__ +#ifndef LM_WEIGHTS_H +#define LM_WEIGHTS_H // Weights for n-grams. Probability and possibly a backoff. @@ -19,4 +19,4 @@ struct RestWeights { }; } // namespace lm -#endif // LM_WEIGHTS__ +#endif // LM_WEIGHTS_H diff --git a/lm/word_index.hh b/lm/word_index.hh index e09557a..a5a0fda 100644 --- a/lm/word_index.hh +++ b/lm/word_index.hh @@ -1,6 +1,6 @@ // Separate header because this is used often. -#ifndef LM_WORD_INDEX__ -#define LM_WORD_INDEX__ +#ifndef LM_WORD_INDEX_H +#define LM_WORD_INDEX_H #include <limits.h> diff --git a/lm/wrappers/nplm.cc b/lm/wrappers/nplm.cc index 6a3fa0d..70622bd 100644 --- a/lm/wrappers/nplm.cc +++ b/lm/wrappers/nplm.cc @@ -13,7 +13,7 @@ namespace np { Vocabulary::Vocabulary(const nplm::vocabulary &vocab) : base::Vocabulary(vocab.lookup_word("<s>"), vocab.lookup_word("</s>"), vocab.lookup_word("<unk>")), - vocab_(vocab) {} + vocab_(vocab), null_word_(vocab.lookup_word("<null>")) {} Vocabulary::~Vocabulary() {} @@ -33,8 +33,11 @@ bool Model::Recognize(const std::string &name) { } } -Model::Model(const std::string &file) : base_instance_(new nplm::neuralLM(file)), vocab_(base_instance_->get_vocabulary()) { +Model::Model(const std::string &file, std::size_t cache) + : base_instance_(new nplm::neuralLM(file)), vocab_(base_instance_->get_vocabulary()), cache_size_(cache) { UTIL_THROW_IF(base_instance_->get_order() > NPLM_MAX_ORDER, util::Exception, "This NPLM has order " << (unsigned int)base_instance_->get_order() << " but the KenLM wrapper was compiled with " << NPLM_MAX_ORDER << ". Change the defintion of NPLM_MAX_ORDER and recompile."); + // log10 compatible with backoff models. + base_instance_->set_log_base(10.0); State begin_sentence, null_context; std::fill(begin_sentence.words, begin_sentence.words + NPLM_MAX_ORDER - 1, base_instance_->lookup_word("<s>")); null_word_ = base_instance_->lookup_word("<null>"); @@ -50,6 +53,7 @@ FullScoreReturn Model::FullScore(const State &from, const WordIndex new_word, St if (!lm) { lm = new nplm::neuralLM(*base_instance_); backend_.reset(lm); + lm->set_cache(cache_size_); } // State is in natural word order. FullScoreReturn ret; diff --git a/lm/wrappers/nplm.hh b/lm/wrappers/nplm.hh index 90f1d49..b7dd4a2 100644 --- a/lm/wrappers/nplm.hh +++ b/lm/wrappers/nplm.hh @@ -1,5 +1,5 @@ -#ifndef LM_WRAPPER_NPLM__ -#define LM_WRAPPER_NPLM__ +#ifndef LM_WRAPPERS_NPLM_H +#define LM_WRAPPERS_NPLM_H #include "lm/facade.hh" #include "lm/max_order.hh" @@ -34,8 +34,12 @@ class Vocabulary : public base::Vocabulary { return Index(std::string(str.data(), str.size())); } + lm::WordIndex NullWord() const { return null_word_; } + private: const nplm::vocabulary &vocab_; + + const lm::WordIndex null_word_; }; // Sorry for imposing my limitations on your code. @@ -53,7 +57,7 @@ class Model : public lm::base::ModelFacade<Model, State, Vocabulary> { // Does this look like an NPLM? static bool Recognize(const std::string &file); - explicit Model(const std::string &file); + explicit Model(const std::string &file, std::size_t cache_size = 1 << 20); ~Model(); @@ -69,9 +73,11 @@ class Model : public lm::base::ModelFacade<Model, State, Vocabulary> { Vocabulary vocab_; lm::WordIndex null_word_; + + const std::size_t cache_size_; }; } // namespace np } // namespace lm -#endif // LM_WRAPPER_NPLM__ +#endif // LM_WRAPPERS_NPLM_H diff --git a/python/kenlm.cpp b/python/kenlm.cpp index d401047..c7052fc 100644 --- a/python/kenlm.cpp +++ b/python/kenlm.cpp @@ -1,4 +1,4 @@ -/* Generated by Cython 0.19.1 on Wed Jun 5 13:00:19 2013 */ +/* Generated by Cython 0.19.1 on Tue Jan 28 09:32:20 2014 */ #define PY_SSIZE_T_CLEAN #ifndef CYTHON_USE_PYLONG_INTERNALS @@ -493,23 +493,8 @@ static const char *__pyx_f[] = { }; /*--- Type declarations ---*/ -struct __pyx_obj_5kenlm_LanguageModel; struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores; - -/* "kenlm.pyx":10 - * raise TypeError('Cannot convert %s to string' % type(data)) - * - * cdef class LanguageModel: # <<<<<<<<<<<<<< - * cdef Model* model - * cdef public bytes path - */ -struct __pyx_obj_5kenlm_LanguageModel { - PyObject_HEAD - lm::base::Model *model; - PyObject *path; - const lm::base::Vocabulary *vocab; -}; - +struct __pyx_obj_5kenlm_LanguageModel; /* "kenlm.pyx":44 * return total @@ -532,6 +517,21 @@ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores { Py_ssize_t __pyx_t_1; }; + +/* "kenlm.pyx":10 + * raise TypeError('Cannot convert %s to string' % type(data)) + * + * cdef class LanguageModel: # <<<<<<<<<<<<<< + * cdef Model* model + * cdef public bytes path + */ +struct __pyx_obj_5kenlm_LanguageModel { + PyObject_HEAD + lm::base::Model *model; + PyObject *path; + const lm::base::Vocabulary *vocab; +}; + #ifndef CYTHON_REFNANNY #define CYTHON_REFNANNY 0 #endif @@ -771,8 +771,8 @@ static int __Pyx_InitStrings(__Pyx_StringTabEntry *t); /*proto*/ /* Module declarations from 'kenlm' */ -static PyTypeObject *__pyx_ptype_5kenlm_LanguageModel = 0; static PyTypeObject *__pyx_ptype_5kenlm___pyx_scope_struct__full_scores = 0; +static PyTypeObject *__pyx_ptype_5kenlm_LanguageModel = 0; static PyObject *__pyx_f_5kenlm_as_str(PyObject *); /*proto*/ #define __Pyx_MODULE_NAME "kenlm" int __pyx_module_is_main_kenlm = 0; @@ -792,8 +792,8 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_13__reduce__(struct __pyx_obj_5 static PyObject *__pyx_pf_5kenlm_13LanguageModel_4path___get__(struct __pyx_obj_5kenlm_LanguageModel *__pyx_v_self); /* proto */ static int __pyx_pf_5kenlm_13LanguageModel_4path_2__set__(struct __pyx_obj_5kenlm_LanguageModel *__pyx_v_self, PyObject *__pyx_v_value); /* proto */ static int __pyx_pf_5kenlm_13LanguageModel_4path_4__del__(struct __pyx_obj_5kenlm_LanguageModel *__pyx_v_self); /* proto */ -static PyObject *__pyx_tp_new_5kenlm_LanguageModel(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ static PyObject *__pyx_tp_new_5kenlm___pyx_scope_struct__full_scores(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new_5kenlm_LanguageModel(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ static char __pyx_k_2[] = "Cannot convert %s to string"; static char __pyx_k_3[] = "\n"; static char __pyx_k_4[] = " "; @@ -1392,7 +1392,7 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_ * cdef State out_state * cdef float total = 0 # <<<<<<<<<<<<<< * for word in words: - * total += self.model.Score(&state, self.vocab.Index(word), &out_state) + * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state) */ __pyx_v_total = 0.0; @@ -1400,7 +1400,7 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_ * cdef State out_state * cdef float total = 0 * for word in words: # <<<<<<<<<<<<<< - * total += self.model.Score(&state, self.vocab.Index(word), &out_state) + * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state) * state = out_state */ if (unlikely(((PyObject *)__pyx_v_words) == Py_None)) { @@ -1422,18 +1422,18 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_ /* "kenlm.pyx":39 * cdef float total = 0 * for word in words: - * total += self.model.Score(&state, self.vocab.Index(word), &out_state) # <<<<<<<<<<<<<< + * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state) # <<<<<<<<<<<<<< * state = out_state - * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state) + * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state) */ __pyx_t_4 = __Pyx_PyObject_AsString(__pyx_v_word); if (unlikely((!__pyx_t_4) && PyErr_Occurred())) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 39; __pyx_clineno = __LINE__; goto __pyx_L1_error;} - __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->Score((&__pyx_v_state), __pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_v_out_state))); + __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->BaseScore((&__pyx_v_state), __pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_v_out_state))); /* "kenlm.pyx":40 * for word in words: - * total += self.model.Score(&state, self.vocab.Index(word), &out_state) + * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state) * state = out_state # <<<<<<<<<<<<<< - * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state) + * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state) * return total */ __pyx_v_state = __pyx_v_out_state; @@ -1441,17 +1441,17 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_ __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; /* "kenlm.pyx":41 - * total += self.model.Score(&state, self.vocab.Index(word), &out_state) + * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state) * state = out_state - * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state) # <<<<<<<<<<<<<< + * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state) # <<<<<<<<<<<<<< * return total * */ - __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->Score((&__pyx_v_state), __pyx_v_self->vocab->EndSentence(), (&__pyx_v_out_state))); + __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->BaseScore((&__pyx_v_state), __pyx_v_self->vocab->EndSentence(), (&__pyx_v_out_state))); /* "kenlm.pyx":42 * state = out_state - * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state) + * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state) * return total # <<<<<<<<<<<<<< * * def full_scores(self, sentence): @@ -1597,7 +1597,7 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec * cdef FullScoreReturn ret * cdef float total = 0 # <<<<<<<<<<<<<< * for word in words: - * ret = self.model.FullScore(&state, + * ret = self.model.BaseFullScore(&state, */ __pyx_cur_scope->__pyx_v_total = 0.0; @@ -1605,7 +1605,7 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec * cdef FullScoreReturn ret * cdef float total = 0 * for word in words: # <<<<<<<<<<<<<< - * ret = self.model.FullScore(&state, + * ret = self.model.BaseFullScore(&state, * self.vocab.Index(word), &out_state) */ if (unlikely(((PyObject *)__pyx_cur_scope->__pyx_v_words) == Py_None)) { @@ -1628,20 +1628,20 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec /* "kenlm.pyx":53 * for word in words: - * ret = self.model.FullScore(&state, + * ret = self.model.BaseFullScore(&state, * self.vocab.Index(word), &out_state) # <<<<<<<<<<<<<< * yield (ret.prob, ret.ngram_length) * state = out_state */ __pyx_t_4 = __Pyx_PyObject_AsString(__pyx_cur_scope->__pyx_v_word); if (unlikely((!__pyx_t_4) && PyErr_Occurred())) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 53; __pyx_clineno = __LINE__; goto __pyx_L1_error;} - __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->FullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_cur_scope->__pyx_v_out_state)); + __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->BaseFullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_cur_scope->__pyx_v_out_state)); /* "kenlm.pyx":54 - * ret = self.model.FullScore(&state, + * ret = self.model.BaseFullScore(&state, * self.vocab.Index(word), &out_state) * yield (ret.prob, ret.ngram_length) # <<<<<<<<<<<<<< * state = out_state - * ret = self.model.FullScore(&state, + * ret = self.model.BaseFullScore(&state, */ __pyx_t_2 = PyFloat_FromDouble(__pyx_cur_scope->__pyx_v_ret.prob); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 54; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __Pyx_GOTREF(__pyx_t_2); @@ -1676,7 +1676,7 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec * self.vocab.Index(word), &out_state) * yield (ret.prob, ret.ngram_length) * state = out_state # <<<<<<<<<<<<<< - * ret = self.model.FullScore(&state, + * ret = self.model.BaseFullScore(&state, * self.vocab.EndSentence(), &out_state) */ __pyx_cur_scope->__pyx_v_state = __pyx_cur_scope->__pyx_v_out_state; @@ -1685,15 +1685,15 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec /* "kenlm.pyx":57 * state = out_state - * ret = self.model.FullScore(&state, + * ret = self.model.BaseFullScore(&state, * self.vocab.EndSentence(), &out_state) # <<<<<<<<<<<<<< * yield (ret.prob, ret.ngram_length) * */ - __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->FullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->EndSentence(), (&__pyx_cur_scope->__pyx_v_out_state)); + __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->BaseFullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->EndSentence(), (&__pyx_cur_scope->__pyx_v_out_state)); /* "kenlm.pyx":58 - * ret = self.model.FullScore(&state, + * ret = self.model.BaseFullScore(&state, * self.vocab.EndSentence(), &out_state) * yield (ret.prob, ret.ngram_length) # <<<<<<<<<<<<<< * @@ -2047,6 +2047,147 @@ static int __pyx_pf_5kenlm_13LanguageModel_4path_4__del__(struct __pyx_obj_5kenl return __pyx_r; } +static struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[8]; +static int __pyx_freecount_5kenlm___pyx_scope_struct__full_scores = 0; + +static PyObject *__pyx_tp_new_5kenlm___pyx_scope_struct__full_scores(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { + struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p; + PyObject *o; + if (likely((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores > 0) & (t->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores)))) { + o = (PyObject*)__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[--__pyx_freecount_5kenlm___pyx_scope_struct__full_scores]; + memset(o, 0, sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores)); + PyObject_INIT(o, t); + PyObject_GC_Track(o); + } else { + o = (*t->tp_alloc)(t, 0); + if (unlikely(!o)) return 0; + } + p = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o); + p->__pyx_v_self = 0; + p->__pyx_v_sentence = 0; + p->__pyx_v_word = 0; + p->__pyx_v_words = 0; + p->__pyx_t_0 = 0; + return o; +} + +static void __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores(PyObject *o) { + struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o; + PyObject_GC_UnTrack(o); + Py_CLEAR(p->__pyx_v_self); + Py_CLEAR(p->__pyx_v_sentence); + Py_CLEAR(p->__pyx_v_word); + Py_CLEAR(p->__pyx_v_words); + Py_CLEAR(p->__pyx_t_0); + if ((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores < 8) & (Py_TYPE(o)->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores))) { + __pyx_freelist_5kenlm___pyx_scope_struct__full_scores[__pyx_freecount_5kenlm___pyx_scope_struct__full_scores++] = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o); + } else { + (*Py_TYPE(o)->tp_free)(o); + } +} + +static int __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores(PyObject *o, visitproc v, void *a) { + int e; + struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o; + if (p->__pyx_v_self) { + e = (*v)(((PyObject*)p->__pyx_v_self), a); if (e) return e; + } + if (p->__pyx_v_sentence) { + e = (*v)(p->__pyx_v_sentence, a); if (e) return e; + } + if (p->__pyx_v_word) { + e = (*v)(p->__pyx_v_word, a); if (e) return e; + } + if (p->__pyx_v_words) { + e = (*v)(p->__pyx_v_words, a); if (e) return e; + } + if (p->__pyx_t_0) { + e = (*v)(p->__pyx_t_0, a); if (e) return e; + } + return 0; +} + +static int __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores(PyObject *o) { + struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o; + PyObject* tmp; + tmp = ((PyObject*)p->__pyx_v_self); + p->__pyx_v_self = ((struct __pyx_obj_5kenlm_LanguageModel *)Py_None); Py_INCREF(Py_None); + Py_XDECREF(tmp); + tmp = ((PyObject*)p->__pyx_v_sentence); + p->__pyx_v_sentence = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + tmp = ((PyObject*)p->__pyx_v_word); + p->__pyx_v_word = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + tmp = ((PyObject*)p->__pyx_v_words); + p->__pyx_v_words = ((PyObject*)Py_None); Py_INCREF(Py_None); + Py_XDECREF(tmp); + tmp = ((PyObject*)p->__pyx_t_0); + p->__pyx_t_0 = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + return 0; +} + +static PyMethodDef __pyx_methods_5kenlm___pyx_scope_struct__full_scores[] = { + {0, 0, 0, 0} +}; + +static PyTypeObject __pyx_type_5kenlm___pyx_scope_struct__full_scores = { + PyVarObject_HEAD_INIT(0, 0) + __Pyx_NAMESTR("kenlm.__pyx_scope_struct__full_scores"), /*tp_name*/ + sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #else + 0, /*reserved*/ + #endif + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ + 0, /*tp_doc*/ + __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores, /*tp_traverse*/ + __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_5kenlm___pyx_scope_struct__full_scores, /*tp_methods*/ + 0, /*tp_members*/ + 0, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + 0, /*tp_dictoffset*/ + 0, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_5kenlm___pyx_scope_struct__full_scores, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + #if PY_VERSION_HEX >= 0x02060000 + 0, /*tp_version_tag*/ + #endif +}; + static PyObject *__pyx_tp_new_5kenlm_LanguageModel(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { struct __pyx_obj_5kenlm_LanguageModel *p; PyObject *o; @@ -2190,147 +2331,6 @@ static PyTypeObject __pyx_type_5kenlm_LanguageModel = { #endif }; -static struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[8]; -static int __pyx_freecount_5kenlm___pyx_scope_struct__full_scores = 0; - -static PyObject *__pyx_tp_new_5kenlm___pyx_scope_struct__full_scores(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { - struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p; - PyObject *o; - if (likely((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores > 0) & (t->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores)))) { - o = (PyObject*)__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[--__pyx_freecount_5kenlm___pyx_scope_struct__full_scores]; - memset(o, 0, sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores)); - PyObject_INIT(o, t); - PyObject_GC_Track(o); - } else { - o = (*t->tp_alloc)(t, 0); - if (unlikely(!o)) return 0; - } - p = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o); - p->__pyx_v_self = 0; - p->__pyx_v_sentence = 0; - p->__pyx_v_word = 0; - p->__pyx_v_words = 0; - p->__pyx_t_0 = 0; - return o; -} - -static void __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores(PyObject *o) { - struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o; - PyObject_GC_UnTrack(o); - Py_CLEAR(p->__pyx_v_self); - Py_CLEAR(p->__pyx_v_sentence); - Py_CLEAR(p->__pyx_v_word); - Py_CLEAR(p->__pyx_v_words); - Py_CLEAR(p->__pyx_t_0); - if ((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores < 8) & (Py_TYPE(o)->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores))) { - __pyx_freelist_5kenlm___pyx_scope_struct__full_scores[__pyx_freecount_5kenlm___pyx_scope_struct__full_scores++] = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o); - } else { - (*Py_TYPE(o)->tp_free)(o); - } -} - -static int __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores(PyObject *o, visitproc v, void *a) { - int e; - struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o; - if (p->__pyx_v_self) { - e = (*v)(((PyObject*)p->__pyx_v_self), a); if (e) return e; - } - if (p->__pyx_v_sentence) { - e = (*v)(p->__pyx_v_sentence, a); if (e) return e; - } - if (p->__pyx_v_word) { - e = (*v)(p->__pyx_v_word, a); if (e) return e; - } - if (p->__pyx_v_words) { - e = (*v)(p->__pyx_v_words, a); if (e) return e; - } - if (p->__pyx_t_0) { - e = (*v)(p->__pyx_t_0, a); if (e) return e; - } - return 0; -} - -static int __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores(PyObject *o) { - struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o; - PyObject* tmp; - tmp = ((PyObject*)p->__pyx_v_self); - p->__pyx_v_self = ((struct __pyx_obj_5kenlm_LanguageModel *)Py_None); Py_INCREF(Py_None); - Py_XDECREF(tmp); - tmp = ((PyObject*)p->__pyx_v_sentence); - p->__pyx_v_sentence = Py_None; Py_INCREF(Py_None); - Py_XDECREF(tmp); - tmp = ((PyObject*)p->__pyx_v_word); - p->__pyx_v_word = Py_None; Py_INCREF(Py_None); - Py_XDECREF(tmp); - tmp = ((PyObject*)p->__pyx_v_words); - p->__pyx_v_words = ((PyObject*)Py_None); Py_INCREF(Py_None); - Py_XDECREF(tmp); - tmp = ((PyObject*)p->__pyx_t_0); - p->__pyx_t_0 = Py_None; Py_INCREF(Py_None); - Py_XDECREF(tmp); - return 0; -} - -static PyMethodDef __pyx_methods_5kenlm___pyx_scope_struct__full_scores[] = { - {0, 0, 0, 0} -}; - -static PyTypeObject __pyx_type_5kenlm___pyx_scope_struct__full_scores = { - PyVarObject_HEAD_INIT(0, 0) - __Pyx_NAMESTR("kenlm.__pyx_scope_struct__full_scores"), /*tp_name*/ - sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores, /*tp_dealloc*/ - 0, /*tp_print*/ - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - #if PY_MAJOR_VERSION < 3 - 0, /*tp_compare*/ - #else - 0, /*reserved*/ - #endif - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - 0, /*tp_as_sequence*/ - 0, /*tp_as_mapping*/ - 0, /*tp_hash*/ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ - 0, /*tp_doc*/ - __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores, /*tp_traverse*/ - __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores, /*tp_clear*/ - 0, /*tp_richcompare*/ - 0, /*tp_weaklistoffset*/ - 0, /*tp_iter*/ - 0, /*tp_iternext*/ - __pyx_methods_5kenlm___pyx_scope_struct__full_scores, /*tp_methods*/ - 0, /*tp_members*/ - 0, /*tp_getset*/ - 0, /*tp_base*/ - 0, /*tp_dict*/ - 0, /*tp_descr_get*/ - 0, /*tp_descr_set*/ - 0, /*tp_dictoffset*/ - 0, /*tp_init*/ - 0, /*tp_alloc*/ - __pyx_tp_new_5kenlm___pyx_scope_struct__full_scores, /*tp_new*/ - 0, /*tp_free*/ - 0, /*tp_is_gc*/ - 0, /*tp_bases*/ - 0, /*tp_mro*/ - 0, /*tp_cache*/ - 0, /*tp_subclasses*/ - 0, /*tp_weaklist*/ - 0, /*tp_del*/ - #if PY_VERSION_HEX >= 0x02060000 - 0, /*tp_version_tag*/ - #endif -}; - static PyMethodDef __pyx_methods[] = { {0, 0, 0, 0} }; @@ -2508,11 +2508,11 @@ PyMODINIT_FUNC PyInit_kenlm(void) /*--- Variable export code ---*/ /*--- Function export code ---*/ /*--- Type init code ---*/ + if (PyType_Ready(&__pyx_type_5kenlm___pyx_scope_struct__full_scores) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 44; __pyx_clineno = __LINE__; goto __pyx_L1_error;} + __pyx_ptype_5kenlm___pyx_scope_struct__full_scores = &__pyx_type_5kenlm___pyx_scope_struct__full_scores; if (PyType_Ready(&__pyx_type_5kenlm_LanguageModel) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 10; __pyx_clineno = __LINE__; goto __pyx_L1_error;} if (__Pyx_SetAttrString(__pyx_m, "LanguageModel", (PyObject *)&__pyx_type_5kenlm_LanguageModel) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 10; __pyx_clineno = __LINE__; goto __pyx_L1_error;} __pyx_ptype_5kenlm_LanguageModel = &__pyx_type_5kenlm_LanguageModel; - if (PyType_Ready(&__pyx_type_5kenlm___pyx_scope_struct__full_scores) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 44; __pyx_clineno = __LINE__; goto __pyx_L1_error;} - __pyx_ptype_5kenlm___pyx_scope_struct__full_scores = &__pyx_type_5kenlm___pyx_scope_struct__full_scores; /*--- Type import code ---*/ /*--- Variable import code ---*/ /*--- Function import code ---*/ diff --git a/python/kenlm.pxd b/python/kenlm.pxd index 9a397f3..7d68fc4 100644 --- a/python/kenlm.pxd +++ b/python/kenlm.pxd @@ -24,8 +24,8 @@ cdef extern from "lm/virtual_interface.hh" namespace "lm::base": void NullContextWrite(void *) unsigned int Order() const_Vocabulary& BaseVocabulary() - float Score(void *in_state, WordIndex new_word, void *out_state) - FullScoreReturn FullScore(void *in_state, WordIndex new_word, void *out_state) + float BaseScore(void *in_state, WordIndex new_word, void *out_state) + FullScoreReturn BaseFullScore(void *in_state, WordIndex new_word, void *out_state) cdef extern from "lm/model.hh" namespace "lm::ngram": cdef Model *LoadVirtual(char *) except + diff --git a/python/kenlm.pyx b/python/kenlm.pyx index 7f965d4..eb45169 100644 --- a/python/kenlm.pyx +++ b/python/kenlm.pyx @@ -36,9 +36,9 @@ cdef class LanguageModel: cdef State out_state cdef float total = 0 for word in words: - total += self.model.Score(&state, self.vocab.Index(word), &out_state) + total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state) state = out_state - total += self.model.Score(&state, self.vocab.EndSentence(), &out_state) + total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state) return total def full_scores(self, sentence): @@ -49,11 +49,11 @@ cdef class LanguageModel: cdef FullScoreReturn ret cdef float total = 0 for word in words: - ret = self.model.FullScore(&state, + ret = self.model.BaseFullScore(&state, self.vocab.Index(word), &out_state) yield (ret.prob, ret.ngram_length) state = out_state - ret = self.model.FullScore(&state, + ret = self.model.BaseFullScore(&state, self.vocab.EndSentence(), &out_state) yield (ret.prob, ret.ngram_length) diff --git a/util/Jamfile b/util/Jamfile index 910b305..18b20a3 100644 --- a/util/Jamfile +++ b/util/Jamfile @@ -19,15 +19,18 @@ alias read_compressed : read_compressed.o $(compressed_deps) ; obj read_compressed_test.o : read_compressed_test.cc /top//boost_unit_test_framework : $(compressed_flags) ; obj file_piece_test.o : file_piece_test.cc /top//boost_unit_test_framework : $(compressed_flags) ; -fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ; +fakelib parallel_read : parallel_read.cc : <threading>multi:<source>/top//boost_thread <threading>multi:<define>WITH_THREADS : : <include>.. ; + +fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc parallel_read pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ; + +exe cat_compressed : cat_compressed_main.cc kenutil ; + +alias programs : cat_compressed ; import testing ; -unit-test bit_packing_test : bit_packing_test.cc kenutil /top//boost_unit_test_framework ; run file_piece_test.o kenutil /top//boost_unit_test_framework : : file_piece.cc ; -unit-test read_compressed_test : read_compressed_test.o kenutil /top//boost_unit_test_framework ; -unit-test joint_sort_test : joint_sort_test.cc kenutil /top//boost_unit_test_framework ; -unit-test probing_hash_table_test : probing_hash_table_test.cc kenutil /top//boost_unit_test_framework ; -unit-test sorted_uniform_test : sorted_uniform_test.cc kenutil /top//boost_unit_test_framework ; -unit-test tokenize_piece_test : tokenize_piece_test.cc kenutil /top//boost_unit_test_framework ; -unit-test multi_intersection_test : multi_intersection_test.cc kenutil /top//boost_unit_test_framework ; +for local t in [ glob *_test.cc : file_piece_test.cc read_compressed_test.cc ] { + local name = [ MATCH "(.*)\.cc" : $(t) ] ; + unit-test $(name) : $(t) kenutil /top//boost_unit_test_framework /top//boost_system ; +} diff --git a/util/bit_packing.hh b/util/bit_packing.hh index dcbd814..1e34d9a 100644 --- a/util/bit_packing.hh +++ b/util/bit_packing.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_BIT_PACKING__ -#define UTIL_BIT_PACKING__ +#ifndef UTIL_BIT_PACKING_H +#define UTIL_BIT_PACKING_H /* Bit-level packing routines * @@ -183,4 +183,4 @@ struct BitAddress { } // namespace util -#endif // UTIL_BIT_PACKING__ +#endif // UTIL_BIT_PACKING_H diff --git a/util/cat_compressed_main.cc b/util/cat_compressed_main.cc new file mode 100644 index 0000000..2b4d729 --- /dev/null +++ b/util/cat_compressed_main.cc @@ -0,0 +1,47 @@ +// Like cat but interprets compressed files. +#include "util/file.hh" +#include "util/read_compressed.hh" + +#include <string.h> +#include <iostream> + +namespace { +const std::size_t kBufSize = 16384; +void Copy(util::ReadCompressed &from, int to) { + util::scoped_malloc buffer(util::MallocOrThrow(kBufSize)); + while (std::size_t amount = from.Read(buffer.get(), kBufSize)) { + util::WriteOrThrow(to, buffer.get(), amount); + } +} +} // namespace + +int main(int argc, char *argv[]) { + // Lane Schwartz likes -h and --help + for (int i = 1; i < argc; ++i) { + char *arg = argv[i]; + if (!strcmp(arg, "--")) break; + if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) { + std::cerr << + "A cat implementation that interprets compressed files.\n" + "Usage: " << argv[0] << " [file1] [file2] ...\n" + "If no file is provided, then stdin is read.\n"; + return 1; + } + } + + try { + if (argc == 1) { + util::ReadCompressed in(0); + Copy(in, 1); + } else { + for (int i = 1; i < argc; ++i) { + util::ReadCompressed in(util::OpenReadOrThrow(argv[i])); + Copy(in, 1); + } + } + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 2; + } + return 0; +} diff --git a/util/ersatz_progress.hh b/util/ersatz_progress.hh index b94399a..535dbde 100644 --- a/util/ersatz_progress.hh +++ b/util/ersatz_progress.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_ERSATZ_PROGRESS__ -#define UTIL_ERSATZ_PROGRESS__ +#ifndef UTIL_ERSATZ_PROGRESS_H +#define UTIL_ERSATZ_PROGRESS_H #include <iostream> #include <string> @@ -55,4 +55,4 @@ class ErsatzProgress { } // namespace util -#endif // UTIL_ERSATZ_PROGRESS__ +#endif // UTIL_ERSATZ_PROGRESS_H diff --git a/util/exception.cc b/util/exception.cc index 557c398..083bac2 100644 --- a/util/exception.cc +++ b/util/exception.cc @@ -51,6 +51,11 @@ void Exception::SetLocation(const char *file, unsigned int line, const char *fun } namespace { +// At least one of these functions will not be called. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-function" +#endif // The XOPEN version. const char *HandleStrerror(int ret, const char *buf) { if (!ret) return buf; @@ -61,6 +66,9 @@ const char *HandleStrerror(int ret, const char *buf) { const char *HandleStrerror(const char *ret, const char * /*buf*/) { return ret; } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif } // namespace ErrnoException::ErrnoException() throw() : errno_(errno) { diff --git a/util/exception.hh b/util/exception.hh index 74046cf..0966740 100644 --- a/util/exception.hh +++ b/util/exception.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_EXCEPTION__ -#define UTIL_EXCEPTION__ +#ifndef UTIL_EXCEPTION_H +#define UTIL_EXCEPTION_H #include <exception> #include <limits> @@ -98,6 +98,9 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep #define UTIL_THROW_IF(Condition, Exception, Modify) \ UTIL_THROW_IF_ARG(Condition, Exception, , Modify) +#define UTIL_THROW_IF2(Condition, Modify) \ + UTIL_THROW_IF_ARG(Condition, util::Exception, , Modify) + // Exception that records errno and adds it to the message. class ErrnoException : public Exception { public: @@ -111,6 +114,13 @@ class ErrnoException : public Exception { int errno_; }; +// file wasn't there, or couldn't be open for some reason +class FileOpenException : public Exception { + public: + FileOpenException() throw() {} + ~FileOpenException() throw() {} +}; + // Utilities for overflow checking. class OverflowException : public Exception { public: @@ -133,4 +143,4 @@ inline std::size_t CheckOverflow(uint64_t value) { } // namespace util -#endif // UTIL_EXCEPTION__ +#endif // UTIL_EXCEPTION_H diff --git a/util/fake_ofstream.hh b/util/fake_ofstream.hh index bcdebe4..eefb1ed 100644 --- a/util/fake_ofstream.hh +++ b/util/fake_ofstream.hh @@ -2,6 +2,9 @@ * Does not support many data types. Currently, it's targeted at writing ARPA * files quickly. */ +#ifndef UTIL_FAKE_OFSTREAM_H +#define UTIL_FAKE_OFSTREAM_H + #include "util/double-conversion/double-conversion.h" #include "util/double-conversion/utils.h" #include "util/file.hh" @@ -17,7 +20,8 @@ class FakeOFStream { static const std::size_t kOutBuf = 1048576; // Does not take ownership of out. - explicit FakeOFStream(int out) + // Allows default constructor, but must call SetFD. + explicit FakeOFStream(int out = -1) : buf_(util::MallocOrThrow(kOutBuf)), builder_(static_cast<char*>(buf_.get()), kOutBuf), // Mostly the default but with inf instead. And no flags. @@ -28,6 +32,11 @@ class FakeOFStream { if (buf_.get()) Flush(); } + void SetFD(int to) { + if (builder_.position()) Flush(); + fd_ = to; + } + FakeOFStream &operator<<(float value) { // Odd, but this is the largest number found in the comments. EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); @@ -92,3 +101,5 @@ class FakeOFStream { }; } // namespace + +#endif diff --git a/util/file.cc b/util/file.cc index bef04cb..0d9adf2 100644 --- a/util/file.cc +++ b/util/file.cc @@ -17,7 +17,11 @@ #include <fcntl.h> #include <stdint.h> -#if defined(_WIN32) || defined(_WIN64) +#if defined __MINGW32__ +#include <windows.h> +#include <unistd.h> +#warning "The file functions on MinGW have not been tested for file sizes above 2^31 - 1. Please read https://stackoverflow.com/questions/12539488/determine-64-bit-file-size-in-c-on-mingw-32-bit and fix" +#elif defined(_WIN32) || defined(_WIN64) #include <windows.h> #include <io.h> #include <algorithm> @@ -76,7 +80,13 @@ int CreateOrThrow(const char *name) { } uint64_t SizeFile(int fd) { -#if defined(_WIN32) || defined(_WIN64) +#if defined __MINGW32__ + struct stat sb; + // Does this handle 64-bit? + int ret = fstat(fd, &sb); + if (ret == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; + return sb.st_size; +#elif defined(_WIN32) || defined(_WIN64) __int64 ret = _filelengthi64(fd); return (ret == -1) ? kBadSize : ret; #else // Not windows. @@ -100,7 +110,10 @@ uint64_t SizeOrThrow(int fd) { } void ResizeOrThrow(int fd, uint64_t to) { -#if defined(_WIN32) || defined(_WIN64) +#if defined __MINGW32__ + // Does this handle 64-bit? + int ret = ftruncate +#elif defined(_WIN32) || defined(_WIN64) errno_t ret = _chsize_s #elif defined(OS_ANDROID) int ret = ftruncate64 @@ -115,8 +128,10 @@ namespace { std::size_t GuardLarge(std::size_t size) { // The following operating systems have broken read/write/pread/pwrite that // only supports up to 2^31. -#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID) - return std::min(static_cast<std::size_t>(static_cast<unsigned>(-1)), size); + // OS X man pages claim to support 64-bit, but Kareem M. Darwish had problems + // building with larger files, so APPLE is also here. +#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID) || defined(__MINGW32__) + return size < INT_MAX ? size : INT_MAX; #else return size; #endif @@ -179,16 +194,15 @@ void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) { #else ssize_t ret; errno = 0; - do { - ret = + ret = #ifdef OS_ANDROID - pread64 + pread64 #else - pread + pread #endif - (fd, to, GuardLarge(size), off); - } while (ret == -1 && errno == EINTR); + (fd, to, GuardLarge(size), off); if (ret <= 0) { + if (ret == -1 && errno == EINTR) continue; UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from " << NameFromFD(fd)); UTIL_THROW_ARG(FDException, (fd), "while reading " << size << " bytes at offset " << off); } @@ -251,7 +265,10 @@ typedef CheckOffT<sizeof(off_t)>::True IgnoredType; // Can't we all just get along? void InternalSeek(int fd, int64_t off, int whence) { if ( -#if defined(_WIN32) || defined(_WIN64) +#if defined __MINGW32__ + // Does this handle 64-bit? + (off_t)-1 == lseek(fd, off, whence) +#elif defined(_WIN32) || defined(_WIN64) (__int64)-1 == _lseeki64(fd, off, whence) #elif defined(OS_ANDROID) (off64_t)-1 == lseek64(fd, off, whence) @@ -427,8 +444,8 @@ void NormalizeTempPrefix(std::string &base) { ) base += '/'; } -int MakeTemp(const std::string &base) { - std::string name(base); +int MakeTemp(const StringPiece &base) { + std::string name(base.data(), base.size()); name += "XXXXXX"; name.push_back(0); int ret; @@ -436,7 +453,7 @@ int MakeTemp(const std::string &base) { return ret; } -std::FILE *FMakeTemp(const std::string &base) { +std::FILE *FMakeTemp(const StringPiece &base) { util::scoped_fd file(MakeTemp(base)); return FDOpenOrThrow(file); } @@ -462,14 +479,18 @@ bool TryName(int fd, std::string &out) { if (-1 == lstat(name.c_str(), &sb)) return false; out.resize(sb.st_size + 1); - ssize_t ret = readlink(name.c_str(), &out[0], sb.st_size + 1); - if (-1 == ret) - return false; - if (ret > sb.st_size) { - // Increased in size?! - return false; + // lstat gave us a size, but I've seen it grow, possibly due to symlinks on top of symlinks. + while (true) { + ssize_t ret = readlink(name.c_str(), &out[0], out.size()); + if (-1 == ret) + return false; + if ((size_t)ret < out.size()) { + out.resize(ret); + break; + } + // Exponential growth. + out.resize(out.size() * 2); } - out.resize(ret); // Don't use the non-file names. if (!out.empty() && out[0] != '/') return false; diff --git a/util/file.hh b/util/file.hh index be88431..170a7c7 100644 --- a/util/file.hh +++ b/util/file.hh @@ -1,7 +1,8 @@ -#ifndef UTIL_FILE__ -#define UTIL_FILE__ +#ifndef UTIL_FILE_H +#define UTIL_FILE_H #include "util/exception.hh" +#include "util/string_piece.hh" #include <cstddef> #include <cstdio> @@ -125,8 +126,8 @@ std::FILE *FDOpenReadOrThrow(scoped_fd &file); // Temporary files // Append a / if base is a directory. void NormalizeTempPrefix(std::string &base); -int MakeTemp(const std::string &prefix); -std::FILE *FMakeTemp(const std::string &prefix); +int MakeTemp(const StringPiece &prefix); +std::FILE *FMakeTemp(const StringPiece &prefix); // dup an fd. int DupOrThrow(int fd); @@ -139,4 +140,4 @@ std::string NameFromFD(int fd); } // namespace util -#endif // UTIL_FILE__ +#endif // UTIL_FILE_H diff --git a/util/file_piece.cc b/util/file_piece.cc index 9c7e00c..4aaa250 100644 --- a/util/file_piece.cc +++ b/util/file_piece.cc @@ -84,6 +84,13 @@ StringPiece FilePiece::ReadLine(char delim) { } } +bool FilePiece::ReadLineOrEOF(StringPiece &to, char delim) { + try { + to = ReadLine(delim); + } catch (const util::EndOfFileException &e) { return false; } + return true; +} + float FilePiece::ReadFloat() { return ReadNumber<float>(); } diff --git a/util/file_piece.hh b/util/file_piece.hh index ed3dc5a..5495ddc 100644 --- a/util/file_piece.hh +++ b/util/file_piece.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_FILE_PIECE__ -#define UTIL_FILE_PIECE__ +#ifndef UTIL_FILE_PIECE_H +#define UTIL_FILE_PIECE_H #include "util/ersatz_progress.hh" #include "util/exception.hh" @@ -56,10 +56,33 @@ class FilePiece { return Consume(FindDelimiterOrEOF(delim)); } + // Read word until the line or file ends. + bool ReadWordSameLine(StringPiece &to, const bool *delim = kSpaces) { + assert(delim[static_cast<unsigned char>('\n')]); + // Skip non-enter spaces. + for (; ; ++position_) { + if (position_ == position_end_) { + try { + Shift(); + } catch (const util::EndOfFileException &e) { return false; } + // And break out at end of file. + if (position_ == position_end_) return false; + } + if (!delim[static_cast<unsigned char>(*position_)]) break; + if (*position_ == '\n') return false; + } + // We can't be at the end of file because there's at least one character open. + to = Consume(FindDelimiterOrEOF(delim)); + return true; + } + // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. // It is similar to getline in that way. StringPiece ReadLine(char delim = '\n'); + // Doesn't throw EndOfFileException, just returns false. + bool ReadLineOrEOF(StringPiece &to, char delim = '\n'); + float ReadFloat(); double ReadDouble(); long int ReadLong(); @@ -132,4 +155,4 @@ class FilePiece { } // namespace util -#endif // UTIL_FILE_PIECE__ +#endif // UTIL_FILE_PIECE_H diff --git a/util/file_piece_test.cc b/util/file_piece_test.cc index 7336007..4361877 100644 --- a/util/file_piece_test.cc +++ b/util/file_piece_test.cc @@ -1,4 +1,4 @@ -// Tests might fail if you have creative characters in your path. Sue me. +// Tests might fail if you have creative characters in your path. Sue me. #include "util/file_piece.hh" #include "util/file.hh" @@ -55,7 +55,7 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) { #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__) /* Apple isn't happy with the popen, fileno, dup. And I don't want to - * reimplement popen. This is an issue with the test. + * reimplement popen. This is an issue with the test. */ /* read() implementation */ BOOST_AUTO_TEST_CASE(StreamReadLine) { @@ -67,7 +67,7 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) { FILE *catter = popen(popen_args.c_str(), "r"); BOOST_REQUIRE(catter); - + FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1); std::string ref_line; while (getline(ref, ref_line)) { @@ -107,8 +107,8 @@ BOOST_AUTO_TEST_CASE(PlainZipReadLine) { } // gzip stream. Apple doesn't like popen, fileno, dup. This is an issue with -// the test. -#ifndef __APPLE__ +// the test. +#if !defined __APPLE__ && !defined __MINGW32__ BOOST_AUTO_TEST_CASE(StreamZipReadLine) { std::fstream ref(FileLocation().c_str(), std::ios::in); @@ -117,7 +117,7 @@ BOOST_AUTO_TEST_CASE(StreamZipReadLine) { FILE * catter = popen(command.c_str(), "r"); BOOST_REQUIRE(catter); - + FilePiece test(dup(fileno(catter)), "file_piece.cc.gz", NULL, 1); std::string ref_line; while (getline(ref, ref_line)) { diff --git a/util/fixed_array.hh b/util/fixed_array.hh new file mode 100644 index 0000000..bae13de --- /dev/null +++ b/util/fixed_array.hh @@ -0,0 +1,94 @@ +#ifndef UTIL_FIXED_ARRAY_H +#define UTIL_FIXED_ARRAY_H + +// Ever want an array of things by they don't have a default constructor or are +// non-copyable? FixedArray allows constructing one at a time. +#include "util/scoped.hh" + +#include <cstddef> + +#include <assert.h> + +namespace util { + +template <class T> class FixedArray { + public: + // Initialize with a given size bound but do not construct the objects. + explicit FixedArray(std::size_t limit) { + Init(limit); + } + + FixedArray() + : newed_end_(NULL) +#ifndef NDEBUG + , allocated_end_(NULL) +#endif + {} + + void Init(std::size_t count) { + assert(!block_.get()); + block_.reset(malloc(sizeof(T) * count)); + if (!block_.get()) throw std::bad_alloc(); + newed_end_ = begin(); +#ifndef NDEBUG + allocated_end_ = begin() + count; +#endif + } + + FixedArray(const FixedArray &from) { + std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get()); + Init(size); + for (std::size_t i = 0; i < size; ++i) { + push_back(from[i]); + } + } + + ~FixedArray() { clear(); } + + T *begin() { return static_cast<T*>(block_.get()); } + const T *begin() const { return static_cast<const T*>(block_.get()); } + // Always call Constructed after successful completion of new. + T *end() { return newed_end_; } + const T *end() const { return newed_end_; } + + T &back() { return *(end() - 1); } + const T &back() const { return *(end() - 1); } + + std::size_t size() const { return end() - begin(); } + bool empty() const { return begin() == end(); } + + T &operator[](std::size_t i) { return begin()[i]; } + const T &operator[](std::size_t i) const { return begin()[i]; } + + template <class C> void push_back(const C &c) { + new (end()) T(c); + Constructed(); + } + + void clear() { + for (T *i = begin(); i != end(); ++i) + i->~T(); + newed_end_ = begin(); + } + + protected: + void Constructed() { + ++newed_end_; +#ifndef NDEBUG + assert(newed_end_ <= allocated_end_); +#endif + } + + private: + util::scoped_malloc block_; + + T *newed_end_; + +#ifndef NDEBUG + T *allocated_end_; +#endif +}; + +} // namespace util + +#endif // UTIL_FIXED_ARRAY_H diff --git a/util/getopt.hh b/util/getopt.hh index 6ad9773..50eab56 100644 --- a/util/getopt.hh +++ b/util/getopt.hh @@ -11,8 +11,8 @@ Code given out at the 1985 UNIFORUM conference in Dallas. #endif #ifndef __GNUC__ -#ifndef _WINGETOPT_H_ -#define _WINGETOPT_H_ +#ifndef UTIL_GETOPT_H +#define UTIL_GETOPT_H #ifdef __cplusplus extern "C" { @@ -28,6 +28,6 @@ extern int getopt(int argc, char **argv, char *opts); } #endif -#endif /* _GETOPT_H_ */ +#endif /* UTIL_GETOPT_H */ #endif /* __GNUC__ */ diff --git a/util/have.hh b/util/have.hh index 6e18529..dc3f633 100644 --- a/util/have.hh +++ b/util/have.hh @@ -1,6 +1,6 @@ /* Optional packages. You might want to integrate this with your build system e.g. config.h from ./configure. */ -#ifndef UTIL_HAVE__ -#define UTIL_HAVE__ +#ifndef UTIL_HAVE_H +#define UTIL_HAVE_H #ifdef HAVE_CONFIG_H #include "config.h" @@ -10,4 +10,4 @@ //#define HAVE_ICU #endif -#endif // UTIL_HAVE__ +#endif // UTIL_HAVE_H diff --git a/util/joint_sort.hh b/util/joint_sort.hh index 1b43ddc..de4b554 100644 --- a/util/joint_sort.hh +++ b/util/joint_sort.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_JOINT_SORT__ -#define UTIL_JOINT_SORT__ +#ifndef UTIL_JOINT_SORT_H +#define UTIL_JOINT_SORT_H /* A terrifying amount of C++ to coax std::sort into soring one range while * also permuting another range the same way. @@ -9,7 +9,6 @@ #include <algorithm> #include <functional> -#include <iostream> namespace util { @@ -35,9 +34,16 @@ template <class KeyIter, class ValueIter> class JointIter { return *this; } - void swap(const JointIter &other) { - std::swap(key_, other.key_); - std::swap(value_, other.value_); + friend void swap(JointIter &first, JointIter &second) { + using std::swap; + swap(first.key_, second.key_); + swap(first.value_, second.value_); + } + + void DeepSwap(JointIter &other) { + using std::swap; + swap(*key_, *other.key_); + swap(*value_, *other.value_); } private: @@ -83,9 +89,8 @@ template <class KeyIter, class ValueIter> class JointProxy { return *(inner_.key_); } - void swap(JointProxy<KeyIter, ValueIter> &other) { - std::swap(*inner_.key_, *other.inner_.key_); - std::swap(*inner_.value_, *other.inner_.value_); + friend void swap(JointProxy<KeyIter, ValueIter> first, JointProxy<KeyIter, ValueIter> second) { + first.Inner().DeepSwap(second.Inner()); } private: @@ -138,14 +143,4 @@ template <class KeyIter, class ValueIter> void JointSort(const KeyIter &key_begi } // namespace util -namespace std { -template <class KeyIter, class ValueIter> void swap(util::detail::JointIter<KeyIter, ValueIter> &left, util::detail::JointIter<KeyIter, ValueIter> &right) { - left.swap(right); -} - -template <class KeyIter, class ValueIter> void swap(util::detail::JointProxy<KeyIter, ValueIter> &left, util::detail::JointProxy<KeyIter, ValueIter> &right) { - left.swap(right); -} -} // namespace std - -#endif // UTIL_JOINT_SORT__ +#endif // UTIL_JOINT_SORT_H diff --git a/util/joint_sort_test.cc b/util/joint_sort_test.cc index 4dc8591..b24c602 100644 --- a/util/joint_sort_test.cc +++ b/util/joint_sort_test.cc @@ -47,4 +47,16 @@ BOOST_AUTO_TEST_CASE(char_int) { BOOST_CHECK_EQUAL(327, values[3]); } +BOOST_AUTO_TEST_CASE(swap_proxy) { + char keys[2] = {0, 1}; + int values[2] = {2, 3}; + detail::JointProxy<char *, int *> first(keys, values); + detail::JointProxy<char *, int *> second(keys + 1, values + 1); + swap(first, second); + BOOST_CHECK_EQUAL(1, keys[0]); + BOOST_CHECK_EQUAL(0, keys[1]); + BOOST_CHECK_EQUAL(3, values[0]); + BOOST_CHECK_EQUAL(2, values[1]); +} + }} // namespace anonymous util diff --git a/util/mmap.cc b/util/mmap.cc index cee6a97..a3c8a02 100644 --- a/util/mmap.cc +++ b/util/mmap.cc @@ -6,6 +6,7 @@ #include "util/exception.hh" #include "util/file.hh" +#include "util/parallel_read.hh" #include "util/scoped.hh" #include <iostream> @@ -40,7 +41,7 @@ void SyncOrThrow(void *start, size_t length) { #if defined(_WIN32) || defined(_WIN64) UTIL_THROW_IF(!::FlushViewOfFile(start, length), ErrnoException, "Failed to sync mmap"); #else - UTIL_THROW_IF(msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap"); + UTIL_THROW_IF(length && msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap"); #endif } @@ -154,6 +155,10 @@ void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scope SeekOrThrow(fd, offset); ReadOrThrow(fd, out.get(), size); break; + case PARALLEL_READ: + out.reset(MallocOrThrow(size), size, scoped_memory::MALLOC_ALLOCATED); + ParallelRead(fd, out.get(), size, offset); + break; } } @@ -189,4 +194,66 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) { } } +Rolling::Rolling(const Rolling ©_from, uint64_t increase) { + *this = copy_from; + IncreaseBase(increase); +} + +Rolling &Rolling::operator=(const Rolling ©_from) { + fd_ = copy_from.fd_; + file_begin_ = copy_from.file_begin_; + file_end_ = copy_from.file_end_; + for_write_ = copy_from.for_write_; + block_ = copy_from.block_; + read_bound_ = copy_from.read_bound_; + + current_begin_ = 0; + if (copy_from.IsPassthrough()) { + current_end_ = copy_from.current_end_; + ptr_ = copy_from.ptr_; + } else { + // Force call on next mmap. + current_end_ = 0; + ptr_ = NULL; + } + return *this; +} + +Rolling::Rolling(int fd, bool for_write, std::size_t block, std::size_t read_bound, uint64_t offset, uint64_t amount) { + current_begin_ = 0; + current_end_ = 0; + fd_ = fd; + file_begin_ = offset; + file_end_ = offset + amount; + for_write_ = for_write; + block_ = block; + read_bound_ = read_bound; +} + +void *Rolling::ExtractNonRolling(scoped_memory &out, uint64_t index, std::size_t size) { + out.reset(); + if (IsPassthrough()) return static_cast<uint8_t*>(get()) + index; + uint64_t offset = index + file_begin_; + // Round down to multiple of page size. + uint64_t cruft = offset % static_cast<uint64_t>(SizePage()); + std::size_t map_size = static_cast<std::size_t>(size + cruft); + out.reset(MapOrThrow(map_size, for_write_, kFileFlags, true, fd_, offset - cruft), map_size, scoped_memory::MMAP_ALLOCATED); + return static_cast<uint8_t*>(out.get()) + static_cast<std::size_t>(cruft); +} + +void Rolling::Roll(uint64_t index) { + assert(!IsPassthrough()); + std::size_t amount; + if (file_end_ - (index + file_begin_) > static_cast<uint64_t>(block_)) { + amount = block_; + current_end_ = index + amount - read_bound_; + } else { + amount = file_end_ - (index + file_begin_); + current_end_ = index + amount; + } + ptr_ = static_cast<uint8_t*>(ExtractNonRolling(mem_, index, amount)) - index; + + current_begin_ = index; +} + } // namespace util diff --git a/util/mmap.hh b/util/mmap.hh index b218c4d..9b1e120 100644 --- a/util/mmap.hh +++ b/util/mmap.hh @@ -1,8 +1,9 @@ -#ifndef UTIL_MMAP__ -#define UTIL_MMAP__ +#ifndef UTIL_MMAP_H +#define UTIL_MMAP_H // Utilities for mmaped files. #include <cstddef> +#include <limits> #include <stdint.h> #include <sys/types.h> @@ -52,6 +53,9 @@ class scoped_memory { public: typedef enum {MMAP_ALLOCATED, ARRAY_ALLOCATED, MALLOC_ALLOCATED, NONE_ALLOCATED} Alloc; + scoped_memory(void *data, std::size_t size, Alloc source) + : data_(data), size_(size), source_(source) {} + scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {} ~scoped_memory() { reset(); } @@ -72,7 +76,6 @@ class scoped_memory { void call_realloc(std::size_t to); private: - void *data_; std::size_t size_; @@ -90,7 +93,9 @@ typedef enum { // Populate on Linux. malloc and read on non-Linux. POPULATE_OR_READ, // malloc and read. - READ + READ, + // malloc and read in parallel (recommended for Lustre) + PARALLEL_READ, } LoadMethod; extern const int kFileFlags; @@ -109,6 +114,79 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file); // msync wrapper void SyncOrThrow(void *start, size_t length); +// Forward rolling memory map with no overlap. +class Rolling { + public: + Rolling() {} + + explicit Rolling(void *data) { Init(data); } + + Rolling(const Rolling ©_from, uint64_t increase = 0); + Rolling &operator=(const Rolling ©_from); + + // For an actual rolling mmap. + explicit Rolling(int fd, bool for_write, std::size_t block, std::size_t read_bound, uint64_t offset, uint64_t amount); + + // For a static mapping + void Init(void *data) { + ptr_ = data; + current_end_ = std::numeric_limits<uint64_t>::max(); + current_begin_ = 0; + // Mark as a pass-through. + fd_ = -1; + } + + void IncreaseBase(uint64_t by) { + file_begin_ += by; + ptr_ = static_cast<uint8_t*>(ptr_) + by; + if (!IsPassthrough()) current_end_ = 0; + } + + void DecreaseBase(uint64_t by) { + file_begin_ -= by; + ptr_ = static_cast<uint8_t*>(ptr_) - by; + if (!IsPassthrough()) current_end_ = 0; + } + + void *ExtractNonRolling(scoped_memory &out, uint64_t index, std::size_t size); + + // Returns base pointer + void *get() const { return ptr_; } + + // Returns base pointer. + void *CheckedBase(uint64_t index) { + if (index >= current_end_ || index < current_begin_) { + Roll(index); + } + return ptr_; + } + + // Returns indexed pointer. + void *CheckedIndex(uint64_t index) { + return static_cast<uint8_t*>(CheckedBase(index)) + index; + } + + private: + void Roll(uint64_t index); + + // True if this is just a thin wrapper on a pointer. + bool IsPassthrough() const { return fd_ == -1; } + + void *ptr_; + uint64_t current_begin_; + uint64_t current_end_; + + scoped_memory mem_; + + int fd_; + uint64_t file_begin_; + uint64_t file_end_; + + bool for_write_; + std::size_t block_; + std::size_t read_bound_; +}; + } // namespace util -#endif // UTIL_MMAP__ +#endif // UTIL_MMAP_H diff --git a/util/multi_intersection.hh b/util/multi_intersection.hh index 8334d39..2955acc 100644 --- a/util/multi_intersection.hh +++ b/util/multi_intersection.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_MULTI_INTERSECTION__ -#define UTIL_MULTI_INTERSECTION__ +#ifndef UTIL_MULTI_INTERSECTION_H +#define UTIL_MULTI_INTERSECTION_H #include <boost/optional.hpp> #include <boost/range/iterator_range.hpp> @@ -66,7 +66,7 @@ template <class Iterator, class Output, class Less> void AllIntersection(std::ve std::sort(sets.begin(), sets.end(), detail::RangeLessBySize<boost::iterator_range<Iterator> >()); boost::optional<Value> ret; - for (boost::optional<Value> ret; ret = detail::FirstIntersectionSorted(sets, less); sets.front().advance_begin(1)) { + for (boost::optional<Value> ret; (ret = detail::FirstIntersectionSorted(sets, less)); sets.front().advance_begin(1)) { out(*ret); } } @@ -77,4 +77,4 @@ template <class Iterator, class Output> void AllIntersection(std::vector<boost:: } // namespace util -#endif // UTIL_MULTI_INTERSECTION__ +#endif // UTIL_MULTI_INTERSECTION_H diff --git a/util/murmur_hash.cc b/util/murmur_hash.cc index 4f51931..189668c 100644 --- a/util/murmur_hash.cc +++ b/util/murmur_hash.cc @@ -153,12 +153,19 @@ uint64_t MurmurHash64B ( const void * key, std::size_t len, uint64_t seed ) // Trick to test for 64-bit architecture at compile time. namespace { +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-function" +#endif template <unsigned L> inline uint64_t MurmurHashNativeBackend(const void * key, std::size_t len, uint64_t seed) { return MurmurHash64A(key, len, seed); } template <> inline uint64_t MurmurHashNativeBackend<4>(const void * key, std::size_t len, uint64_t seed) { return MurmurHash64B(key, len, seed); } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif } // namespace uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed) { diff --git a/util/murmur_hash.hh b/util/murmur_hash.hh index ae7e88d..f17157c 100644 --- a/util/murmur_hash.hh +++ b/util/murmur_hash.hh @@ -1,14 +1,18 @@ -#ifndef UTIL_MURMUR_HASH__ -#define UTIL_MURMUR_HASH__ +#ifndef UTIL_MURMUR_HASH_H +#define UTIL_MURMUR_HASH_H #include <cstddef> #include <stdint.h> namespace util { +// 64-bit machine version uint64_t MurmurHash64A(const void * key, std::size_t len, uint64_t seed = 0); +// 32-bit machine version (not the same function as above) uint64_t MurmurHash64B(const void * key, std::size_t len, uint64_t seed = 0); +// Use the version for this arch. Because the values differ across +// architectures, really only use it for in-memory structures. uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed = 0); } // namespace util -#endif // UTIL_MURMUR_HASH__ +#endif // UTIL_MURMUR_HASH_H diff --git a/util/parallel_read.cc b/util/parallel_read.cc new file mode 100644 index 0000000..10972d7 --- /dev/null +++ b/util/parallel_read.cc @@ -0,0 +1,69 @@ +#include "util/parallel_read.hh" + +#include "util/file.hh" + +#ifdef WITH_THREADS +#include "util/thread_pool.hh" + +namespace util { +namespace { + +class Reader { + public: + explicit Reader(int fd) : fd_(fd) {} + + struct Request { + void *to; + std::size_t size; + uint64_t offset; + + bool operator==(const Request &other) const { + return (to == other.to) && (size == other.size) && (offset == other.offset); + } + }; + + void operator()(const Request &request) { + util::PReadOrThrow(fd_, request.to, request.size, request.offset); + } + + private: + int fd_; +}; + +} // namespace + +void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) { + Reader::Request poison; + poison.to = NULL; + poison.size = 0; + poison.offset = 0; + unsigned threads = boost::thread::hardware_concurrency(); + if (!threads) threads = 2; + ThreadPool<Reader> pool(2 /* don't need much of a queue */, threads, fd, poison); + const std::size_t kBatch = 1ULL << 25; // 32 MB + Reader::Request request; + request.to = to; + request.size = kBatch; + request.offset = offset; + for (; amount > kBatch; amount -= kBatch) { + pool.Produce(request); + request.to = reinterpret_cast<uint8_t*>(request.to) + kBatch; + request.offset += kBatch; + } + request.size = amount; + if (request.size) { + pool.Produce(request); + } +} + +} // namespace util + +#else // WITH_THREADS + +namespace util { +void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) { + util::PReadOrThrow(fd, to, amount, offset); +} +} // namespace util + +#endif diff --git a/util/parallel_read.hh b/util/parallel_read.hh new file mode 100644 index 0000000..1e96e79 --- /dev/null +++ b/util/parallel_read.hh @@ -0,0 +1,16 @@ +#ifndef UTIL_PARALLEL_READ__ +#define UTIL_PARALLEL_READ__ + +/* Read pieces of a file in parallel. This has a very specific use case: + * reading files from Lustre is CPU bound so multiple threads actually + * increases throughput. Speed matters when an LM takes a terabyte. + */ + +#include <cstddef> +#include <stdint.h> + +namespace util { +void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset); +} // namespace util + +#endif // UTIL_PARALLEL_READ__ diff --git a/util/pcqueue.hh b/util/pcqueue.hh index 3df8749..312a66f 100644 --- a/util/pcqueue.hh +++ b/util/pcqueue.hh @@ -1,5 +1,7 @@ -#ifndef UTIL_PCQUEUE__ -#define UTIL_PCQUEUE__ +#ifndef UTIL_PCQUEUE_H +#define UTIL_PCQUEUE_H + +#include "util/exception.hh" #include <boost/interprocess/sync/interprocess_semaphore.hpp> #include <boost/scoped_array.hpp> @@ -8,20 +10,68 @@ #include <errno.h> +#ifdef __APPLE__ +#include <mach/semaphore.h> +#include <mach/task.h> +#include <mach/mach_traps.h> +#include <mach/mach.h> +#endif // __APPLE__ + namespace util { -inline void WaitSemaphore (boost::interprocess::interprocess_semaphore &on) { +/* OS X Maverick and Boost interprocess were doing "Function not implemented." + * So this is my own wrapper around the mach kernel APIs. + */ +#ifdef __APPLE__ + +#define MACH_CALL(call) UTIL_THROW_IF(KERN_SUCCESS != (call), Exception, "Mach call failure") + +class Semaphore { + public: + explicit Semaphore(int value) : task_(mach_task_self()) { + MACH_CALL(semaphore_create(task_, &back_, SYNC_POLICY_FIFO, value)); + } + + ~Semaphore() { + MACH_CALL(semaphore_destroy(task_, back_)); + } + + void wait() { + MACH_CALL(semaphore_wait(back_)); + } + + void post() { + MACH_CALL(semaphore_signal(back_)); + } + + private: + semaphore_t back_; + task_t task_; +}; + +inline void WaitSemaphore(Semaphore &semaphore) { + semaphore.wait(); +} + +#else +typedef boost::interprocess::interprocess_semaphore Semaphore; + +inline void WaitSemaphore (Semaphore &on) { while (1) { try { on.wait(); break; } catch (boost::interprocess::interprocess_exception &e) { - if (e.get_native_error() != EINTR) throw; + if (e.get_native_error() != EINTR) { + throw; + } } } } +#endif // __APPLE__ + /* Producer consumer queue safe for multiple producers and multiple consumers. * T must be default constructable and have operator=. * The value is copied twice for Consume(T &out) or three times for Consume(), @@ -82,9 +132,9 @@ template <class T> class PCQueue : boost::noncopyable { private: // Number of empty spaces in storage_. - boost::interprocess::interprocess_semaphore empty_; + Semaphore empty_; // Number of occupied spaces in storage_. - boost::interprocess::interprocess_semaphore used_; + Semaphore used_; boost::scoped_array<T> storage_; @@ -102,4 +152,4 @@ template <class T> class PCQueue : boost::noncopyable { } // namespace util -#endif // UTIL_PCQUEUE__ +#endif // UTIL_PCQUEUE_H diff --git a/util/pcqueue_test.cc b/util/pcqueue_test.cc new file mode 100644 index 0000000..22ed2c6 --- /dev/null +++ b/util/pcqueue_test.cc @@ -0,0 +1,20 @@ +#include "util/pcqueue.hh" + +#define BOOST_TEST_MODULE PCQueueTest +#include <boost/test/unit_test.hpp> + +namespace util { +namespace { + +BOOST_AUTO_TEST_CASE(SingleThread) { + PCQueue<int> queue(10); + for (int i = 0; i < 10; ++i) { + queue.Produce(i); + } + for (int i = 0; i < 10; ++i) { + BOOST_CHECK_EQUAL(i, queue.Consume()); + } +} + +} +} // namespace util diff --git a/util/pool.hh b/util/pool.hh index 72f8a0c..89e793d 100644 --- a/util/pool.hh +++ b/util/pool.hh @@ -1,8 +1,8 @@ // Very simple pool. It can only allocate memory. And all of the memory it // allocates must be freed at the same time. -#ifndef UTIL_POOL__ -#define UTIL_POOL__ +#ifndef UTIL_POOL_H +#define UTIL_POOL_H #include <vector> @@ -42,4 +42,4 @@ class Pool { } // namespace util -#endif // UTIL_POOL__ +#endif // UTIL_POOL_H diff --git a/util/probing_hash_table.hh b/util/probing_hash_table.hh index 51a2944..c1fe917 100644 --- a/util/probing_hash_table.hh +++ b/util/probing_hash_table.hh @@ -1,7 +1,8 @@ -#ifndef UTIL_PROBING_HASH_TABLE__ -#define UTIL_PROBING_HASH_TABLE__ +#ifndef UTIL_PROBING_HASH_TABLE_H +#define UTIL_PROBING_HASH_TABLE_H #include "util/exception.hh" +#include "util/scoped.hh" #include <algorithm> #include <cstddef> @@ -25,6 +26,8 @@ struct IdentityHash { template <class T> T operator()(T arg) const { return arg; } }; +template <class EntryT, class HashT, class EqualT> class AutoProbing; + /* Non-standard hash table * Buckets must be set at the beginning and must be greater than maximum number * of elements, else it throws ProbingSizeException. @@ -33,7 +36,6 @@ struct IdentityHash { * Uses linear probing to find value. * Only insert and lookup operations. */ - template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class ProbingHashTable { public: typedef EntryT Entry; @@ -43,7 +45,6 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry typedef HashT Hash; typedef EqualT Equal; - public: static uint64_t Size(uint64_t entries, float multiplier) { uint64_t buckets = std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries))); return buckets * sizeof(Entry); @@ -69,6 +70,11 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry #endif {} + void Relocate(void *new_base) { + begin_ = reinterpret_cast<MutableIterator>(new_base); + end_ = begin_ + buckets_; + } + template <class T> MutableIterator Insert(const T &t) { #ifdef DEBUG assert(initialized_); @@ -82,7 +88,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry #ifdef DEBUG assert(initialized_); #endif - for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) { + for (MutableIterator i = Ideal(t);;) { Key got(i->GetKey()); if (equal_(got, t.GetKey())) { out = i; return true; } if (equal_(got, invalid_)) { @@ -97,8 +103,6 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry void FinishedInserting() {} - void LoadedBinary() {} - // Don't change anything related to GetKey, template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) { #ifdef DEBUG @@ -224,6 +228,8 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry } private: + friend class AutoProbing<Entry, Hash, Equal>; + template <class T> MutableIterator Ideal(const T &t) { return begin_ + (hash_(t.GetKey()) % buckets_); } @@ -247,6 +253,75 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry #endif }; +// Resizable linear probing hash table. This owns the memory. +template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class AutoProbing { + private: + typedef ProbingHashTable<EntryT, HashT, EqualT> Backend; + public: + typedef EntryT Entry; + typedef typename Entry::Key Key; + typedef const Entry *ConstIterator; + typedef Entry *MutableIterator; + typedef HashT Hash; + typedef EqualT Equal; + + AutoProbing(std::size_t initial_size = 10, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) : + allocated_(Backend::Size(initial_size, 1.5)), mem_(util::MallocOrThrow(allocated_)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) { + threshold_ = initial_size * 1.2; + Clear(); + } + + // Assumes that the key is unique. Multiple insertions won't cause a failure, just inconsistent lookup. + template <class T> MutableIterator Insert(const T &t) { + DoubleIfNeeded(); + return backend_.UncheckedInsert(t); + } + + template <class T> bool FindOrInsert(const T &t, MutableIterator &out) { + DoubleIfNeeded(); + return backend_.FindOrInsert(t, out); + } + + template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) { + return backend_.UnsafeMutableFind(key, out); + } + + template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) { + return backend_.UnsafeMutableMustFind(key); + } + + template <class Key> bool Find(const Key key, ConstIterator &out) const { + return backend_.Find(key, out); + } + + template <class Key> ConstIterator MustFind(const Key key) const { + return backend_.MustFind(key); + } + + std::size_t Size() const { + return backend_.SizeNoSerialization(); + } + + void Clear() { + backend_.Clear(); + } + + private: + void DoubleIfNeeded() { + if (Size() < threshold_) + return; + mem_.call_realloc(backend_.DoubleTo()); + allocated_ = backend_.DoubleTo(); + backend_.Double(mem_.get()); + threshold_ *= 2; + } + + std::size_t allocated_; + util::scoped_malloc mem_; + Backend backend_; + std::size_t threshold_; +}; + } // namespace util -#endif // UTIL_PROBING_HASH_TABLE__ +#endif // UTIL_PROBING_HASH_TABLE_H diff --git a/util/proxy_iterator.hh b/util/proxy_iterator.hh index 0ee1716..8aa697b 100644 --- a/util/proxy_iterator.hh +++ b/util/proxy_iterator.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_PROXY_ITERATOR__ -#define UTIL_PROXY_ITERATOR__ +#ifndef UTIL_PROXY_ITERATOR_H +#define UTIL_PROXY_ITERATOR_H #include <cstddef> #include <iterator> @@ -38,8 +38,8 @@ template <class Proxy> class ProxyIterator { typedef std::random_access_iterator_tag iterator_category; typedef typename Proxy::value_type value_type; typedef std::ptrdiff_t difference_type; - typedef Proxy & reference; - typedef Proxy * pointer; + typedef Proxy reference; + typedef ProxyIterator<Proxy> * pointer; ProxyIterator() {} @@ -47,10 +47,10 @@ template <class Proxy> class ProxyIterator { template <class AlternateProxy> ProxyIterator(const ProxyIterator<AlternateProxy> &in) : p_(*in) {} explicit ProxyIterator(const Proxy &p) : p_(p) {} - // p_'s swap does value swapping, but here we want iterator swapping +/* // p_'s swap does value swapping, but here we want iterator swapping friend inline void swap(ProxyIterator<Proxy> &first, ProxyIterator<Proxy> &second) { swap(first.I(), second.I()); - } + }*/ // p_'s operator= does value copying, but here we want iterator copying. S &operator=(const S &other) { @@ -77,8 +77,8 @@ template <class Proxy> class ProxyIterator { std::ptrdiff_t operator-(const S &other) const { return I() - other.I(); } - Proxy &operator*() { return p_; } - const Proxy &operator*() const { return p_; } + Proxy operator*() { return p_; } + const Proxy operator*() const { return p_; } Proxy *operator->() { return &p_; } const Proxy *operator->() const { return &p_; } Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); } @@ -98,4 +98,4 @@ template <class Proxy> ProxyIterator<Proxy> operator+(std::ptrdiff_t amount, con } // namespace util -#endif // UTIL_PROXY_ITERATOR__ +#endif // UTIL_PROXY_ITERATOR_H diff --git a/util/read_compressed.cc b/util/read_compressed.cc index b62a6e8..71ef0e2 100644 --- a/util/read_compressed.cc +++ b/util/read_compressed.cc @@ -49,6 +49,8 @@ class ReadBase { thunk.internal_.reset(with); } + ReadBase *Current(ReadCompressed &thunk) { return thunk.internal_.get(); } + static uint64_t &ReadCount(ReadCompressed &thunk) { return thunk.raw_amount_; } @@ -56,6 +58,8 @@ class ReadBase { namespace { +ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, std::size_t already_size, bool require_compressed); + // Completed file that other classes can thunk to. class Complete : public ReadBase { public: @@ -80,7 +84,7 @@ class Uncompressed : public ReadBase { class UncompressedWithHeader : public ReadBase { public: - UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) { + UncompressedWithHeader(int fd, const void *already_data, std::size_t already_size) : fd_(fd) { assert(already_size); buf_.reset(malloc(already_size)); if (!buf_.get()) throw std::bad_alloc(); @@ -91,6 +95,7 @@ class UncompressedWithHeader : public ReadBase { std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { assert(buf_.get()); + assert(remain_ != end_); std::size_t sending = std::min<std::size_t>(amount, end_ - remain_); memcpy(to, remain_, sending); remain_ += sending; @@ -108,23 +113,51 @@ class UncompressedWithHeader : public ReadBase { scoped_fd fd_; }; -#ifdef HAVE_ZLIB -class GZip : public ReadBase { +static const std::size_t kInputBuffer = 16384; + +template <class Compression> class StreamCompressed : public ReadBase { + public: + StreamCompressed(int fd, const void *already_data, std::size_t already_size) + : file_(fd), + in_buffer_(MallocOrThrow(kInputBuffer)), + back_(memcpy(in_buffer_.get(), already_data, already_size), already_size) {} + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + if (amount == 0) return 0; + back_.SetOutput(to, amount); + do { + if (!back_.Stream().avail_in) ReadInput(thunk); + if (!back_.Process()) { + // reached end, at least for the compressed portion. + std::size_t ret = static_cast<const uint8_t *>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to); + ReplaceThis(ReadFactory(file_.release(), ReadCount(thunk), back_.Stream().next_in, back_.Stream().avail_in, true), thunk); + if (ret) return ret; + // We did not read anything this round, so clients might think EOF. Transfer responsibility to the next reader. + return Current(thunk)->Read(to, amount, thunk); + } + } while (back_.Stream().next_out == to); + return static_cast<const uint8_t*>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to); + } + private: - static const std::size_t kInputBuffer = 16384; + void ReadInput(ReadCompressed &thunk) { + assert(!back_.Stream().avail_in); + std::size_t got = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); + back_.SetInput(in_buffer_.get(), got); + ReadCount(thunk) += got; + } + + scoped_fd file_; + scoped_malloc in_buffer_; + + Compression back_; +}; + +#ifdef HAVE_ZLIB +class GZip { public: - GZip(int fd, void *already_data, std::size_t already_size) - : file_(fd), in_buffer_(malloc(kInputBuffer)) { - if (!in_buffer_.get()) throw std::bad_alloc(); - assert(already_size < kInputBuffer); - if (already_size) { - memcpy(in_buffer_.get(), already_data, already_size); - stream_.next_in = static_cast<Bytef *>(in_buffer_.get()); - stream_.avail_in = already_size; - stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size); - } else { - stream_.avail_in = 0; - } + GZip(const void *base, std::size_t amount) { + SetInput(base, amount); stream_.zalloc = Z_NULL; stream_.zfree = Z_NULL; stream_.opaque = Z_NULL; @@ -141,227 +174,154 @@ class GZip : public ReadBase { } } - std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { - if (amount == 0) return 0; + void SetOutput(void *to, std::size_t amount) { stream_.next_out = static_cast<Bytef*>(to); stream_.avail_out = std::min<std::size_t>(std::numeric_limits<uInt>::max(), amount); - do { - if (!stream_.avail_in) ReadInput(thunk); - int result = inflate(&stream_, 0); - switch (result) { - case Z_OK: - break; - case Z_STREAM_END: - { - std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); - ReplaceThis(new Complete(), thunk); - return ret; - } - case Z_ERRNO: - UTIL_THROW(ErrnoException, "zlib error"); - default: - UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result); - } - } while (stream_.next_out == to); - return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); } - private: - void ReadInput(ReadCompressed &thunk) { - assert(!stream_.avail_in); - stream_.next_in = static_cast<Bytef *>(in_buffer_.get()); - stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); - ReadCount(thunk) += stream_.avail_in; + void SetInput(const void *base, std::size_t amount) { + assert(amount < static_cast<std::size_t>(std::numeric_limits<uInt>::max())); + stream_.next_in = const_cast<Bytef*>(static_cast<const Bytef*>(base)); + stream_.avail_in = amount; } - scoped_fd file_; - scoped_malloc in_buffer_; + const z_stream &Stream() const { return stream_; } + + bool Process() { + int result = inflate(&stream_, 0); + switch (result) { + case Z_OK: + return true; + case Z_STREAM_END: + return false; + case Z_ERRNO: + UTIL_THROW(ErrnoException, "zlib error"); + default: + UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result); + } + } + + private: z_stream stream_; }; #endif // HAVE_ZLIB -const uint8_t kBZMagic[3] = {'B', 'Z', 'h'}; - #ifdef HAVE_BZLIB -class BZip : public ReadBase { +class BZip { public: - BZip(int fd, void *already_data, std::size_t already_size) { - scoped_fd hold(fd); - closer_.reset(FDOpenReadOrThrow(hold)); - file_ = NULL; - Open(already_data, already_size); + BZip(const void *base, std::size_t amount) { + memset(&stream_, 0, sizeof(stream_)); + SetInput(base, amount); + HandleError(BZ2_bzDecompressInit(&stream_, 0, 0)); } - BZip(FILE *file, void *already_data, std::size_t already_size) { - closer_.reset(file); - file_ = NULL; - Open(already_data, already_size); + ~BZip() { + try { + HandleError(BZ2_bzDecompressEnd(&stream_)); + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } } - ~BZip() { - Close(file_); + bool Process() { + int ret = BZ2_bzDecompress(&stream_); + if (ret == BZ_STREAM_END) return false; + HandleError(ret); + return true; } - std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { - assert(file_); - int bzerror = BZ_OK; - int ret = BZ2_bzRead(&bzerror, file_, to, std::min<std::size_t>(static_cast<std::size_t>(INT_MAX), amount)); - long pos = ftell(closer_.get()); - if (pos != -1) ReadCount(thunk) = pos; - switch (bzerror) { - case BZ_STREAM_END: - /* bzip2 files can be concatenated by e.g. pbzip2. Annoyingly, the - * library doesn't handle this internally. This gets the trailing - * data, grows it up to magic as needed, validates the magic, and - * reopens. - */ - { - bzerror = BZ_OK; - void *trailing_data; - int trailing_size; - BZ2_bzReadGetUnused(&bzerror, file_, &trailing_data, &trailing_size); - UTIL_THROW_IF(bzerror != BZ_OK, BZException, "bzip2 error in BZ2_bzReadGetUnused " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror); - std::string trailing(static_cast<const char*>(trailing_data), trailing_size); - Close(file_); - - if (trailing_size < (int)sizeof(kBZMagic)) { - trailing.resize(sizeof(kBZMagic)); - if (1 != fread(&trailing[trailing_size], sizeof(kBZMagic) - trailing_size, 1, closer_.get())) { - UTIL_THROW_IF(trailing_size, BZException, "File has trailing cruft"); - // Legitimate end of file. - ReplaceThis(new Complete(), thunk); - return ret; - } - } - UTIL_THROW_IF(memcmp(trailing.data(), kBZMagic, sizeof(kBZMagic)), BZException, "Trailing cruft is not another bzip2 stream"); - Open(&trailing[0], trailing.size()); - } - return ret; - case BZ_OK: - return ret; - default: - UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror); - } + void SetOutput(void *base, std::size_t amount) { + stream_.next_out = static_cast<char*>(base); + stream_.avail_out = std::min<std::size_t>(std::numeric_limits<unsigned int>::max(), amount); + } + + void SetInput(const void *base, std::size_t amount) { + stream_.next_in = const_cast<char*>(static_cast<const char*>(base)); + stream_.avail_in = amount; } + const bz_stream &Stream() const { return stream_; } + private: - void Open(void *already_data, std::size_t already_size) { - assert(!file_); - int bzerror = BZ_OK; - file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size); - switch (bzerror) { + void HandleError(int value) { + switch(value) { case BZ_OK: return; case BZ_CONFIG_ERROR: - UTIL_THROW(BZException, "Looks like bzip2 was miscompiled."); + UTIL_THROW(BZException, "bzip2 seems to be miscompiled."); case BZ_PARAM_ERROR: - UTIL_THROW(BZException, "Parameter error"); - case BZ_IO_ERROR: - UTIL_THROW(BZException, "IO error reading file"); + UTIL_THROW(BZException, "bzip2 Parameter error"); + case BZ_DATA_ERROR: + UTIL_THROW(BZException, "bzip2 detected a corrupt file"); + case BZ_DATA_ERROR_MAGIC: + UTIL_THROW(BZException, "bzip2 detected bad magic bytes. Perhaps this was not a bzip2 file after all?"); case BZ_MEM_ERROR: throw std::bad_alloc(); default: - UTIL_THROW(BZException, "Unknown bzip2 error code " << bzerror); - } - assert(file_); - } - - static void Close(BZFILE *&file) { - if (file == NULL) return; - int bzerror = BZ_OK; - BZ2_bzReadClose(&bzerror, file); - if (bzerror != BZ_OK) { - std::cerr << "bz2 readclose error number " << bzerror << std::endl; - abort(); + UTIL_THROW(BZException, "Unknown bzip2 error code " << value); } - file = NULL; } - scoped_FILE closer_; - BZFILE *file_; + bz_stream stream_; }; #endif // HAVE_BZLIB #ifdef HAVE_XZLIB -class XZip : public ReadBase { - private: - static const std::size_t kInputBuffer = 16384; +class XZip { public: - XZip(int fd, void *already_data, std::size_t already_size) - : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) { - if (!in_buffer_.get()) throw std::bad_alloc(); - assert(already_size < kInputBuffer); - if (already_size) { - memcpy(in_buffer_.get(), already_data, already_size); - stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get()); - stream_.avail_in = already_size; - stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size); - } else { - stream_.avail_in = 0; - } - stream_.allocator = NULL; - lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED); - switch (ret) { - case LZMA_OK: - break; - case LZMA_MEM_ERROR: - UTIL_THROW(ErrnoException, "xz open error"); - default: - UTIL_THROW(XZException, "xz error code " << ret); - } + XZip(const void *base, std::size_t amount) + : stream_(), action_(LZMA_RUN) { + memset(&stream_, 0, sizeof(stream_)); + SetInput(base, amount); + HandleError(lzma_stream_decoder(&stream_, UINT64_MAX, 0)); } ~XZip() { lzma_end(&stream_); } - std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { - if (amount == 0) return 0; - stream_.next_out = static_cast<uint8_t*>(to); + void SetOutput(void *base, std::size_t amount) { + stream_.next_out = static_cast<uint8_t*>(base); stream_.avail_out = amount; - do { - if (!stream_.avail_in) ReadInput(thunk); - lzma_ret status = lzma_code(&stream_, action_); - switch (status) { - case LZMA_OK: - break; - case LZMA_STREAM_END: - UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet."); - { - std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); - ReplaceThis(new Complete(), thunk); - return ret; - } - case LZMA_MEM_ERROR: - throw std::bad_alloc(); - case LZMA_FORMAT_ERROR: - UTIL_THROW(XZException, "xzlib says file format not recognized"); - case LZMA_OPTIONS_ERROR: - UTIL_THROW(XZException, "xzlib says unsupported compression options"); - case LZMA_DATA_ERROR: - UTIL_THROW(XZException, "xzlib says this file is corrupt"); - case LZMA_BUF_ERROR: - UTIL_THROW(XZException, "xzlib says unexpected end of input"); - default: - UTIL_THROW(XZException, "unrecognized xzlib error " << status); - } - } while (stream_.next_out == to); - return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); + } + + void SetInput(const void *base, std::size_t amount) { + stream_.next_in = static_cast<const uint8_t*>(base); + stream_.avail_in = amount; + if (!amount) action_ = LZMA_FINISH; + } + + const lzma_stream &Stream() const { return stream_; } + + bool Process() { + lzma_ret status = lzma_code(&stream_, action_); + if (status == LZMA_STREAM_END) return false; + HandleError(status); + return true; } private: - void ReadInput(ReadCompressed &thunk) { - assert(!stream_.avail_in); - stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get()); - stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); - if (!stream_.avail_in) action_ = LZMA_FINISH; - ReadCount(thunk) += stream_.avail_in; + void HandleError(lzma_ret value) { + switch (value) { + case LZMA_OK: + return; + case LZMA_MEM_ERROR: + throw std::bad_alloc(); + case LZMA_FORMAT_ERROR: + UTIL_THROW(XZException, "xzlib says file format not recognized"); + case LZMA_OPTIONS_ERROR: + UTIL_THROW(XZException, "xzlib says unsupported compression options"); + case LZMA_DATA_ERROR: + UTIL_THROW(XZException, "xzlib says this file is corrupt"); + case LZMA_BUF_ERROR: + UTIL_THROW(XZException, "xzlib says unexpected end of input"); + default: + UTIL_THROW(XZException, "unrecognized xzlib error " << value); + } } - scoped_fd file_; - scoped_malloc in_buffer_; lzma_stream stream_; - lzma_action action_; }; #endif // HAVE_XZLIB @@ -384,66 +344,68 @@ class IStreamReader : public ReadBase { }; enum MagicResult { - UNKNOWN, GZIP, BZIP, XZIP + UTIL_UNKNOWN, UTIL_GZIP, UTIL_BZIP, UTIL_XZIP }; -MagicResult DetectMagic(const void *from_void) { +MagicResult DetectMagic(const void *from_void, std::size_t length) { const uint8_t *header = static_cast<const uint8_t*>(from_void); - if (header[0] == 0x1f && header[1] == 0x8b) { - return GZIP; + if (length >= 2 && header[0] == 0x1f && header[1] == 0x8b) { + return UTIL_GZIP; } - if (!memcmp(header, kBZMagic, sizeof(kBZMagic))) { - return BZIP; + const uint8_t kBZMagic[3] = {'B', 'Z', 'h'}; + if (length >= sizeof(kBZMagic) && !memcmp(header, kBZMagic, sizeof(kBZMagic))) { + return UTIL_BZIP; } const uint8_t kXZMagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 }; - if (!memcmp(header, kXZMagic, sizeof(kXZMagic))) { - return XZIP; + if (length >= sizeof(kXZMagic) && !memcmp(header, kXZMagic, sizeof(kXZMagic))) { + return UTIL_XZIP; } - return UNKNOWN; + return UTIL_UNKNOWN; } -ReadBase *ReadFactory(int fd, uint64_t &raw_amount) { +ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, const std::size_t already_size, bool require_compressed) { scoped_fd hold(fd); - unsigned char header[ReadCompressed::kMagicSize]; - raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize); - if (!raw_amount) - return new Uncompressed(hold.release()); - if (raw_amount != ReadCompressed::kMagicSize) - return new UncompressedWithHeader(hold.release(), header, raw_amount); - switch (DetectMagic(header)) { - case GZIP: + std::string header(reinterpret_cast<const char*>(already_data), already_size); + if (header.size() < ReadCompressed::kMagicSize) { + std::size_t original = header.size(); + header.resize(ReadCompressed::kMagicSize); + std::size_t got = ReadOrEOF(fd, &header[original], ReadCompressed::kMagicSize - original); + raw_amount += got; + header.resize(original + got); + } + if (header.empty()) { + hold.release(); + return new Complete(); + } + switch (DetectMagic(&header[0], header.size())) { + case UTIL_GZIP: #ifdef HAVE_ZLIB - return new GZip(hold.release(), header, ReadCompressed::kMagicSize); + return new StreamCompressed<GZip>(hold.release(), header.data(), header.size()); #else UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in."); #endif - case BZIP: + case UTIL_BZIP: #ifdef HAVE_BZLIB - return new BZip(hold.release(), header, ReadCompressed::kMagicSize); + return new StreamCompressed<BZip>(hold.release(), &header[0], header.size()); #else - UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in."); + UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZh), but bzip support was not compiled in."); #endif - case XZIP: + case UTIL_XZIP: #ifdef HAVE_XZLIB - return new XZip(hold.release(), header, ReadCompressed::kMagicSize); + return new StreamCompressed<XZip>(hold.release(), header.data(), header.size()); #else UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in."); #endif - case UNKNOWN: - break; - } - try { - SeekOrThrow(fd, 0); - } catch (const util::ErrnoException &e) { - return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize); + default: + UTIL_THROW_IF(require_compressed, CompressedException, "Uncompressed data detected after a compresssed file. This could be supported but usually indicates an error."); + return new UncompressedWithHeader(hold.release(), header.data(), header.size()); } - return new Uncompressed(hold.release()); } } // namespace bool ReadCompressed::DetectCompressedMagic(const void *from_void) { - return DetectMagic(from_void) != UNKNOWN; + return DetectMagic(from_void, kMagicSize) != UTIL_UNKNOWN; } ReadCompressed::ReadCompressed(int fd) { @@ -459,8 +421,9 @@ ReadCompressed::ReadCompressed() {} ReadCompressed::~ReadCompressed() {} void ReadCompressed::Reset(int fd) { + raw_amount_ = 0; internal_.reset(); - internal_.reset(ReadFactory(fd, raw_amount_)); + internal_.reset(ReadFactory(fd, raw_amount_, NULL, 0, false)); } void ReadCompressed::Reset(std::istream &in) { diff --git a/util/read_compressed.hh b/util/read_compressed.hh index 8b54c9e..763e6bb 100644 --- a/util/read_compressed.hh +++ b/util/read_compressed.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_READ_COMPRESSED__ -#define UTIL_READ_COMPRESSED__ +#ifndef UTIL_READ_COMPRESSED_H +#define UTIL_READ_COMPRESSED_H #include "util/exception.hh" #include "util/scoped.hh" @@ -78,4 +78,4 @@ class ReadCompressed { } // namespace util -#endif // UTIL_READ_COMPRESSED__ +#endif // UTIL_READ_COMPRESSED_H diff --git a/util/read_compressed_test.cc b/util/read_compressed_test.cc index 9cb4a4b..301e8f4 100644 --- a/util/read_compressed_test.cc +++ b/util/read_compressed_test.cc @@ -12,6 +12,23 @@ #include <stdlib.h> +#if defined __MINGW32__ +#include <time.h> +#include <fcntl.h> + +#if !defined mkstemp +// TODO insecure +int mkstemp(char * stemplate) +{ + char *filename = mktemp(stemplate); + if (filename == NULL) + return -1; + return open(filename, O_RDWR | O_CREAT, 0600); +} +#endif + +#endif // defined + namespace util { namespace { @@ -96,6 +113,11 @@ BOOST_AUTO_TEST_CASE(ReadXZ) { } #endif +#ifdef HAVE_ZLIB +BOOST_AUTO_TEST_CASE(AppendGZ) { +} +#endif + BOOST_AUTO_TEST_CASE(IStream) { std::string name(WriteRandom()); std::fstream stream(name.c_str(), std::ios::in); diff --git a/util/scoped.hh b/util/scoped.hh index b642d06..ae70b6b 100644 --- a/util/scoped.hh +++ b/util/scoped.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_SCOPED__ -#define UTIL_SCOPED__ +#ifndef UTIL_SCOPED_H +#define UTIL_SCOPED_H /* Other scoped objects in the style of scoped_ptr. */ #include "util/exception.hh" @@ -101,4 +101,4 @@ template <class T> class scoped_ptr { } // namespace util -#endif // UTIL_SCOPED__ +#endif // UTIL_SCOPED_H diff --git a/util/sized_iterator.hh b/util/sized_iterator.hh index dce8f22..75f6886 100644 --- a/util/sized_iterator.hh +++ b/util/sized_iterator.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_SIZED_ITERATOR__ -#define UTIL_SIZED_ITERATOR__ +#ifndef UTIL_SIZED_ITERATOR_H +#define UTIL_SIZED_ITERATOR_H #include "util/proxy_iterator.hh" @@ -36,7 +36,7 @@ class SizedInnerIterator { void *Data() { return ptr_; } std::size_t EntrySize() const { return size_; } - friend inline void swap(SizedInnerIterator &first, SizedInnerIterator &second) { + friend void swap(SizedInnerIterator &first, SizedInnerIterator &second) { std::swap(first.ptr_, second.ptr_); std::swap(first.size_, second.size_); } @@ -69,17 +69,7 @@ class SizedProxy { const void *Data() const { return inner_.Data(); } void *Data() { return inner_.Data(); } - /** - // TODO: this (deep) swap was recently added. why? if any std heap sort etc - // algs are using swap, that's going to be worse performance than using - // =. i'm not sure why we *want* a deep swap. if C++11 compilers are - // choosing between move constructor and swap, then we'd better implement a - // (deep) move constructor. it may also be that this is moot since i made - // ProxyIterator a reference and added a shallow ProxyIterator swap? (I - // need Ken or someone competent to judge whether that's correct also. - - // let me know at graehl@gmail.com - */ - friend void swap(SizedProxy &first, SizedProxy &second) { + friend void swap(SizedProxy first, SizedProxy second) { std::swap_ranges( static_cast<char*>(first.inner_.Data()), static_cast<char*>(first.inner_.Data()) + first.inner_.EntrySize(), @@ -127,4 +117,4 @@ template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public }; } // namespace util -#endif // UTIL_SIZED_ITERATOR__ +#endif // UTIL_SIZED_ITERATOR_H diff --git a/util/sized_iterator_test.cc b/util/sized_iterator_test.cc new file mode 100644 index 0000000..c36bcb2 --- /dev/null +++ b/util/sized_iterator_test.cc @@ -0,0 +1,16 @@ +#include "util/sized_iterator.hh" + +#define BOOST_TEST_MODULE SizedIteratorTest +#include <boost/test/unit_test.hpp> + +namespace util { namespace { + +BOOST_AUTO_TEST_CASE(swap_works) { + char str[2] = { 0, 1 }; + SizedProxy first(str, 1), second(str + 1, 1); + swap(first, second); + BOOST_CHECK_EQUAL(1, str[0]); + BOOST_CHECK_EQUAL(0, str[1]); +} + +}} // namespace anonymous util diff --git a/util/sorted_uniform.hh b/util/sorted_uniform.hh index 7700d9e..a7dba5e 100644 --- a/util/sorted_uniform.hh +++ b/util/sorted_uniform.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_SORTED_UNIFORM__ -#define UTIL_SORTED_UNIFORM__ +#ifndef UTIL_SORTED_UNIFORM_H +#define UTIL_SORTED_UNIFORM_H #include <algorithm> #include <cstddef> @@ -124,4 +124,4 @@ template <class Iterator, class Accessor> Iterator BinaryBelow( } // namespace util -#endif // UTIL_SORTED_UNIFORM__ +#endif // UTIL_SORTED_UNIFORM_H diff --git a/util/stream/block.hh b/util/stream/block.hh index 11aa991..b9f1d49 100644 --- a/util/stream/block.hh +++ b/util/stream/block.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_BLOCK__ -#define UTIL_STREAM_BLOCK__ +#ifndef UTIL_STREAM_BLOCK_H +#define UTIL_STREAM_BLOCK_H #include <cstddef> #include <stdint.h> @@ -40,4 +40,4 @@ class Block { } // namespace stream } // namespace util -#endif // UTIL_STREAM_BLOCK__ +#endif // UTIL_STREAM_BLOCK_H diff --git a/util/stream/chain.hh b/util/stream/chain.hh index 0cc83a8..bda653a 100644 --- a/util/stream/chain.hh +++ b/util/stream/chain.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_CHAIN__ -#define UTIL_STREAM_CHAIN__ +#ifndef UTIL_STREAM_CHAIN_H +#define UTIL_STREAM_CHAIN_H #include "util/stream/block.hh" #include "util/stream/config.hh" @@ -195,4 +195,4 @@ inline Chain &operator>>(Chain &chain, Link &link) { } // namespace stream } // namespace util -#endif // UTIL_STREAM_CHAIN__ +#endif // UTIL_STREAM_CHAIN_H diff --git a/util/stream/config.hh b/util/stream/config.hh index 1eeb3a8..052a05b 100644 --- a/util/stream/config.hh +++ b/util/stream/config.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_CONFIG__ -#define UTIL_STREAM_CONFIG__ +#ifndef UTIL_STREAM_CONFIG_H +#define UTIL_STREAM_CONFIG_H #include <cstddef> #include <string> @@ -29,4 +29,4 @@ struct SortConfig { }; }} // namespaces -#endif // UTIL_STREAM_CONFIG__ +#endif // UTIL_STREAM_CONFIG_H diff --git a/util/stream/io.hh b/util/stream/io.hh index 934b6b3..65918a9 100644 --- a/util/stream/io.hh +++ b/util/stream/io.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_IO__ -#define UTIL_STREAM_IO__ +#ifndef UTIL_STREAM_IO_H +#define UTIL_STREAM_IO_H #include "util/exception.hh" #include "util/file.hh" @@ -73,4 +73,4 @@ class FileBuffer { } // namespace stream } // namespace util -#endif // UTIL_STREAM_IO__ +#endif // UTIL_STREAM_IO_H diff --git a/util/stream/line_input.hh b/util/stream/line_input.hh index 86db1dd..a870a66 100644 --- a/util/stream/line_input.hh +++ b/util/stream/line_input.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_LINE_INPUT__ -#define UTIL_STREAM_LINE_INPUT__ +#ifndef UTIL_STREAM_LINE_INPUT_H +#define UTIL_STREAM_LINE_INPUT_H namespace util {namespace stream { class ChainPosition; @@ -19,4 +19,4 @@ class LineInput { }; }} // namespaces -#endif // UTIL_STREAM_LINE_INPUT__ +#endif // UTIL_STREAM_LINE_INPUT_H diff --git a/util/stream/multi_progress.hh b/util/stream/multi_progress.hh index c4dd45a..82e698a 100644 --- a/util/stream/multi_progress.hh +++ b/util/stream/multi_progress.hh @@ -1,6 +1,6 @@ /* Progress bar suitable for chains of workers */ -#ifndef UTIL_MULTI_PROGRESS__ -#define UTIL_MULTI_PROGRESS__ +#ifndef UTIL_STREAM_MULTI_PROGRESS_H +#define UTIL_STREAM_MULTI_PROGRESS_H #include <boost/thread/mutex.hpp> @@ -87,4 +87,4 @@ class WorkerProgress { }} // namespaces -#endif // UTIL_MULTI_PROGRESS__ +#endif // UTIL_STREAM_MULTI_PROGRESS_H diff --git a/util/stream/sort.hh b/util/stream/sort.hh index 16aa6a0..e18c6ae 100644 --- a/util/stream/sort.hh +++ b/util/stream/sort.hh @@ -15,8 +15,8 @@ * sort. Use a hash table for that. */ -#ifndef UTIL_STREAM_SORT__ -#define UTIL_STREAM_SORT__ +#ifndef UTIL_STREAM_SORT_H +#define UTIL_STREAM_SORT_H #include "util/stream/chain.hh" #include "util/stream/config.hh" @@ -545,4 +545,4 @@ template <class Compare, class Combine> uint64_t BlockingSort(Chain &chain, cons } // namespace stream } // namespace util -#endif // UTIL_STREAM_SORT__ +#endif // UTIL_STREAM_SORT_H diff --git a/util/stream/stream.hh b/util/stream/stream.hh index 6ff45b8..5cd1bdd 100644 --- a/util/stream/stream.hh +++ b/util/stream/stream.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_STREAM__ -#define UTIL_STREAM_STREAM__ +#ifndef UTIL_STREAM_STREAM_H +#define UTIL_STREAM_STREAM_H #include "util/stream/chain.hh" @@ -71,4 +71,4 @@ inline Chain &operator>>(Chain &chain, Stream &stream) { } // namespace stream } // namespace util -#endif // UTIL_STREAM_STREAM__ +#endif // UTIL_STREAM_STREAM_H diff --git a/util/stream/timer.hh b/util/stream/timer.hh index 7e1a588..06488a1 100644 --- a/util/stream/timer.hh +++ b/util/stream/timer.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STREAM_TIMER__ -#define UTIL_STREAM_TIMER__ +#ifndef UTIL_STREAM_TIMER_H +#define UTIL_STREAM_TIMER_H // Sorry Jon, this was adding library dependencies in Moses and people complained. @@ -13,4 +13,4 @@ #define UTIL_TIMER(str) //#endif -#endif // UTIL_STREAM_TIMER__ +#endif // UTIL_STREAM_TIMER_H diff --git a/util/string_piece.hh b/util/string_piece.hh index 84431db..114e254 100644 --- a/util/string_piece.hh +++ b/util/string_piece.hh @@ -45,8 +45,8 @@ // conversions from "const char*" to "string" and back again. // -#ifndef BASE_STRING_PIECE_H__ -#define BASE_STRING_PIECE_H__ +#ifndef UTIL_STRING_PIECE_H +#define UTIL_STRING_PIECE_H #include "util/have.hh" @@ -267,4 +267,4 @@ U_NAMESPACE_END using U_NAMESPACE_QUALIFIER StringPiece; #endif -#endif // BASE_STRING_PIECE_H__ +#endif // UTIL_STRING_PIECE_H diff --git a/util/string_piece_hash.hh b/util/string_piece_hash.hh index f206b1d..5c8c525 100644 --- a/util/string_piece_hash.hh +++ b/util/string_piece_hash.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_STRING_PIECE_HASH__ -#define UTIL_STRING_PIECE_HASH__ +#ifndef UTIL_STRING_PIECE_HASH_H +#define UTIL_STRING_PIECE_HASH_H #include "util/string_piece.hh" @@ -40,4 +40,4 @@ template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece #endif } -#endif // UTIL_STRING_PIECE_HASH__ +#endif // UTIL_STRING_PIECE_HASH_H diff --git a/util/thread_pool.hh b/util/thread_pool.hh index 84e257e..d1a883a 100644 --- a/util/thread_pool.hh +++ b/util/thread_pool.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_THREAD_POOL__ -#define UTIL_THREAD_POOL__ +#ifndef UTIL_THREAD_POOL_H +#define UTIL_THREAD_POOL_H #include "util/pcqueue.hh" @@ -18,8 +18,8 @@ template <class HandlerT> class Worker : boost::noncopyable { typedef HandlerT Handler; typedef typename Handler::Request Request; - template <class Construct> Worker(PCQueue<Request> &in, Construct &construct, Request &poison) - : in_(in), handler_(construct), thread_(boost::ref(*this)), poison_(poison) {} + template <class Construct> Worker(PCQueue<Request> &in, Construct &construct, const Request &poison) + : in_(in), handler_(construct), poison_(poison), thread_(boost::ref(*this)) {} // Only call from thread. void operator()() { @@ -30,7 +30,7 @@ template <class HandlerT> class Worker : boost::noncopyable { try { (*handler_)(request); } - catch(std::exception &e) { + catch(const std::exception &e) { std::cerr << "Handler threw " << e.what() << std::endl; abort(); } @@ -49,10 +49,10 @@ template <class HandlerT> class Worker : boost::noncopyable { PCQueue<Request> &in_; boost::optional<Handler> handler_; + + const Request poison_; boost::thread thread_; - - Request poison_; }; template <class HandlerT> class ThreadPool : boost::noncopyable { @@ -92,4 +92,4 @@ template <class HandlerT> class ThreadPool : boost::noncopyable { } // namespace util -#endif // UTIL_THREAD_POOL__ +#endif // UTIL_THREAD_POOL_H diff --git a/util/tokenize_piece.hh b/util/tokenize_piece.hh index a588c3f..908c8da 100644 --- a/util/tokenize_piece.hh +++ b/util/tokenize_piece.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_TOKENIZE_PIECE__ -#define UTIL_TOKENIZE_PIECE__ +#ifndef UTIL_TOKENIZE_PIECE_H +#define UTIL_TOKENIZE_PIECE_H #include "util/exception.hh" #include "util/string_piece.hh" @@ -7,7 +7,8 @@ #include <boost/iterator/iterator_facade.hpp> #include <algorithm> -#include <iostream> + +#include <string.h> namespace util { @@ -58,6 +59,30 @@ class AnyCharacter { StringPiece chars_; }; +class BoolCharacter { + public: + BoolCharacter() {} + + explicit BoolCharacter(const bool *delimiter) { delimiter_ = delimiter; } + + StringPiece Find(const StringPiece &in) const { + for (const char *i = in.data(); i != in.data() + in.size(); ++i) { + if (delimiter_[static_cast<unsigned char>(*i)]) return StringPiece(i, 1); + } + return StringPiece(in.data() + in.size(), 0); + } + + template <unsigned Length> static void Build(const char (&characters)[Length], bool (&out)[256]) { + memset(out, 0, sizeof(out)); + for (const char *i = characters; i != characters + Length; ++i) { + out[static_cast<unsigned char>(*i)] = true; + } + } + + private: + const bool *delimiter_; +}; + class AnyCharacterLast { public: AnyCharacterLast() {} @@ -123,4 +148,4 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it } // namespace util -#endif // UTIL_TOKENIZE_PIECE__ +#endif // UTIL_TOKENIZE_PIECE_H diff --git a/util/usage.cc b/util/usage.cc index 25fc097..5912f90 100644 --- a/util/usage.cc +++ b/util/usage.cc @@ -10,61 +10,119 @@ #include <string.h> #include <ctype.h> -#if !defined(_WIN32) && !defined(_WIN64) +#include <time.h> +#if defined(_WIN32) || defined(_WIN64) +// This code lifted from physmem.c in gnulib. See the copyright statement +// below. +# define WIN32_LEAN_AND_MEAN +# include <windows.h> +/* MEMORYSTATUSEX is missing from older windows headers, so define + a local replacement. */ +typedef struct +{ + DWORD dwLength; + DWORD dwMemoryLoad; + DWORDLONG ullTotalPhys; + DWORDLONG ullAvailPhys; + DWORDLONG ullTotalPageFile; + DWORDLONG ullAvailPageFile; + DWORDLONG ullTotalVirtual; + DWORDLONG ullAvailVirtual; + DWORDLONG ullAvailExtendedVirtual; +} lMEMORYSTATUSEX; +// Is this really supposed to be defined like this? +typedef int WINBOOL; +typedef WINBOOL (WINAPI *PFN_MS_EX) (lMEMORYSTATUSEX*); +#else #include <sys/resource.h> #include <sys/time.h> -#include <time.h> #include <unistd.h> #endif -namespace util { +#if defined(__MACH__) || defined(__FreeBSD__) || defined(__APPLE__) +#include <sys/types.h> +#include <sys/sysctl.h> +#endif -#if !defined(_WIN32) && !defined(_WIN64) +namespace util { namespace { -// On Mac OS X, clock_gettime is not implemented. -// CLOCK_MONOTONIC is not defined either. -#ifdef __MACH__ -#define CLOCK_MONOTONIC 0 - -int clock_gettime(int clk_id, struct timespec *tp) { +#if defined(__MACH__) +typedef struct timeval Wall; +Wall GetWall() { struct timeval tv; gettimeofday(&tv, NULL); - tp->tv_sec = tv.tv_sec; - tp->tv_nsec = tv.tv_usec * 1000; - return 0; + return tv; } -#endif // __MACH__ - -float FloatSec(const struct timeval &tv) { - return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000.0); +#elif defined(_WIN32) || defined(_WIN64) +typedef time_t Wall; +Wall GetWall() { + return time(NULL); } -float FloatSec(const struct timespec &tv) { - return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_nsec) / 1000000000.0); +#else +typedef struct timespec Wall; +Wall GetWall() { + Wall ret; + clock_gettime(CLOCK_MONOTONIC, &ret); + return ret; } +#endif -const char *SkipSpaces(const char *at) { - for (; *at == ' ' || *at == '\t'; ++at) {} - return at; +// Some of these functions are only used on some platforms. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-function" +#endif +// These all assume first > second +double Subtract(time_t first, time_t second) { + return difftime(first, second); +} +double DoubleSec(time_t tv) { + return static_cast<double>(tv); +} +#if !defined(_WIN32) && !defined(_WIN64) +double Subtract(const struct timeval &first, const struct timeval &second) { + return static_cast<double>(first.tv_sec - second.tv_sec) + static_cast<double>(first.tv_usec - second.tv_usec) / 1000000.0; } +double Subtract(const struct timespec &first, const struct timespec &second) { + return static_cast<double>(first.tv_sec - second.tv_sec) + static_cast<double>(first.tv_nsec - second.tv_nsec) / 1000000000.0; +} +double DoubleSec(const struct timeval &tv) { + return static_cast<double>(tv.tv_sec) + (static_cast<double>(tv.tv_usec) / 1000000.0); +} +double DoubleSec(const struct timespec &tv) { + return static_cast<double>(tv.tv_sec) + (static_cast<double>(tv.tv_nsec) / 1000000000.0); +} +#endif +#ifdef __clang__ +#pragma clang diagnostic pop +#endif class RecordStart { public: RecordStart() { - clock_gettime(CLOCK_MONOTONIC, &started_); + started_ = GetWall(); } - const struct timespec &Started() const { + const Wall &Started() const { return started_; } private: - struct timespec started_; + Wall started_; }; const RecordStart kRecordStart; + +const char *SkipSpaces(const char *at) { + for (; *at == ' ' || *at == '\t'; ++at) {} + return at; +} } // namespace -#endif + +double WallTime() { + return Subtract(GetWall(), kRecordStart.Started()); +} void PrintUsage(std::ostream &out) { #if !defined(_WIN32) && !defined(_WIN64) @@ -88,27 +146,84 @@ void PrintUsage(std::ostream &out) { return; } out << "RSSMax:" << usage.ru_maxrss << " kB" << '\t'; - out << "user:" << FloatSec(usage.ru_utime) << "\tsys:" << FloatSec(usage.ru_stime) << '\t'; - out << "CPU:" << (FloatSec(usage.ru_utime) + FloatSec(usage.ru_stime)); - - struct timespec current; - clock_gettime(CLOCK_MONOTONIC, ¤t); - out << "\treal:" << (FloatSec(current) - FloatSec(kRecordStart.Started())) << '\n'; + out << "user:" << DoubleSec(usage.ru_utime) << "\tsys:" << DoubleSec(usage.ru_stime) << '\t'; + out << "CPU:" << (DoubleSec(usage.ru_utime) + DoubleSec(usage.ru_stime)); + out << '\t'; #endif + + out << "real:" << WallTime() << '\n'; } +/* Adapted from physmem.c in gnulib 831b84c59ef413c57a36b67344467d66a8a2ba70 */ +/* Calculate the size of physical memory. + + Copyright (C) 2000-2001, 2003, 2005-2006, 2009-2013 Free Software + Foundation, Inc. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Lesser General Public License as published by + the Free Software Foundation; either version 2.1 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. */ + +/* Written by Paul Eggert. */ uint64_t GuessPhysicalMemory() { +#if defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE) + { + long pages = sysconf(_SC_PHYS_PAGES); + long page_size = sysconf(_SC_PAGESIZE); + if (pages != -1 && page_size != -1) + return static_cast<uint64_t>(pages) * static_cast<uint64_t>(page_size); + } +#endif +#ifdef HW_PHYSMEM + { /* This works on *bsd and darwin. */ + unsigned int physmem; + size_t len = sizeof physmem; + static int mib[2] = { CTL_HW, HW_PHYSMEM }; + + if (sysctl (mib, sizeof(mib) / sizeof(mib[0]), &physmem, &len, NULL, 0) == 0 + && len == sizeof (physmem)) + return static_cast<uint64_t>(physmem); + } +#endif + #if defined(_WIN32) || defined(_WIN64) - return 0; -#elif defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE) - long pages = sysconf(_SC_PHYS_PAGES); - if (pages == -1) return 0; - long page_size = sysconf(_SC_PAGESIZE); - if (page_size == -1) return 0; - return static_cast<uint64_t>(pages) * static_cast<uint64_t>(page_size); -#else - return 0; + { /* this works on windows */ + PFN_MS_EX pfnex; + HMODULE h = GetModuleHandle ("kernel32.dll"); + + if (!h) + return 0; + + /* Use GlobalMemoryStatusEx if available. */ + if ((pfnex = (PFN_MS_EX) GetProcAddress (h, "GlobalMemoryStatusEx"))) + { + lMEMORYSTATUSEX lms_ex; + lms_ex.dwLength = sizeof lms_ex; + if (!pfnex (&lms_ex)) + return 0; + return lms_ex.ullTotalPhys; + } + + /* Fall back to GlobalMemoryStatus which is always available. + but returns wrong results for physical memory > 4GB. */ + else + { + MEMORYSTATUS ms; + GlobalMemoryStatus (&ms); + return ms.dwTotalPhys; + } + } #endif + return 0; } namespace { diff --git a/util/usage.hh b/util/usage.hh index e19eda7..e578b0a 100644 --- a/util/usage.hh +++ b/util/usage.hh @@ -1,5 +1,5 @@ -#ifndef UTIL_USAGE__ -#define UTIL_USAGE__ +#ifndef UTIL_USAGE_H +#define UTIL_USAGE_H #include <cstddef> #include <iosfwd> #include <string> @@ -7,6 +7,9 @@ #include <stdint.h> namespace util { +// Time in seconds since process started. Zero on unsupported platforms. +double WallTime(); + void PrintUsage(std::ostream &to); // Determine how much physical memory there is. Return 0 on failure. @@ -15,4 +18,4 @@ uint64_t GuessPhysicalMemory(); // Parse a size like unix sort. Sadly, this means the default multiplier is K. uint64_t ParseSize(const std::string &arg); } // namespace util -#endif // UTIL_USAGE__ +#endif // UTIL_USAGE_H |