Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2014-04-08 03:17:11 +0400
committerKenneth Heafield <github@kheafield.com>2014-04-08 03:17:11 +0400
commit8d41ee98530e2941e9bb50f9a62e09afdb35f3bf (patch)
tree7f001935207178a286acb9c66230f010ca527c08
parent7e4d1bb7893021f21ae263e7d474dabc469e5d5d (diff)
parent38d40aa509af2329eda2aff65191f4606598f516 (diff)
Merge branch 'master' into pruning2
Conflicts: lm/builder/interpolate.cc lm/builder/interpolate.hh lm/builder/lmplz_main.cc lm/builder/pipeline.cc lm/builder/pipeline.hh
-rw-r--r--.gitignore2
-rw-r--r--Jamroot4
-rw-r--r--LICENSE1
-rwxr-xr-xcompile_query_only.sh4
-rw-r--r--jam-files/sanity.jam10
-rw-r--r--lm/bhiksha.cc15
-rw-r--r--lm/bhiksha.hh15
-rw-r--r--lm/binary_format.cc248
-rw-r--r--lm/binary_format.hh118
-rw-r--r--lm/blank.hh6
-rw-r--r--lm/build_binary_main.cc7
-rw-r--r--lm/builder/adjust_counts.hh6
-rw-r--r--lm/builder/corpus_count.cc32
-rw-r--r--lm/builder/corpus_count.hh11
-rw-r--r--lm/builder/corpus_count_test.cc2
-rw-r--r--lm/builder/discount.hh6
-rw-r--r--lm/builder/header_info.hh4
-rw-r--r--lm/builder/initial_probabilities.hh6
-rw-r--r--lm/builder/interpolate.cc9
-rw-r--r--lm/builder/interpolate.hh11
-rw-r--r--lm/builder/joint_order.hh6
-rw-r--r--lm/builder/lmplz_main.cc35
-rw-r--r--lm/builder/multi_stream.hh84
-rw-r--r--lm/builder/ngram.hh6
-rw-r--r--lm/builder/ngram_stream.hh6
-rw-r--r--lm/builder/pipeline.cc13
-rw-r--r--lm/builder/pipeline.hh25
-rw-r--r--lm/builder/print.hh6
-rw-r--r--lm/builder/sort.hh10
-rw-r--r--lm/config.hh6
-rw-r--r--lm/enumerate_vocab.hh6
-rw-r--r--lm/facade.hh12
-rw-r--r--lm/filter/arpa_io.hh7
-rw-r--r--lm/filter/count_io.hh18
-rw-r--r--lm/filter/filter_main.cc155
-rw-r--r--lm/filter/format.hh6
-rw-r--r--lm/filter/phrase.hh6
-rw-r--r--lm/filter/thread.hh6
-rw-r--r--lm/filter/vocab.cc1
-rw-r--r--lm/filter/vocab.hh6
-rw-r--r--lm/filter/wrapper.hh16
-rw-r--r--lm/left.hh6
-rw-r--r--lm/lm_exception.hh4
-rw-r--r--lm/max_order.hh8
-rw-r--r--lm/model.cc84
-rw-r--r--lm/model.hh18
-rw-r--r--lm/model_test.cc8
-rw-r--r--lm/model_type.hh6
-rw-r--r--lm/ngram_query.hh54
-rw-r--r--lm/partial.hh6
-rw-r--r--lm/quantize.cc12
-rw-r--r--lm/quantize.hh11
-rw-r--r--lm/query_main.cc48
-rw-r--r--lm/read_arpa.cc2
-rw-r--r--lm/read_arpa.hh6
-rw-r--r--lm/return.hh6
-rw-r--r--lm/search_hashed.cc33
-rw-r--r--lm/search_hashed.hh18
-rw-r--r--lm/search_trie.cc35
-rw-r--r--lm/search_trie.hh24
-rw-r--r--lm/sizes.hh6
-rw-r--r--lm/state.hh10
-rw-r--r--lm/trie.hh15
-rw-r--r--lm/trie_sort.cc4
-rw-r--r--lm/trie_sort.hh6
-rw-r--r--lm/value.hh6
-rw-r--r--lm/value_build.hh6
-rw-r--r--lm/virtual_interface.hh12
-rw-r--r--lm/vocab.cc28
-rw-r--r--lm/vocab.hh18
-rw-r--r--lm/weights.hh6
-rw-r--r--lm/word_index.hh4
-rw-r--r--lm/wrappers/nplm.cc8
-rw-r--r--lm/wrappers/nplm.hh14
-rw-r--r--python/kenlm.cpp366
-rw-r--r--python/kenlm.pxd4
-rw-r--r--python/kenlm.pyx8
-rw-r--r--util/Jamfile19
-rw-r--r--util/bit_packing.hh6
-rw-r--r--util/cat_compressed_main.cc47
-rw-r--r--util/ersatz_progress.hh6
-rw-r--r--util/exception.cc8
-rw-r--r--util/exception.hh16
-rw-r--r--util/fake_ofstream.hh13
-rw-r--r--util/file.cc65
-rw-r--r--util/file.hh11
-rw-r--r--util/file_piece.cc7
-rw-r--r--util/file_piece.hh29
-rw-r--r--util/file_piece_test.cc12
-rw-r--r--util/fixed_array.hh94
-rw-r--r--util/getopt.hh6
-rw-r--r--util/have.hh6
-rw-r--r--util/joint_sort.hh35
-rw-r--r--util/joint_sort_test.cc12
-rw-r--r--util/mmap.cc69
-rw-r--r--util/mmap.hh88
-rw-r--r--util/multi_intersection.hh8
-rw-r--r--util/murmur_hash.cc7
-rw-r--r--util/murmur_hash.hh10
-rw-r--r--util/parallel_read.cc69
-rw-r--r--util/parallel_read.hh16
-rw-r--r--util/pcqueue.hh64
-rw-r--r--util/pcqueue_test.cc20
-rw-r--r--util/pool.hh6
-rw-r--r--util/probing_hash_table.hh91
-rw-r--r--util/proxy_iterator.hh18
-rw-r--r--util/read_compressed.cc407
-rw-r--r--util/read_compressed.hh6
-rw-r--r--util/read_compressed_test.cc22
-rw-r--r--util/scoped.hh6
-rw-r--r--util/sized_iterator.hh20
-rw-r--r--util/sized_iterator_test.cc16
-rw-r--r--util/sorted_uniform.hh6
-rw-r--r--util/stream/block.hh6
-rw-r--r--util/stream/chain.hh6
-rw-r--r--util/stream/config.hh6
-rw-r--r--util/stream/io.hh6
-rw-r--r--util/stream/line_input.hh6
-rw-r--r--util/stream/multi_progress.hh6
-rw-r--r--util/stream/sort.hh6
-rw-r--r--util/stream/stream.hh6
-rw-r--r--util/stream/timer.hh6
-rw-r--r--util/string_piece.hh6
-rw-r--r--util/string_piece_hash.hh6
-rw-r--r--util/thread_pool.hh16
-rw-r--r--util/tokenize_piece.hh33
-rw-r--r--util/usage.cc197
-rw-r--r--util/usage.hh9
128 files changed, 2183 insertions, 1342 deletions
diff --git a/.gitignore b/.gitignore
index b0067e3..c11f7d1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -19,5 +19,5 @@ util/read_compressed_test
util/sorted_uniform_test
previous.sh
jam-files/bjam
-jam-files/engine/bin.linuxx86/
+jam-files/engine/bin.*/
jam-files/engine/bootstrap/
diff --git a/Jamroot b/Jamroot
index 1ff09ba..54c7a55 100644
--- a/Jamroot
+++ b/Jamroot
@@ -52,7 +52,7 @@ include $(TOP)/jam-files/sanity.jam ;
boost 103600 ;
project : requirements $(requirements) <include>. ;
-project : default-build <threading>multi <warnings>on <variant>release ;
+project : default-build <threading>multi <warnings>on <variant>release <link>static ;
external-lib z ;
@@ -61,7 +61,7 @@ build-project util ;
lib kenlm : lm//kenlm ;
-install-bin-libs lm//programs kenlm ;
+install-bin-libs lm//programs util//programs kenlm ;
install-headers headers : [ glob-tree *.hh : dist include ] : . ;
alias install : prefix-bin prefix-lib prefix-include ;
explicit headers ;
diff --git a/LICENSE b/LICENSE
index 9e2556e..e88a7e2 100644
--- a/LICENSE
+++ b/LICENSE
@@ -2,6 +2,7 @@ Most of the code here is licensed under the LGPL. There are exceptions that
have their own licenses, listed below. See comments in those files for more
details.
+util/getopt.* is getopt for Windows
util/murmur_hash.cc
util/string_piece.hh and util/string_piece.cc
util/double-conversion/LICENSE covers util/double-conversion except Jamfile
diff --git a/compile_query_only.sh b/compile_query_only.sh
index 7c6d3f6..7a82f49 100755
--- a/compile_query_only.sh
+++ b/compile_query_only.sh
@@ -30,5 +30,5 @@ mkdir -p bin
if [ "$(uname)" != Darwin ]; then
CXXFLAGS="$CXXFLAGS -lrt"
fi
-$CXX lm/build_binary_main.cc $objects -o bin/build_binary $CXXFLAGS
-$CXX lm/query_main.cc $objects -o bin/query $CXXFLAGS
+$CXX lm/build_binary_main.cc $objects -o bin/build_binary $CXXFLAGS $LDFLAGS
+$CXX lm/query_main.cc $objects -o bin/query $CXXFLAGS $LDFLAGS
diff --git a/jam-files/sanity.jam b/jam-files/sanity.jam
index 7f9c45d..bc07945 100644
--- a/jam-files/sanity.jam
+++ b/jam-files/sanity.jam
@@ -140,10 +140,9 @@ rule boost-lib ( name macro : deps * ) {
}
if $(boost-auto-shared) = "<link>shared" {
- alias boost_$(name) : inner_boost_$(name) : <link>shared ;
- requirements += <define>BOOST_$(macro) ;
+ alias boost_$(name) : inner_boost_$(name) : <link>shared : : <define>BOOST_$(macro) ;
} else {
- alias boost_$(name) : inner_boost_$(name) : <link>static ;
+ alias boost_$(name) : inner_boost_$(name) : : : <link>shared:<define>BOOST_$(macro) ;
}
}
@@ -171,9 +170,10 @@ rule boost ( min-version ) {
boost-auto-shared = [ auto-shared "boost_program_options"$(boost-lib-version) : $(L-boost-search) ] ;
#See tools/build/v2/contrib/boost.jam in a boost distribution for a table of macros to define.
+ boost-lib exception EXCEPTION_DYN_LINK ;
boost-lib system SYSTEM_DYN_LINK ;
- boost-lib thread THREAD_DYN_DLL : boost_system ;
- boost-lib program_options PROGRAM_OPTIONS_DYN_LINK ;
+ boost-lib thread THREAD_DYN_DLL : boost_system boost_exception ;
+ boost-lib program_options PROGRAM_OPTIONS_DYN_LINK : boost_exception ;
boost-lib unit_test_framework TEST_DYN_LINK ;
boost-lib iostreams IOSTREAMS_DYN_LINK ;
boost-lib filesystem FILE_SYSTEM_DYN_LINK ;
diff --git a/lm/bhiksha.cc b/lm/bhiksha.cc
index 088ea98..c8a18df 100644
--- a/lm/bhiksha.cc
+++ b/lm/bhiksha.cc
@@ -1,4 +1,6 @@
#include "lm/bhiksha.hh"
+
+#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "util/file.hh"
#include "util/exception.hh"
@@ -15,11 +17,11 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_
const uint8_t kArrayBhikshaVersion = 0;
// TODO: put this in binary file header instead when I change the binary file format again.
-void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) {
- uint8_t version;
- uint8_t configured_bits;
- util::ReadOrThrow(fd, &version, 1);
- util::ReadOrThrow(fd, &configured_bits, 1);
+void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) {
+ uint8_t buffer[2];
+ file.ReadForConfig(buffer, 2, offset);
+ uint8_t version = buffer[0];
+ uint8_t configured_bits = buffer[1];
if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion);
config.pointer_bhiksha_bits = configured_bits;
}
@@ -87,9 +89,6 @@ void ArrayBhiksha::FinishedLoading(const Config &config) {
*(head_write++) = config.pointer_bhiksha_bits;
}
-void ArrayBhiksha::LoadedBinary() {
-}
-
} // namespace trie
} // namespace ngram
} // namespace lm
diff --git a/lm/bhiksha.hh b/lm/bhiksha.hh
index 8ff8865..db71766 100644
--- a/lm/bhiksha.hh
+++ b/lm/bhiksha.hh
@@ -10,8 +10,8 @@
* Currently only used for next pointers.
*/
-#ifndef LM_BHIKSHA__
-#define LM_BHIKSHA__
+#ifndef LM_BHIKSHA_H
+#define LM_BHIKSHA_H
#include <stdint.h>
#include <assert.h>
@@ -24,6 +24,7 @@
namespace lm {
namespace ngram {
struct Config;
+class BinaryFormat;
namespace trie {
@@ -31,7 +32,7 @@ class DontBhiksha {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
- static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {}
static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
@@ -53,8 +54,6 @@ class DontBhiksha {
void FinishedLoading(const Config &/*config*/) {}
- void LoadedBinary() {}
-
uint8_t InlineBits() const { return next_.bits; }
private:
@@ -65,7 +64,7 @@ class ArrayBhiksha {
public:
static const ModelType kModelTypeAdd = kArrayAdd;
- static void UpdateConfigFromBinary(int fd, Config &config);
+ static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
@@ -93,8 +92,6 @@ class ArrayBhiksha {
void FinishedLoading(const Config &config);
- void LoadedBinary();
-
uint8_t InlineBits() const { return next_inline_.bits; }
private:
@@ -112,4 +109,4 @@ class ArrayBhiksha {
} // namespace ngram
} // namespace lm
-#endif // LM_BHIKSHA__
+#endif // LM_BHIKSHA_H
diff --git a/lm/binary_format.cc b/lm/binary_format.cc
index bef51eb..9c744b1 100644
--- a/lm/binary_format.cc
+++ b/lm/binary_format.cc
@@ -14,6 +14,9 @@
namespace lm {
namespace ngram {
+
+const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
+
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
@@ -58,8 +61,6 @@ struct Sanity {
}
};
-const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
-
std::size_t TotalHeaderSize(unsigned char order) {
return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
}
@@ -81,83 +82,6 @@ void WriteHeader(void *to, const Parameters &params) {
} // namespace
-uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
- if (config.write_mmap) {
- std::size_t total = TotalHeaderSize(order) + memory_size;
- backing.file.reset(util::CreateOrThrow(config.write_mmap));
- if (config.write_method == Config::WRITE_MMAP) {
- backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
- } else {
- util::ResizeOrThrow(backing.file.get(), 0);
- util::MapAnonymous(total, backing.vocab);
- }
- strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
- return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
- } else {
- util::MapAnonymous(memory_size, backing.vocab);
- return reinterpret_cast<uint8_t*>(backing.vocab.get());
- }
-}
-
-uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
- std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
- if (config.write_mmap) {
- // Grow the file to accomodate the search, using zeros.
- try {
- util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
- } catch (util::ErrnoException &e) {
- e << " for file " << config.write_mmap;
- throw e;
- }
-
- if (config.write_method == Config::WRITE_AFTER) {
- util::MapAnonymous(memory_size, backing.search);
- return reinterpret_cast<uint8_t*>(backing.search.get());
- }
- // mmap it now.
- // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
- std::size_t page_size = util::SizePage();
- std::size_t alignment_cruft = adjusted_vocab % page_size;
- backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
- return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
- } else {
- util::MapAnonymous(memory_size, backing.search);
- return reinterpret_cast<uint8_t*>(backing.search.get());
- }
-}
-
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
- if (!config.write_mmap) return;
- switch (config.write_method) {
- case Config::WRITE_MMAP:
- util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
- util::SyncOrThrow(backing.search.get(), backing.search.size());
- break;
- case Config::WRITE_AFTER:
- util::SeekOrThrow(backing.file.get(), 0);
- util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size());
- util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
- util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
- util::FSyncOrThrow(backing.file.get());
- break;
- }
- // header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params = Parameters();
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
- params.fixed.search_version = search_version;
- WriteHeader(backing.vocab.get(), params);
- if (config.write_method == Config::WRITE_AFTER) {
- util::SeekOrThrow(backing.file.get(), 0);
- util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size()));
- }
-}
-
-namespace detail {
-
bool IsBinaryFormat(int fd) {
const uint64_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
@@ -209,44 +133,164 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version);
}
-void SeekPastHeader(int fd, const Parameters &params) {
- util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
+const std::size_t kInvalidSize = static_cast<std::size_t>(-1);
+
+BinaryFormat::BinaryFormat(const Config &config)
+ : write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method),
+ header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {}
+
+void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params) {
+ file_.reset(fd);
+ write_mmap_ = NULL; // Ignore write requests; this is already in binary format.
+ ReadHeader(fd, params);
+ MatchCheck(model_type, search_version, params);
+ header_size_ = TotalHeaderSize(params.counts.size());
+}
+
+void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const {
+ assert(header_size_ != kInvalidSize);
+ util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_);
}
-uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
- const uint64_t file_size = util::SizeFile(backing.file.get());
+void *BinaryFormat::LoadBinary(std::size_t size) {
+ assert(header_size_ != kInvalidSize);
+ const uint64_t file_size = util::SizeFile(file_.get());
// The header is smaller than a page, so we have to map the whole header as well.
- std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
- if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
- UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
+ uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size);
+ UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
- util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search);
+ util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_);
- if (config.enumerate_vocab && !params.fixed.has_vocabulary)
- UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+ vocab_string_offset_ = total_map;
+ return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
+}
+
+void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
+ vocab_size_ = memory_size;
+ if (!write_mmap_) {
+ header_size_ = 0;
+ util::MapAnonymous(memory_size, memory_vocab_);
+ return reinterpret_cast<uint8_t*>(memory_vocab_.get());
+ }
+ header_size_ = TotalHeaderSize(order);
+ std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size));
+ file_.reset(util::CreateOrThrow(write_mmap_));
+ // some gccs complain about uninitialized variables even though all enum values are covered.
+ void *vocab_base = NULL;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
+ vocab_base = mapping_.get();
+ break;
+ case Config::WRITE_AFTER:
+ util::ResizeOrThrow(file_.get(), 0);
+ util::MapAnonymous(total, memory_vocab_);
+ vocab_base = memory_vocab_.get();
+ break;
+ }
+ strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_);
+ return reinterpret_cast<uint8_t*>(vocab_base) + header_size_;
+}
- // Seek to vocabulary words
- util::SeekOrThrow(backing.file.get(), total_map);
- return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
+void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) {
+ assert(vocab_size_ != kInvalidSize);
+ vocab_pad_ = vocab_pad;
+ std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size;
+ vocab_string_offset_ = new_size;
+ if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) {
+ util::MapAnonymous(memory_size, memory_search_);
+ assert(header_size_ == 0 || write_mmap_);
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ return reinterpret_cast<uint8_t*>(memory_search_.get());
+ }
+
+ assert(write_method_ == Config::WRITE_MMAP);
+ // Also known as total size without vocab words.
+ // Grow the file to accomodate the search, using zeros.
+ // According to man mmap, behavior is undefined when the file is resized
+ // underneath a mmap that is not a multiple of the page size. So to be
+ // safe, we'll unmap it and map it again.
+ mapping_.reset();
+ util::ResizeOrThrow(file_.get(), new_size);
+ void *ret;
+ MapFile(vocab_base, ret);
+ return ret;
}
-void ComplainAboutARPA(const Config &config, ModelType model_type) {
- if (config.write_mmap || !config.messages) return;
- if (config.arpa_complain == Config::ALL) {
- *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
- } else if (config.arpa_complain == Config::EXPENSIVE &&
- (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
- *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) {
+ // Checking Config's include_vocab is the responsibility of the caller.
+ assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize);
+ if (!write_mmap_) {
+ // Unchanged base.
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get());
+ search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
+ return;
+ }
+ if (write_method_ == Config::WRITE_MMAP) {
+ mapping_.reset();
+ }
+ util::SeekOrThrow(file_.get(), VocabStringReadingOffset());
+ util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
+ if (write_method_ == Config::WRITE_MMAP) {
+ MapFile(vocab_base, search_base);
+ } else {
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
+ }
+}
+
+void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) {
+ if (!write_mmap_) return;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ util::SyncOrThrow(mapping_.get(), mapping_.size());
+ break;
+ case Config::WRITE_AFTER:
+ util::SeekOrThrow(file_.get(), 0);
+ util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size());
+ util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_);
+ util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size());
+ util::FSyncOrThrow(file_.get());
+ break;
+ }
+ // header and vocab share the same mmap.
+ Parameters params = Parameters();
+ memset(&params, 0, sizeof(Parameters));
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ params.fixed.search_version = search_version;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ WriteHeader(mapping_.get(), params);
+ util::SyncOrThrow(mapping_.get(), mapping_.size());
+ break;
+ case Config::WRITE_AFTER:
+ {
+ std::vector<uint8_t> buffer(TotalHeaderSize(counts.size()));
+ WriteHeader(&buffer[0], params);
+ util::SeekOrThrow(file_.get(), 0);
+ util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
+ }
+ break;
}
}
-} // namespace detail
+void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) {
+ mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED);
+ vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
+ search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_;
+}
bool RecognizeBinary(const char *file, ModelType &recognized) {
util::scoped_fd fd(util::OpenReadOrThrow(file));
- if (!detail::IsBinaryFormat(fd.get())) return false;
+ if (!IsBinaryFormat(fd.get())) {
+ return false;
+ }
Parameters params;
- detail::ReadHeader(fd.get(), params);
+ ReadHeader(fd.get(), params);
recognized = params.fixed.model_type;
return true;
}
diff --git a/lm/binary_format.hh b/lm/binary_format.hh
index bf699d5..136d6b1 100644
--- a/lm/binary_format.hh
+++ b/lm/binary_format.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BINARY_FORMAT__
-#define LM_BINARY_FORMAT__
+#ifndef LM_BINARY_FORMAT_H
+#define LM_BINARY_FORMAT_H
#include "lm/config.hh"
#include "lm/model_type.hh"
@@ -17,6 +17,8 @@
namespace lm {
namespace ngram {
+extern const char *kModelNames[6];
+
/*Inspect a file to determine if it is a binary lm. If not, return false.
* If so, return true and set recognized to the type. This is the only API in
* this header designed for use by decoder authors.
@@ -42,67 +44,63 @@ struct Parameters {
std::vector<uint64_t> counts;
};
-struct Backing {
- // File behind memory, if any.
- util::scoped_fd file;
- // Vocabulary lookup table. Not to be confused with the vocab words themselves.
- util::scoped_memory vocab;
- // Raw block of memory backing the language model data structures
- util::scoped_memory search;
-};
-
-// Create just enough of a binary file to write vocabulary to it.
-uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
-// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
-uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing);
-
-// Write header to binary file. This is done last to prevent incomplete files
-// from loading.
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing);
+class BinaryFormat {
+ public:
+ explicit BinaryFormat(const Config &config);
+
+ // Reading a binary file:
+ // Takes ownership of fd
+ void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params);
+ // Used to read parts of the file to update the config object before figuring out full size.
+ void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const;
+ // Actually load the binary file and return a pointer to the beginning of the search area.
+ void *LoadBinary(std::size_t size);
+
+ uint64_t VocabStringReadingOffset() const {
+ assert(vocab_string_offset_ != kInvalidOffset);
+ return vocab_string_offset_;
+ }
-namespace detail {
+ // Writing a binary file or initializing in RAM from ARPA:
+ // Size for vocabulary.
+ void *SetupJustVocab(std::size_t memory_size, uint8_t order);
+ // Warning: can change the vocaulary base pointer.
+ void *GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base);
+ // Warning: can change vocabulary and search base addresses.
+ void WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base);
+ // Write the header at the beginning of the file.
+ void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts);
+
+ private:
+ void MapFile(void *&vocab_base, void *&search_base);
+
+ // Copied from configuration.
+ const Config::WriteMethod write_method_;
+ const char *write_mmap_;
+ util::LoadMethod load_method_;
+
+ // File behind memory, if any.
+ util::scoped_fd file_;
+
+ // If there is a file involved, a single mapping.
+ util::scoped_memory mapping_;
+
+ // If the data is only in memory, separately allocate each because the trie
+ // knows vocab's size before it knows search's size (because SRILM might
+ // have pruned).
+ util::scoped_memory memory_vocab_, memory_search_;
+
+ // Memory ranges. Note that these may not be contiguous and may not all
+ // exist.
+ std::size_t header_size_, vocab_size_, vocab_pad_;
+ // aka end of search.
+ uint64_t vocab_string_offset_;
+
+ static const uint64_t kInvalidOffset = (uint64_t)-1;
+};
bool IsBinaryFormat(int fd);
-void ReadHeader(int fd, Parameters &params);
-
-void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters &params);
-
-void SeekPastHeader(int fd, const Parameters &params);
-
-uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing);
-
-void ComplainAboutARPA(const Config &config, ModelType model_type);
-
-} // namespace detail
-
-template <class To> void LoadLM(const char *file, const Config &config, To &to) {
- Backing &backing = to.MutableBacking();
- backing.file.reset(util::OpenReadOrThrow(file));
-
- try {
- if (detail::IsBinaryFormat(backing.file.get())) {
- Parameters params;
- detail::ReadHeader(backing.file.get(), params);
- detail::MatchCheck(To::kModelType, To::kVersion, params);
- // Replace the run-time configured probing_multiplier with the one in the file.
- Config new_config(config);
- new_config.probing_multiplier = params.fixed.probing_multiplier;
- detail::SeekPastHeader(backing.file.get(), params);
- To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config);
- uint64_t memory_size = To::Size(params.counts, new_config);
- uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);
- to.InitializeFromBinary(start, params, new_config, backing.file.get());
- } else {
- detail::ComplainAboutARPA(config, To::kModelType);
- to.InitializeFromARPA(file, config);
- }
- } catch (util::Exception &e) {
- e << " File: " << file;
- throw;
- }
-}
-
} // namespace ngram
} // namespace lm
-#endif // LM_BINARY_FORMAT__
+#endif // LM_BINARY_FORMAT_H
diff --git a/lm/blank.hh b/lm/blank.hh
index 4da8120..94a71ad 100644
--- a/lm/blank.hh
+++ b/lm/blank.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BLANK__
-#define LM_BLANK__
+#ifndef LM_BLANK_H
+#define LM_BLANK_H
#include <limits>
@@ -40,4 +40,4 @@ inline bool HasExtension(const float &backoff) {
} // namespace ngram
} // namespace lm
-#endif // LM_BLANK__
+#endif // LM_BLANK_H
diff --git a/lm/build_binary_main.cc b/lm/build_binary_main.cc
index 425a123..15b421e 100644
--- a/lm/build_binary_main.cc
+++ b/lm/build_binary_main.cc
@@ -52,6 +52,7 @@ void Usage(const char *name, const char *default_mem) {
"-a compresses pointers using an array of offsets. The parameter is the\n"
" maximum number of bits encoded by the array. Memory is minimized subject\n"
" to the maximum, so pick 255 to minimize memory.\n\n"
+"-h print this help message.\n\n"
"Get a memory estimate by passing an ARPA file without an output file name.\n";
exit(1);
}
@@ -104,12 +105,15 @@ int main(int argc, char *argv[]) {
const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G";
+ if (argc == 2 && !strcmp(argv[1], "--help"))
+ Usage(argv[0], default_mem);
+
try {
bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false;
lm::ngram::Config config;
config.building_memory = util::ParseSize(default_mem);
int opt;
- while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) {
+ while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:h")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -161,6 +165,7 @@ int main(int argc, char *argv[]) {
ParseFileList(optarg, config.rest_lower_files);
config.rest_function = Config::REST_LOWER;
break;
+ case 'h': // help
default:
Usage(argv[0], default_mem);
}
diff --git a/lm/builder/adjust_counts.hh b/lm/builder/adjust_counts.hh
index ea8a1e2..00b5d43 100644
--- a/lm/builder/adjust_counts.hh
+++ b/lm/builder/adjust_counts.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_ADJUST_COUNTS__
-#define LM_BUILDER_ADJUST_COUNTS__
+#ifndef LM_BUILDER_ADJUST_COUNTS_H
+#define LM_BUILDER_ADJUST_COUNTS_H
#include "lm/builder/discount.hh"
#include "util/exception.hh"
@@ -43,5 +43,5 @@ class AdjustCounts {
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_ADJUST_COUNTS__
+#endif // LM_BUILDER_ADJUST_COUNTS_H
diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc
index aea93ad..b99edd0 100644
--- a/lm/builder/corpus_count.cc
+++ b/lm/builder/corpus_count.cc
@@ -223,29 +223,47 @@ std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
return VocabHandout::MemUsage(vocab_estimate);
}
-CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
+CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
- dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
+ dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)),
+ disallowed_symbol_action_(disallowed_symbol) {
}
-void CorpusCount::Run(const util::stream::ChainPosition &position) {
- UTIL_TIMER("(%w s) Counted n-grams\n");
+namespace {
+ void ComplainDisallowed(StringPiece word, WarningAction &action) {
+ switch (action) {
+ case SILENT:
+ return;
+ case COMPLAIN:
+ std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl;
+ action = SILENT;
+ return;
+ case THROW_UP:
+ UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace.");
+ }
+ }
+} // namespace
+void CorpusCount::Run(const util::stream::ChainPosition &position) {
VocabHandout vocab(vocab_write_, type_count_);
token_count_ = 0;
type_count_ = 0;
const WordIndex end_sentence = vocab.Lookup("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
- StringPiece delimiters("\0\t\r ", 4);
+ bool delimiters[256];
+ util::BoolCharacter::Build("\0\t\n\r ", delimiters);
try {
while(true) {
StringPiece line(from_.ReadLine());
writer.StartSentence();
- for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) {
+ for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {
WordIndex word = vocab.Lookup(*w);
- UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future.");
+ if (word <= 2) {
+ ComplainDisallowed(*w, disallowed_symbol_action_);
+ continue;
+ }
writer.Append(word);
++count;
}
diff --git a/lm/builder/corpus_count.hh b/lm/builder/corpus_count.hh
index aa0ed8e..da4ff9f 100644
--- a/lm/builder/corpus_count.hh
+++ b/lm/builder/corpus_count.hh
@@ -1,6 +1,7 @@
-#ifndef LM_BUILDER_CORPUS_COUNT__
-#define LM_BUILDER_CORPUS_COUNT__
+#ifndef LM_BUILDER_CORPUS_COUNT_H
+#define LM_BUILDER_CORPUS_COUNT_H
+#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
#include "util/scoped.hh"
@@ -28,7 +29,7 @@ class CorpusCount {
// token_count: out.
// type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
- CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
+ CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol);
void Run(const util::stream::ChainPosition &position);
@@ -40,8 +41,10 @@ class CorpusCount {
std::size_t dedupe_mem_size_;
util::scoped_malloc dedupe_mem_;
+
+ WarningAction disallowed_symbol_action_;
};
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_CORPUS_COUNT__
+#endif // LM_BUILDER_CORPUS_COUNT_H
diff --git a/lm/builder/corpus_count_test.cc b/lm/builder/corpus_count_test.cc
index 6d325ef..26cb634 100644
--- a/lm/builder/corpus_count_test.cc
+++ b/lm/builder/corpus_count_test.cc
@@ -45,7 +45,7 @@ BOOST_AUTO_TEST_CASE(Short) {
NGramStream stream;
uint64_t token_count;
WordIndex type_count = 10;
- CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize(), SILENT);
chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"};
diff --git a/lm/builder/discount.hh b/lm/builder/discount.hh
index 4d0aa4f..e2f4084 100644
--- a/lm/builder/discount.hh
+++ b/lm/builder/discount.hh
@@ -1,5 +1,5 @@
-#ifndef BUILDER_DISCOUNT__
-#define BUILDER_DISCOUNT__
+#ifndef LM_BUILDER_DISCOUNT_H
+#define LM_BUILDER_DISCOUNT_H
#include <algorithm>
@@ -23,4 +23,4 @@ struct Discount {
} // namespace builder
} // namespace lm
-#endif // BUILDER_DISCOUNT__
+#endif // LM_BUILDER_DISCOUNT_H
diff --git a/lm/builder/header_info.hh b/lm/builder/header_info.hh
index ccca145..16f3f60 100644
--- a/lm/builder/header_info.hh
+++ b/lm/builder/header_info.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_HEADER_INFO__
-#define LM_BUILDER_HEADER_INFO__
+#ifndef LM_BUILDER_HEADER_INFO_H
+#define LM_BUILDER_HEADER_INFO_H
#include <string>
#include <stdint.h>
diff --git a/lm/builder/initial_probabilities.hh b/lm/builder/initial_probabilities.hh
index 0729a6a..f2a2975 100644
--- a/lm/builder/initial_probabilities.hh
+++ b/lm/builder/initial_probabilities.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_INITIAL_PROBABILITIES__
-#define LM_BUILDER_INITIAL_PROBABILITIES__
+#ifndef LM_BUILDER_INITIAL_PROBABILITIES_H
+#define LM_BUILDER_INITIAL_PROBABILITIES_H
#include "lm/builder/discount.hh"
#include "util/stream/config.hh"
@@ -31,4 +31,4 @@ void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::v
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_INITIAL_PROBABILITIES__
+#endif // LM_BUILDER_INITIAL_PROBABILITIES_H
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc
index cbf4bb4..4f8ffed 100644
--- a/lm/builder/interpolate.cc
+++ b/lm/builder/interpolate.cc
@@ -5,6 +5,7 @@
#include "lm/builder/multi_stream.hh"
#include "lm/builder/sort.hh"
#include "lm/lm_exception.hh"
+#include "util/fixed_array.hh"
#include "util/murmur_hash.hh"
#include <assert.h>
@@ -74,15 +75,17 @@ class Callback {
void Exit(unsigned, const NGram &) const {}
private:
- FixedArray<util::stream::Stream> backoffs_;
+ util::FixedArray<util::stream::Stream> backoffs_;
std::vector<float> probs_;
const std::vector<uint64_t>& prune_thresholds_;
};
} // namespace
-Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds)
- : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs), prune_thresholds_(prune_thresholds) {}
+Interpolate::Interpolate(uint64_t vocab_size, const ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds)
+ : uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
+ backoffs_(backoffs),
+ prune_thresholds_(prune_thresholds) {}
// perform order-wise interpolation
void Interpolate::Run(const ChainPositions &positions) {
diff --git a/lm/builder/interpolate.hh b/lm/builder/interpolate.hh
index cc372b4..d266191 100644
--- a/lm/builder/interpolate.hh
+++ b/lm/builder/interpolate.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_INTERPOLATE__
-#define LM_BUILDER_INTERPOLATE__
+#ifndef LM_BUILDER_INTERPOLATE_H
+#define LM_BUILDER_INTERPOLATE_H
#include <stdint.h>
@@ -14,8 +14,9 @@ namespace lm { namespace builder {
*/
class Interpolate {
public:
- explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs,
- const std::vector<uint64_t> &prune_thresholds_);
+ // Normally vocab_size is the unigram count-1 (since p(<s>) = 0) but might
+ // be larger when the user specifies a consistent vocabulary size.
+ explicit Interpolate(uint64_t vocab_size, const ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds);
void Run(const ChainPositions &positions);
@@ -26,4 +27,4 @@ class Interpolate {
};
}} // namespaces
-#endif // LM_BUILDER_INTERPOLATE__
+#endif // LM_BUILDER_INTERPOLATE_H
diff --git a/lm/builder/joint_order.hh b/lm/builder/joint_order.hh
index b562014..b9c22a0 100644
--- a/lm/builder/joint_order.hh
+++ b/lm/builder/joint_order.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_JOINT_ORDER__
-#define LM_BUILDER_JOINT_ORDER__
+#ifndef LM_BUILDER_JOINT_ORDER_H
+#define LM_BUILDER_JOINT_ORDER_H
#include "lm/builder/multi_stream.hh"
#include "lm/lm_exception.hh"
@@ -40,4 +40,4 @@ template <class Callback, class Compare> void JointOrder(const ChainPositions &p
}} // namespaces
-#endif // LM_BUILDER_JOINT_ORDER__
+#endif // LM_BUILDER_JOINT_ORDER_H
diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc
index d09028c..e09f9df 100644
--- a/lm/builder/lmplz_main.cc
+++ b/lm/builder/lmplz_main.cc
@@ -1,4 +1,5 @@
#include "lm/builder/pipeline.hh"
+#include "lm/lm_exception.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/usage.hh"
@@ -37,24 +38,30 @@ int main(int argc, char *argv[]) {
std::string text, arpa;
options.add_options()
+ ("help,h", po::bool_switch(), "Show this help message")
("order,o", po::value<std::size_t>(&pipeline.order)
#if BOOST_VERSION >= 104200
->required()
#endif
, "Order of the model")
("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
+ ("skip_symbols", po::bool_switch(), "Treat <s>, </s>, and <unk> as whitespace instead of throwing an exception")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
- ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
- ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
+ ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
+ ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write a file containing the unique vocabulary strings delimited by null bytes")
+ ("vocab_pad", po::value<std::size_t>(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with <unk> to reach this size. Requires --interpolate_unigrams")
("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
("prune_thresholds,P", po::value<std::vector<uint64_t> >(&pipeline.prune_thresholds), "Prune n-grams of count equal to or lower than threshold. 0 means no pruning");
- if (argc == 1) {
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, options), vm);
+
+ if (argc == 1 || vm["help"].as<bool>()) {
std::cerr <<
"Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
"Please cite:\n"
@@ -72,12 +79,17 @@ int main(int argc, char *argv[]) {
"setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n"
"Memory sizes are specified like GNU sort: a number followed by a unit character.\n"
"Valid units are \% for percentage of memory (supported platforms only) and (in\n"
- "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n";
+ "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n";
+ uint64_t mem = util::GuessPhysicalMemory();
+ if (mem) {
+ std::cerr << "This machine has " << mem << " bytes of memory.\n\n";
+ } else {
+ std::cerr << "Unable to determine the amount of memory on this machine.\n\n";
+ }
std::cerr << options << std::endl;
return 1;
}
- po::variables_map vm;
- po::store(po::parse_command_line(argc, argv, options), vm);
+
po::notify(vm);
//std::cerr << "vector: " << pipeline.counts_threshold.size() << std::endl;
@@ -120,6 +132,17 @@ int main(int argc, char *argv[]) {
}
#endif
+ if (pipeline.vocab_size_for_unk && !pipeline.initial_probs.interpolate_unigrams) {
+ std::cerr << "--vocab_pad requires --interpolate_unigrams" << std::endl;
+ return 1;
+ }
+
+ if (vm["skip_symbols"].as<bool>()) {
+ pipeline.disallowed_symbol_action = lm::COMPLAIN;
+ } else {
+ pipeline.disallowed_symbol_action = lm::THROW_UP;
+ }
+
util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;
diff --git a/lm/builder/multi_stream.hh b/lm/builder/multi_stream.hh
index 707a98c..1a8eb8b 100644
--- a/lm/builder/multi_stream.hh
+++ b/lm/builder/multi_stream.hh
@@ -1,7 +1,8 @@
-#ifndef LM_BUILDER_MULTI_STREAM__
-#define LM_BUILDER_MULTI_STREAM__
+#ifndef LM_BUILDER_MULTI_STREAM_H
+#define LM_BUILDER_MULTI_STREAM_H
#include "lm/builder/ngram_stream.hh"
+#include "util/fixed_array.hh"
#include "util/scoped.hh"
#include "util/stream/chain.hh"
@@ -13,72 +14,9 @@
namespace lm { namespace builder {
-template <class T> class FixedArray {
- public:
- explicit FixedArray(std::size_t count) {
- Init(count);
- }
-
- FixedArray() : newed_end_(NULL) {}
-
- void Init(std::size_t count) {
- assert(!block_.get());
- block_.reset(malloc(sizeof(T) * count));
- if (!block_.get()) throw std::bad_alloc();
- newed_end_ = begin();
- }
-
- FixedArray(const FixedArray &from) {
- std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
- Init(size);
- for (std::size_t i = 0; i < size; ++i) {
- new(end()) T(from[i]);
- Constructed();
- }
- }
-
- ~FixedArray() { clear(); }
-
- T *begin() { return static_cast<T*>(block_.get()); }
- const T *begin() const { return static_cast<const T*>(block_.get()); }
- // Always call Constructed after successful completion of new.
- T *end() { return newed_end_; }
- const T *end() const { return newed_end_; }
-
- T &back() { return *(end() - 1); }
- const T &back() const { return *(end() - 1); }
-
- std::size_t size() const { return end() - begin(); }
- bool empty() const { return begin() == end(); }
-
- T &operator[](std::size_t i) { return begin()[i]; }
- const T &operator[](std::size_t i) const { return begin()[i]; }
-
- template <class C> void push_back(const C &c) {
- new (end()) T(c);
- Constructed();
- }
-
- void clear() {
- for (T *i = begin(); i != end(); ++i)
- i->~T();
- newed_end_ = begin();
- }
-
- protected:
- void Constructed() {
- ++newed_end_;
- }
-
- private:
- util::scoped_malloc block_;
-
- T *newed_end_;
-};
-
class Chains;
-class ChainPositions : public FixedArray<util::stream::ChainPosition> {
+class ChainPositions : public util::FixedArray<util::stream::ChainPosition> {
public:
ChainPositions() {}
@@ -89,14 +27,14 @@ class ChainPositions : public FixedArray<util::stream::ChainPosition> {
}
};
-class Chains : public FixedArray<util::stream::Chain> {
+class Chains : public util::FixedArray<util::stream::Chain> {
private:
template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun {
typedef Chains type;
};
public:
- explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {}
+ explicit Chains(std::size_t limit) : util::FixedArray<util::stream::Chain>(limit) {}
template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
@@ -129,7 +67,7 @@ class Chains : public FixedArray<util::stream::Chain> {
};
inline void ChainPositions::Init(Chains &chains) {
- FixedArray<util::stream::ChainPosition>::Init(chains.size());
+ util::FixedArray<util::stream::ChainPosition>::Init(chains.size());
for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) {
new (end()) util::stream::ChainPosition(i->Add()); Constructed();
}
@@ -140,13 +78,13 @@ inline Chains &operator>>(Chains &chains, ChainPositions &positions) {
return chains;
}
-class NGramStreams : public FixedArray<NGramStream> {
+class NGramStreams : public util::FixedArray<NGramStream> {
public:
NGramStreams() {}
// This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning).
void InitWithDummy(const ChainPositions &positions) {
- FixedArray<NGramStream>::Init(positions.size() + 1);
+ util::FixedArray<NGramStream>::Init(positions.size() + 1);
new (end()) NGramStream(); Constructed();
for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) {
push_back(*i);
@@ -155,7 +93,7 @@ class NGramStreams : public FixedArray<NGramStream> {
// Limit restricts to positions[0,limit)
void Init(const ChainPositions &positions, std::size_t limit) {
- FixedArray<NGramStream>::Init(limit);
+ util::FixedArray<NGramStream>::Init(limit);
for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) {
push_back(*i);
}
@@ -177,4 +115,4 @@ inline Chains &operator>>(Chains &chains, NGramStreams &streams) {
}
}} // namespaces
-#endif // LM_BUILDER_MULTI_STREAM__
+#endif // LM_BUILDER_MULTI_STREAM_H
diff --git a/lm/builder/ngram.hh b/lm/builder/ngram.hh
index 756eaa6..0472bcb 100644
--- a/lm/builder/ngram.hh
+++ b/lm/builder/ngram.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_NGRAM__
-#define LM_BUILDER_NGRAM__
+#ifndef LM_BUILDER_NGRAM_H
+#define LM_BUILDER_NGRAM_H
#include "lm/weights.hh"
#include "lm/word_index.hh"
@@ -106,4 +106,4 @@ const WordIndex kEOS = 2;
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_NGRAM__
+#endif // LM_BUILDER_NGRAM_H
diff --git a/lm/builder/ngram_stream.hh b/lm/builder/ngram_stream.hh
index 3c99466..d7bf23a 100644
--- a/lm/builder/ngram_stream.hh
+++ b/lm/builder/ngram_stream.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_NGRAM_STREAM__
-#define LM_BUILDER_NGRAM_STREAM__
+#ifndef LM_BUILDER_NGRAM_STREAM_H
+#define LM_BUILDER_NGRAM_STREAM_H
#include "lm/builder/ngram.hh"
#include "util/stream/chain.hh"
@@ -52,4 +52,4 @@ inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &
}
}} // namespaces
-#endif // LM_BUILDER_NGRAM_STREAM__
+#endif // LM_BUILDER_NGRAM_STREAM_H
diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc
index f5548f7..cede3c7 100644
--- a/lm/builder/pipeline.cc
+++ b/lm/builder/pipeline.cc
@@ -204,7 +204,7 @@ class Master {
Chains chains_;
// Often only unigrams, but sometimes all orders.
- FixedArray<util::stream::FileBuffer> files_;
+ util::FixedArray<util::stream::FileBuffer> files_;
};
void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) {
@@ -225,17 +225,18 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
WordIndex type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
- CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
chain >> boost::ref(counter);
util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
chain.Wait(true);
+ std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;
std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;
master.InitForAdjust(sorter, type_count);
}
void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary,
- FixedArray<util::stream::FileBuffer> &gammas, std::vector<uint64_t> &prune_thresholds) {
+ util::FixedArray<util::stream::FileBuffer> &gammas, std::vector<uint64_t> &prune_thresholds) {
const PipelineConfig &config = master.Config();
Chains second(config.order);
@@ -261,7 +262,7 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector
master.SetupSorts(primary);
}
-void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas) {
std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl;
const PipelineConfig &config = master.Config();
master.MaximumLazyInput(counts, primary);
@@ -279,7 +280,7 @@ void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &maste
gamma_chains.push_back(read_backoffs);
gamma_chains.back() >> gammas[i].Source();
}
- master >> Interpolate(counts[0], ChainPositions(gamma_chains), config.prune_thresholds);
+ master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), ChainPositions(gamma_chains), config.prune_thresholds);
gamma_chains >> util::stream::kRecycle;
master.BufferFinal(counts);
}
@@ -316,7 +317,7 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
master >> AdjustCounts(counts, counts_pruned, discounts, config.prune_thresholds);
{
- FixedArray<util::stream::FileBuffer> gammas;
+ util::FixedArray<util::stream::FileBuffer> gammas;
Sorts<SuffixOrder> primary;
InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds);
InterpolateProbabilities(counts_pruned, master, primary, gammas);
diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh
index a937169..4395622 100644
--- a/lm/builder/pipeline.hh
+++ b/lm/builder/pipeline.hh
@@ -1,8 +1,9 @@
-#ifndef LM_BUILDER_PIPELINE__
-#define LM_BUILDER_PIPELINE__
+#ifndef LM_BUILDER_PIPELINE_H
+#define LM_BUILDER_PIPELINE_H
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/header_info.hh"
+#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
#include "util/stream/config.hh"
#include "util/file_piece.hh"
@@ -34,6 +35,24 @@ struct PipelineConfig {
// corresponding n-gram order
std::vector<uint64_t> prune_thresholds; //mjd
+ /* Computing the perplexity of LMs with different vocabularies is hard. For
+ * example, the lowest perplexity is attained by a unigram model that
+ * predicts p(<unk>) = 1 and has no other vocabulary. Also, linearly
+ * interpolated models will sum to more than 1 because <unk> is duplicated
+ * (SRI just pretends p(<unk>) = 0 for these purposes, which makes it sum to
+ * 1 but comes with its own problems). This option will make the vocabulary
+ * a particular size by replicating <unk> multiple times for purposes of
+ * computing vocabulary size. It has no effect if the actual vocabulary is
+ * larger. This parameter serves the same purpose as IRSTLM's "dub".
+ */
+ uint64_t vocab_size_for_unk;
+
+ /* What to do the first time <s>, </s>, or <unk> appears in the input. If
+ * this is anything but THROW_UP, then the symbol will always be treated as
+ * whitespace.
+ */
+ WarningAction disallowed_symbol_action;
+
const std::string &TempPrefix() const { return sort.temp_prefix; }
std::size_t TotalMemory() const { return sort.total_memory; }
};
@@ -42,4 +61,4 @@ struct PipelineConfig {
void Pipeline(PipelineConfig config, int text_file, int out_arpa);
}} // namespaces
-#endif // LM_BUILDER_PIPELINE__
+#endif // LM_BUILDER_PIPELINE_H
diff --git a/lm/builder/print.hh b/lm/builder/print.hh
index adbbb94..397ca95 100644
--- a/lm/builder/print.hh
+++ b/lm/builder/print.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_PRINT__
-#define LM_BUILDER_PRINT__
+#ifndef LM_BUILDER_PRINT_H
+#define LM_BUILDER_PRINT_H
#include "lm/builder/ngram.hh"
#include "lm/builder/multi_stream.hh"
@@ -100,4 +100,4 @@ class PrintARPA {
};
}} // namespaces
-#endif // LM_BUILDER_PRINT__
+#endif // LM_BUILDER_PRINT_H
diff --git a/lm/builder/sort.hh b/lm/builder/sort.hh
index 9989389..c7f2ff8 100644
--- a/lm/builder/sort.hh
+++ b/lm/builder/sort.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_SORT__
-#define LM_BUILDER_SORT__
+#ifndef LM_BUILDER_SORT_H
+#define LM_BUILDER_SORT_H
#include "lm/builder/multi_stream.hh"
#include "lm/builder/ngram.hh"
@@ -85,10 +85,10 @@ struct AddCombiner {
// The combiner is only used on a single chain, so I didn't bother to allow
// that template.
-template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > {
+template <class Compare> class Sorts : public util::FixedArray<util::stream::Sort<Compare> > {
private:
typedef util::stream::Sort<Compare> S;
- typedef FixedArray<S> P;
+ typedef util::FixedArray<S> P;
public:
void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
@@ -100,4 +100,4 @@ template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Comp
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_SORT__
+#endif // LM_BUILDER_SORT_H
diff --git a/lm/config.hh b/lm/config.hh
index 0de7b7c..dab2812 100644
--- a/lm/config.hh
+++ b/lm/config.hh
@@ -1,5 +1,5 @@
-#ifndef LM_CONFIG__
-#define LM_CONFIG__
+#ifndef LM_CONFIG_H
+#define LM_CONFIG_H
#include "lm/lm_exception.hh"
#include "util/mmap.hh"
@@ -120,4 +120,4 @@ struct Config {
} /* namespace ngram */ } /* namespace lm */
-#endif // LM_CONFIG__
+#endif // LM_CONFIG_H
diff --git a/lm/enumerate_vocab.hh b/lm/enumerate_vocab.hh
index 2726362..f5ce789 100644
--- a/lm/enumerate_vocab.hh
+++ b/lm/enumerate_vocab.hh
@@ -1,5 +1,5 @@
-#ifndef LM_ENUMERATE_VOCAB__
-#define LM_ENUMERATE_VOCAB__
+#ifndef LM_ENUMERATE_VOCAB_H
+#define LM_ENUMERATE_VOCAB_H
#include "lm/word_index.hh"
#include "util/string_piece.hh"
@@ -24,5 +24,5 @@ class EnumerateVocab {
} // namespace lm
-#endif // LM_ENUMERATE_VOCAB__
+#endif // LM_ENUMERATE_VOCAB_H
diff --git a/lm/facade.hh b/lm/facade.hh
index 760e839..8e12b62 100644
--- a/lm/facade.hh
+++ b/lm/facade.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FACADE__
-#define LM_FACADE__
+#ifndef LM_FACADE_H
+#define LM_FACADE_H
#include "lm/virtual_interface.hh"
#include "util/string_piece.hh"
@@ -17,14 +17,14 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ
typedef VocabularyT Vocabulary;
/* Translate from void* to State */
- FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const {
+ FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->FullScore(
*reinterpret_cast<const State*>(in_state),
new_word,
*reinterpret_cast<State*>(out_state));
}
- FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const {
+ FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->FullScoreForgotState(
context_rbegin,
context_rend,
@@ -37,7 +37,7 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ
return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob;
}
- float Score(const void *in_state, const WordIndex new_word, void *out_state) const {
+ float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->Score(
*reinterpret_cast<const State*>(in_state),
new_word,
@@ -70,4 +70,4 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ
} // mamespace base
} // namespace lm
-#endif // LM_FACADE__
+#endif // LM_FACADE_H
diff --git a/lm/filter/arpa_io.hh b/lm/filter/arpa_io.hh
index 5b31620..99c97b1 100644
--- a/lm/filter/arpa_io.hh
+++ b/lm/filter/arpa_io.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_ARPA_IO__
-#define LM_FILTER_ARPA_IO__
+#ifndef LM_FILTER_ARPA_IO_H
+#define LM_FILTER_ARPA_IO_H
/* Input and output for ARPA format language model files.
*/
#include "lm/read_arpa.hh"
@@ -14,7 +14,6 @@
#include <string>
#include <vector>
-#include <err.h>
#include <string.h>
#include <stdint.h>
@@ -112,4 +111,4 @@ template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) {
} // namespace lm
-#endif // LM_FILTER_ARPA_IO__
+#endif // LM_FILTER_ARPA_IO_H
diff --git a/lm/filter/count_io.hh b/lm/filter/count_io.hh
index 97c0fa2..de894ba 100644
--- a/lm/filter/count_io.hh
+++ b/lm/filter/count_io.hh
@@ -1,24 +1,22 @@
-#ifndef LM_FILTER_COUNT_IO__
-#define LM_FILTER_COUNT_IO__
+#ifndef LM_FILTER_COUNT_IO_H
+#define LM_FILTER_COUNT_IO_H
#include <fstream>
#include <iostream>
#include <string>
-#include <err.h>
-
+#include "util/fake_ofstream.hh"
+#include "util/file.hh"
#include "util/file_piece.hh"
namespace lm {
class CountOutput : boost::noncopyable {
public:
- explicit CountOutput(const char *name) : file_(name, std::ios::out) {}
+ explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {}
void AddNGram(const StringPiece &line) {
- if (!(file_ << line << '\n')) {
- err(3, "Writing counts file failed");
- }
+ file_ << line << '\n';
}
template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
@@ -30,7 +28,7 @@ class CountOutput : boost::noncopyable {
}
private:
- std::fstream file_;
+ util::FakeOFStream file_;
};
class CountBatch {
@@ -88,4 +86,4 @@ template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) {
} // namespace lm
-#endif // LM_FILTER_COUNT_IO__
+#endif // LM_FILTER_COUNT_IO_H
diff --git a/lm/filter/filter_main.cc b/lm/filter/filter_main.cc
index 1736bc4..82fdc1e 100644
--- a/lm/filter/filter_main.cc
+++ b/lm/filter/filter_main.cc
@@ -6,6 +6,7 @@
#endif
#include "lm/filter/vocab.hh"
#include "lm/filter/wrapper.hh"
+#include "util/exception.hh"
#include "util/file_piece.hh"
#include <boost/ptr_container/ptr_vector.hpp>
@@ -157,92 +158,96 @@ template <class Format> void DispatchFilterModes(const Config &config, std::istr
} // namespace lm
int main(int argc, char *argv[]) {
- if (argc < 4) {
- lm::DisplayHelp(argv[0]);
- return 1;
- }
+ try {
+ if (argc < 4) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
- // I used to have boost::program_options, but some users didn't want to compile boost.
- lm::Config config;
- config.mode = lm::MODE_UNSET;
- for (int i = 1; i < argc - 2; ++i) {
- const char *str = argv[i];
- if (!std::strcmp(str, "copy")) {
- config.mode = lm::MODE_COPY;
- } else if (!std::strcmp(str, "single")) {
- config.mode = lm::MODE_SINGLE;
- } else if (!std::strcmp(str, "multiple")) {
- config.mode = lm::MODE_MULTIPLE;
- } else if (!std::strcmp(str, "union")) {
- config.mode = lm::MODE_UNION;
- } else if (!std::strcmp(str, "phrase")) {
- config.phrase = true;
- } else if (!std::strcmp(str, "context")) {
- config.context = true;
- } else if (!std::strcmp(str, "arpa")) {
- config.format = lm::FORMAT_ARPA;
- } else if (!std::strcmp(str, "raw")) {
- config.format = lm::FORMAT_COUNT;
+ // I used to have boost::program_options, but some users didn't want to compile boost.
+ lm::Config config;
+ config.mode = lm::MODE_UNSET;
+ for (int i = 1; i < argc - 2; ++i) {
+ const char *str = argv[i];
+ if (!std::strcmp(str, "copy")) {
+ config.mode = lm::MODE_COPY;
+ } else if (!std::strcmp(str, "single")) {
+ config.mode = lm::MODE_SINGLE;
+ } else if (!std::strcmp(str, "multiple")) {
+ config.mode = lm::MODE_MULTIPLE;
+ } else if (!std::strcmp(str, "union")) {
+ config.mode = lm::MODE_UNION;
+ } else if (!std::strcmp(str, "phrase")) {
+ config.phrase = true;
+ } else if (!std::strcmp(str, "context")) {
+ config.context = true;
+ } else if (!std::strcmp(str, "arpa")) {
+ config.format = lm::FORMAT_ARPA;
+ } else if (!std::strcmp(str, "raw")) {
+ config.format = lm::FORMAT_COUNT;
#ifndef NTHREAD
- } else if (!std::strncmp(str, "threads:", 8)) {
- config.threads = boost::lexical_cast<size_t>(str + 8);
- if (!config.threads) {
- std::cerr << "Specify at least one thread." << std::endl;
+ } else if (!std::strncmp(str, "threads:", 8)) {
+ config.threads = boost::lexical_cast<size_t>(str + 8);
+ if (!config.threads) {
+ std::cerr << "Specify at least one thread." << std::endl;
+ return 1;
+ }
+ } else if (!std::strncmp(str, "batch_size:", 11)) {
+ config.batch_size = boost::lexical_cast<size_t>(str + 11);
+ if (config.batch_size < 5000) {
+ std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
+ if (!config.batch_size) return 1;
+ }
+#endif
+ } else {
+ lm::DisplayHelp(argv[0]);
return 1;
}
- } else if (!std::strncmp(str, "batch_size:", 11)) {
- config.batch_size = boost::lexical_cast<size_t>(str + 11);
- if (config.batch_size < 5000) {
- std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
- if (!config.batch_size) return 1;
- }
-#endif
- } else {
+ }
+
+ if (config.mode == lm::MODE_UNSET) {
lm::DisplayHelp(argv[0]);
return 1;
}
- }
-
- if (config.mode == lm::MODE_UNSET) {
- lm::DisplayHelp(argv[0]);
- return 1;
- }
- if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
- std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
- return 1;
- }
+ if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
+ std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
+ return 1;
+ }
- bool cmd_is_model = true;
- const char *cmd_input = argv[argc - 2];
- if (!strncmp(cmd_input, "vocab:", 6)) {
- cmd_is_model = false;
- cmd_input += 6;
- } else if (!strncmp(cmd_input, "model:", 6)) {
- cmd_input += 6;
- } else if (strchr(cmd_input, ':')) {
- errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input);
- } else {
- std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
- }
- std::ifstream cmd_file;
- std::istream *vocab;
- if (cmd_is_model) {
- vocab = &std::cin;
- } else {
- cmd_file.open(cmd_input, std::ios::in);
- if (!cmd_file) {
- err(2, "Could not open input file %s", cmd_input);
+ bool cmd_is_model = true;
+ const char *cmd_input = argv[argc - 2];
+ if (!strncmp(cmd_input, "vocab:", 6)) {
+ cmd_is_model = false;
+ cmd_input += 6;
+ } else if (!strncmp(cmd_input, "model:", 6)) {
+ cmd_input += 6;
+ } else if (strchr(cmd_input, ':')) {
+ std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl;
+ return 1;
+ } else {
+ std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
+ }
+ std::ifstream cmd_file;
+ std::istream *vocab;
+ if (cmd_is_model) {
+ vocab = &std::cin;
+ } else {
+ cmd_file.open(cmd_input, std::ios::in);
+ UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input);
+ vocab = &cmd_file;
}
- vocab = &cmd_file;
- }
- util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
+ util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
- if (config.format == lm::FORMAT_ARPA) {
- lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
- } else if (config.format == lm::FORMAT_COUNT) {
- lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ if (config.format == lm::FORMAT_ARPA) {
+ lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
+ } else if (config.format == lm::FORMAT_COUNT) {
+ lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ }
+ return 0;
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
}
- return 0;
}
diff --git a/lm/filter/format.hh b/lm/filter/format.hh
index 7f945b0..5a2e2db 100644
--- a/lm/filter/format.hh
+++ b/lm/filter/format.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_FORMAT_H__
-#define LM_FITLER_FORMAT_H__
+#ifndef LM_FILTER_FORMAT_H
+#define LM_FILTER_FORMAT_H
#include "lm/filter/arpa_io.hh"
#include "lm/filter/count_io.hh"
@@ -247,4 +247,4 @@ class MultipleOutputBuffer {
} // namespace lm
-#endif // LM_FILTER_FORMAT_H__
+#endif // LM_FILTER_FORMAT_H
diff --git a/lm/filter/phrase.hh b/lm/filter/phrase.hh
index e8e8583..e5898c9 100644
--- a/lm/filter/phrase.hh
+++ b/lm/filter/phrase.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_PHRASE_H__
-#define LM_FILTER_PHRASE_H__
+#ifndef LM_FILTER_PHRASE_H
+#define LM_FILTER_PHRASE_H
#include "util/murmur_hash.hh"
#include "util/string_piece.hh"
@@ -165,4 +165,4 @@ class Multiple : public detail::ConditionCommon {
} // namespace phrase
} // namespace lm
-#endif // LM_FILTER_PHRASE_H__
+#endif // LM_FILTER_PHRASE_H
diff --git a/lm/filter/thread.hh b/lm/filter/thread.hh
index e785b26..6a6523f 100644
--- a/lm/filter/thread.hh
+++ b/lm/filter/thread.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_THREAD_H__
-#define LM_FILTER_THREAD_H__
+#ifndef LM_FILTER_THREAD_H
+#define LM_FILTER_THREAD_H
#include "util/thread_pool.hh"
@@ -164,4 +164,4 @@ template <class Filter, class OutputBuffer, class RealOutput> class Controller :
} // namespace lm
-#endif // LM_FILTER_THREAD_H__
+#endif // LM_FILTER_THREAD_H
diff --git a/lm/filter/vocab.cc b/lm/filter/vocab.cc
index 7ee4e84..011ab59 100644
--- a/lm/filter/vocab.cc
+++ b/lm/filter/vocab.cc
@@ -4,7 +4,6 @@
#include <iostream>
#include <ctype.h>
-#include <err.h>
namespace lm {
namespace vocab {
diff --git a/lm/filter/vocab.hh b/lm/filter/vocab.hh
index 7f0fada..2ee6e1f 100644
--- a/lm/filter/vocab.hh
+++ b/lm/filter/vocab.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_VOCAB_H__
-#define LM_FILTER_VOCAB_H__
+#ifndef LM_FILTER_VOCAB_H
+#define LM_FILTER_VOCAB_H
// Vocabulary-based filters for language models.
@@ -130,4 +130,4 @@ class Multiple {
} // namespace vocab
} // namespace lm
-#endif // LM_FILTER_VOCAB_H__
+#endif // LM_FILTER_VOCAB_H
diff --git a/lm/filter/wrapper.hh b/lm/filter/wrapper.hh
index 90b07a0..822c5c2 100644
--- a/lm/filter/wrapper.hh
+++ b/lm/filter/wrapper.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_WRAPPER_H__
-#define LM_FILTER_WRAPPER_H__
+#ifndef LM_FILTER_WRAPPER_H
+#define LM_FILTER_WRAPPER_H
#include "util/string_piece.hh"
@@ -39,20 +39,18 @@ template <class FilterT> class ContextFilter {
explicit ContextFilter(Filter &backend) : backend_(backend) {}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
- pieces_.clear();
- // TODO: this copy could be avoided by a lookahead iterator.
- std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_));
- backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output);
+ // Find beginning of string or last space.
+ const char *last_space;
+ for (last_space = ngram.data() + ngram.size() - 1; last_space > ngram.data() && *last_space != ' '; --last_space) {}
+ backend_.AddNGram(StringPiece(ngram.data(), last_space - ngram.data()), line, output);
}
void Flush() const {}
private:
- std::vector<StringPiece> pieces_;
-
Filter backend_;
};
} // namespace lm
-#endif // LM_FILTER_WRAPPER_H__
+#endif // LM_FILTER_WRAPPER_H
diff --git a/lm/left.hh b/lm/left.hh
index 85c1ea3..36d6136 100644
--- a/lm/left.hh
+++ b/lm/left.hh
@@ -35,8 +35,8 @@
* phrase, even if hypotheses are generated left-to-right.
*/
-#ifndef LM_LEFT__
-#define LM_LEFT__
+#ifndef LM_LEFT_H
+#define LM_LEFT_H
#include "lm/max_order.hh"
#include "lm/state.hh"
@@ -213,4 +213,4 @@ template <class M> class RuleScore {
} // namespace ngram
} // namespace lm
-#endif // LM_LEFT__
+#endif // LM_LEFT_H
diff --git a/lm/lm_exception.hh b/lm/lm_exception.hh
index f607ced..8bb6108 100644
--- a/lm/lm_exception.hh
+++ b/lm/lm_exception.hh
@@ -1,5 +1,5 @@
-#ifndef LM_LM_EXCEPTION__
-#define LM_LM_EXCEPTION__
+#ifndef LM_LM_EXCEPTION_H
+#define LM_LM_EXCEPTION_H
// Named to avoid conflict with util/exception.hh.
diff --git a/lm/max_order.hh b/lm/max_order.hh
index 3eb97cc..f7344cd 100644
--- a/lm/max_order.hh
+++ b/lm/max_order.hh
@@ -1,9 +1,13 @@
-/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM.
+#ifndef LM_MAX_ORDER_H
+#define LM_MAX_ORDER_H
+/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER_H, THEN CHANGE THE BUILD SYSTEM.
* If not, this is the default maximum order.
* Having this limit means that State can be
* (kMaxOrder - 1) * sizeof(float) bytes instead of
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
*/
#ifndef KENLM_ORDER_MESSAGE
-#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
+#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER_H, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
#endif
+
+#endif // LM_MAX_ORDER_H
diff --git a/lm/model.cc b/lm/model.cc
index a26654a..a5a16bf 100644
--- a/lm/model.cc
+++ b/lm/model.cc
@@ -34,23 +34,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
}
-template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
- LoadLM(file, config, *this);
-
- // g++ prints warnings unless these are fully initialized.
- State begin_sentence = State();
- begin_sentence.length = 1;
- begin_sentence.words[0] = vocab_.BeginSentence();
- typename Search::Node ignored_node;
- bool ignored_independent_left;
- uint64_t ignored_extend_left;
- begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
- State null_context = State();
- null_context.length = 0;
- P::Init(begin_sentence, null_context, vocab_, search_.Order());
+namespace {
+void ComplainAboutARPA(const Config &config, ModelType model_type) {
+ if (config.write_mmap || !config.messages) return;
+ if (config.arpa_complain == Config::ALL) {
+ *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
+ } else if (config.arpa_complain == Config::EXPENSIVE &&
+ (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
+ *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+ }
}
-namespace {
void CheckCounts(const std::vector<uint64_t> &counts) {
UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
if (sizeof(uint64_t) > sizeof(std::size_t)) {
@@ -59,18 +53,45 @@ void CheckCounts(const std::vector<uint64_t> &counts) {
}
}
}
+
} // namespace
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
- CheckCounts(params.counts);
- SetupMemory(start, params.counts, config);
- vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
- search_.LoadedBinary();
+template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) {
+ util::scoped_fd fd(util::OpenReadOrThrow(file));
+ if (IsBinaryFormat(fd.get())) {
+ Parameters parameters;
+ int fd_shallow = fd.release();
+ backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters);
+ CheckCounts(parameters.counts);
+
+ Config new_config(init_config);
+ new_config.probing_multiplier = parameters.fixed.probing_multiplier;
+ Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config);
+ UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+
+ SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config);
+ vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset());
+ } else {
+ ComplainAboutARPA(init_config, kModelType);
+ InitializeFromARPA(fd.release(), file, init_config);
+ }
+
+ // g++ prints warnings unless these are fully initialized.
+ State begin_sentence = State();
+ begin_sentence.length = 1;
+ begin_sentence.words[0] = vocab_.BeginSentence();
+ typename Search::Node ignored_node;
+ bool ignored_independent_left;
+ uint64_t ignored_extend_left;
+ begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
+ State null_context = State();
+ null_context.length = 0;
+ P::Init(begin_sentence, null_context, vocab_, search_.Order());
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
- // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
- util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) {
+ // Backing file is the ARPA.
+ util::FilePiece f(fd, file, config.ProgressMessages());
try {
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
@@ -81,13 +102,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
- vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
+ vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config);
- if (config.write_mmap) {
+ if (config.write_mmap && config.include_vocab) {
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));
+ void *vocab_rebase, *search_rebase;
+ backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase);
+ // Due to writing at the end of file, mmap may have relocated data. So remap.
+ vocab_.Relocate(vocab_rebase);
+ search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config);
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
@@ -99,18 +124,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.UnknownUnigram().backoff = 0.0;
search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
- FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
+ backing_.FinishFile(config, kModelType, kVersion, counts);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
}
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
- util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
- Search::UpdateConfigFromBinary(fd, counts, config);
-}
-
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {
diff --git a/lm/model.hh b/lm/model.hh
index c9c17c4..6925a56 100644
--- a/lm/model.hh
+++ b/lm/model.hh
@@ -1,5 +1,5 @@
-#ifndef LM_MODEL__
-#define LM_MODEL__
+#ifndef LM_MODEL_H
+#define LM_MODEL_H
#include "lm/bhiksha.hh"
#include "lm/binary_format.hh"
@@ -104,10 +104,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
}
private:
- friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
-
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
-
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
// Score bigrams and above. Do not include backoff.
@@ -116,15 +112,11 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
- void InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd);
-
- void InitializeFromARPA(const char *file, const Config &config);
+ void InitializeFromARPA(int fd, const char *file, const Config &config);
float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const;
- Backing &MutableBacking() { return backing_; }
-
- Backing backing_;
+ BinaryFormat backing_;
VocabularyT vocab_;
@@ -161,4 +153,4 @@ base::Model *LoadVirtual(const char *file_name, const Config &config = Config(),
} // namespace ngram
} // namespace lm
-#endif // LM_MODEL__
+#endif // LM_MODEL_H
diff --git a/lm/model_test.cc b/lm/model_test.cc
index eb15909..7005b05 100644
--- a/lm/model_test.cc
+++ b/lm/model_test.cc
@@ -360,10 +360,11 @@ BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
LoadingTest<QuantArrayTrieModel>();
}
-template <class ModelT> void BinaryTest() {
+template <class ModelT> void BinaryTest(Config::WriteMethod write_method) {
Config config;
config.write_mmap = "test.binary";
config.messages = NULL;
+ config.write_method = write_method;
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
@@ -406,6 +407,11 @@ template <class ModelT> void BinaryTest() {
unlink("test_nounk.binary");
}
+template <class ModelT> void BinaryTest() {
+ BinaryTest<ModelT>(Config::WRITE_MMAP);
+ BinaryTest<ModelT>(Config::WRITE_AFTER);
+}
+
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
BinaryTest<ProbingModel>();
}
diff --git a/lm/model_type.hh b/lm/model_type.hh
index 8b35c79..fbe1117 100644
--- a/lm/model_type.hh
+++ b/lm/model_type.hh
@@ -1,5 +1,5 @@
-#ifndef LM_MODEL_TYPE__
-#define LM_MODEL_TYPE__
+#ifndef LM_MODEL_TYPE_H
+#define LM_MODEL_TYPE_H
namespace lm {
namespace ngram {
@@ -20,4 +20,4 @@ const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE - TRIE);
} // namespace ngram
} // namespace lm
-#endif // LM_MODEL_TYPE__
+#endif // LM_MODEL_TYPE_H
diff --git a/lm/ngram_query.hh b/lm/ngram_query.hh
index ec2590f..454e856 100644
--- a/lm/ngram_query.hh
+++ b/lm/ngram_query.hh
@@ -1,8 +1,9 @@
-#ifndef LM_NGRAM_QUERY__
-#define LM_NGRAM_QUERY__
+#ifndef LM_NGRAM_QUERY_H
+#define LM_NGRAM_QUERY_H
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
+#include "util/file_piece.hh"
#include "util/usage.hh"
#include <cstdlib>
@@ -16,42 +17,41 @@
namespace lm {
namespace ngram {
-template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
+template <class Model> void Query(const Model &model, bool sentence_context) {
typename Model::State state, out;
lm::FullScoreReturn ret;
- std::string word;
+ StringPiece word;
+
+ util::FilePiece in(0);
+ std::ostream &out_stream = std::cout;
double corpus_total = 0.0;
+ double corpus_total_oov_only = 0.0;
uint64_t corpus_oov = 0;
uint64_t corpus_tokens = 0;
- while (in_stream) {
+ while (true) {
state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
float total = 0.0;
- bool got = false;
uint64_t oov = 0;
- while (in_stream >> word) {
- got = true;
+
+ while (in.ReadWordSameLine(word)) {
lm::WordIndex vocab = model.GetVocabulary().Index(word);
- if (vocab == 0) ++oov;
ret = model.FullScore(state, vocab, out);
+ if (vocab == model.GetVocabulary().NotFound()) {
+ ++oov;
+ corpus_total_oov_only += ret.prob;
+ }
total += ret.prob;
out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
++corpus_tokens;
state = out;
- char c;
- while (true) {
- c = in_stream.get();
- if (!in_stream) break;
- if (c == '\n') break;
- if (!isspace(c)) {
- in_stream.unget();
- break;
- }
- }
- if (c == '\n') break;
}
- if (!got && !in_stream) break;
+ // If people don't have a newline after their last query, this won't add a </s>.
+ // Sue me.
+ try {
+ UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused.");
+ } catch (const util::EndOfFileException &e) { break; }
if (sentence_context) {
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
@@ -62,18 +62,22 @@ template <class Model> void Query(const Model &model, bool sentence_context, std
corpus_total += total;
corpus_oov += oov;
}
- out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl;
+ out_stream <<
+ "Perplexity including OOVs:\t" << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << "\n"
+ "Perplexity excluding OOVs:\t" << pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast<double>(corpus_tokens - corpus_oov))) << "\n"
+ "OOVs:\t" << corpus_oov << "\n"
+ ;
}
-template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
+template <class M> void Query(const char *file, bool sentence_context) {
Config config;
M model(file, config);
- Query(model, sentence_context, in_stream, out_stream);
+ Query(model, sentence_context);
}
} // namespace ngram
} // namespace lm
-#endif // LM_NGRAM_QUERY__
+#endif // LM_NGRAM_QUERY_H
diff --git a/lm/partial.hh b/lm/partial.hh
index 1dede35..d8adc69 100644
--- a/lm/partial.hh
+++ b/lm/partial.hh
@@ -1,5 +1,5 @@
-#ifndef LM_PARTIAL__
-#define LM_PARTIAL__
+#ifndef LM_PARTIAL_H
+#define LM_PARTIAL_H
#include "lm/return.hh"
#include "lm/state.hh"
@@ -164,4 +164,4 @@ template <class Model> float Subsume(const Model &model, Left &first_left, const
} // namespace ngram
} // namespace lm
-#endif // LM_PARTIAL__
+#endif // LM_PARTIAL_H
diff --git a/lm/quantize.cc b/lm/quantize.cc
index b58c3f3..273ea39 100644
--- a/lm/quantize.cc
+++ b/lm/quantize.cc
@@ -38,13 +38,13 @@ const char kSeparatelyQuantizeVersion = 2;
} // namespace
-void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) {
- char version;
- util::ReadOrThrow(fd, &version, 1);
- util::ReadOrThrow(fd, &config.prob_bits, 1);
- util::ReadOrThrow(fd, &config.backoff_bits, 1);
+void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) {
+ unsigned char buffer[3];
+ file.ReadForConfig(buffer, 3, offset);
+ char version = buffer[0];
+ config.prob_bits = buffer[1];
+ config.backoff_bits = buffer[2];
if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion);
- util::AdvanceOrThrow(fd, -3);
}
void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) {
diff --git a/lm/quantize.hh b/lm/quantize.hh
index 8ce2378..84a3087 100644
--- a/lm/quantize.hh
+++ b/lm/quantize.hh
@@ -1,5 +1,5 @@
-#ifndef LM_QUANTIZE_H__
-#define LM_QUANTIZE_H__
+#ifndef LM_QUANTIZE_H
+#define LM_QUANTIZE_H
#include "lm/blank.hh"
#include "lm/config.hh"
@@ -18,12 +18,13 @@ namespace lm {
namespace ngram {
struct Config;
+class BinaryFormat;
/* Store values directly and don't quantize. */
class DontQuantize {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
- static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {}
static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
@@ -136,7 +137,7 @@ class SeparatelyQuantize {
public:
static const ModelType kModelTypeAdd = kQuantAdd;
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
+ static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
static uint64_t Size(uint8_t order, const Config &config) {
uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
@@ -229,4 +230,4 @@ class SeparatelyQuantize {
} // namespace ngram
} // namespace lm
-#endif // LM_QUANTIZE_H__
+#endif // LM_QUANTIZE_H
diff --git a/lm/query_main.cc b/lm/query_main.cc
index b9db7b0..cd661f7 100644
--- a/lm/query_main.cc
+++ b/lm/query_main.cc
@@ -4,48 +4,62 @@
#include "lm/wrappers/nplm.hh"
#endif
+#include <stdlib.h>
+
+void Usage(const char *name) {
+ std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl;
+ std::cerr << "Usage: " << name << " [-n] lm_file" << std::endl;
+ std::cerr << "Input is wrapped in <s> and </s> unless -n is passed." << std::endl;
+ exit(1);
+}
+
int main(int argc, char *argv[]) {
- if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) {
- std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl;
- std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl;
- std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl;
- return 1;
+ bool sentence_context = true;
+ const char *file = NULL;
+ for (char **arg = argv + 1; arg != argv + argc; ++arg) {
+ if (!strcmp(*arg, "-n")) {
+ sentence_context = false;
+ } else if (!strcmp(*arg, "-h") || !strcmp(*arg, "--help") || file) {
+ Usage(argv[0]);
+ } else {
+ file = *arg;
+ }
}
+ if (!file) Usage(argv[0]);
try {
- bool sentence_context = (argc == 2);
using namespace lm::ngram;
ModelType model_type;
- if (RecognizeBinary(argv[1], model_type)) {
+ if (RecognizeBinary(file, model_type)) {
switch(model_type) {
case PROBING:
- Query<lm::ngram::ProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<lm::ngram::ProbingModel>(file, sentence_context);
break;
case REST_PROBING:
- Query<lm::ngram::RestProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<lm::ngram::RestProbingModel>(file, sentence_context);
break;
case TRIE:
- Query<TrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<TrieModel>(file, sentence_context);
break;
case QUANT_TRIE:
- Query<QuantTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<QuantTrieModel>(file, sentence_context);
break;
case ARRAY_TRIE:
- Query<ArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<ArrayTrieModel>(file, sentence_context);
break;
case QUANT_ARRAY_TRIE:
- Query<QuantArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<QuantArrayTrieModel>(file, sentence_context);
break;
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
abort();
}
#ifdef WITH_NPLM
- } else if (lm::np::Model::Recognize(argv[1])) {
- lm::np::Model model(argv[1]);
- Query(model, sentence_context, std::cin, std::cout);
+ } else if (lm::np::Model::Recognize(file)) {
+ lm::np::Model model(file);
+ Query(model, sentence_context);
#endif
} else {
- Query<ProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<ProbingModel>(file, sentence_context);
}
std::cerr << "Total time including destruction:\n";
util::PrintUsage(std::cerr);
diff --git a/lm/read_arpa.cc b/lm/read_arpa.cc
index 5ccba71..fb8bbfa 100644
--- a/lm/read_arpa.cc
+++ b/lm/read_arpa.cc
@@ -150,7 +150,7 @@ void PositiveProbWarn::Warn(float prob) {
case THROW_UP:
UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error");
case COMPLAIN:
- std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl;
+ std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapped to 0 log probability." << std::endl;
action_ = SILENT;
break;
case SILENT:
diff --git a/lm/read_arpa.hh b/lm/read_arpa.hh
index 234d130..16e1fc1 100644
--- a/lm/read_arpa.hh
+++ b/lm/read_arpa.hh
@@ -1,5 +1,5 @@
-#ifndef LM_READ_ARPA__
-#define LM_READ_ARPA__
+#ifndef LM_READ_ARPA_H
+#define LM_READ_ARPA_H
#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
@@ -87,4 +87,4 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns
} // namespace lm
-#endif // LM_READ_ARPA__
+#endif // LM_READ_ARPA_H
diff --git a/lm/return.hh b/lm/return.hh
index 622320c..982ffd6 100644
--- a/lm/return.hh
+++ b/lm/return.hh
@@ -1,5 +1,5 @@
-#ifndef LM_RETURN__
-#define LM_RETURN__
+#ifndef LM_RETURN_H
+#define LM_RETURN_H
#include <stdint.h>
@@ -39,4 +39,4 @@ struct FullScoreReturn {
};
} // namespace lm
-#endif // LM_RETURN__
+#endif // LM_RETURN_H
diff --git a/lm/search_hashed.cc b/lm/search_hashed.cc
index 62275d2..354a56b 100644
--- a/lm/search_hashed.cc
+++ b/lm/search_hashed.cc
@@ -204,9 +204,10 @@ template <class Build, class Activate, class Store> void ReadNGrams(
namespace detail {
template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
- std::size_t allocated = Unigram::Size(counts[0]);
- unigram_ = Unigram(start, counts[0], allocated);
- start += allocated;
+ unigram_ = Unigram(start, counts[0]);
+ start += Unigram::Size(counts[0]);
+ std::size_t allocated;
+ middle_.clear();
for (unsigned int n = 2; n < counts.size(); ++n) {
allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
middle_.push_back(Middle(start, allocated));
@@ -218,9 +219,21 @@ template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start,
return start;
}
-template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing) {
- // TODO: fix sorted.
- SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config);
+/*template <class Value> void HashedSearch<Value>::Relocate(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
+ unigram_ = Unigram(start, counts[0]);
+ start += Unigram::Size(counts[0]);
+ for (unsigned int n = 2; n < counts.size(); ++n) {
+ middle[n-2].Relocate(start);
+ start += Middle::Size(counts[n - 1], config.probing_multiplier)
+ }
+ longest_.Relocate(start);
+}*/
+
+template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing) {
+ void *vocab_rebase;
+ void *search_base = backing.GrowForSearch(Size(counts, config), vocab.UnkCountChangePadding(), vocab_rebase);
+ vocab.Relocate(vocab_rebase);
+ SetupMemory(reinterpret_cast<uint8_t*>(search_base), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn);
@@ -277,14 +290,6 @@ template <class Value> template <class Build> void HashedSearch<Value>::ApplyBui
ReadEnd(f);
}
-template <class Value> void HashedSearch<Value>::LoadedBinary() {
- unigram_.LoadedBinary();
- for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) {
- i->LoadedBinary();
- }
- longest_.LoadedBinary();
-}
-
template class HashedSearch<BackoffValue>;
template class HashedSearch<RestValue>;
diff --git a/lm/search_hashed.hh b/lm/search_hashed.hh
index 9d067bc..9dc8445 100644
--- a/lm/search_hashed.hh
+++ b/lm/search_hashed.hh
@@ -1,5 +1,5 @@
-#ifndef LM_SEARCH_HASHED__
-#define LM_SEARCH_HASHED__
+#ifndef LM_SEARCH_HASHED_H
+#define LM_SEARCH_HASHED_H
#include "lm/model_type.hh"
#include "lm/config.hh"
@@ -18,7 +18,7 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
-struct Backing;
+class BinaryFormat;
class ProbingVocabulary;
namespace detail {
@@ -72,7 +72,7 @@ template <class Value> class HashedSearch {
static const unsigned int kVersion = 0;
// TODO: move probing_multiplier here with next binary file format update.
- static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
uint64_t ret = Unigram::Size(counts[0]);
@@ -84,9 +84,7 @@ template <class Value> class HashedSearch {
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing);
-
- void LoadedBinary();
+ void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing);
unsigned char Order() const {
return middle_.size() + 2;
@@ -148,7 +146,7 @@ template <class Value> class HashedSearch {
public:
Unigram() {}
- Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
+ Unigram(void *start, uint64_t count) :
unigram_(static_cast<typename Value::Weights*>(start))
#ifdef DEBUG
, count_(count)
@@ -168,8 +166,6 @@ template <class Value> class HashedSearch {
typename Value::Weights &Unknown() { return unigram_[0]; }
- void LoadedBinary() {}
-
// For building.
typename Value::Weights *Raw() { return unigram_; }
@@ -193,4 +189,4 @@ template <class Value> class HashedSearch {
} // namespace ngram
} // namespace lm
-#endif // LM_SEARCH_HASHED__
+#endif // LM_SEARCH_HASHED_H
diff --git a/lm/search_trie.cc b/lm/search_trie.cc
index 1b0d9b2..4a88194 100644
--- a/lm/search_trie.cc
+++ b/lm/search_trie.cc
@@ -253,11 +253,6 @@ class FindBlanks {
++counts_.back();
}
- // Unigrams wrote one past.
- void Cleanup() {
- --counts_[0];
- }
-
const std::vector<uint64_t> &Counts() const {
return counts_;
}
@@ -310,8 +305,6 @@ template <class Quant, class Bhiksha> class WriteEntries {
typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob);
}
- void Cleanup() {}
-
private:
RecordReader *contexts_;
const Quant &quant_;
@@ -385,14 +378,14 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con
util::ErsatzProgress progress(unigram_count + 1, progress_out, message);
WordIndex unigram = 0;
std::priority_queue<Gram> grams;
- grams.push(Gram(&unigram, 1));
+ if (unigram_count) grams.push(Gram(&unigram, 1));
for (unsigned char i = 2; i <= total_order; ++i) {
if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i));
}
BlankManager<Doing> blank(total_order, doing);
- while (true) {
+ while (!grams.empty()) {
Gram top = grams.top();
grams.pop();
unsigned char order = top.end - top.begin;
@@ -400,8 +393,7 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con
blank.Visit(&unigram, 1, doing.UnigramProb(unigram));
doing.Unigram(unigram);
progress.Set(unigram);
- if (++unigram == unigram_count + 1) break;
- grams.push(top);
+ if (++unigram < unigram_count) grams.push(top);
} else {
if (order == total_order) {
blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob);
@@ -414,8 +406,6 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con
if (++reader) grams.push(top);
}
}
- assert(grams.empty());
- doing.Cleanup();
}
void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
@@ -469,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c
} // namespace
-template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) {
RecordReader inputs[KENLM_MAX_ORDER - 1];
RecordReader contexts[KENLM_MAX_ORDER - 1];
@@ -498,7 +488,10 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
- out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
+ void *vocab_relocate;
+ void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate);
+ vocab.Relocate(vocab_relocate);
+ out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config);
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
@@ -524,6 +517,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
{
WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
+ // Write the last unigram entry, which is the end pointer for the bigrams.
+ writer.Unigram(counts[0]);
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
@@ -579,15 +574,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
-template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
- unigram_.LoadedBinary();
- for (Middle *i = middle_begin_; i != middle_end_; ++i) {
- i->LoadedBinary();
- }
- longest_.LoadedBinary();
-}
-
-template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) {
std::string temporary_prefix;
if (config.temporary_directory_prefix) {
temporary_prefix = config.temporary_directory_prefix;
diff --git a/lm/search_trie.hh b/lm/search_trie.hh
index 763fd1a..d8838d2 100644
--- a/lm/search_trie.hh
+++ b/lm/search_trie.hh
@@ -1,5 +1,5 @@
-#ifndef LM_SEARCH_TRIE__
-#define LM_SEARCH_TRIE__
+#ifndef LM_SEARCH_TRIE_H
+#define LM_SEARCH_TRIE_H
#include "lm/config.hh"
#include "lm/model_type.hh"
@@ -17,13 +17,13 @@
namespace lm {
namespace ngram {
-struct Backing;
+class BinaryFormat;
class SortedVocabulary;
namespace trie {
template <class Quant, class Bhiksha> class TrieSearch;
class SortedFiles;
-template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
template <class Quant, class Bhiksha> class TrieSearch {
public:
@@ -39,11 +39,11 @@ template <class Quant, class Bhiksha> class TrieSearch {
static const unsigned int kVersion = 1;
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
- Quant::UpdateConfigFromBinary(fd, counts, config);
- util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
+ static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) {
+ Quant::UpdateConfigFromBinary(file, offset, config);
// Currently the unigram pointers are not compresssed, so there will only be a header for order > 2.
- if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config);
+ if (counts.size() > 2)
+ Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config);
}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
@@ -60,9 +60,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- void LoadedBinary();
-
- void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
+ void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing);
unsigned char Order() const {
return middle_end_ - middle_begin_ + 2;
@@ -103,7 +101,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
}
private:
- friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
+ friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
@@ -129,4 +127,4 @@ template <class Quant, class Bhiksha> class TrieSearch {
} // namespace ngram
} // namespace lm
-#endif // LM_SEARCH_TRIE__
+#endif // LM_SEARCH_TRIE_H
diff --git a/lm/sizes.hh b/lm/sizes.hh
index 85abade..eb7e99d 100644
--- a/lm/sizes.hh
+++ b/lm/sizes.hh
@@ -1,5 +1,5 @@
-#ifndef LM_SIZES__
-#define LM_SIZES__
+#ifndef LM_SIZES_H
+#define LM_SIZES_H
#include <vector>
@@ -14,4 +14,4 @@ void ShowSizes(const std::vector<uint64_t> &counts);
void ShowSizes(const char *file, const lm::ngram::Config &config);
}} // namespaces
-#endif // LM_SIZES__
+#endif // LM_SIZES_H
diff --git a/lm/state.hh b/lm/state.hh
index d8e6c13..f6c51d6 100644
--- a/lm/state.hh
+++ b/lm/state.hh
@@ -1,5 +1,5 @@
-#ifndef LM_STATE__
-#define LM_STATE__
+#ifndef LM_STATE_H
+#define LM_STATE_H
#include "lm/max_order.hh"
#include "lm/word_index.hh"
@@ -91,7 +91,7 @@ inline uint64_t hash_value(const Left &left) {
}
struct ChartState {
- bool operator==(const ChartState &other) {
+ bool operator==(const ChartState &other) const {
return (right == other.right) && (left == other.left);
}
@@ -102,7 +102,7 @@ struct ChartState {
}
bool operator<(const ChartState &other) const {
- return Compare(other) == -1;
+ return Compare(other) < 0;
}
void ZeroRemaining() {
@@ -122,4 +122,4 @@ inline uint64_t hash_value(const ChartState &state) {
} // namespace ngram
} // namespace lm
-#endif // LM_STATE__
+#endif // LM_STATE_H
diff --git a/lm/trie.hh b/lm/trie.hh
index 9ea3c54..cd39298 100644
--- a/lm/trie.hh
+++ b/lm/trie.hh
@@ -1,5 +1,5 @@
-#ifndef LM_TRIE__
-#define LM_TRIE__
+#ifndef LM_TRIE_H
+#define LM_TRIE_H
#include "lm/weights.hh"
#include "lm/word_index.hh"
@@ -62,8 +62,6 @@ class Unigram {
return unigram_;
}
- void LoadedBinary() {}
-
UnigramPointer Find(WordIndex word, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
next.begin = val->next;
@@ -108,8 +106,6 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked {
void FinishedLoading(uint64_t next_end, const Config &config);
- void LoadedBinary() { bhiksha_.LoadedBinary(); }
-
util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
@@ -138,18 +134,13 @@ class BitPackedLongest : public BitPacked {
BaseInit(base, max_vocab, quant_bits);
}
- void LoadedBinary() {}
-
util::BitAddress Insert(WordIndex word);
util::BitAddress Find(WordIndex word, const NodeRange &node) const;
-
- private:
- uint8_t quant_bits_;
};
} // namespace trie
} // namespace ngram
} // namespace lm
-#endif // LM_TRIE__
+#endif // LM_TRIE_H
diff --git a/lm/trie_sort.cc b/lm/trie_sort.cc
index dc542bb..126d43a 100644
--- a/lm/trie_sort.cc
+++ b/lm/trie_sort.cc
@@ -50,6 +50,10 @@ class PartialViewProxy {
const void *Data() const { return inner_.Data(); }
void *Data() { return inner_.Data(); }
+ friend void swap(PartialViewProxy first, PartialViewProxy second) {
+ std::swap_ranges(reinterpret_cast<char*>(first.Data()), reinterpret_cast<char*>(first.Data()) + first.attention_size_, reinterpret_cast<char*>(second.Data()));
+ }
+
private:
friend class util::ProxyIterator<PartialViewProxy>;
diff --git a/lm/trie_sort.hh b/lm/trie_sort.hh
index 1afd956..e5406d9 100644
--- a/lm/trie_sort.hh
+++ b/lm/trie_sort.hh
@@ -1,7 +1,7 @@
// Step of trie builder: create sorted files.
-#ifndef LM_TRIE_SORT__
-#define LM_TRIE_SORT__
+#ifndef LM_TRIE_SORT_H
+#define LM_TRIE_SORT_H
#include "lm/max_order.hh"
#include "lm/word_index.hh"
@@ -111,4 +111,4 @@ class SortedFiles {
} // namespace ngram
} // namespace lm
-#endif // LM_TRIE_SORT__
+#endif // LM_TRIE_SORT_H
diff --git a/lm/value.hh b/lm/value.hh
index ba71671..36e8708 100644
--- a/lm/value.hh
+++ b/lm/value.hh
@@ -1,5 +1,5 @@
-#ifndef LM_VALUE__
-#define LM_VALUE__
+#ifndef LM_VALUE_H
+#define LM_VALUE_H
#include "lm/model_type.hh"
#include "lm/value_build.hh"
@@ -154,4 +154,4 @@ struct RestValue {
} // namespace ngram
} // namespace lm
-#endif // LM_VALUE__
+#endif // LM_VALUE_H
diff --git a/lm/value_build.hh b/lm/value_build.hh
index 461e6a5..6fd26ef 100644
--- a/lm/value_build.hh
+++ b/lm/value_build.hh
@@ -1,5 +1,5 @@
-#ifndef LM_VALUE_BUILD__
-#define LM_VALUE_BUILD__
+#ifndef LM_VALUE_BUILD_H
+#define LM_VALUE_BUILD_H
#include "lm/weights.hh"
#include "lm/word_index.hh"
@@ -94,4 +94,4 @@ template <class Model> class LowerRestBuild {
} // namespace ngram
} // namespace lm
-#endif // LM_VALUE_BUILD__
+#endif // LM_VALUE_BUILD_H
diff --git a/lm/virtual_interface.hh b/lm/virtual_interface.hh
index ff4a388..2a2690e 100644
--- a/lm/virtual_interface.hh
+++ b/lm/virtual_interface.hh
@@ -1,5 +1,5 @@
-#ifndef LM_VIRTUAL_INTERFACE__
-#define LM_VIRTUAL_INTERFACE__
+#ifndef LM_VIRTUAL_INTERFACE_H
+#define LM_VIRTUAL_INTERFACE_H
#include "lm/return.hh"
#include "lm/word_index.hh"
@@ -125,13 +125,13 @@ class Model {
void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }
// Requires in_state != out_state
- virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+ virtual float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
// Requires in_state != out_state
- virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+ virtual FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
// Prefer to use FullScore. The context words should be provided in reverse order.
- virtual FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0;
+ virtual FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0;
unsigned char Order() const { return order_; }
@@ -157,4 +157,4 @@ class Model {
} // mamespace base
} // namespace lm
-#endif // LM_VIRTUAL_INTERFACE__
+#endif // LM_VIRTUAL_INTERFACE_H
diff --git a/lm/vocab.cc b/lm/vocab.cc
index fd7f96d..7f0878f 100644
--- a/lm/vocab.cc
+++ b/lm/vocab.cc
@@ -32,7 +32,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
-void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) {
+void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) {
+ util::SeekOrThrow(fd, offset);
// Check that we're at the right place by reading <unk> which is always first.
char check_unk[6];
util::ReadOrThrow(fd, check_unk, 6);
@@ -80,11 +81,6 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
-void WriteWordsWrapper::Write(int fd, uint64_t start) {
- util::SeekOrThrow(fd, start);
- util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
-}
-
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
@@ -100,6 +96,12 @@ void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size
saw_unk_ = false;
}
+void SortedVocabulary::Relocate(void *new_start) {
+ std::size_t delta = end_ - begin_;
+ begin_ = reinterpret_cast<uint64_t*>(new_start) + 1;
+ end_ = begin_ + delta;
+}
+
void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
enumerate_ = to;
if (enumerate_) {
@@ -147,11 +149,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
bound_ = end_ - begin_ + 1;
}
-void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
+void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
- if (have_words) ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_, offset);
}
namespace {
@@ -179,6 +181,11 @@ void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::siz
saw_unk_ = false;
}
+void ProbingVocabulary::Relocate(void *new_start) {
+ header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start);
+ lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)));
+}
+
void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {
enumerate_ = to;
if (enumerate_) {
@@ -206,12 +213,11 @@ void ProbingVocabulary::InternalFinishedLoading() {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
-void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
+void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
- lookup_.LoadedBinary();
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
- if (have_words) ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_, offset);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
diff --git a/lm/vocab.hh b/lm/vocab.hh
index 226ae43..dcd298a 100644
--- a/lm/vocab.hh
+++ b/lm/vocab.hh
@@ -1,5 +1,5 @@
-#ifndef LM_VOCAB__
-#define LM_VOCAB__
+#ifndef LM_VOCAB_H
+#define LM_VOCAB_H
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
@@ -36,7 +36,7 @@ class WriteWordsWrapper : public EnumerateVocab {
void Add(WordIndex index, const StringPiece &str);
- void Write(int fd, uint64_t start);
+ const std::string &Buffer() const { return buffer_; }
private:
EnumerateVocab *inner_;
@@ -71,6 +71,8 @@ class SortedVocabulary : public base::Vocabulary {
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+ void Relocate(void *new_start);
+
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
@@ -83,15 +85,13 @@ class SortedVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
uint64_t *begin_, *end_;
WordIndex bound_;
- WordIndex highest_value_;
-
bool saw_unk_;
EnumerateVocab *enumerate_;
@@ -140,6 +140,8 @@ class ProbingVocabulary : public base::Vocabulary {
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+ void Relocate(void *new_start);
+
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
@@ -152,7 +154,7 @@ class ProbingVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
void InternalFinishedLoading();
@@ -182,4 +184,4 @@ template <class Vocab> void CheckSpecials(const Config &config, const Vocab &voc
} // namespace ngram
} // namespace lm
-#endif // LM_VOCAB__
+#endif // LM_VOCAB_H
diff --git a/lm/weights.hh b/lm/weights.hh
index bd5d803..da1963d 100644
--- a/lm/weights.hh
+++ b/lm/weights.hh
@@ -1,5 +1,5 @@
-#ifndef LM_WEIGHTS__
-#define LM_WEIGHTS__
+#ifndef LM_WEIGHTS_H
+#define LM_WEIGHTS_H
// Weights for n-grams. Probability and possibly a backoff.
@@ -19,4 +19,4 @@ struct RestWeights {
};
} // namespace lm
-#endif // LM_WEIGHTS__
+#endif // LM_WEIGHTS_H
diff --git a/lm/word_index.hh b/lm/word_index.hh
index e09557a..a5a0fda 100644
--- a/lm/word_index.hh
+++ b/lm/word_index.hh
@@ -1,6 +1,6 @@
// Separate header because this is used often.
-#ifndef LM_WORD_INDEX__
-#define LM_WORD_INDEX__
+#ifndef LM_WORD_INDEX_H
+#define LM_WORD_INDEX_H
#include <limits.h>
diff --git a/lm/wrappers/nplm.cc b/lm/wrappers/nplm.cc
index 6a3fa0d..70622bd 100644
--- a/lm/wrappers/nplm.cc
+++ b/lm/wrappers/nplm.cc
@@ -13,7 +13,7 @@ namespace np {
Vocabulary::Vocabulary(const nplm::vocabulary &vocab)
: base::Vocabulary(vocab.lookup_word("<s>"), vocab.lookup_word("</s>"), vocab.lookup_word("<unk>")),
- vocab_(vocab) {}
+ vocab_(vocab), null_word_(vocab.lookup_word("<null>")) {}
Vocabulary::~Vocabulary() {}
@@ -33,8 +33,11 @@ bool Model::Recognize(const std::string &name) {
}
}
-Model::Model(const std::string &file) : base_instance_(new nplm::neuralLM(file)), vocab_(base_instance_->get_vocabulary()) {
+Model::Model(const std::string &file, std::size_t cache)
+ : base_instance_(new nplm::neuralLM(file)), vocab_(base_instance_->get_vocabulary()), cache_size_(cache) {
UTIL_THROW_IF(base_instance_->get_order() > NPLM_MAX_ORDER, util::Exception, "This NPLM has order " << (unsigned int)base_instance_->get_order() << " but the KenLM wrapper was compiled with " << NPLM_MAX_ORDER << ". Change the defintion of NPLM_MAX_ORDER and recompile.");
+ // log10 compatible with backoff models.
+ base_instance_->set_log_base(10.0);
State begin_sentence, null_context;
std::fill(begin_sentence.words, begin_sentence.words + NPLM_MAX_ORDER - 1, base_instance_->lookup_word("<s>"));
null_word_ = base_instance_->lookup_word("<null>");
@@ -50,6 +53,7 @@ FullScoreReturn Model::FullScore(const State &from, const WordIndex new_word, St
if (!lm) {
lm = new nplm::neuralLM(*base_instance_);
backend_.reset(lm);
+ lm->set_cache(cache_size_);
}
// State is in natural word order.
FullScoreReturn ret;
diff --git a/lm/wrappers/nplm.hh b/lm/wrappers/nplm.hh
index 90f1d49..b7dd4a2 100644
--- a/lm/wrappers/nplm.hh
+++ b/lm/wrappers/nplm.hh
@@ -1,5 +1,5 @@
-#ifndef LM_WRAPPER_NPLM__
-#define LM_WRAPPER_NPLM__
+#ifndef LM_WRAPPERS_NPLM_H
+#define LM_WRAPPERS_NPLM_H
#include "lm/facade.hh"
#include "lm/max_order.hh"
@@ -34,8 +34,12 @@ class Vocabulary : public base::Vocabulary {
return Index(std::string(str.data(), str.size()));
}
+ lm::WordIndex NullWord() const { return null_word_; }
+
private:
const nplm::vocabulary &vocab_;
+
+ const lm::WordIndex null_word_;
};
// Sorry for imposing my limitations on your code.
@@ -53,7 +57,7 @@ class Model : public lm::base::ModelFacade<Model, State, Vocabulary> {
// Does this look like an NPLM?
static bool Recognize(const std::string &file);
- explicit Model(const std::string &file);
+ explicit Model(const std::string &file, std::size_t cache_size = 1 << 20);
~Model();
@@ -69,9 +73,11 @@ class Model : public lm::base::ModelFacade<Model, State, Vocabulary> {
Vocabulary vocab_;
lm::WordIndex null_word_;
+
+ const std::size_t cache_size_;
};
} // namespace np
} // namespace lm
-#endif // LM_WRAPPER_NPLM__
+#endif // LM_WRAPPERS_NPLM_H
diff --git a/python/kenlm.cpp b/python/kenlm.cpp
index d401047..c7052fc 100644
--- a/python/kenlm.cpp
+++ b/python/kenlm.cpp
@@ -1,4 +1,4 @@
-/* Generated by Cython 0.19.1 on Wed Jun 5 13:00:19 2013 */
+/* Generated by Cython 0.19.1 on Tue Jan 28 09:32:20 2014 */
#define PY_SSIZE_T_CLEAN
#ifndef CYTHON_USE_PYLONG_INTERNALS
@@ -493,23 +493,8 @@ static const char *__pyx_f[] = {
};
/*--- Type declarations ---*/
-struct __pyx_obj_5kenlm_LanguageModel;
struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores;
-
-/* "kenlm.pyx":10
- * raise TypeError('Cannot convert %s to string' % type(data))
- *
- * cdef class LanguageModel: # <<<<<<<<<<<<<<
- * cdef Model* model
- * cdef public bytes path
- */
-struct __pyx_obj_5kenlm_LanguageModel {
- PyObject_HEAD
- lm::base::Model *model;
- PyObject *path;
- const lm::base::Vocabulary *vocab;
-};
-
+struct __pyx_obj_5kenlm_LanguageModel;
/* "kenlm.pyx":44
* return total
@@ -532,6 +517,21 @@ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores {
Py_ssize_t __pyx_t_1;
};
+
+/* "kenlm.pyx":10
+ * raise TypeError('Cannot convert %s to string' % type(data))
+ *
+ * cdef class LanguageModel: # <<<<<<<<<<<<<<
+ * cdef Model* model
+ * cdef public bytes path
+ */
+struct __pyx_obj_5kenlm_LanguageModel {
+ PyObject_HEAD
+ lm::base::Model *model;
+ PyObject *path;
+ const lm::base::Vocabulary *vocab;
+};
+
#ifndef CYTHON_REFNANNY
#define CYTHON_REFNANNY 0
#endif
@@ -771,8 +771,8 @@ static int __Pyx_InitStrings(__Pyx_StringTabEntry *t); /*proto*/
/* Module declarations from 'kenlm' */
-static PyTypeObject *__pyx_ptype_5kenlm_LanguageModel = 0;
static PyTypeObject *__pyx_ptype_5kenlm___pyx_scope_struct__full_scores = 0;
+static PyTypeObject *__pyx_ptype_5kenlm_LanguageModel = 0;
static PyObject *__pyx_f_5kenlm_as_str(PyObject *); /*proto*/
#define __Pyx_MODULE_NAME "kenlm"
int __pyx_module_is_main_kenlm = 0;
@@ -792,8 +792,8 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_13__reduce__(struct __pyx_obj_5
static PyObject *__pyx_pf_5kenlm_13LanguageModel_4path___get__(struct __pyx_obj_5kenlm_LanguageModel *__pyx_v_self); /* proto */
static int __pyx_pf_5kenlm_13LanguageModel_4path_2__set__(struct __pyx_obj_5kenlm_LanguageModel *__pyx_v_self, PyObject *__pyx_v_value); /* proto */
static int __pyx_pf_5kenlm_13LanguageModel_4path_4__del__(struct __pyx_obj_5kenlm_LanguageModel *__pyx_v_self); /* proto */
-static PyObject *__pyx_tp_new_5kenlm_LanguageModel(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
static PyObject *__pyx_tp_new_5kenlm___pyx_scope_struct__full_scores(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
+static PyObject *__pyx_tp_new_5kenlm_LanguageModel(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
static char __pyx_k_2[] = "Cannot convert %s to string";
static char __pyx_k_3[] = "\n";
static char __pyx_k_4[] = " ";
@@ -1392,7 +1392,7 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_
* cdef State out_state
* cdef float total = 0 # <<<<<<<<<<<<<<
* for word in words:
- * total += self.model.Score(&state, self.vocab.Index(word), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
*/
__pyx_v_total = 0.0;
@@ -1400,7 +1400,7 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_
* cdef State out_state
* cdef float total = 0
* for word in words: # <<<<<<<<<<<<<<
- * total += self.model.Score(&state, self.vocab.Index(word), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
* state = out_state
*/
if (unlikely(((PyObject *)__pyx_v_words) == Py_None)) {
@@ -1422,18 +1422,18 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_
/* "kenlm.pyx":39
* cdef float total = 0
* for word in words:
- * total += self.model.Score(&state, self.vocab.Index(word), &out_state) # <<<<<<<<<<<<<<
+ * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state) # <<<<<<<<<<<<<<
* state = out_state
- * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state)
*/
__pyx_t_4 = __Pyx_PyObject_AsString(__pyx_v_word); if (unlikely((!__pyx_t_4) && PyErr_Occurred())) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 39; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
- __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->Score((&__pyx_v_state), __pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_v_out_state)));
+ __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->BaseScore((&__pyx_v_state), __pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_v_out_state)));
/* "kenlm.pyx":40
* for word in words:
- * total += self.model.Score(&state, self.vocab.Index(word), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
* state = out_state # <<<<<<<<<<<<<<
- * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state)
* return total
*/
__pyx_v_state = __pyx_v_out_state;
@@ -1441,17 +1441,17 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
/* "kenlm.pyx":41
- * total += self.model.Score(&state, self.vocab.Index(word), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
* state = out_state
- * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state) # <<<<<<<<<<<<<<
+ * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state) # <<<<<<<<<<<<<<
* return total
*
*/
- __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->Score((&__pyx_v_state), __pyx_v_self->vocab->EndSentence(), (&__pyx_v_out_state)));
+ __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->BaseScore((&__pyx_v_state), __pyx_v_self->vocab->EndSentence(), (&__pyx_v_out_state)));
/* "kenlm.pyx":42
* state = out_state
- * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state)
* return total # <<<<<<<<<<<<<<
*
* def full_scores(self, sentence):
@@ -1597,7 +1597,7 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec
* cdef FullScoreReturn ret
* cdef float total = 0 # <<<<<<<<<<<<<<
* for word in words:
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
*/
__pyx_cur_scope->__pyx_v_total = 0.0;
@@ -1605,7 +1605,7 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec
* cdef FullScoreReturn ret
* cdef float total = 0
* for word in words: # <<<<<<<<<<<<<<
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.Index(word), &out_state)
*/
if (unlikely(((PyObject *)__pyx_cur_scope->__pyx_v_words) == Py_None)) {
@@ -1628,20 +1628,20 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec
/* "kenlm.pyx":53
* for word in words:
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.Index(word), &out_state) # <<<<<<<<<<<<<<
* yield (ret.prob, ret.ngram_length)
* state = out_state
*/
__pyx_t_4 = __Pyx_PyObject_AsString(__pyx_cur_scope->__pyx_v_word); if (unlikely((!__pyx_t_4) && PyErr_Occurred())) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 53; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
- __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->FullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_cur_scope->__pyx_v_out_state));
+ __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->BaseFullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_cur_scope->__pyx_v_out_state));
/* "kenlm.pyx":54
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.Index(word), &out_state)
* yield (ret.prob, ret.ngram_length) # <<<<<<<<<<<<<<
* state = out_state
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
*/
__pyx_t_2 = PyFloat_FromDouble(__pyx_cur_scope->__pyx_v_ret.prob); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 54; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_2);
@@ -1676,7 +1676,7 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec
* self.vocab.Index(word), &out_state)
* yield (ret.prob, ret.ngram_length)
* state = out_state # <<<<<<<<<<<<<<
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.EndSentence(), &out_state)
*/
__pyx_cur_scope->__pyx_v_state = __pyx_cur_scope->__pyx_v_out_state;
@@ -1685,15 +1685,15 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec
/* "kenlm.pyx":57
* state = out_state
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.EndSentence(), &out_state) # <<<<<<<<<<<<<<
* yield (ret.prob, ret.ngram_length)
*
*/
- __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->FullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->EndSentence(), (&__pyx_cur_scope->__pyx_v_out_state));
+ __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->BaseFullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->EndSentence(), (&__pyx_cur_scope->__pyx_v_out_state));
/* "kenlm.pyx":58
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.EndSentence(), &out_state)
* yield (ret.prob, ret.ngram_length) # <<<<<<<<<<<<<<
*
@@ -2047,6 +2047,147 @@ static int __pyx_pf_5kenlm_13LanguageModel_4path_4__del__(struct __pyx_obj_5kenl
return __pyx_r;
}
+static struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[8];
+static int __pyx_freecount_5kenlm___pyx_scope_struct__full_scores = 0;
+
+static PyObject *__pyx_tp_new_5kenlm___pyx_scope_struct__full_scores(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) {
+ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p;
+ PyObject *o;
+ if (likely((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores > 0) & (t->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores)))) {
+ o = (PyObject*)__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[--__pyx_freecount_5kenlm___pyx_scope_struct__full_scores];
+ memset(o, 0, sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores));
+ PyObject_INIT(o, t);
+ PyObject_GC_Track(o);
+ } else {
+ o = (*t->tp_alloc)(t, 0);
+ if (unlikely(!o)) return 0;
+ }
+ p = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o);
+ p->__pyx_v_self = 0;
+ p->__pyx_v_sentence = 0;
+ p->__pyx_v_word = 0;
+ p->__pyx_v_words = 0;
+ p->__pyx_t_0 = 0;
+ return o;
+}
+
+static void __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores(PyObject *o) {
+ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
+ PyObject_GC_UnTrack(o);
+ Py_CLEAR(p->__pyx_v_self);
+ Py_CLEAR(p->__pyx_v_sentence);
+ Py_CLEAR(p->__pyx_v_word);
+ Py_CLEAR(p->__pyx_v_words);
+ Py_CLEAR(p->__pyx_t_0);
+ if ((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores < 8) & (Py_TYPE(o)->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores))) {
+ __pyx_freelist_5kenlm___pyx_scope_struct__full_scores[__pyx_freecount_5kenlm___pyx_scope_struct__full_scores++] = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o);
+ } else {
+ (*Py_TYPE(o)->tp_free)(o);
+ }
+}
+
+static int __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores(PyObject *o, visitproc v, void *a) {
+ int e;
+ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
+ if (p->__pyx_v_self) {
+ e = (*v)(((PyObject*)p->__pyx_v_self), a); if (e) return e;
+ }
+ if (p->__pyx_v_sentence) {
+ e = (*v)(p->__pyx_v_sentence, a); if (e) return e;
+ }
+ if (p->__pyx_v_word) {
+ e = (*v)(p->__pyx_v_word, a); if (e) return e;
+ }
+ if (p->__pyx_v_words) {
+ e = (*v)(p->__pyx_v_words, a); if (e) return e;
+ }
+ if (p->__pyx_t_0) {
+ e = (*v)(p->__pyx_t_0, a); if (e) return e;
+ }
+ return 0;
+}
+
+static int __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores(PyObject *o) {
+ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
+ PyObject* tmp;
+ tmp = ((PyObject*)p->__pyx_v_self);
+ p->__pyx_v_self = ((struct __pyx_obj_5kenlm_LanguageModel *)Py_None); Py_INCREF(Py_None);
+ Py_XDECREF(tmp);
+ tmp = ((PyObject*)p->__pyx_v_sentence);
+ p->__pyx_v_sentence = Py_None; Py_INCREF(Py_None);
+ Py_XDECREF(tmp);
+ tmp = ((PyObject*)p->__pyx_v_word);
+ p->__pyx_v_word = Py_None; Py_INCREF(Py_None);
+ Py_XDECREF(tmp);
+ tmp = ((PyObject*)p->__pyx_v_words);
+ p->__pyx_v_words = ((PyObject*)Py_None); Py_INCREF(Py_None);
+ Py_XDECREF(tmp);
+ tmp = ((PyObject*)p->__pyx_t_0);
+ p->__pyx_t_0 = Py_None; Py_INCREF(Py_None);
+ Py_XDECREF(tmp);
+ return 0;
+}
+
+static PyMethodDef __pyx_methods_5kenlm___pyx_scope_struct__full_scores[] = {
+ {0, 0, 0, 0}
+};
+
+static PyTypeObject __pyx_type_5kenlm___pyx_scope_struct__full_scores = {
+ PyVarObject_HEAD_INIT(0, 0)
+ __Pyx_NAMESTR("kenlm.__pyx_scope_struct__full_scores"), /*tp_name*/
+ sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores), /*tp_basicsize*/
+ 0, /*tp_itemsize*/
+ __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores, /*tp_dealloc*/
+ 0, /*tp_print*/
+ 0, /*tp_getattr*/
+ 0, /*tp_setattr*/
+ #if PY_MAJOR_VERSION < 3
+ 0, /*tp_compare*/
+ #else
+ 0, /*reserved*/
+ #endif
+ 0, /*tp_repr*/
+ 0, /*tp_as_number*/
+ 0, /*tp_as_sequence*/
+ 0, /*tp_as_mapping*/
+ 0, /*tp_hash*/
+ 0, /*tp_call*/
+ 0, /*tp_str*/
+ 0, /*tp_getattro*/
+ 0, /*tp_setattro*/
+ 0, /*tp_as_buffer*/
+ Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_HAVE_GC, /*tp_flags*/
+ 0, /*tp_doc*/
+ __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores, /*tp_traverse*/
+ __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores, /*tp_clear*/
+ 0, /*tp_richcompare*/
+ 0, /*tp_weaklistoffset*/
+ 0, /*tp_iter*/
+ 0, /*tp_iternext*/
+ __pyx_methods_5kenlm___pyx_scope_struct__full_scores, /*tp_methods*/
+ 0, /*tp_members*/
+ 0, /*tp_getset*/
+ 0, /*tp_base*/
+ 0, /*tp_dict*/
+ 0, /*tp_descr_get*/
+ 0, /*tp_descr_set*/
+ 0, /*tp_dictoffset*/
+ 0, /*tp_init*/
+ 0, /*tp_alloc*/
+ __pyx_tp_new_5kenlm___pyx_scope_struct__full_scores, /*tp_new*/
+ 0, /*tp_free*/
+ 0, /*tp_is_gc*/
+ 0, /*tp_bases*/
+ 0, /*tp_mro*/
+ 0, /*tp_cache*/
+ 0, /*tp_subclasses*/
+ 0, /*tp_weaklist*/
+ 0, /*tp_del*/
+ #if PY_VERSION_HEX >= 0x02060000
+ 0, /*tp_version_tag*/
+ #endif
+};
+
static PyObject *__pyx_tp_new_5kenlm_LanguageModel(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) {
struct __pyx_obj_5kenlm_LanguageModel *p;
PyObject *o;
@@ -2190,147 +2331,6 @@ static PyTypeObject __pyx_type_5kenlm_LanguageModel = {
#endif
};
-static struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[8];
-static int __pyx_freecount_5kenlm___pyx_scope_struct__full_scores = 0;
-
-static PyObject *__pyx_tp_new_5kenlm___pyx_scope_struct__full_scores(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) {
- struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p;
- PyObject *o;
- if (likely((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores > 0) & (t->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores)))) {
- o = (PyObject*)__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[--__pyx_freecount_5kenlm___pyx_scope_struct__full_scores];
- memset(o, 0, sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores));
- PyObject_INIT(o, t);
- PyObject_GC_Track(o);
- } else {
- o = (*t->tp_alloc)(t, 0);
- if (unlikely(!o)) return 0;
- }
- p = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o);
- p->__pyx_v_self = 0;
- p->__pyx_v_sentence = 0;
- p->__pyx_v_word = 0;
- p->__pyx_v_words = 0;
- p->__pyx_t_0 = 0;
- return o;
-}
-
-static void __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores(PyObject *o) {
- struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
- PyObject_GC_UnTrack(o);
- Py_CLEAR(p->__pyx_v_self);
- Py_CLEAR(p->__pyx_v_sentence);
- Py_CLEAR(p->__pyx_v_word);
- Py_CLEAR(p->__pyx_v_words);
- Py_CLEAR(p->__pyx_t_0);
- if ((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores < 8) & (Py_TYPE(o)->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores))) {
- __pyx_freelist_5kenlm___pyx_scope_struct__full_scores[__pyx_freecount_5kenlm___pyx_scope_struct__full_scores++] = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o);
- } else {
- (*Py_TYPE(o)->tp_free)(o);
- }
-}
-
-static int __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores(PyObject *o, visitproc v, void *a) {
- int e;
- struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
- if (p->__pyx_v_self) {
- e = (*v)(((PyObject*)p->__pyx_v_self), a); if (e) return e;
- }
- if (p->__pyx_v_sentence) {
- e = (*v)(p->__pyx_v_sentence, a); if (e) return e;
- }
- if (p->__pyx_v_word) {
- e = (*v)(p->__pyx_v_word, a); if (e) return e;
- }
- if (p->__pyx_v_words) {
- e = (*v)(p->__pyx_v_words, a); if (e) return e;
- }
- if (p->__pyx_t_0) {
- e = (*v)(p->__pyx_t_0, a); if (e) return e;
- }
- return 0;
-}
-
-static int __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores(PyObject *o) {
- struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
- PyObject* tmp;
- tmp = ((PyObject*)p->__pyx_v_self);
- p->__pyx_v_self = ((struct __pyx_obj_5kenlm_LanguageModel *)Py_None); Py_INCREF(Py_None);
- Py_XDECREF(tmp);
- tmp = ((PyObject*)p->__pyx_v_sentence);
- p->__pyx_v_sentence = Py_None; Py_INCREF(Py_None);
- Py_XDECREF(tmp);
- tmp = ((PyObject*)p->__pyx_v_word);
- p->__pyx_v_word = Py_None; Py_INCREF(Py_None);
- Py_XDECREF(tmp);
- tmp = ((PyObject*)p->__pyx_v_words);
- p->__pyx_v_words = ((PyObject*)Py_None); Py_INCREF(Py_None);
- Py_XDECREF(tmp);
- tmp = ((PyObject*)p->__pyx_t_0);
- p->__pyx_t_0 = Py_None; Py_INCREF(Py_None);
- Py_XDECREF(tmp);
- return 0;
-}
-
-static PyMethodDef __pyx_methods_5kenlm___pyx_scope_struct__full_scores[] = {
- {0, 0, 0, 0}
-};
-
-static PyTypeObject __pyx_type_5kenlm___pyx_scope_struct__full_scores = {
- PyVarObject_HEAD_INIT(0, 0)
- __Pyx_NAMESTR("kenlm.__pyx_scope_struct__full_scores"), /*tp_name*/
- sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores), /*tp_basicsize*/
- 0, /*tp_itemsize*/
- __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores, /*tp_dealloc*/
- 0, /*tp_print*/
- 0, /*tp_getattr*/
- 0, /*tp_setattr*/
- #if PY_MAJOR_VERSION < 3
- 0, /*tp_compare*/
- #else
- 0, /*reserved*/
- #endif
- 0, /*tp_repr*/
- 0, /*tp_as_number*/
- 0, /*tp_as_sequence*/
- 0, /*tp_as_mapping*/
- 0, /*tp_hash*/
- 0, /*tp_call*/
- 0, /*tp_str*/
- 0, /*tp_getattro*/
- 0, /*tp_setattro*/
- 0, /*tp_as_buffer*/
- Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_HAVE_GC, /*tp_flags*/
- 0, /*tp_doc*/
- __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores, /*tp_traverse*/
- __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores, /*tp_clear*/
- 0, /*tp_richcompare*/
- 0, /*tp_weaklistoffset*/
- 0, /*tp_iter*/
- 0, /*tp_iternext*/
- __pyx_methods_5kenlm___pyx_scope_struct__full_scores, /*tp_methods*/
- 0, /*tp_members*/
- 0, /*tp_getset*/
- 0, /*tp_base*/
- 0, /*tp_dict*/
- 0, /*tp_descr_get*/
- 0, /*tp_descr_set*/
- 0, /*tp_dictoffset*/
- 0, /*tp_init*/
- 0, /*tp_alloc*/
- __pyx_tp_new_5kenlm___pyx_scope_struct__full_scores, /*tp_new*/
- 0, /*tp_free*/
- 0, /*tp_is_gc*/
- 0, /*tp_bases*/
- 0, /*tp_mro*/
- 0, /*tp_cache*/
- 0, /*tp_subclasses*/
- 0, /*tp_weaklist*/
- 0, /*tp_del*/
- #if PY_VERSION_HEX >= 0x02060000
- 0, /*tp_version_tag*/
- #endif
-};
-
static PyMethodDef __pyx_methods[] = {
{0, 0, 0, 0}
};
@@ -2508,11 +2508,11 @@ PyMODINIT_FUNC PyInit_kenlm(void)
/*--- Variable export code ---*/
/*--- Function export code ---*/
/*--- Type init code ---*/
+ if (PyType_Ready(&__pyx_type_5kenlm___pyx_scope_struct__full_scores) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 44; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_ptype_5kenlm___pyx_scope_struct__full_scores = &__pyx_type_5kenlm___pyx_scope_struct__full_scores;
if (PyType_Ready(&__pyx_type_5kenlm_LanguageModel) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 10; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
if (__Pyx_SetAttrString(__pyx_m, "LanguageModel", (PyObject *)&__pyx_type_5kenlm_LanguageModel) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 10; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__pyx_ptype_5kenlm_LanguageModel = &__pyx_type_5kenlm_LanguageModel;
- if (PyType_Ready(&__pyx_type_5kenlm___pyx_scope_struct__full_scores) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 44; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
- __pyx_ptype_5kenlm___pyx_scope_struct__full_scores = &__pyx_type_5kenlm___pyx_scope_struct__full_scores;
/*--- Type import code ---*/
/*--- Variable import code ---*/
/*--- Function import code ---*/
diff --git a/python/kenlm.pxd b/python/kenlm.pxd
index 9a397f3..7d68fc4 100644
--- a/python/kenlm.pxd
+++ b/python/kenlm.pxd
@@ -24,8 +24,8 @@ cdef extern from "lm/virtual_interface.hh" namespace "lm::base":
void NullContextWrite(void *)
unsigned int Order()
const_Vocabulary& BaseVocabulary()
- float Score(void *in_state, WordIndex new_word, void *out_state)
- FullScoreReturn FullScore(void *in_state, WordIndex new_word, void *out_state)
+ float BaseScore(void *in_state, WordIndex new_word, void *out_state)
+ FullScoreReturn BaseFullScore(void *in_state, WordIndex new_word, void *out_state)
cdef extern from "lm/model.hh" namespace "lm::ngram":
cdef Model *LoadVirtual(char *) except +
diff --git a/python/kenlm.pyx b/python/kenlm.pyx
index 7f965d4..eb45169 100644
--- a/python/kenlm.pyx
+++ b/python/kenlm.pyx
@@ -36,9 +36,9 @@ cdef class LanguageModel:
cdef State out_state
cdef float total = 0
for word in words:
- total += self.model.Score(&state, self.vocab.Index(word), &out_state)
+ total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
state = out_state
- total += self.model.Score(&state, self.vocab.EndSentence(), &out_state)
+ total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state)
return total
def full_scores(self, sentence):
@@ -49,11 +49,11 @@ cdef class LanguageModel:
cdef FullScoreReturn ret
cdef float total = 0
for word in words:
- ret = self.model.FullScore(&state,
+ ret = self.model.BaseFullScore(&state,
self.vocab.Index(word), &out_state)
yield (ret.prob, ret.ngram_length)
state = out_state
- ret = self.model.FullScore(&state,
+ ret = self.model.BaseFullScore(&state,
self.vocab.EndSentence(), &out_state)
yield (ret.prob, ret.ngram_length)
diff --git a/util/Jamfile b/util/Jamfile
index 910b305..18b20a3 100644
--- a/util/Jamfile
+++ b/util/Jamfile
@@ -19,15 +19,18 @@ alias read_compressed : read_compressed.o $(compressed_deps) ;
obj read_compressed_test.o : read_compressed_test.cc /top//boost_unit_test_framework : $(compressed_flags) ;
obj file_piece_test.o : file_piece_test.cc /top//boost_unit_test_framework : $(compressed_flags) ;
-fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ;
+fakelib parallel_read : parallel_read.cc : <threading>multi:<source>/top//boost_thread <threading>multi:<define>WITH_THREADS : : <include>.. ;
+
+fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc parallel_read pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ;
+
+exe cat_compressed : cat_compressed_main.cc kenutil ;
+
+alias programs : cat_compressed ;
import testing ;
-unit-test bit_packing_test : bit_packing_test.cc kenutil /top//boost_unit_test_framework ;
run file_piece_test.o kenutil /top//boost_unit_test_framework : : file_piece.cc ;
-unit-test read_compressed_test : read_compressed_test.o kenutil /top//boost_unit_test_framework ;
-unit-test joint_sort_test : joint_sort_test.cc kenutil /top//boost_unit_test_framework ;
-unit-test probing_hash_table_test : probing_hash_table_test.cc kenutil /top//boost_unit_test_framework ;
-unit-test sorted_uniform_test : sorted_uniform_test.cc kenutil /top//boost_unit_test_framework ;
-unit-test tokenize_piece_test : tokenize_piece_test.cc kenutil /top//boost_unit_test_framework ;
-unit-test multi_intersection_test : multi_intersection_test.cc kenutil /top//boost_unit_test_framework ;
+for local t in [ glob *_test.cc : file_piece_test.cc read_compressed_test.cc ] {
+ local name = [ MATCH "(.*)\.cc" : $(t) ] ;
+ unit-test $(name) : $(t) kenutil /top//boost_unit_test_framework /top//boost_system ;
+}
diff --git a/util/bit_packing.hh b/util/bit_packing.hh
index dcbd814..1e34d9a 100644
--- a/util/bit_packing.hh
+++ b/util/bit_packing.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_BIT_PACKING__
-#define UTIL_BIT_PACKING__
+#ifndef UTIL_BIT_PACKING_H
+#define UTIL_BIT_PACKING_H
/* Bit-level packing routines
*
@@ -183,4 +183,4 @@ struct BitAddress {
} // namespace util
-#endif // UTIL_BIT_PACKING__
+#endif // UTIL_BIT_PACKING_H
diff --git a/util/cat_compressed_main.cc b/util/cat_compressed_main.cc
new file mode 100644
index 0000000..2b4d729
--- /dev/null
+++ b/util/cat_compressed_main.cc
@@ -0,0 +1,47 @@
+// Like cat but interprets compressed files.
+#include "util/file.hh"
+#include "util/read_compressed.hh"
+
+#include <string.h>
+#include <iostream>
+
+namespace {
+const std::size_t kBufSize = 16384;
+void Copy(util::ReadCompressed &from, int to) {
+ util::scoped_malloc buffer(util::MallocOrThrow(kBufSize));
+ while (std::size_t amount = from.Read(buffer.get(), kBufSize)) {
+ util::WriteOrThrow(to, buffer.get(), amount);
+ }
+}
+} // namespace
+
+int main(int argc, char *argv[]) {
+ // Lane Schwartz likes -h and --help
+ for (int i = 1; i < argc; ++i) {
+ char *arg = argv[i];
+ if (!strcmp(arg, "--")) break;
+ if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) {
+ std::cerr <<
+ "A cat implementation that interprets compressed files.\n"
+ "Usage: " << argv[0] << " [file1] [file2] ...\n"
+ "If no file is provided, then stdin is read.\n";
+ return 1;
+ }
+ }
+
+ try {
+ if (argc == 1) {
+ util::ReadCompressed in(0);
+ Copy(in, 1);
+ } else {
+ for (int i = 1; i < argc; ++i) {
+ util::ReadCompressed in(util::OpenReadOrThrow(argv[i]));
+ Copy(in, 1);
+ }
+ }
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 2;
+ }
+ return 0;
+}
diff --git a/util/ersatz_progress.hh b/util/ersatz_progress.hh
index b94399a..535dbde 100644
--- a/util/ersatz_progress.hh
+++ b/util/ersatz_progress.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_ERSATZ_PROGRESS__
-#define UTIL_ERSATZ_PROGRESS__
+#ifndef UTIL_ERSATZ_PROGRESS_H
+#define UTIL_ERSATZ_PROGRESS_H
#include <iostream>
#include <string>
@@ -55,4 +55,4 @@ class ErsatzProgress {
} // namespace util
-#endif // UTIL_ERSATZ_PROGRESS__
+#endif // UTIL_ERSATZ_PROGRESS_H
diff --git a/util/exception.cc b/util/exception.cc
index 557c398..083bac2 100644
--- a/util/exception.cc
+++ b/util/exception.cc
@@ -51,6 +51,11 @@ void Exception::SetLocation(const char *file, unsigned int line, const char *fun
}
namespace {
+// At least one of these functions will not be called.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-function"
+#endif
// The XOPEN version.
const char *HandleStrerror(int ret, const char *buf) {
if (!ret) return buf;
@@ -61,6 +66,9 @@ const char *HandleStrerror(int ret, const char *buf) {
const char *HandleStrerror(const char *ret, const char * /*buf*/) {
return ret;
}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
} // namespace
ErrnoException::ErrnoException() throw() : errno_(errno) {
diff --git a/util/exception.hh b/util/exception.hh
index 74046cf..0966740 100644
--- a/util/exception.hh
+++ b/util/exception.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_EXCEPTION__
-#define UTIL_EXCEPTION__
+#ifndef UTIL_EXCEPTION_H
+#define UTIL_EXCEPTION_H
#include <exception>
#include <limits>
@@ -98,6 +98,9 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep
#define UTIL_THROW_IF(Condition, Exception, Modify) \
UTIL_THROW_IF_ARG(Condition, Exception, , Modify)
+#define UTIL_THROW_IF2(Condition, Modify) \
+ UTIL_THROW_IF_ARG(Condition, util::Exception, , Modify)
+
// Exception that records errno and adds it to the message.
class ErrnoException : public Exception {
public:
@@ -111,6 +114,13 @@ class ErrnoException : public Exception {
int errno_;
};
+// file wasn't there, or couldn't be open for some reason
+class FileOpenException : public Exception {
+ public:
+ FileOpenException() throw() {}
+ ~FileOpenException() throw() {}
+};
+
// Utilities for overflow checking.
class OverflowException : public Exception {
public:
@@ -133,4 +143,4 @@ inline std::size_t CheckOverflow(uint64_t value) {
} // namespace util
-#endif // UTIL_EXCEPTION__
+#endif // UTIL_EXCEPTION_H
diff --git a/util/fake_ofstream.hh b/util/fake_ofstream.hh
index bcdebe4..eefb1ed 100644
--- a/util/fake_ofstream.hh
+++ b/util/fake_ofstream.hh
@@ -2,6 +2,9 @@
* Does not support many data types. Currently, it's targeted at writing ARPA
* files quickly.
*/
+#ifndef UTIL_FAKE_OFSTREAM_H
+#define UTIL_FAKE_OFSTREAM_H
+
#include "util/double-conversion/double-conversion.h"
#include "util/double-conversion/utils.h"
#include "util/file.hh"
@@ -17,7 +20,8 @@ class FakeOFStream {
static const std::size_t kOutBuf = 1048576;
// Does not take ownership of out.
- explicit FakeOFStream(int out)
+ // Allows default constructor, but must call SetFD.
+ explicit FakeOFStream(int out = -1)
: buf_(util::MallocOrThrow(kOutBuf)),
builder_(static_cast<char*>(buf_.get()), kOutBuf),
// Mostly the default but with inf instead. And no flags.
@@ -28,6 +32,11 @@ class FakeOFStream {
if (buf_.get()) Flush();
}
+ void SetFD(int to) {
+ if (builder_.position()) Flush();
+ fd_ = to;
+ }
+
FakeOFStream &operator<<(float value) {
// Odd, but this is the largest number found in the comments.
EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
@@ -92,3 +101,5 @@ class FakeOFStream {
};
} // namespace
+
+#endif
diff --git a/util/file.cc b/util/file.cc
index bef04cb..0d9adf2 100644
--- a/util/file.cc
+++ b/util/file.cc
@@ -17,7 +17,11 @@
#include <fcntl.h>
#include <stdint.h>
-#if defined(_WIN32) || defined(_WIN64)
+#if defined __MINGW32__
+#include <windows.h>
+#include <unistd.h>
+#warning "The file functions on MinGW have not been tested for file sizes above 2^31 - 1. Please read https://stackoverflow.com/questions/12539488/determine-64-bit-file-size-in-c-on-mingw-32-bit and fix"
+#elif defined(_WIN32) || defined(_WIN64)
#include <windows.h>
#include <io.h>
#include <algorithm>
@@ -76,7 +80,13 @@ int CreateOrThrow(const char *name) {
}
uint64_t SizeFile(int fd) {
-#if defined(_WIN32) || defined(_WIN64)
+#if defined __MINGW32__
+ struct stat sb;
+ // Does this handle 64-bit?
+ int ret = fstat(fd, &sb);
+ if (ret == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize;
+ return sb.st_size;
+#elif defined(_WIN32) || defined(_WIN64)
__int64 ret = _filelengthi64(fd);
return (ret == -1) ? kBadSize : ret;
#else // Not windows.
@@ -100,7 +110,10 @@ uint64_t SizeOrThrow(int fd) {
}
void ResizeOrThrow(int fd, uint64_t to) {
-#if defined(_WIN32) || defined(_WIN64)
+#if defined __MINGW32__
+ // Does this handle 64-bit?
+ int ret = ftruncate
+#elif defined(_WIN32) || defined(_WIN64)
errno_t ret = _chsize_s
#elif defined(OS_ANDROID)
int ret = ftruncate64
@@ -115,8 +128,10 @@ namespace {
std::size_t GuardLarge(std::size_t size) {
// The following operating systems have broken read/write/pread/pwrite that
// only supports up to 2^31.
-#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID)
- return std::min(static_cast<std::size_t>(static_cast<unsigned>(-1)), size);
+ // OS X man pages claim to support 64-bit, but Kareem M. Darwish had problems
+ // building with larger files, so APPLE is also here.
+#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID) || defined(__MINGW32__)
+ return size < INT_MAX ? size : INT_MAX;
#else
return size;
#endif
@@ -179,16 +194,15 @@ void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) {
#else
ssize_t ret;
errno = 0;
- do {
- ret =
+ ret =
#ifdef OS_ANDROID
- pread64
+ pread64
#else
- pread
+ pread
#endif
- (fd, to, GuardLarge(size), off);
- } while (ret == -1 && errno == EINTR);
+ (fd, to, GuardLarge(size), off);
if (ret <= 0) {
+ if (ret == -1 && errno == EINTR) continue;
UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from " << NameFromFD(fd));
UTIL_THROW_ARG(FDException, (fd), "while reading " << size << " bytes at offset " << off);
}
@@ -251,7 +265,10 @@ typedef CheckOffT<sizeof(off_t)>::True IgnoredType;
// Can't we all just get along?
void InternalSeek(int fd, int64_t off, int whence) {
if (
-#if defined(_WIN32) || defined(_WIN64)
+#if defined __MINGW32__
+ // Does this handle 64-bit?
+ (off_t)-1 == lseek(fd, off, whence)
+#elif defined(_WIN32) || defined(_WIN64)
(__int64)-1 == _lseeki64(fd, off, whence)
#elif defined(OS_ANDROID)
(off64_t)-1 == lseek64(fd, off, whence)
@@ -427,8 +444,8 @@ void NormalizeTempPrefix(std::string &base) {
) base += '/';
}
-int MakeTemp(const std::string &base) {
- std::string name(base);
+int MakeTemp(const StringPiece &base) {
+ std::string name(base.data(), base.size());
name += "XXXXXX";
name.push_back(0);
int ret;
@@ -436,7 +453,7 @@ int MakeTemp(const std::string &base) {
return ret;
}
-std::FILE *FMakeTemp(const std::string &base) {
+std::FILE *FMakeTemp(const StringPiece &base) {
util::scoped_fd file(MakeTemp(base));
return FDOpenOrThrow(file);
}
@@ -462,14 +479,18 @@ bool TryName(int fd, std::string &out) {
if (-1 == lstat(name.c_str(), &sb))
return false;
out.resize(sb.st_size + 1);
- ssize_t ret = readlink(name.c_str(), &out[0], sb.st_size + 1);
- if (-1 == ret)
- return false;
- if (ret > sb.st_size) {
- // Increased in size?!
- return false;
+ // lstat gave us a size, but I've seen it grow, possibly due to symlinks on top of symlinks.
+ while (true) {
+ ssize_t ret = readlink(name.c_str(), &out[0], out.size());
+ if (-1 == ret)
+ return false;
+ if ((size_t)ret < out.size()) {
+ out.resize(ret);
+ break;
+ }
+ // Exponential growth.
+ out.resize(out.size() * 2);
}
- out.resize(ret);
// Don't use the non-file names.
if (!out.empty() && out[0] != '/')
return false;
diff --git a/util/file.hh b/util/file.hh
index be88431..170a7c7 100644
--- a/util/file.hh
+++ b/util/file.hh
@@ -1,7 +1,8 @@
-#ifndef UTIL_FILE__
-#define UTIL_FILE__
+#ifndef UTIL_FILE_H
+#define UTIL_FILE_H
#include "util/exception.hh"
+#include "util/string_piece.hh"
#include <cstddef>
#include <cstdio>
@@ -125,8 +126,8 @@ std::FILE *FDOpenReadOrThrow(scoped_fd &file);
// Temporary files
// Append a / if base is a directory.
void NormalizeTempPrefix(std::string &base);
-int MakeTemp(const std::string &prefix);
-std::FILE *FMakeTemp(const std::string &prefix);
+int MakeTemp(const StringPiece &prefix);
+std::FILE *FMakeTemp(const StringPiece &prefix);
// dup an fd.
int DupOrThrow(int fd);
@@ -139,4 +140,4 @@ std::string NameFromFD(int fd);
} // namespace util
-#endif // UTIL_FILE__
+#endif // UTIL_FILE_H
diff --git a/util/file_piece.cc b/util/file_piece.cc
index 9c7e00c..4aaa250 100644
--- a/util/file_piece.cc
+++ b/util/file_piece.cc
@@ -84,6 +84,13 @@ StringPiece FilePiece::ReadLine(char delim) {
}
}
+bool FilePiece::ReadLineOrEOF(StringPiece &to, char delim) {
+ try {
+ to = ReadLine(delim);
+ } catch (const util::EndOfFileException &e) { return false; }
+ return true;
+}
+
float FilePiece::ReadFloat() {
return ReadNumber<float>();
}
diff --git a/util/file_piece.hh b/util/file_piece.hh
index ed3dc5a..5495ddc 100644
--- a/util/file_piece.hh
+++ b/util/file_piece.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_FILE_PIECE__
-#define UTIL_FILE_PIECE__
+#ifndef UTIL_FILE_PIECE_H
+#define UTIL_FILE_PIECE_H
#include "util/ersatz_progress.hh"
#include "util/exception.hh"
@@ -56,10 +56,33 @@ class FilePiece {
return Consume(FindDelimiterOrEOF(delim));
}
+ // Read word until the line or file ends.
+ bool ReadWordSameLine(StringPiece &to, const bool *delim = kSpaces) {
+ assert(delim[static_cast<unsigned char>('\n')]);
+ // Skip non-enter spaces.
+ for (; ; ++position_) {
+ if (position_ == position_end_) {
+ try {
+ Shift();
+ } catch (const util::EndOfFileException &e) { return false; }
+ // And break out at end of file.
+ if (position_ == position_end_) return false;
+ }
+ if (!delim[static_cast<unsigned char>(*position_)]) break;
+ if (*position_ == '\n') return false;
+ }
+ // We can't be at the end of file because there's at least one character open.
+ to = Consume(FindDelimiterOrEOF(delim));
+ return true;
+ }
+
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
// It is similar to getline in that way.
StringPiece ReadLine(char delim = '\n');
+ // Doesn't throw EndOfFileException, just returns false.
+ bool ReadLineOrEOF(StringPiece &to, char delim = '\n');
+
float ReadFloat();
double ReadDouble();
long int ReadLong();
@@ -132,4 +155,4 @@ class FilePiece {
} // namespace util
-#endif // UTIL_FILE_PIECE__
+#endif // UTIL_FILE_PIECE_H
diff --git a/util/file_piece_test.cc b/util/file_piece_test.cc
index 7336007..4361877 100644
--- a/util/file_piece_test.cc
+++ b/util/file_piece_test.cc
@@ -1,4 +1,4 @@
-// Tests might fail if you have creative characters in your path. Sue me.
+// Tests might fail if you have creative characters in your path. Sue me.
#include "util/file_piece.hh"
#include "util/file.hh"
@@ -55,7 +55,7 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) {
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
/* Apple isn't happy with the popen, fileno, dup. And I don't want to
- * reimplement popen. This is an issue with the test.
+ * reimplement popen. This is an issue with the test.
*/
/* read() implementation */
BOOST_AUTO_TEST_CASE(StreamReadLine) {
@@ -67,7 +67,7 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) {
FILE *catter = popen(popen_args.c_str(), "r");
BOOST_REQUIRE(catter);
-
+
FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
@@ -107,8 +107,8 @@ BOOST_AUTO_TEST_CASE(PlainZipReadLine) {
}
// gzip stream. Apple doesn't like popen, fileno, dup. This is an issue with
-// the test.
-#ifndef __APPLE__
+// the test.
+#if !defined __APPLE__ && !defined __MINGW32__
BOOST_AUTO_TEST_CASE(StreamZipReadLine) {
std::fstream ref(FileLocation().c_str(), std::ios::in);
@@ -117,7 +117,7 @@ BOOST_AUTO_TEST_CASE(StreamZipReadLine) {
FILE * catter = popen(command.c_str(), "r");
BOOST_REQUIRE(catter);
-
+
FilePiece test(dup(fileno(catter)), "file_piece.cc.gz", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
diff --git a/util/fixed_array.hh b/util/fixed_array.hh
new file mode 100644
index 0000000..bae13de
--- /dev/null
+++ b/util/fixed_array.hh
@@ -0,0 +1,94 @@
+#ifndef UTIL_FIXED_ARRAY_H
+#define UTIL_FIXED_ARRAY_H
+
+// Ever want an array of things by they don't have a default constructor or are
+// non-copyable? FixedArray allows constructing one at a time.
+#include "util/scoped.hh"
+
+#include <cstddef>
+
+#include <assert.h>
+
+namespace util {
+
+template <class T> class FixedArray {
+ public:
+ // Initialize with a given size bound but do not construct the objects.
+ explicit FixedArray(std::size_t limit) {
+ Init(limit);
+ }
+
+ FixedArray()
+ : newed_end_(NULL)
+#ifndef NDEBUG
+ , allocated_end_(NULL)
+#endif
+ {}
+
+ void Init(std::size_t count) {
+ assert(!block_.get());
+ block_.reset(malloc(sizeof(T) * count));
+ if (!block_.get()) throw std::bad_alloc();
+ newed_end_ = begin();
+#ifndef NDEBUG
+ allocated_end_ = begin() + count;
+#endif
+ }
+
+ FixedArray(const FixedArray &from) {
+ std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
+ Init(size);
+ for (std::size_t i = 0; i < size; ++i) {
+ push_back(from[i]);
+ }
+ }
+
+ ~FixedArray() { clear(); }
+
+ T *begin() { return static_cast<T*>(block_.get()); }
+ const T *begin() const { return static_cast<const T*>(block_.get()); }
+ // Always call Constructed after successful completion of new.
+ T *end() { return newed_end_; }
+ const T *end() const { return newed_end_; }
+
+ T &back() { return *(end() - 1); }
+ const T &back() const { return *(end() - 1); }
+
+ std::size_t size() const { return end() - begin(); }
+ bool empty() const { return begin() == end(); }
+
+ T &operator[](std::size_t i) { return begin()[i]; }
+ const T &operator[](std::size_t i) const { return begin()[i]; }
+
+ template <class C> void push_back(const C &c) {
+ new (end()) T(c);
+ Constructed();
+ }
+
+ void clear() {
+ for (T *i = begin(); i != end(); ++i)
+ i->~T();
+ newed_end_ = begin();
+ }
+
+ protected:
+ void Constructed() {
+ ++newed_end_;
+#ifndef NDEBUG
+ assert(newed_end_ <= allocated_end_);
+#endif
+ }
+
+ private:
+ util::scoped_malloc block_;
+
+ T *newed_end_;
+
+#ifndef NDEBUG
+ T *allocated_end_;
+#endif
+};
+
+} // namespace util
+
+#endif // UTIL_FIXED_ARRAY_H
diff --git a/util/getopt.hh b/util/getopt.hh
index 6ad9773..50eab56 100644
--- a/util/getopt.hh
+++ b/util/getopt.hh
@@ -11,8 +11,8 @@ Code given out at the 1985 UNIFORUM conference in Dallas.
#endif
#ifndef __GNUC__
-#ifndef _WINGETOPT_H_
-#define _WINGETOPT_H_
+#ifndef UTIL_GETOPT_H
+#define UTIL_GETOPT_H
#ifdef __cplusplus
extern "C" {
@@ -28,6 +28,6 @@ extern int getopt(int argc, char **argv, char *opts);
}
#endif
-#endif /* _GETOPT_H_ */
+#endif /* UTIL_GETOPT_H */
#endif /* __GNUC__ */
diff --git a/util/have.hh b/util/have.hh
index 6e18529..dc3f633 100644
--- a/util/have.hh
+++ b/util/have.hh
@@ -1,6 +1,6 @@
/* Optional packages. You might want to integrate this with your build system e.g. config.h from ./configure. */
-#ifndef UTIL_HAVE__
-#define UTIL_HAVE__
+#ifndef UTIL_HAVE_H
+#define UTIL_HAVE_H
#ifdef HAVE_CONFIG_H
#include "config.h"
@@ -10,4 +10,4 @@
//#define HAVE_ICU
#endif
-#endif // UTIL_HAVE__
+#endif // UTIL_HAVE_H
diff --git a/util/joint_sort.hh b/util/joint_sort.hh
index 1b43ddc..de4b554 100644
--- a/util/joint_sort.hh
+++ b/util/joint_sort.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_JOINT_SORT__
-#define UTIL_JOINT_SORT__
+#ifndef UTIL_JOINT_SORT_H
+#define UTIL_JOINT_SORT_H
/* A terrifying amount of C++ to coax std::sort into soring one range while
* also permuting another range the same way.
@@ -9,7 +9,6 @@
#include <algorithm>
#include <functional>
-#include <iostream>
namespace util {
@@ -35,9 +34,16 @@ template <class KeyIter, class ValueIter> class JointIter {
return *this;
}
- void swap(const JointIter &other) {
- std::swap(key_, other.key_);
- std::swap(value_, other.value_);
+ friend void swap(JointIter &first, JointIter &second) {
+ using std::swap;
+ swap(first.key_, second.key_);
+ swap(first.value_, second.value_);
+ }
+
+ void DeepSwap(JointIter &other) {
+ using std::swap;
+ swap(*key_, *other.key_);
+ swap(*value_, *other.value_);
}
private:
@@ -83,9 +89,8 @@ template <class KeyIter, class ValueIter> class JointProxy {
return *(inner_.key_);
}
- void swap(JointProxy<KeyIter, ValueIter> &other) {
- std::swap(*inner_.key_, *other.inner_.key_);
- std::swap(*inner_.value_, *other.inner_.value_);
+ friend void swap(JointProxy<KeyIter, ValueIter> first, JointProxy<KeyIter, ValueIter> second) {
+ first.Inner().DeepSwap(second.Inner());
}
private:
@@ -138,14 +143,4 @@ template <class KeyIter, class ValueIter> void JointSort(const KeyIter &key_begi
} // namespace util
-namespace std {
-template <class KeyIter, class ValueIter> void swap(util::detail::JointIter<KeyIter, ValueIter> &left, util::detail::JointIter<KeyIter, ValueIter> &right) {
- left.swap(right);
-}
-
-template <class KeyIter, class ValueIter> void swap(util::detail::JointProxy<KeyIter, ValueIter> &left, util::detail::JointProxy<KeyIter, ValueIter> &right) {
- left.swap(right);
-}
-} // namespace std
-
-#endif // UTIL_JOINT_SORT__
+#endif // UTIL_JOINT_SORT_H
diff --git a/util/joint_sort_test.cc b/util/joint_sort_test.cc
index 4dc8591..b24c602 100644
--- a/util/joint_sort_test.cc
+++ b/util/joint_sort_test.cc
@@ -47,4 +47,16 @@ BOOST_AUTO_TEST_CASE(char_int) {
BOOST_CHECK_EQUAL(327, values[3]);
}
+BOOST_AUTO_TEST_CASE(swap_proxy) {
+ char keys[2] = {0, 1};
+ int values[2] = {2, 3};
+ detail::JointProxy<char *, int *> first(keys, values);
+ detail::JointProxy<char *, int *> second(keys + 1, values + 1);
+ swap(first, second);
+ BOOST_CHECK_EQUAL(1, keys[0]);
+ BOOST_CHECK_EQUAL(0, keys[1]);
+ BOOST_CHECK_EQUAL(3, values[0]);
+ BOOST_CHECK_EQUAL(2, values[1]);
+}
+
}} // namespace anonymous util
diff --git a/util/mmap.cc b/util/mmap.cc
index cee6a97..a3c8a02 100644
--- a/util/mmap.cc
+++ b/util/mmap.cc
@@ -6,6 +6,7 @@
#include "util/exception.hh"
#include "util/file.hh"
+#include "util/parallel_read.hh"
#include "util/scoped.hh"
#include <iostream>
@@ -40,7 +41,7 @@ void SyncOrThrow(void *start, size_t length) {
#if defined(_WIN32) || defined(_WIN64)
UTIL_THROW_IF(!::FlushViewOfFile(start, length), ErrnoException, "Failed to sync mmap");
#else
- UTIL_THROW_IF(msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap");
+ UTIL_THROW_IF(length && msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap");
#endif
}
@@ -154,6 +155,10 @@ void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scope
SeekOrThrow(fd, offset);
ReadOrThrow(fd, out.get(), size);
break;
+ case PARALLEL_READ:
+ out.reset(MallocOrThrow(size), size, scoped_memory::MALLOC_ALLOCATED);
+ ParallelRead(fd, out.get(), size, offset);
+ break;
}
}
@@ -189,4 +194,66 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) {
}
}
+Rolling::Rolling(const Rolling &copy_from, uint64_t increase) {
+ *this = copy_from;
+ IncreaseBase(increase);
+}
+
+Rolling &Rolling::operator=(const Rolling &copy_from) {
+ fd_ = copy_from.fd_;
+ file_begin_ = copy_from.file_begin_;
+ file_end_ = copy_from.file_end_;
+ for_write_ = copy_from.for_write_;
+ block_ = copy_from.block_;
+ read_bound_ = copy_from.read_bound_;
+
+ current_begin_ = 0;
+ if (copy_from.IsPassthrough()) {
+ current_end_ = copy_from.current_end_;
+ ptr_ = copy_from.ptr_;
+ } else {
+ // Force call on next mmap.
+ current_end_ = 0;
+ ptr_ = NULL;
+ }
+ return *this;
+}
+
+Rolling::Rolling(int fd, bool for_write, std::size_t block, std::size_t read_bound, uint64_t offset, uint64_t amount) {
+ current_begin_ = 0;
+ current_end_ = 0;
+ fd_ = fd;
+ file_begin_ = offset;
+ file_end_ = offset + amount;
+ for_write_ = for_write;
+ block_ = block;
+ read_bound_ = read_bound;
+}
+
+void *Rolling::ExtractNonRolling(scoped_memory &out, uint64_t index, std::size_t size) {
+ out.reset();
+ if (IsPassthrough()) return static_cast<uint8_t*>(get()) + index;
+ uint64_t offset = index + file_begin_;
+ // Round down to multiple of page size.
+ uint64_t cruft = offset % static_cast<uint64_t>(SizePage());
+ std::size_t map_size = static_cast<std::size_t>(size + cruft);
+ out.reset(MapOrThrow(map_size, for_write_, kFileFlags, true, fd_, offset - cruft), map_size, scoped_memory::MMAP_ALLOCATED);
+ return static_cast<uint8_t*>(out.get()) + static_cast<std::size_t>(cruft);
+}
+
+void Rolling::Roll(uint64_t index) {
+ assert(!IsPassthrough());
+ std::size_t amount;
+ if (file_end_ - (index + file_begin_) > static_cast<uint64_t>(block_)) {
+ amount = block_;
+ current_end_ = index + amount - read_bound_;
+ } else {
+ amount = file_end_ - (index + file_begin_);
+ current_end_ = index + amount;
+ }
+ ptr_ = static_cast<uint8_t*>(ExtractNonRolling(mem_, index, amount)) - index;
+
+ current_begin_ = index;
+}
+
} // namespace util
diff --git a/util/mmap.hh b/util/mmap.hh
index b218c4d..9b1e120 100644
--- a/util/mmap.hh
+++ b/util/mmap.hh
@@ -1,8 +1,9 @@
-#ifndef UTIL_MMAP__
-#define UTIL_MMAP__
+#ifndef UTIL_MMAP_H
+#define UTIL_MMAP_H
// Utilities for mmaped files.
#include <cstddef>
+#include <limits>
#include <stdint.h>
#include <sys/types.h>
@@ -52,6 +53,9 @@ class scoped_memory {
public:
typedef enum {MMAP_ALLOCATED, ARRAY_ALLOCATED, MALLOC_ALLOCATED, NONE_ALLOCATED} Alloc;
+ scoped_memory(void *data, std::size_t size, Alloc source)
+ : data_(data), size_(size), source_(source) {}
+
scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {}
~scoped_memory() { reset(); }
@@ -72,7 +76,6 @@ class scoped_memory {
void call_realloc(std::size_t to);
private:
-
void *data_;
std::size_t size_;
@@ -90,7 +93,9 @@ typedef enum {
// Populate on Linux. malloc and read on non-Linux.
POPULATE_OR_READ,
// malloc and read.
- READ
+ READ,
+ // malloc and read in parallel (recommended for Lustre)
+ PARALLEL_READ,
} LoadMethod;
extern const int kFileFlags;
@@ -109,6 +114,79 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file);
// msync wrapper
void SyncOrThrow(void *start, size_t length);
+// Forward rolling memory map with no overlap.
+class Rolling {
+ public:
+ Rolling() {}
+
+ explicit Rolling(void *data) { Init(data); }
+
+ Rolling(const Rolling &copy_from, uint64_t increase = 0);
+ Rolling &operator=(const Rolling &copy_from);
+
+ // For an actual rolling mmap.
+ explicit Rolling(int fd, bool for_write, std::size_t block, std::size_t read_bound, uint64_t offset, uint64_t amount);
+
+ // For a static mapping
+ void Init(void *data) {
+ ptr_ = data;
+ current_end_ = std::numeric_limits<uint64_t>::max();
+ current_begin_ = 0;
+ // Mark as a pass-through.
+ fd_ = -1;
+ }
+
+ void IncreaseBase(uint64_t by) {
+ file_begin_ += by;
+ ptr_ = static_cast<uint8_t*>(ptr_) + by;
+ if (!IsPassthrough()) current_end_ = 0;
+ }
+
+ void DecreaseBase(uint64_t by) {
+ file_begin_ -= by;
+ ptr_ = static_cast<uint8_t*>(ptr_) - by;
+ if (!IsPassthrough()) current_end_ = 0;
+ }
+
+ void *ExtractNonRolling(scoped_memory &out, uint64_t index, std::size_t size);
+
+ // Returns base pointer
+ void *get() const { return ptr_; }
+
+ // Returns base pointer.
+ void *CheckedBase(uint64_t index) {
+ if (index >= current_end_ || index < current_begin_) {
+ Roll(index);
+ }
+ return ptr_;
+ }
+
+ // Returns indexed pointer.
+ void *CheckedIndex(uint64_t index) {
+ return static_cast<uint8_t*>(CheckedBase(index)) + index;
+ }
+
+ private:
+ void Roll(uint64_t index);
+
+ // True if this is just a thin wrapper on a pointer.
+ bool IsPassthrough() const { return fd_ == -1; }
+
+ void *ptr_;
+ uint64_t current_begin_;
+ uint64_t current_end_;
+
+ scoped_memory mem_;
+
+ int fd_;
+ uint64_t file_begin_;
+ uint64_t file_end_;
+
+ bool for_write_;
+ std::size_t block_;
+ std::size_t read_bound_;
+};
+
} // namespace util
-#endif // UTIL_MMAP__
+#endif // UTIL_MMAP_H
diff --git a/util/multi_intersection.hh b/util/multi_intersection.hh
index 8334d39..2955acc 100644
--- a/util/multi_intersection.hh
+++ b/util/multi_intersection.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_MULTI_INTERSECTION__
-#define UTIL_MULTI_INTERSECTION__
+#ifndef UTIL_MULTI_INTERSECTION_H
+#define UTIL_MULTI_INTERSECTION_H
#include <boost/optional.hpp>
#include <boost/range/iterator_range.hpp>
@@ -66,7 +66,7 @@ template <class Iterator, class Output, class Less> void AllIntersection(std::ve
std::sort(sets.begin(), sets.end(), detail::RangeLessBySize<boost::iterator_range<Iterator> >());
boost::optional<Value> ret;
- for (boost::optional<Value> ret; ret = detail::FirstIntersectionSorted(sets, less); sets.front().advance_begin(1)) {
+ for (boost::optional<Value> ret; (ret = detail::FirstIntersectionSorted(sets, less)); sets.front().advance_begin(1)) {
out(*ret);
}
}
@@ -77,4 +77,4 @@ template <class Iterator, class Output> void AllIntersection(std::vector<boost::
} // namespace util
-#endif // UTIL_MULTI_INTERSECTION__
+#endif // UTIL_MULTI_INTERSECTION_H
diff --git a/util/murmur_hash.cc b/util/murmur_hash.cc
index 4f51931..189668c 100644
--- a/util/murmur_hash.cc
+++ b/util/murmur_hash.cc
@@ -153,12 +153,19 @@ uint64_t MurmurHash64B ( const void * key, std::size_t len, uint64_t seed )
// Trick to test for 64-bit architecture at compile time.
namespace {
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-function"
+#endif
template <unsigned L> inline uint64_t MurmurHashNativeBackend(const void * key, std::size_t len, uint64_t seed) {
return MurmurHash64A(key, len, seed);
}
template <> inline uint64_t MurmurHashNativeBackend<4>(const void * key, std::size_t len, uint64_t seed) {
return MurmurHash64B(key, len, seed);
}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
} // namespace
uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed) {
diff --git a/util/murmur_hash.hh b/util/murmur_hash.hh
index ae7e88d..f17157c 100644
--- a/util/murmur_hash.hh
+++ b/util/murmur_hash.hh
@@ -1,14 +1,18 @@
-#ifndef UTIL_MURMUR_HASH__
-#define UTIL_MURMUR_HASH__
+#ifndef UTIL_MURMUR_HASH_H
+#define UTIL_MURMUR_HASH_H
#include <cstddef>
#include <stdint.h>
namespace util {
+// 64-bit machine version
uint64_t MurmurHash64A(const void * key, std::size_t len, uint64_t seed = 0);
+// 32-bit machine version (not the same function as above)
uint64_t MurmurHash64B(const void * key, std::size_t len, uint64_t seed = 0);
+// Use the version for this arch. Because the values differ across
+// architectures, really only use it for in-memory structures.
uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed = 0);
} // namespace util
-#endif // UTIL_MURMUR_HASH__
+#endif // UTIL_MURMUR_HASH_H
diff --git a/util/parallel_read.cc b/util/parallel_read.cc
new file mode 100644
index 0000000..10972d7
--- /dev/null
+++ b/util/parallel_read.cc
@@ -0,0 +1,69 @@
+#include "util/parallel_read.hh"
+
+#include "util/file.hh"
+
+#ifdef WITH_THREADS
+#include "util/thread_pool.hh"
+
+namespace util {
+namespace {
+
+class Reader {
+ public:
+ explicit Reader(int fd) : fd_(fd) {}
+
+ struct Request {
+ void *to;
+ std::size_t size;
+ uint64_t offset;
+
+ bool operator==(const Request &other) const {
+ return (to == other.to) && (size == other.size) && (offset == other.offset);
+ }
+ };
+
+ void operator()(const Request &request) {
+ util::PReadOrThrow(fd_, request.to, request.size, request.offset);
+ }
+
+ private:
+ int fd_;
+};
+
+} // namespace
+
+void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) {
+ Reader::Request poison;
+ poison.to = NULL;
+ poison.size = 0;
+ poison.offset = 0;
+ unsigned threads = boost::thread::hardware_concurrency();
+ if (!threads) threads = 2;
+ ThreadPool<Reader> pool(2 /* don't need much of a queue */, threads, fd, poison);
+ const std::size_t kBatch = 1ULL << 25; // 32 MB
+ Reader::Request request;
+ request.to = to;
+ request.size = kBatch;
+ request.offset = offset;
+ for (; amount > kBatch; amount -= kBatch) {
+ pool.Produce(request);
+ request.to = reinterpret_cast<uint8_t*>(request.to) + kBatch;
+ request.offset += kBatch;
+ }
+ request.size = amount;
+ if (request.size) {
+ pool.Produce(request);
+ }
+}
+
+} // namespace util
+
+#else // WITH_THREADS
+
+namespace util {
+void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) {
+ util::PReadOrThrow(fd, to, amount, offset);
+}
+} // namespace util
+
+#endif
diff --git a/util/parallel_read.hh b/util/parallel_read.hh
new file mode 100644
index 0000000..1e96e79
--- /dev/null
+++ b/util/parallel_read.hh
@@ -0,0 +1,16 @@
+#ifndef UTIL_PARALLEL_READ__
+#define UTIL_PARALLEL_READ__
+
+/* Read pieces of a file in parallel. This has a very specific use case:
+ * reading files from Lustre is CPU bound so multiple threads actually
+ * increases throughput. Speed matters when an LM takes a terabyte.
+ */
+
+#include <cstddef>
+#include <stdint.h>
+
+namespace util {
+void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset);
+} // namespace util
+
+#endif // UTIL_PARALLEL_READ__
diff --git a/util/pcqueue.hh b/util/pcqueue.hh
index 3df8749..312a66f 100644
--- a/util/pcqueue.hh
+++ b/util/pcqueue.hh
@@ -1,5 +1,7 @@
-#ifndef UTIL_PCQUEUE__
-#define UTIL_PCQUEUE__
+#ifndef UTIL_PCQUEUE_H
+#define UTIL_PCQUEUE_H
+
+#include "util/exception.hh"
#include <boost/interprocess/sync/interprocess_semaphore.hpp>
#include <boost/scoped_array.hpp>
@@ -8,20 +10,68 @@
#include <errno.h>
+#ifdef __APPLE__
+#include <mach/semaphore.h>
+#include <mach/task.h>
+#include <mach/mach_traps.h>
+#include <mach/mach.h>
+#endif // __APPLE__
+
namespace util {
-inline void WaitSemaphore (boost::interprocess::interprocess_semaphore &on) {
+/* OS X Maverick and Boost interprocess were doing "Function not implemented."
+ * So this is my own wrapper around the mach kernel APIs.
+ */
+#ifdef __APPLE__
+
+#define MACH_CALL(call) UTIL_THROW_IF(KERN_SUCCESS != (call), Exception, "Mach call failure")
+
+class Semaphore {
+ public:
+ explicit Semaphore(int value) : task_(mach_task_self()) {
+ MACH_CALL(semaphore_create(task_, &back_, SYNC_POLICY_FIFO, value));
+ }
+
+ ~Semaphore() {
+ MACH_CALL(semaphore_destroy(task_, back_));
+ }
+
+ void wait() {
+ MACH_CALL(semaphore_wait(back_));
+ }
+
+ void post() {
+ MACH_CALL(semaphore_signal(back_));
+ }
+
+ private:
+ semaphore_t back_;
+ task_t task_;
+};
+
+inline void WaitSemaphore(Semaphore &semaphore) {
+ semaphore.wait();
+}
+
+#else
+typedef boost::interprocess::interprocess_semaphore Semaphore;
+
+inline void WaitSemaphore (Semaphore &on) {
while (1) {
try {
on.wait();
break;
}
catch (boost::interprocess::interprocess_exception &e) {
- if (e.get_native_error() != EINTR) throw;
+ if (e.get_native_error() != EINTR) {
+ throw;
+ }
}
}
}
+#endif // __APPLE__
+
/* Producer consumer queue safe for multiple producers and multiple consumers.
* T must be default constructable and have operator=.
* The value is copied twice for Consume(T &out) or three times for Consume(),
@@ -82,9 +132,9 @@ template <class T> class PCQueue : boost::noncopyable {
private:
// Number of empty spaces in storage_.
- boost::interprocess::interprocess_semaphore empty_;
+ Semaphore empty_;
// Number of occupied spaces in storage_.
- boost::interprocess::interprocess_semaphore used_;
+ Semaphore used_;
boost::scoped_array<T> storage_;
@@ -102,4 +152,4 @@ template <class T> class PCQueue : boost::noncopyable {
} // namespace util
-#endif // UTIL_PCQUEUE__
+#endif // UTIL_PCQUEUE_H
diff --git a/util/pcqueue_test.cc b/util/pcqueue_test.cc
new file mode 100644
index 0000000..22ed2c6
--- /dev/null
+++ b/util/pcqueue_test.cc
@@ -0,0 +1,20 @@
+#include "util/pcqueue.hh"
+
+#define BOOST_TEST_MODULE PCQueueTest
+#include <boost/test/unit_test.hpp>
+
+namespace util {
+namespace {
+
+BOOST_AUTO_TEST_CASE(SingleThread) {
+ PCQueue<int> queue(10);
+ for (int i = 0; i < 10; ++i) {
+ queue.Produce(i);
+ }
+ for (int i = 0; i < 10; ++i) {
+ BOOST_CHECK_EQUAL(i, queue.Consume());
+ }
+}
+
+}
+} // namespace util
diff --git a/util/pool.hh b/util/pool.hh
index 72f8a0c..89e793d 100644
--- a/util/pool.hh
+++ b/util/pool.hh
@@ -1,8 +1,8 @@
// Very simple pool. It can only allocate memory. And all of the memory it
// allocates must be freed at the same time.
-#ifndef UTIL_POOL__
-#define UTIL_POOL__
+#ifndef UTIL_POOL_H
+#define UTIL_POOL_H
#include <vector>
@@ -42,4 +42,4 @@ class Pool {
} // namespace util
-#endif // UTIL_POOL__
+#endif // UTIL_POOL_H
diff --git a/util/probing_hash_table.hh b/util/probing_hash_table.hh
index 51a2944..c1fe917 100644
--- a/util/probing_hash_table.hh
+++ b/util/probing_hash_table.hh
@@ -1,7 +1,8 @@
-#ifndef UTIL_PROBING_HASH_TABLE__
-#define UTIL_PROBING_HASH_TABLE__
+#ifndef UTIL_PROBING_HASH_TABLE_H
+#define UTIL_PROBING_HASH_TABLE_H
#include "util/exception.hh"
+#include "util/scoped.hh"
#include <algorithm>
#include <cstddef>
@@ -25,6 +26,8 @@ struct IdentityHash {
template <class T> T operator()(T arg) const { return arg; }
};
+template <class EntryT, class HashT, class EqualT> class AutoProbing;
+
/* Non-standard hash table
* Buckets must be set at the beginning and must be greater than maximum number
* of elements, else it throws ProbingSizeException.
@@ -33,7 +36,6 @@ struct IdentityHash {
* Uses linear probing to find value.
* Only insert and lookup operations.
*/
-
template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class ProbingHashTable {
public:
typedef EntryT Entry;
@@ -43,7 +45,6 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
typedef HashT Hash;
typedef EqualT Equal;
- public:
static uint64_t Size(uint64_t entries, float multiplier) {
uint64_t buckets = std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries)));
return buckets * sizeof(Entry);
@@ -69,6 +70,11 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#endif
{}
+ void Relocate(void *new_base) {
+ begin_ = reinterpret_cast<MutableIterator>(new_base);
+ end_ = begin_ + buckets_;
+ }
+
template <class T> MutableIterator Insert(const T &t) {
#ifdef DEBUG
assert(initialized_);
@@ -82,7 +88,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#ifdef DEBUG
assert(initialized_);
#endif
- for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) {
+ for (MutableIterator i = Ideal(t);;) {
Key got(i->GetKey());
if (equal_(got, t.GetKey())) { out = i; return true; }
if (equal_(got, invalid_)) {
@@ -97,8 +103,6 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
void FinishedInserting() {}
- void LoadedBinary() {}
-
// Don't change anything related to GetKey,
template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
#ifdef DEBUG
@@ -224,6 +228,8 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
}
private:
+ friend class AutoProbing<Entry, Hash, Equal>;
+
template <class T> MutableIterator Ideal(const T &t) {
return begin_ + (hash_(t.GetKey()) % buckets_);
}
@@ -247,6 +253,75 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#endif
};
+// Resizable linear probing hash table. This owns the memory.
+template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class AutoProbing {
+ private:
+ typedef ProbingHashTable<EntryT, HashT, EqualT> Backend;
+ public:
+ typedef EntryT Entry;
+ typedef typename Entry::Key Key;
+ typedef const Entry *ConstIterator;
+ typedef Entry *MutableIterator;
+ typedef HashT Hash;
+ typedef EqualT Equal;
+
+ AutoProbing(std::size_t initial_size = 10, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) :
+ allocated_(Backend::Size(initial_size, 1.5)), mem_(util::MallocOrThrow(allocated_)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) {
+ threshold_ = initial_size * 1.2;
+ Clear();
+ }
+
+ // Assumes that the key is unique. Multiple insertions won't cause a failure, just inconsistent lookup.
+ template <class T> MutableIterator Insert(const T &t) {
+ DoubleIfNeeded();
+ return backend_.UncheckedInsert(t);
+ }
+
+ template <class T> bool FindOrInsert(const T &t, MutableIterator &out) {
+ DoubleIfNeeded();
+ return backend_.FindOrInsert(t, out);
+ }
+
+ template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
+ return backend_.UnsafeMutableFind(key, out);
+ }
+
+ template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) {
+ return backend_.UnsafeMutableMustFind(key);
+ }
+
+ template <class Key> bool Find(const Key key, ConstIterator &out) const {
+ return backend_.Find(key, out);
+ }
+
+ template <class Key> ConstIterator MustFind(const Key key) const {
+ return backend_.MustFind(key);
+ }
+
+ std::size_t Size() const {
+ return backend_.SizeNoSerialization();
+ }
+
+ void Clear() {
+ backend_.Clear();
+ }
+
+ private:
+ void DoubleIfNeeded() {
+ if (Size() < threshold_)
+ return;
+ mem_.call_realloc(backend_.DoubleTo());
+ allocated_ = backend_.DoubleTo();
+ backend_.Double(mem_.get());
+ threshold_ *= 2;
+ }
+
+ std::size_t allocated_;
+ util::scoped_malloc mem_;
+ Backend backend_;
+ std::size_t threshold_;
+};
+
} // namespace util
-#endif // UTIL_PROBING_HASH_TABLE__
+#endif // UTIL_PROBING_HASH_TABLE_H
diff --git a/util/proxy_iterator.hh b/util/proxy_iterator.hh
index 0ee1716..8aa697b 100644
--- a/util/proxy_iterator.hh
+++ b/util/proxy_iterator.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_PROXY_ITERATOR__
-#define UTIL_PROXY_ITERATOR__
+#ifndef UTIL_PROXY_ITERATOR_H
+#define UTIL_PROXY_ITERATOR_H
#include <cstddef>
#include <iterator>
@@ -38,8 +38,8 @@ template <class Proxy> class ProxyIterator {
typedef std::random_access_iterator_tag iterator_category;
typedef typename Proxy::value_type value_type;
typedef std::ptrdiff_t difference_type;
- typedef Proxy & reference;
- typedef Proxy * pointer;
+ typedef Proxy reference;
+ typedef ProxyIterator<Proxy> * pointer;
ProxyIterator() {}
@@ -47,10 +47,10 @@ template <class Proxy> class ProxyIterator {
template <class AlternateProxy> ProxyIterator(const ProxyIterator<AlternateProxy> &in) : p_(*in) {}
explicit ProxyIterator(const Proxy &p) : p_(p) {}
- // p_'s swap does value swapping, but here we want iterator swapping
+/* // p_'s swap does value swapping, but here we want iterator swapping
friend inline void swap(ProxyIterator<Proxy> &first, ProxyIterator<Proxy> &second) {
swap(first.I(), second.I());
- }
+ }*/
// p_'s operator= does value copying, but here we want iterator copying.
S &operator=(const S &other) {
@@ -77,8 +77,8 @@ template <class Proxy> class ProxyIterator {
std::ptrdiff_t operator-(const S &other) const { return I() - other.I(); }
- Proxy &operator*() { return p_; }
- const Proxy &operator*() const { return p_; }
+ Proxy operator*() { return p_; }
+ const Proxy operator*() const { return p_; }
Proxy *operator->() { return &p_; }
const Proxy *operator->() const { return &p_; }
Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); }
@@ -98,4 +98,4 @@ template <class Proxy> ProxyIterator<Proxy> operator+(std::ptrdiff_t amount, con
} // namespace util
-#endif // UTIL_PROXY_ITERATOR__
+#endif // UTIL_PROXY_ITERATOR_H
diff --git a/util/read_compressed.cc b/util/read_compressed.cc
index b62a6e8..71ef0e2 100644
--- a/util/read_compressed.cc
+++ b/util/read_compressed.cc
@@ -49,6 +49,8 @@ class ReadBase {
thunk.internal_.reset(with);
}
+ ReadBase *Current(ReadCompressed &thunk) { return thunk.internal_.get(); }
+
static uint64_t &ReadCount(ReadCompressed &thunk) {
return thunk.raw_amount_;
}
@@ -56,6 +58,8 @@ class ReadBase {
namespace {
+ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, std::size_t already_size, bool require_compressed);
+
// Completed file that other classes can thunk to.
class Complete : public ReadBase {
public:
@@ -80,7 +84,7 @@ class Uncompressed : public ReadBase {
class UncompressedWithHeader : public ReadBase {
public:
- UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) {
+ UncompressedWithHeader(int fd, const void *already_data, std::size_t already_size) : fd_(fd) {
assert(already_size);
buf_.reset(malloc(already_size));
if (!buf_.get()) throw std::bad_alloc();
@@ -91,6 +95,7 @@ class UncompressedWithHeader : public ReadBase {
std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
assert(buf_.get());
+ assert(remain_ != end_);
std::size_t sending = std::min<std::size_t>(amount, end_ - remain_);
memcpy(to, remain_, sending);
remain_ += sending;
@@ -108,23 +113,51 @@ class UncompressedWithHeader : public ReadBase {
scoped_fd fd_;
};
-#ifdef HAVE_ZLIB
-class GZip : public ReadBase {
+static const std::size_t kInputBuffer = 16384;
+
+template <class Compression> class StreamCompressed : public ReadBase {
+ public:
+ StreamCompressed(int fd, const void *already_data, std::size_t already_size)
+ : file_(fd),
+ in_buffer_(MallocOrThrow(kInputBuffer)),
+ back_(memcpy(in_buffer_.get(), already_data, already_size), already_size) {}
+
+ std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
+ if (amount == 0) return 0;
+ back_.SetOutput(to, amount);
+ do {
+ if (!back_.Stream().avail_in) ReadInput(thunk);
+ if (!back_.Process()) {
+ // reached end, at least for the compressed portion.
+ std::size_t ret = static_cast<const uint8_t *>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to);
+ ReplaceThis(ReadFactory(file_.release(), ReadCount(thunk), back_.Stream().next_in, back_.Stream().avail_in, true), thunk);
+ if (ret) return ret;
+ // We did not read anything this round, so clients might think EOF. Transfer responsibility to the next reader.
+ return Current(thunk)->Read(to, amount, thunk);
+ }
+ } while (back_.Stream().next_out == to);
+ return static_cast<const uint8_t*>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to);
+ }
+
private:
- static const std::size_t kInputBuffer = 16384;
+ void ReadInput(ReadCompressed &thunk) {
+ assert(!back_.Stream().avail_in);
+ std::size_t got = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
+ back_.SetInput(in_buffer_.get(), got);
+ ReadCount(thunk) += got;
+ }
+
+ scoped_fd file_;
+ scoped_malloc in_buffer_;
+
+ Compression back_;
+};
+
+#ifdef HAVE_ZLIB
+class GZip {
public:
- GZip(int fd, void *already_data, std::size_t already_size)
- : file_(fd), in_buffer_(malloc(kInputBuffer)) {
- if (!in_buffer_.get()) throw std::bad_alloc();
- assert(already_size < kInputBuffer);
- if (already_size) {
- memcpy(in_buffer_.get(), already_data, already_size);
- stream_.next_in = static_cast<Bytef *>(in_buffer_.get());
- stream_.avail_in = already_size;
- stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size);
- } else {
- stream_.avail_in = 0;
- }
+ GZip(const void *base, std::size_t amount) {
+ SetInput(base, amount);
stream_.zalloc = Z_NULL;
stream_.zfree = Z_NULL;
stream_.opaque = Z_NULL;
@@ -141,227 +174,154 @@ class GZip : public ReadBase {
}
}
- std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
- if (amount == 0) return 0;
+ void SetOutput(void *to, std::size_t amount) {
stream_.next_out = static_cast<Bytef*>(to);
stream_.avail_out = std::min<std::size_t>(std::numeric_limits<uInt>::max(), amount);
- do {
- if (!stream_.avail_in) ReadInput(thunk);
- int result = inflate(&stream_, 0);
- switch (result) {
- case Z_OK:
- break;
- case Z_STREAM_END:
- {
- std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
- ReplaceThis(new Complete(), thunk);
- return ret;
- }
- case Z_ERRNO:
- UTIL_THROW(ErrnoException, "zlib error");
- default:
- UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result);
- }
- } while (stream_.next_out == to);
- return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
}
- private:
- void ReadInput(ReadCompressed &thunk) {
- assert(!stream_.avail_in);
- stream_.next_in = static_cast<Bytef *>(in_buffer_.get());
- stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
- ReadCount(thunk) += stream_.avail_in;
+ void SetInput(const void *base, std::size_t amount) {
+ assert(amount < static_cast<std::size_t>(std::numeric_limits<uInt>::max()));
+ stream_.next_in = const_cast<Bytef*>(static_cast<const Bytef*>(base));
+ stream_.avail_in = amount;
}
- scoped_fd file_;
- scoped_malloc in_buffer_;
+ const z_stream &Stream() const { return stream_; }
+
+ bool Process() {
+ int result = inflate(&stream_, 0);
+ switch (result) {
+ case Z_OK:
+ return true;
+ case Z_STREAM_END:
+ return false;
+ case Z_ERRNO:
+ UTIL_THROW(ErrnoException, "zlib error");
+ default:
+ UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result);
+ }
+ }
+
+ private:
z_stream stream_;
};
#endif // HAVE_ZLIB
-const uint8_t kBZMagic[3] = {'B', 'Z', 'h'};
-
#ifdef HAVE_BZLIB
-class BZip : public ReadBase {
+class BZip {
public:
- BZip(int fd, void *already_data, std::size_t already_size) {
- scoped_fd hold(fd);
- closer_.reset(FDOpenReadOrThrow(hold));
- file_ = NULL;
- Open(already_data, already_size);
+ BZip(const void *base, std::size_t amount) {
+ memset(&stream_, 0, sizeof(stream_));
+ SetInput(base, amount);
+ HandleError(BZ2_bzDecompressInit(&stream_, 0, 0));
}
- BZip(FILE *file, void *already_data, std::size_t already_size) {
- closer_.reset(file);
- file_ = NULL;
- Open(already_data, already_size);
+ ~BZip() {
+ try {
+ HandleError(BZ2_bzDecompressEnd(&stream_));
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ abort();
+ }
}
- ~BZip() {
- Close(file_);
+ bool Process() {
+ int ret = BZ2_bzDecompress(&stream_);
+ if (ret == BZ_STREAM_END) return false;
+ HandleError(ret);
+ return true;
}
- std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
- assert(file_);
- int bzerror = BZ_OK;
- int ret = BZ2_bzRead(&bzerror, file_, to, std::min<std::size_t>(static_cast<std::size_t>(INT_MAX), amount));
- long pos = ftell(closer_.get());
- if (pos != -1) ReadCount(thunk) = pos;
- switch (bzerror) {
- case BZ_STREAM_END:
- /* bzip2 files can be concatenated by e.g. pbzip2. Annoyingly, the
- * library doesn't handle this internally. This gets the trailing
- * data, grows it up to magic as needed, validates the magic, and
- * reopens.
- */
- {
- bzerror = BZ_OK;
- void *trailing_data;
- int trailing_size;
- BZ2_bzReadGetUnused(&bzerror, file_, &trailing_data, &trailing_size);
- UTIL_THROW_IF(bzerror != BZ_OK, BZException, "bzip2 error in BZ2_bzReadGetUnused " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror);
- std::string trailing(static_cast<const char*>(trailing_data), trailing_size);
- Close(file_);
-
- if (trailing_size < (int)sizeof(kBZMagic)) {
- trailing.resize(sizeof(kBZMagic));
- if (1 != fread(&trailing[trailing_size], sizeof(kBZMagic) - trailing_size, 1, closer_.get())) {
- UTIL_THROW_IF(trailing_size, BZException, "File has trailing cruft");
- // Legitimate end of file.
- ReplaceThis(new Complete(), thunk);
- return ret;
- }
- }
- UTIL_THROW_IF(memcmp(trailing.data(), kBZMagic, sizeof(kBZMagic)), BZException, "Trailing cruft is not another bzip2 stream");
- Open(&trailing[0], trailing.size());
- }
- return ret;
- case BZ_OK:
- return ret;
- default:
- UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror);
- }
+ void SetOutput(void *base, std::size_t amount) {
+ stream_.next_out = static_cast<char*>(base);
+ stream_.avail_out = std::min<std::size_t>(std::numeric_limits<unsigned int>::max(), amount);
+ }
+
+ void SetInput(const void *base, std::size_t amount) {
+ stream_.next_in = const_cast<char*>(static_cast<const char*>(base));
+ stream_.avail_in = amount;
}
+ const bz_stream &Stream() const { return stream_; }
+
private:
- void Open(void *already_data, std::size_t already_size) {
- assert(!file_);
- int bzerror = BZ_OK;
- file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size);
- switch (bzerror) {
+ void HandleError(int value) {
+ switch(value) {
case BZ_OK:
return;
case BZ_CONFIG_ERROR:
- UTIL_THROW(BZException, "Looks like bzip2 was miscompiled.");
+ UTIL_THROW(BZException, "bzip2 seems to be miscompiled.");
case BZ_PARAM_ERROR:
- UTIL_THROW(BZException, "Parameter error");
- case BZ_IO_ERROR:
- UTIL_THROW(BZException, "IO error reading file");
+ UTIL_THROW(BZException, "bzip2 Parameter error");
+ case BZ_DATA_ERROR:
+ UTIL_THROW(BZException, "bzip2 detected a corrupt file");
+ case BZ_DATA_ERROR_MAGIC:
+ UTIL_THROW(BZException, "bzip2 detected bad magic bytes. Perhaps this was not a bzip2 file after all?");
case BZ_MEM_ERROR:
throw std::bad_alloc();
default:
- UTIL_THROW(BZException, "Unknown bzip2 error code " << bzerror);
- }
- assert(file_);
- }
-
- static void Close(BZFILE *&file) {
- if (file == NULL) return;
- int bzerror = BZ_OK;
- BZ2_bzReadClose(&bzerror, file);
- if (bzerror != BZ_OK) {
- std::cerr << "bz2 readclose error number " << bzerror << std::endl;
- abort();
+ UTIL_THROW(BZException, "Unknown bzip2 error code " << value);
}
- file = NULL;
}
- scoped_FILE closer_;
- BZFILE *file_;
+ bz_stream stream_;
};
#endif // HAVE_BZLIB
#ifdef HAVE_XZLIB
-class XZip : public ReadBase {
- private:
- static const std::size_t kInputBuffer = 16384;
+class XZip {
public:
- XZip(int fd, void *already_data, std::size_t already_size)
- : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) {
- if (!in_buffer_.get()) throw std::bad_alloc();
- assert(already_size < kInputBuffer);
- if (already_size) {
- memcpy(in_buffer_.get(), already_data, already_size);
- stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get());
- stream_.avail_in = already_size;
- stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size);
- } else {
- stream_.avail_in = 0;
- }
- stream_.allocator = NULL;
- lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED);
- switch (ret) {
- case LZMA_OK:
- break;
- case LZMA_MEM_ERROR:
- UTIL_THROW(ErrnoException, "xz open error");
- default:
- UTIL_THROW(XZException, "xz error code " << ret);
- }
+ XZip(const void *base, std::size_t amount)
+ : stream_(), action_(LZMA_RUN) {
+ memset(&stream_, 0, sizeof(stream_));
+ SetInput(base, amount);
+ HandleError(lzma_stream_decoder(&stream_, UINT64_MAX, 0));
}
~XZip() {
lzma_end(&stream_);
}
- std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
- if (amount == 0) return 0;
- stream_.next_out = static_cast<uint8_t*>(to);
+ void SetOutput(void *base, std::size_t amount) {
+ stream_.next_out = static_cast<uint8_t*>(base);
stream_.avail_out = amount;
- do {
- if (!stream_.avail_in) ReadInput(thunk);
- lzma_ret status = lzma_code(&stream_, action_);
- switch (status) {
- case LZMA_OK:
- break;
- case LZMA_STREAM_END:
- UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet.");
- {
- std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
- ReplaceThis(new Complete(), thunk);
- return ret;
- }
- case LZMA_MEM_ERROR:
- throw std::bad_alloc();
- case LZMA_FORMAT_ERROR:
- UTIL_THROW(XZException, "xzlib says file format not recognized");
- case LZMA_OPTIONS_ERROR:
- UTIL_THROW(XZException, "xzlib says unsupported compression options");
- case LZMA_DATA_ERROR:
- UTIL_THROW(XZException, "xzlib says this file is corrupt");
- case LZMA_BUF_ERROR:
- UTIL_THROW(XZException, "xzlib says unexpected end of input");
- default:
- UTIL_THROW(XZException, "unrecognized xzlib error " << status);
- }
- } while (stream_.next_out == to);
- return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
+ }
+
+ void SetInput(const void *base, std::size_t amount) {
+ stream_.next_in = static_cast<const uint8_t*>(base);
+ stream_.avail_in = amount;
+ if (!amount) action_ = LZMA_FINISH;
+ }
+
+ const lzma_stream &Stream() const { return stream_; }
+
+ bool Process() {
+ lzma_ret status = lzma_code(&stream_, action_);
+ if (status == LZMA_STREAM_END) return false;
+ HandleError(status);
+ return true;
}
private:
- void ReadInput(ReadCompressed &thunk) {
- assert(!stream_.avail_in);
- stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get());
- stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
- if (!stream_.avail_in) action_ = LZMA_FINISH;
- ReadCount(thunk) += stream_.avail_in;
+ void HandleError(lzma_ret value) {
+ switch (value) {
+ case LZMA_OK:
+ return;
+ case LZMA_MEM_ERROR:
+ throw std::bad_alloc();
+ case LZMA_FORMAT_ERROR:
+ UTIL_THROW(XZException, "xzlib says file format not recognized");
+ case LZMA_OPTIONS_ERROR:
+ UTIL_THROW(XZException, "xzlib says unsupported compression options");
+ case LZMA_DATA_ERROR:
+ UTIL_THROW(XZException, "xzlib says this file is corrupt");
+ case LZMA_BUF_ERROR:
+ UTIL_THROW(XZException, "xzlib says unexpected end of input");
+ default:
+ UTIL_THROW(XZException, "unrecognized xzlib error " << value);
+ }
}
- scoped_fd file_;
- scoped_malloc in_buffer_;
lzma_stream stream_;
-
lzma_action action_;
};
#endif // HAVE_XZLIB
@@ -384,66 +344,68 @@ class IStreamReader : public ReadBase {
};
enum MagicResult {
- UNKNOWN, GZIP, BZIP, XZIP
+ UTIL_UNKNOWN, UTIL_GZIP, UTIL_BZIP, UTIL_XZIP
};
-MagicResult DetectMagic(const void *from_void) {
+MagicResult DetectMagic(const void *from_void, std::size_t length) {
const uint8_t *header = static_cast<const uint8_t*>(from_void);
- if (header[0] == 0x1f && header[1] == 0x8b) {
- return GZIP;
+ if (length >= 2 && header[0] == 0x1f && header[1] == 0x8b) {
+ return UTIL_GZIP;
}
- if (!memcmp(header, kBZMagic, sizeof(kBZMagic))) {
- return BZIP;
+ const uint8_t kBZMagic[3] = {'B', 'Z', 'h'};
+ if (length >= sizeof(kBZMagic) && !memcmp(header, kBZMagic, sizeof(kBZMagic))) {
+ return UTIL_BZIP;
}
const uint8_t kXZMagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 };
- if (!memcmp(header, kXZMagic, sizeof(kXZMagic))) {
- return XZIP;
+ if (length >= sizeof(kXZMagic) && !memcmp(header, kXZMagic, sizeof(kXZMagic))) {
+ return UTIL_XZIP;
}
- return UNKNOWN;
+ return UTIL_UNKNOWN;
}
-ReadBase *ReadFactory(int fd, uint64_t &raw_amount) {
+ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, const std::size_t already_size, bool require_compressed) {
scoped_fd hold(fd);
- unsigned char header[ReadCompressed::kMagicSize];
- raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize);
- if (!raw_amount)
- return new Uncompressed(hold.release());
- if (raw_amount != ReadCompressed::kMagicSize)
- return new UncompressedWithHeader(hold.release(), header, raw_amount);
- switch (DetectMagic(header)) {
- case GZIP:
+ std::string header(reinterpret_cast<const char*>(already_data), already_size);
+ if (header.size() < ReadCompressed::kMagicSize) {
+ std::size_t original = header.size();
+ header.resize(ReadCompressed::kMagicSize);
+ std::size_t got = ReadOrEOF(fd, &header[original], ReadCompressed::kMagicSize - original);
+ raw_amount += got;
+ header.resize(original + got);
+ }
+ if (header.empty()) {
+ hold.release();
+ return new Complete();
+ }
+ switch (DetectMagic(&header[0], header.size())) {
+ case UTIL_GZIP:
#ifdef HAVE_ZLIB
- return new GZip(hold.release(), header, ReadCompressed::kMagicSize);
+ return new StreamCompressed<GZip>(hold.release(), header.data(), header.size());
#else
UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in.");
#endif
- case BZIP:
+ case UTIL_BZIP:
#ifdef HAVE_BZLIB
- return new BZip(hold.release(), header, ReadCompressed::kMagicSize);
+ return new StreamCompressed<BZip>(hold.release(), &header[0], header.size());
#else
- UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in.");
+ UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZh), but bzip support was not compiled in.");
#endif
- case XZIP:
+ case UTIL_XZIP:
#ifdef HAVE_XZLIB
- return new XZip(hold.release(), header, ReadCompressed::kMagicSize);
+ return new StreamCompressed<XZip>(hold.release(), header.data(), header.size());
#else
UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in.");
#endif
- case UNKNOWN:
- break;
- }
- try {
- SeekOrThrow(fd, 0);
- } catch (const util::ErrnoException &e) {
- return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize);
+ default:
+ UTIL_THROW_IF(require_compressed, CompressedException, "Uncompressed data detected after a compresssed file. This could be supported but usually indicates an error.");
+ return new UncompressedWithHeader(hold.release(), header.data(), header.size());
}
- return new Uncompressed(hold.release());
}
} // namespace
bool ReadCompressed::DetectCompressedMagic(const void *from_void) {
- return DetectMagic(from_void) != UNKNOWN;
+ return DetectMagic(from_void, kMagicSize) != UTIL_UNKNOWN;
}
ReadCompressed::ReadCompressed(int fd) {
@@ -459,8 +421,9 @@ ReadCompressed::ReadCompressed() {}
ReadCompressed::~ReadCompressed() {}
void ReadCompressed::Reset(int fd) {
+ raw_amount_ = 0;
internal_.reset();
- internal_.reset(ReadFactory(fd, raw_amount_));
+ internal_.reset(ReadFactory(fd, raw_amount_, NULL, 0, false));
}
void ReadCompressed::Reset(std::istream &in) {
diff --git a/util/read_compressed.hh b/util/read_compressed.hh
index 8b54c9e..763e6bb 100644
--- a/util/read_compressed.hh
+++ b/util/read_compressed.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_READ_COMPRESSED__
-#define UTIL_READ_COMPRESSED__
+#ifndef UTIL_READ_COMPRESSED_H
+#define UTIL_READ_COMPRESSED_H
#include "util/exception.hh"
#include "util/scoped.hh"
@@ -78,4 +78,4 @@ class ReadCompressed {
} // namespace util
-#endif // UTIL_READ_COMPRESSED__
+#endif // UTIL_READ_COMPRESSED_H
diff --git a/util/read_compressed_test.cc b/util/read_compressed_test.cc
index 9cb4a4b..301e8f4 100644
--- a/util/read_compressed_test.cc
+++ b/util/read_compressed_test.cc
@@ -12,6 +12,23 @@
#include <stdlib.h>
+#if defined __MINGW32__
+#include <time.h>
+#include <fcntl.h>
+
+#if !defined mkstemp
+// TODO insecure
+int mkstemp(char * stemplate)
+{
+ char *filename = mktemp(stemplate);
+ if (filename == NULL)
+ return -1;
+ return open(filename, O_RDWR | O_CREAT, 0600);
+}
+#endif
+
+#endif // defined
+
namespace util {
namespace {
@@ -96,6 +113,11 @@ BOOST_AUTO_TEST_CASE(ReadXZ) {
}
#endif
+#ifdef HAVE_ZLIB
+BOOST_AUTO_TEST_CASE(AppendGZ) {
+}
+#endif
+
BOOST_AUTO_TEST_CASE(IStream) {
std::string name(WriteRandom());
std::fstream stream(name.c_str(), std::ios::in);
diff --git a/util/scoped.hh b/util/scoped.hh
index b642d06..ae70b6b 100644
--- a/util/scoped.hh
+++ b/util/scoped.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_SCOPED__
-#define UTIL_SCOPED__
+#ifndef UTIL_SCOPED_H
+#define UTIL_SCOPED_H
/* Other scoped objects in the style of scoped_ptr. */
#include "util/exception.hh"
@@ -101,4 +101,4 @@ template <class T> class scoped_ptr {
} // namespace util
-#endif // UTIL_SCOPED__
+#endif // UTIL_SCOPED_H
diff --git a/util/sized_iterator.hh b/util/sized_iterator.hh
index dce8f22..75f6886 100644
--- a/util/sized_iterator.hh
+++ b/util/sized_iterator.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_SIZED_ITERATOR__
-#define UTIL_SIZED_ITERATOR__
+#ifndef UTIL_SIZED_ITERATOR_H
+#define UTIL_SIZED_ITERATOR_H
#include "util/proxy_iterator.hh"
@@ -36,7 +36,7 @@ class SizedInnerIterator {
void *Data() { return ptr_; }
std::size_t EntrySize() const { return size_; }
- friend inline void swap(SizedInnerIterator &first, SizedInnerIterator &second) {
+ friend void swap(SizedInnerIterator &first, SizedInnerIterator &second) {
std::swap(first.ptr_, second.ptr_);
std::swap(first.size_, second.size_);
}
@@ -69,17 +69,7 @@ class SizedProxy {
const void *Data() const { return inner_.Data(); }
void *Data() { return inner_.Data(); }
- /**
- // TODO: this (deep) swap was recently added. why? if any std heap sort etc
- // algs are using swap, that's going to be worse performance than using
- // =. i'm not sure why we *want* a deep swap. if C++11 compilers are
- // choosing between move constructor and swap, then we'd better implement a
- // (deep) move constructor. it may also be that this is moot since i made
- // ProxyIterator a reference and added a shallow ProxyIterator swap? (I
- // need Ken or someone competent to judge whether that's correct also. -
- // let me know at graehl@gmail.com
- */
- friend void swap(SizedProxy &first, SizedProxy &second) {
+ friend void swap(SizedProxy first, SizedProxy second) {
std::swap_ranges(
static_cast<char*>(first.inner_.Data()),
static_cast<char*>(first.inner_.Data()) + first.inner_.EntrySize(),
@@ -127,4 +117,4 @@ template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public
};
} // namespace util
-#endif // UTIL_SIZED_ITERATOR__
+#endif // UTIL_SIZED_ITERATOR_H
diff --git a/util/sized_iterator_test.cc b/util/sized_iterator_test.cc
new file mode 100644
index 0000000..c36bcb2
--- /dev/null
+++ b/util/sized_iterator_test.cc
@@ -0,0 +1,16 @@
+#include "util/sized_iterator.hh"
+
+#define BOOST_TEST_MODULE SizedIteratorTest
+#include <boost/test/unit_test.hpp>
+
+namespace util { namespace {
+
+BOOST_AUTO_TEST_CASE(swap_works) {
+ char str[2] = { 0, 1 };
+ SizedProxy first(str, 1), second(str + 1, 1);
+ swap(first, second);
+ BOOST_CHECK_EQUAL(1, str[0]);
+ BOOST_CHECK_EQUAL(0, str[1]);
+}
+
+}} // namespace anonymous util
diff --git a/util/sorted_uniform.hh b/util/sorted_uniform.hh
index 7700d9e..a7dba5e 100644
--- a/util/sorted_uniform.hh
+++ b/util/sorted_uniform.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_SORTED_UNIFORM__
-#define UTIL_SORTED_UNIFORM__
+#ifndef UTIL_SORTED_UNIFORM_H
+#define UTIL_SORTED_UNIFORM_H
#include <algorithm>
#include <cstddef>
@@ -124,4 +124,4 @@ template <class Iterator, class Accessor> Iterator BinaryBelow(
} // namespace util
-#endif // UTIL_SORTED_UNIFORM__
+#endif // UTIL_SORTED_UNIFORM_H
diff --git a/util/stream/block.hh b/util/stream/block.hh
index 11aa991..b9f1d49 100644
--- a/util/stream/block.hh
+++ b/util/stream/block.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_BLOCK__
-#define UTIL_STREAM_BLOCK__
+#ifndef UTIL_STREAM_BLOCK_H
+#define UTIL_STREAM_BLOCK_H
#include <cstddef>
#include <stdint.h>
@@ -40,4 +40,4 @@ class Block {
} // namespace stream
} // namespace util
-#endif // UTIL_STREAM_BLOCK__
+#endif // UTIL_STREAM_BLOCK_H
diff --git a/util/stream/chain.hh b/util/stream/chain.hh
index 0cc83a8..bda653a 100644
--- a/util/stream/chain.hh
+++ b/util/stream/chain.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_CHAIN__
-#define UTIL_STREAM_CHAIN__
+#ifndef UTIL_STREAM_CHAIN_H
+#define UTIL_STREAM_CHAIN_H
#include "util/stream/block.hh"
#include "util/stream/config.hh"
@@ -195,4 +195,4 @@ inline Chain &operator>>(Chain &chain, Link &link) {
} // namespace stream
} // namespace util
-#endif // UTIL_STREAM_CHAIN__
+#endif // UTIL_STREAM_CHAIN_H
diff --git a/util/stream/config.hh b/util/stream/config.hh
index 1eeb3a8..052a05b 100644
--- a/util/stream/config.hh
+++ b/util/stream/config.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_CONFIG__
-#define UTIL_STREAM_CONFIG__
+#ifndef UTIL_STREAM_CONFIG_H
+#define UTIL_STREAM_CONFIG_H
#include <cstddef>
#include <string>
@@ -29,4 +29,4 @@ struct SortConfig {
};
}} // namespaces
-#endif // UTIL_STREAM_CONFIG__
+#endif // UTIL_STREAM_CONFIG_H
diff --git a/util/stream/io.hh b/util/stream/io.hh
index 934b6b3..65918a9 100644
--- a/util/stream/io.hh
+++ b/util/stream/io.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_IO__
-#define UTIL_STREAM_IO__
+#ifndef UTIL_STREAM_IO_H
+#define UTIL_STREAM_IO_H
#include "util/exception.hh"
#include "util/file.hh"
@@ -73,4 +73,4 @@ class FileBuffer {
} // namespace stream
} // namespace util
-#endif // UTIL_STREAM_IO__
+#endif // UTIL_STREAM_IO_H
diff --git a/util/stream/line_input.hh b/util/stream/line_input.hh
index 86db1dd..a870a66 100644
--- a/util/stream/line_input.hh
+++ b/util/stream/line_input.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_LINE_INPUT__
-#define UTIL_STREAM_LINE_INPUT__
+#ifndef UTIL_STREAM_LINE_INPUT_H
+#define UTIL_STREAM_LINE_INPUT_H
namespace util {namespace stream {
class ChainPosition;
@@ -19,4 +19,4 @@ class LineInput {
};
}} // namespaces
-#endif // UTIL_STREAM_LINE_INPUT__
+#endif // UTIL_STREAM_LINE_INPUT_H
diff --git a/util/stream/multi_progress.hh b/util/stream/multi_progress.hh
index c4dd45a..82e698a 100644
--- a/util/stream/multi_progress.hh
+++ b/util/stream/multi_progress.hh
@@ -1,6 +1,6 @@
/* Progress bar suitable for chains of workers */
-#ifndef UTIL_MULTI_PROGRESS__
-#define UTIL_MULTI_PROGRESS__
+#ifndef UTIL_STREAM_MULTI_PROGRESS_H
+#define UTIL_STREAM_MULTI_PROGRESS_H
#include <boost/thread/mutex.hpp>
@@ -87,4 +87,4 @@ class WorkerProgress {
}} // namespaces
-#endif // UTIL_MULTI_PROGRESS__
+#endif // UTIL_STREAM_MULTI_PROGRESS_H
diff --git a/util/stream/sort.hh b/util/stream/sort.hh
index 16aa6a0..e18c6ae 100644
--- a/util/stream/sort.hh
+++ b/util/stream/sort.hh
@@ -15,8 +15,8 @@
* sort. Use a hash table for that.
*/
-#ifndef UTIL_STREAM_SORT__
-#define UTIL_STREAM_SORT__
+#ifndef UTIL_STREAM_SORT_H
+#define UTIL_STREAM_SORT_H
#include "util/stream/chain.hh"
#include "util/stream/config.hh"
@@ -545,4 +545,4 @@ template <class Compare, class Combine> uint64_t BlockingSort(Chain &chain, cons
} // namespace stream
} // namespace util
-#endif // UTIL_STREAM_SORT__
+#endif // UTIL_STREAM_SORT_H
diff --git a/util/stream/stream.hh b/util/stream/stream.hh
index 6ff45b8..5cd1bdd 100644
--- a/util/stream/stream.hh
+++ b/util/stream/stream.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_STREAM__
-#define UTIL_STREAM_STREAM__
+#ifndef UTIL_STREAM_STREAM_H
+#define UTIL_STREAM_STREAM_H
#include "util/stream/chain.hh"
@@ -71,4 +71,4 @@ inline Chain &operator>>(Chain &chain, Stream &stream) {
} // namespace stream
} // namespace util
-#endif // UTIL_STREAM_STREAM__
+#endif // UTIL_STREAM_STREAM_H
diff --git a/util/stream/timer.hh b/util/stream/timer.hh
index 7e1a588..06488a1 100644
--- a/util/stream/timer.hh
+++ b/util/stream/timer.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_TIMER__
-#define UTIL_STREAM_TIMER__
+#ifndef UTIL_STREAM_TIMER_H
+#define UTIL_STREAM_TIMER_H
// Sorry Jon, this was adding library dependencies in Moses and people complained.
@@ -13,4 +13,4 @@
#define UTIL_TIMER(str)
//#endif
-#endif // UTIL_STREAM_TIMER__
+#endif // UTIL_STREAM_TIMER_H
diff --git a/util/string_piece.hh b/util/string_piece.hh
index 84431db..114e254 100644
--- a/util/string_piece.hh
+++ b/util/string_piece.hh
@@ -45,8 +45,8 @@
// conversions from "const char*" to "string" and back again.
//
-#ifndef BASE_STRING_PIECE_H__
-#define BASE_STRING_PIECE_H__
+#ifndef UTIL_STRING_PIECE_H
+#define UTIL_STRING_PIECE_H
#include "util/have.hh"
@@ -267,4 +267,4 @@ U_NAMESPACE_END
using U_NAMESPACE_QUALIFIER StringPiece;
#endif
-#endif // BASE_STRING_PIECE_H__
+#endif // UTIL_STRING_PIECE_H
diff --git a/util/string_piece_hash.hh b/util/string_piece_hash.hh
index f206b1d..5c8c525 100644
--- a/util/string_piece_hash.hh
+++ b/util/string_piece_hash.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STRING_PIECE_HASH__
-#define UTIL_STRING_PIECE_HASH__
+#ifndef UTIL_STRING_PIECE_HASH_H
+#define UTIL_STRING_PIECE_HASH_H
#include "util/string_piece.hh"
@@ -40,4 +40,4 @@ template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece
#endif
}
-#endif // UTIL_STRING_PIECE_HASH__
+#endif // UTIL_STRING_PIECE_HASH_H
diff --git a/util/thread_pool.hh b/util/thread_pool.hh
index 84e257e..d1a883a 100644
--- a/util/thread_pool.hh
+++ b/util/thread_pool.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_THREAD_POOL__
-#define UTIL_THREAD_POOL__
+#ifndef UTIL_THREAD_POOL_H
+#define UTIL_THREAD_POOL_H
#include "util/pcqueue.hh"
@@ -18,8 +18,8 @@ template <class HandlerT> class Worker : boost::noncopyable {
typedef HandlerT Handler;
typedef typename Handler::Request Request;
- template <class Construct> Worker(PCQueue<Request> &in, Construct &construct, Request &poison)
- : in_(in), handler_(construct), thread_(boost::ref(*this)), poison_(poison) {}
+ template <class Construct> Worker(PCQueue<Request> &in, Construct &construct, const Request &poison)
+ : in_(in), handler_(construct), poison_(poison), thread_(boost::ref(*this)) {}
// Only call from thread.
void operator()() {
@@ -30,7 +30,7 @@ template <class HandlerT> class Worker : boost::noncopyable {
try {
(*handler_)(request);
}
- catch(std::exception &e) {
+ catch(const std::exception &e) {
std::cerr << "Handler threw " << e.what() << std::endl;
abort();
}
@@ -49,10 +49,10 @@ template <class HandlerT> class Worker : boost::noncopyable {
PCQueue<Request> &in_;
boost::optional<Handler> handler_;
+
+ const Request poison_;
boost::thread thread_;
-
- Request poison_;
};
template <class HandlerT> class ThreadPool : boost::noncopyable {
@@ -92,4 +92,4 @@ template <class HandlerT> class ThreadPool : boost::noncopyable {
} // namespace util
-#endif // UTIL_THREAD_POOL__
+#endif // UTIL_THREAD_POOL_H
diff --git a/util/tokenize_piece.hh b/util/tokenize_piece.hh
index a588c3f..908c8da 100644
--- a/util/tokenize_piece.hh
+++ b/util/tokenize_piece.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_TOKENIZE_PIECE__
-#define UTIL_TOKENIZE_PIECE__
+#ifndef UTIL_TOKENIZE_PIECE_H
+#define UTIL_TOKENIZE_PIECE_H
#include "util/exception.hh"
#include "util/string_piece.hh"
@@ -7,7 +7,8 @@
#include <boost/iterator/iterator_facade.hpp>
#include <algorithm>
-#include <iostream>
+
+#include <string.h>
namespace util {
@@ -58,6 +59,30 @@ class AnyCharacter {
StringPiece chars_;
};
+class BoolCharacter {
+ public:
+ BoolCharacter() {}
+
+ explicit BoolCharacter(const bool *delimiter) { delimiter_ = delimiter; }
+
+ StringPiece Find(const StringPiece &in) const {
+ for (const char *i = in.data(); i != in.data() + in.size(); ++i) {
+ if (delimiter_[static_cast<unsigned char>(*i)]) return StringPiece(i, 1);
+ }
+ return StringPiece(in.data() + in.size(), 0);
+ }
+
+ template <unsigned Length> static void Build(const char (&characters)[Length], bool (&out)[256]) {
+ memset(out, 0, sizeof(out));
+ for (const char *i = characters; i != characters + Length; ++i) {
+ out[static_cast<unsigned char>(*i)] = true;
+ }
+ }
+
+ private:
+ const bool *delimiter_;
+};
+
class AnyCharacterLast {
public:
AnyCharacterLast() {}
@@ -123,4 +148,4 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it
} // namespace util
-#endif // UTIL_TOKENIZE_PIECE__
+#endif // UTIL_TOKENIZE_PIECE_H
diff --git a/util/usage.cc b/util/usage.cc
index 25fc097..5912f90 100644
--- a/util/usage.cc
+++ b/util/usage.cc
@@ -10,61 +10,119 @@
#include <string.h>
#include <ctype.h>
-#if !defined(_WIN32) && !defined(_WIN64)
+#include <time.h>
+#if defined(_WIN32) || defined(_WIN64)
+// This code lifted from physmem.c in gnulib. See the copyright statement
+// below.
+# define WIN32_LEAN_AND_MEAN
+# include <windows.h>
+/* MEMORYSTATUSEX is missing from older windows headers, so define
+ a local replacement. */
+typedef struct
+{
+ DWORD dwLength;
+ DWORD dwMemoryLoad;
+ DWORDLONG ullTotalPhys;
+ DWORDLONG ullAvailPhys;
+ DWORDLONG ullTotalPageFile;
+ DWORDLONG ullAvailPageFile;
+ DWORDLONG ullTotalVirtual;
+ DWORDLONG ullAvailVirtual;
+ DWORDLONG ullAvailExtendedVirtual;
+} lMEMORYSTATUSEX;
+// Is this really supposed to be defined like this?
+typedef int WINBOOL;
+typedef WINBOOL (WINAPI *PFN_MS_EX) (lMEMORYSTATUSEX*);
+#else
#include <sys/resource.h>
#include <sys/time.h>
-#include <time.h>
#include <unistd.h>
#endif
-namespace util {
+#if defined(__MACH__) || defined(__FreeBSD__) || defined(__APPLE__)
+#include <sys/types.h>
+#include <sys/sysctl.h>
+#endif
-#if !defined(_WIN32) && !defined(_WIN64)
+namespace util {
namespace {
-// On Mac OS X, clock_gettime is not implemented.
-// CLOCK_MONOTONIC is not defined either.
-#ifdef __MACH__
-#define CLOCK_MONOTONIC 0
-
-int clock_gettime(int clk_id, struct timespec *tp) {
+#if defined(__MACH__)
+typedef struct timeval Wall;
+Wall GetWall() {
struct timeval tv;
gettimeofday(&tv, NULL);
- tp->tv_sec = tv.tv_sec;
- tp->tv_nsec = tv.tv_usec * 1000;
- return 0;
+ return tv;
}
-#endif // __MACH__
-
-float FloatSec(const struct timeval &tv) {
- return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000.0);
+#elif defined(_WIN32) || defined(_WIN64)
+typedef time_t Wall;
+Wall GetWall() {
+ return time(NULL);
}
-float FloatSec(const struct timespec &tv) {
- return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_nsec) / 1000000000.0);
+#else
+typedef struct timespec Wall;
+Wall GetWall() {
+ Wall ret;
+ clock_gettime(CLOCK_MONOTONIC, &ret);
+ return ret;
}
+#endif
-const char *SkipSpaces(const char *at) {
- for (; *at == ' ' || *at == '\t'; ++at) {}
- return at;
+// Some of these functions are only used on some platforms.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-function"
+#endif
+// These all assume first > second
+double Subtract(time_t first, time_t second) {
+ return difftime(first, second);
+}
+double DoubleSec(time_t tv) {
+ return static_cast<double>(tv);
+}
+#if !defined(_WIN32) && !defined(_WIN64)
+double Subtract(const struct timeval &first, const struct timeval &second) {
+ return static_cast<double>(first.tv_sec - second.tv_sec) + static_cast<double>(first.tv_usec - second.tv_usec) / 1000000.0;
}
+double Subtract(const struct timespec &first, const struct timespec &second) {
+ return static_cast<double>(first.tv_sec - second.tv_sec) + static_cast<double>(first.tv_nsec - second.tv_nsec) / 1000000000.0;
+}
+double DoubleSec(const struct timeval &tv) {
+ return static_cast<double>(tv.tv_sec) + (static_cast<double>(tv.tv_usec) / 1000000.0);
+}
+double DoubleSec(const struct timespec &tv) {
+ return static_cast<double>(tv.tv_sec) + (static_cast<double>(tv.tv_nsec) / 1000000000.0);
+}
+#endif
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
class RecordStart {
public:
RecordStart() {
- clock_gettime(CLOCK_MONOTONIC, &started_);
+ started_ = GetWall();
}
- const struct timespec &Started() const {
+ const Wall &Started() const {
return started_;
}
private:
- struct timespec started_;
+ Wall started_;
};
const RecordStart kRecordStart;
+
+const char *SkipSpaces(const char *at) {
+ for (; *at == ' ' || *at == '\t'; ++at) {}
+ return at;
+}
} // namespace
-#endif
+
+double WallTime() {
+ return Subtract(GetWall(), kRecordStart.Started());
+}
void PrintUsage(std::ostream &out) {
#if !defined(_WIN32) && !defined(_WIN64)
@@ -88,27 +146,84 @@ void PrintUsage(std::ostream &out) {
return;
}
out << "RSSMax:" << usage.ru_maxrss << " kB" << '\t';
- out << "user:" << FloatSec(usage.ru_utime) << "\tsys:" << FloatSec(usage.ru_stime) << '\t';
- out << "CPU:" << (FloatSec(usage.ru_utime) + FloatSec(usage.ru_stime));
-
- struct timespec current;
- clock_gettime(CLOCK_MONOTONIC, &current);
- out << "\treal:" << (FloatSec(current) - FloatSec(kRecordStart.Started())) << '\n';
+ out << "user:" << DoubleSec(usage.ru_utime) << "\tsys:" << DoubleSec(usage.ru_stime) << '\t';
+ out << "CPU:" << (DoubleSec(usage.ru_utime) + DoubleSec(usage.ru_stime));
+ out << '\t';
#endif
+
+ out << "real:" << WallTime() << '\n';
}
+/* Adapted from physmem.c in gnulib 831b84c59ef413c57a36b67344467d66a8a2ba70 */
+/* Calculate the size of physical memory.
+
+ Copyright (C) 2000-2001, 2003, 2005-2006, 2009-2013 Free Software
+ Foundation, Inc.
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation; either version 2.1 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>. */
+
+/* Written by Paul Eggert. */
uint64_t GuessPhysicalMemory() {
+#if defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE)
+ {
+ long pages = sysconf(_SC_PHYS_PAGES);
+ long page_size = sysconf(_SC_PAGESIZE);
+ if (pages != -1 && page_size != -1)
+ return static_cast<uint64_t>(pages) * static_cast<uint64_t>(page_size);
+ }
+#endif
+#ifdef HW_PHYSMEM
+ { /* This works on *bsd and darwin. */
+ unsigned int physmem;
+ size_t len = sizeof physmem;
+ static int mib[2] = { CTL_HW, HW_PHYSMEM };
+
+ if (sysctl (mib, sizeof(mib) / sizeof(mib[0]), &physmem, &len, NULL, 0) == 0
+ && len == sizeof (physmem))
+ return static_cast<uint64_t>(physmem);
+ }
+#endif
+
#if defined(_WIN32) || defined(_WIN64)
- return 0;
-#elif defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE)
- long pages = sysconf(_SC_PHYS_PAGES);
- if (pages == -1) return 0;
- long page_size = sysconf(_SC_PAGESIZE);
- if (page_size == -1) return 0;
- return static_cast<uint64_t>(pages) * static_cast<uint64_t>(page_size);
-#else
- return 0;
+ { /* this works on windows */
+ PFN_MS_EX pfnex;
+ HMODULE h = GetModuleHandle ("kernel32.dll");
+
+ if (!h)
+ return 0;
+
+ /* Use GlobalMemoryStatusEx if available. */
+ if ((pfnex = (PFN_MS_EX) GetProcAddress (h, "GlobalMemoryStatusEx")))
+ {
+ lMEMORYSTATUSEX lms_ex;
+ lms_ex.dwLength = sizeof lms_ex;
+ if (!pfnex (&lms_ex))
+ return 0;
+ return lms_ex.ullTotalPhys;
+ }
+
+ /* Fall back to GlobalMemoryStatus which is always available.
+ but returns wrong results for physical memory > 4GB. */
+ else
+ {
+ MEMORYSTATUS ms;
+ GlobalMemoryStatus (&ms);
+ return ms.dwTotalPhys;
+ }
+ }
#endif
+ return 0;
}
namespace {
diff --git a/util/usage.hh b/util/usage.hh
index e19eda7..e578b0a 100644
--- a/util/usage.hh
+++ b/util/usage.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_USAGE__
-#define UTIL_USAGE__
+#ifndef UTIL_USAGE_H
+#define UTIL_USAGE_H
#include <cstddef>
#include <iosfwd>
#include <string>
@@ -7,6 +7,9 @@
#include <stdint.h>
namespace util {
+// Time in seconds since process started. Zero on unsupported platforms.
+double WallTime();
+
void PrintUsage(std::ostream &to);
// Determine how much physical memory there is. Return 0 on failure.
@@ -15,4 +18,4 @@ uint64_t GuessPhysicalMemory();
// Parse a size like unix sort. Sadly, this means the default multiplier is K.
uint64_t ParseSize(const std::string &arg);
} // namespace util
-#endif // UTIL_USAGE__
+#endif // UTIL_USAGE_H