Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2014-04-08 16:00:33 +0400
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2014-04-08 16:00:33 +0400
commit5512e96185c4f3894efab7c49b834509bb16b529 (patch)
tree0c69ef9c19fcd3a0639ea748a2616aa2254faf67
parent395acf26221024f17de22801dae22603dbd593eb (diff)
parent2edf319dd54ebb2bbcdaec258a0c46346b43059e (diff)
Merge branch 'pruning2' of github.com:kpu/kenlm into pruning2
-rw-r--r--.gitignore2
-rw-r--r--Jamroot4
-rw-r--r--LICENSE1
-rwxr-xr-xcompile_query_only.sh4
-rw-r--r--jam-files/sanity.jam10
-rw-r--r--lm/bhiksha.cc15
-rw-r--r--lm/bhiksha.hh15
-rw-r--r--lm/binary_format.cc248
-rw-r--r--lm/binary_format.hh118
-rw-r--r--lm/blank.hh6
-rw-r--r--lm/build_binary_main.cc7
-rw-r--r--lm/builder/adjust_counts.hh6
-rw-r--r--lm/builder/corpus_count.cc32
-rw-r--r--lm/builder/corpus_count.hh11
-rw-r--r--lm/builder/corpus_count_test.cc2
-rw-r--r--lm/builder/discount.hh6
-rw-r--r--lm/builder/header_info.hh4
-rw-r--r--lm/builder/initial_probabilities.hh6
-rw-r--r--lm/builder/interpolate.cc9
-rw-r--r--lm/builder/interpolate.hh11
-rw-r--r--lm/builder/joint_order.hh6
-rw-r--r--lm/builder/lmplz_main.cc35
-rw-r--r--lm/builder/multi_stream.hh84
-rw-r--r--lm/builder/ngram.hh6
-rw-r--r--lm/builder/ngram_stream.hh6
-rw-r--r--lm/builder/pipeline.cc13
-rw-r--r--lm/builder/pipeline.hh25
-rw-r--r--lm/builder/print.hh6
-rw-r--r--lm/builder/sort.hh10
-rw-r--r--lm/config.hh6
-rw-r--r--lm/enumerate_vocab.hh6
-rw-r--r--lm/facade.hh12
-rw-r--r--lm/filter/arpa_io.hh7
-rw-r--r--lm/filter/count_io.hh18
-rw-r--r--lm/filter/filter_main.cc155
-rw-r--r--lm/filter/format.hh6
-rw-r--r--lm/filter/phrase.hh6
-rw-r--r--lm/filter/thread.hh6
-rw-r--r--lm/filter/vocab.cc1
-rw-r--r--lm/filter/vocab.hh6
-rw-r--r--lm/filter/wrapper.hh16
-rw-r--r--lm/left.hh6
-rw-r--r--lm/lm_exception.hh4
-rw-r--r--lm/max_order.hh8
-rw-r--r--lm/model.cc84
-rw-r--r--lm/model.hh18
-rw-r--r--lm/model_test.cc8
-rw-r--r--lm/model_type.hh6
-rw-r--r--lm/ngram_query.hh54
-rw-r--r--lm/partial.hh6
-rw-r--r--lm/quantize.cc12
-rw-r--r--lm/quantize.hh11
-rw-r--r--lm/query_main.cc48
-rw-r--r--lm/read_arpa.cc2
-rw-r--r--lm/read_arpa.hh6
-rw-r--r--lm/return.hh6
-rw-r--r--lm/search_hashed.cc33
-rw-r--r--lm/search_hashed.hh18
-rw-r--r--lm/search_trie.cc35
-rw-r--r--lm/search_trie.hh24
-rw-r--r--lm/sizes.hh6
-rw-r--r--lm/state.hh10
-rw-r--r--lm/trie.hh15
-rw-r--r--lm/trie_sort.cc4
-rw-r--r--lm/trie_sort.hh6
-rw-r--r--lm/value.hh6
-rw-r--r--lm/value_build.hh6
-rw-r--r--lm/virtual_interface.hh12
-rw-r--r--lm/vocab.cc28
-rw-r--r--lm/vocab.hh18
-rw-r--r--lm/weights.hh6
-rw-r--r--lm/word_index.hh4
-rw-r--r--lm/wrappers/nplm.cc8
-rw-r--r--lm/wrappers/nplm.hh14
-rw-r--r--python/kenlm.cpp366
-rw-r--r--python/kenlm.pxd4
-rw-r--r--python/kenlm.pyx8
-rw-r--r--util/Jamfile19
-rw-r--r--util/bit_packing.hh6
-rw-r--r--util/cat_compressed_main.cc47
-rw-r--r--util/ersatz_progress.hh6
-rw-r--r--util/exception.cc8
-rw-r--r--util/exception.hh16
-rw-r--r--util/fake_ofstream.hh13
-rw-r--r--util/file.cc65
-rw-r--r--util/file.hh11
-rw-r--r--util/file_piece.cc7
-rw-r--r--util/file_piece.hh29
-rw-r--r--util/file_piece_test.cc12
-rw-r--r--util/fixed_array.hh94
-rw-r--r--util/getopt.hh6
-rw-r--r--util/have.hh6
-rw-r--r--util/joint_sort.hh35
-rw-r--r--util/joint_sort_test.cc12
-rw-r--r--util/mmap.cc69
-rw-r--r--util/mmap.hh88
-rw-r--r--util/multi_intersection.hh8
-rw-r--r--util/murmur_hash.cc7
-rw-r--r--util/murmur_hash.hh10
-rw-r--r--util/parallel_read.cc69
-rw-r--r--util/parallel_read.hh16
-rw-r--r--util/pcqueue.hh64
-rw-r--r--util/pcqueue_test.cc20
-rw-r--r--util/pool.hh6
-rw-r--r--util/probing_hash_table.hh91
-rw-r--r--util/proxy_iterator.hh18
-rw-r--r--util/read_compressed.cc407
-rw-r--r--util/read_compressed.hh6
-rw-r--r--util/read_compressed_test.cc22
-rw-r--r--util/scoped.hh6
-rw-r--r--util/sized_iterator.hh20
-rw-r--r--util/sized_iterator_test.cc16
-rw-r--r--util/sorted_uniform.hh6
-rw-r--r--util/stream/block.hh6
-rw-r--r--util/stream/chain.hh6
-rw-r--r--util/stream/config.hh6
-rw-r--r--util/stream/io.hh6
-rw-r--r--util/stream/line_input.hh6
-rw-r--r--util/stream/multi_progress.hh6
-rw-r--r--util/stream/sort.hh6
-rw-r--r--util/stream/stream.hh6
-rw-r--r--util/stream/timer.hh6
-rw-r--r--util/string_piece.hh6
-rw-r--r--util/string_piece_hash.hh6
-rw-r--r--util/thread_pool.hh16
-rw-r--r--util/tokenize_piece.hh33
-rw-r--r--util/usage.cc197
-rw-r--r--util/usage.hh9
128 files changed, 2183 insertions, 1342 deletions
diff --git a/.gitignore b/.gitignore
index b0067e3..c11f7d1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -19,5 +19,5 @@ util/read_compressed_test
util/sorted_uniform_test
previous.sh
jam-files/bjam
-jam-files/engine/bin.linuxx86/
+jam-files/engine/bin.*/
jam-files/engine/bootstrap/
diff --git a/Jamroot b/Jamroot
index 1ff09ba..54c7a55 100644
--- a/Jamroot
+++ b/Jamroot
@@ -52,7 +52,7 @@ include $(TOP)/jam-files/sanity.jam ;
boost 103600 ;
project : requirements $(requirements) <include>. ;
-project : default-build <threading>multi <warnings>on <variant>release ;
+project : default-build <threading>multi <warnings>on <variant>release <link>static ;
external-lib z ;
@@ -61,7 +61,7 @@ build-project util ;
lib kenlm : lm//kenlm ;
-install-bin-libs lm//programs kenlm ;
+install-bin-libs lm//programs util//programs kenlm ;
install-headers headers : [ glob-tree *.hh : dist include ] : . ;
alias install : prefix-bin prefix-lib prefix-include ;
explicit headers ;
diff --git a/LICENSE b/LICENSE
index 9e2556e..e88a7e2 100644
--- a/LICENSE
+++ b/LICENSE
@@ -2,6 +2,7 @@ Most of the code here is licensed under the LGPL. There are exceptions that
have their own licenses, listed below. See comments in those files for more
details.
+util/getopt.* is getopt for Windows
util/murmur_hash.cc
util/string_piece.hh and util/string_piece.cc
util/double-conversion/LICENSE covers util/double-conversion except Jamfile
diff --git a/compile_query_only.sh b/compile_query_only.sh
index 7c6d3f6..7a82f49 100755
--- a/compile_query_only.sh
+++ b/compile_query_only.sh
@@ -30,5 +30,5 @@ mkdir -p bin
if [ "$(uname)" != Darwin ]; then
CXXFLAGS="$CXXFLAGS -lrt"
fi
-$CXX lm/build_binary_main.cc $objects -o bin/build_binary $CXXFLAGS
-$CXX lm/query_main.cc $objects -o bin/query $CXXFLAGS
+$CXX lm/build_binary_main.cc $objects -o bin/build_binary $CXXFLAGS $LDFLAGS
+$CXX lm/query_main.cc $objects -o bin/query $CXXFLAGS $LDFLAGS
diff --git a/jam-files/sanity.jam b/jam-files/sanity.jam
index 7f9c45d..bc07945 100644
--- a/jam-files/sanity.jam
+++ b/jam-files/sanity.jam
@@ -140,10 +140,9 @@ rule boost-lib ( name macro : deps * ) {
}
if $(boost-auto-shared) = "<link>shared" {
- alias boost_$(name) : inner_boost_$(name) : <link>shared ;
- requirements += <define>BOOST_$(macro) ;
+ alias boost_$(name) : inner_boost_$(name) : <link>shared : : <define>BOOST_$(macro) ;
} else {
- alias boost_$(name) : inner_boost_$(name) : <link>static ;
+ alias boost_$(name) : inner_boost_$(name) : : : <link>shared:<define>BOOST_$(macro) ;
}
}
@@ -171,9 +170,10 @@ rule boost ( min-version ) {
boost-auto-shared = [ auto-shared "boost_program_options"$(boost-lib-version) : $(L-boost-search) ] ;
#See tools/build/v2/contrib/boost.jam in a boost distribution for a table of macros to define.
+ boost-lib exception EXCEPTION_DYN_LINK ;
boost-lib system SYSTEM_DYN_LINK ;
- boost-lib thread THREAD_DYN_DLL : boost_system ;
- boost-lib program_options PROGRAM_OPTIONS_DYN_LINK ;
+ boost-lib thread THREAD_DYN_DLL : boost_system boost_exception ;
+ boost-lib program_options PROGRAM_OPTIONS_DYN_LINK : boost_exception ;
boost-lib unit_test_framework TEST_DYN_LINK ;
boost-lib iostreams IOSTREAMS_DYN_LINK ;
boost-lib filesystem FILE_SYSTEM_DYN_LINK ;
diff --git a/lm/bhiksha.cc b/lm/bhiksha.cc
index 088ea98..c8a18df 100644
--- a/lm/bhiksha.cc
+++ b/lm/bhiksha.cc
@@ -1,4 +1,6 @@
#include "lm/bhiksha.hh"
+
+#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "util/file.hh"
#include "util/exception.hh"
@@ -15,11 +17,11 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_
const uint8_t kArrayBhikshaVersion = 0;
// TODO: put this in binary file header instead when I change the binary file format again.
-void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) {
- uint8_t version;
- uint8_t configured_bits;
- util::ReadOrThrow(fd, &version, 1);
- util::ReadOrThrow(fd, &configured_bits, 1);
+void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) {
+ uint8_t buffer[2];
+ file.ReadForConfig(buffer, 2, offset);
+ uint8_t version = buffer[0];
+ uint8_t configured_bits = buffer[1];
if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion);
config.pointer_bhiksha_bits = configured_bits;
}
@@ -87,9 +89,6 @@ void ArrayBhiksha::FinishedLoading(const Config &config) {
*(head_write++) = config.pointer_bhiksha_bits;
}
-void ArrayBhiksha::LoadedBinary() {
-}
-
} // namespace trie
} // namespace ngram
} // namespace lm
diff --git a/lm/bhiksha.hh b/lm/bhiksha.hh
index 8ff8865..db71766 100644
--- a/lm/bhiksha.hh
+++ b/lm/bhiksha.hh
@@ -10,8 +10,8 @@
* Currently only used for next pointers.
*/
-#ifndef LM_BHIKSHA__
-#define LM_BHIKSHA__
+#ifndef LM_BHIKSHA_H
+#define LM_BHIKSHA_H
#include <stdint.h>
#include <assert.h>
@@ -24,6 +24,7 @@
namespace lm {
namespace ngram {
struct Config;
+class BinaryFormat;
namespace trie {
@@ -31,7 +32,7 @@ class DontBhiksha {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
- static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {}
static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
@@ -53,8 +54,6 @@ class DontBhiksha {
void FinishedLoading(const Config &/*config*/) {}
- void LoadedBinary() {}
-
uint8_t InlineBits() const { return next_.bits; }
private:
@@ -65,7 +64,7 @@ class ArrayBhiksha {
public:
static const ModelType kModelTypeAdd = kArrayAdd;
- static void UpdateConfigFromBinary(int fd, Config &config);
+ static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
@@ -93,8 +92,6 @@ class ArrayBhiksha {
void FinishedLoading(const Config &config);
- void LoadedBinary();
-
uint8_t InlineBits() const { return next_inline_.bits; }
private:
@@ -112,4 +109,4 @@ class ArrayBhiksha {
} // namespace ngram
} // namespace lm
-#endif // LM_BHIKSHA__
+#endif // LM_BHIKSHA_H
diff --git a/lm/binary_format.cc b/lm/binary_format.cc
index bef51eb..9c744b1 100644
--- a/lm/binary_format.cc
+++ b/lm/binary_format.cc
@@ -14,6 +14,9 @@
namespace lm {
namespace ngram {
+
+const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
+
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
@@ -58,8 +61,6 @@ struct Sanity {
}
};
-const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
-
std::size_t TotalHeaderSize(unsigned char order) {
return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
}
@@ -81,83 +82,6 @@ void WriteHeader(void *to, const Parameters &params) {
} // namespace
-uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
- if (config.write_mmap) {
- std::size_t total = TotalHeaderSize(order) + memory_size;
- backing.file.reset(util::CreateOrThrow(config.write_mmap));
- if (config.write_method == Config::WRITE_MMAP) {
- backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
- } else {
- util::ResizeOrThrow(backing.file.get(), 0);
- util::MapAnonymous(total, backing.vocab);
- }
- strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
- return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
- } else {
- util::MapAnonymous(memory_size, backing.vocab);
- return reinterpret_cast<uint8_t*>(backing.vocab.get());
- }
-}
-
-uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
- std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
- if (config.write_mmap) {
- // Grow the file to accomodate the search, using zeros.
- try {
- util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
- } catch (util::ErrnoException &e) {
- e << " for file " << config.write_mmap;
- throw e;
- }
-
- if (config.write_method == Config::WRITE_AFTER) {
- util::MapAnonymous(memory_size, backing.search);
- return reinterpret_cast<uint8_t*>(backing.search.get());
- }
- // mmap it now.
- // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
- std::size_t page_size = util::SizePage();
- std::size_t alignment_cruft = adjusted_vocab % page_size;
- backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
- return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
- } else {
- util::MapAnonymous(memory_size, backing.search);
- return reinterpret_cast<uint8_t*>(backing.search.get());
- }
-}
-
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
- if (!config.write_mmap) return;
- switch (config.write_method) {
- case Config::WRITE_MMAP:
- util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
- util::SyncOrThrow(backing.search.get(), backing.search.size());
- break;
- case Config::WRITE_AFTER:
- util::SeekOrThrow(backing.file.get(), 0);
- util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size());
- util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
- util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
- util::FSyncOrThrow(backing.file.get());
- break;
- }
- // header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params = Parameters();
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
- params.fixed.search_version = search_version;
- WriteHeader(backing.vocab.get(), params);
- if (config.write_method == Config::WRITE_AFTER) {
- util::SeekOrThrow(backing.file.get(), 0);
- util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size()));
- }
-}
-
-namespace detail {
-
bool IsBinaryFormat(int fd) {
const uint64_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
@@ -209,44 +133,164 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version);
}
-void SeekPastHeader(int fd, const Parameters &params) {
- util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
+const std::size_t kInvalidSize = static_cast<std::size_t>(-1);
+
+BinaryFormat::BinaryFormat(const Config &config)
+ : write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method),
+ header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {}
+
+void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params) {
+ file_.reset(fd);
+ write_mmap_ = NULL; // Ignore write requests; this is already in binary format.
+ ReadHeader(fd, params);
+ MatchCheck(model_type, search_version, params);
+ header_size_ = TotalHeaderSize(params.counts.size());
+}
+
+void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const {
+ assert(header_size_ != kInvalidSize);
+ util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_);
}
-uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
- const uint64_t file_size = util::SizeFile(backing.file.get());
+void *BinaryFormat::LoadBinary(std::size_t size) {
+ assert(header_size_ != kInvalidSize);
+ const uint64_t file_size = util::SizeFile(file_.get());
// The header is smaller than a page, so we have to map the whole header as well.
- std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
- if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
- UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
+ uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size);
+ UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
- util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search);
+ util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_);
- if (config.enumerate_vocab && !params.fixed.has_vocabulary)
- UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+ vocab_string_offset_ = total_map;
+ return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
+}
+
+void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
+ vocab_size_ = memory_size;
+ if (!write_mmap_) {
+ header_size_ = 0;
+ util::MapAnonymous(memory_size, memory_vocab_);
+ return reinterpret_cast<uint8_t*>(memory_vocab_.get());
+ }
+ header_size_ = TotalHeaderSize(order);
+ std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size));
+ file_.reset(util::CreateOrThrow(write_mmap_));
+ // some gccs complain about uninitialized variables even though all enum values are covered.
+ void *vocab_base = NULL;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
+ vocab_base = mapping_.get();
+ break;
+ case Config::WRITE_AFTER:
+ util::ResizeOrThrow(file_.get(), 0);
+ util::MapAnonymous(total, memory_vocab_);
+ vocab_base = memory_vocab_.get();
+ break;
+ }
+ strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_);
+ return reinterpret_cast<uint8_t*>(vocab_base) + header_size_;
+}
- // Seek to vocabulary words
- util::SeekOrThrow(backing.file.get(), total_map);
- return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
+void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) {
+ assert(vocab_size_ != kInvalidSize);
+ vocab_pad_ = vocab_pad;
+ std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size;
+ vocab_string_offset_ = new_size;
+ if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) {
+ util::MapAnonymous(memory_size, memory_search_);
+ assert(header_size_ == 0 || write_mmap_);
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ return reinterpret_cast<uint8_t*>(memory_search_.get());
+ }
+
+ assert(write_method_ == Config::WRITE_MMAP);
+ // Also known as total size without vocab words.
+ // Grow the file to accomodate the search, using zeros.
+ // According to man mmap, behavior is undefined when the file is resized
+ // underneath a mmap that is not a multiple of the page size. So to be
+ // safe, we'll unmap it and map it again.
+ mapping_.reset();
+ util::ResizeOrThrow(file_.get(), new_size);
+ void *ret;
+ MapFile(vocab_base, ret);
+ return ret;
}
-void ComplainAboutARPA(const Config &config, ModelType model_type) {
- if (config.write_mmap || !config.messages) return;
- if (config.arpa_complain == Config::ALL) {
- *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
- } else if (config.arpa_complain == Config::EXPENSIVE &&
- (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
- *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) {
+ // Checking Config's include_vocab is the responsibility of the caller.
+ assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize);
+ if (!write_mmap_) {
+ // Unchanged base.
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get());
+ search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
+ return;
+ }
+ if (write_method_ == Config::WRITE_MMAP) {
+ mapping_.reset();
+ }
+ util::SeekOrThrow(file_.get(), VocabStringReadingOffset());
+ util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
+ if (write_method_ == Config::WRITE_MMAP) {
+ MapFile(vocab_base, search_base);
+ } else {
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
+ }
+}
+
+void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) {
+ if (!write_mmap_) return;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ util::SyncOrThrow(mapping_.get(), mapping_.size());
+ break;
+ case Config::WRITE_AFTER:
+ util::SeekOrThrow(file_.get(), 0);
+ util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size());
+ util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_);
+ util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size());
+ util::FSyncOrThrow(file_.get());
+ break;
+ }
+ // header and vocab share the same mmap.
+ Parameters params = Parameters();
+ memset(&params, 0, sizeof(Parameters));
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ params.fixed.search_version = search_version;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ WriteHeader(mapping_.get(), params);
+ util::SyncOrThrow(mapping_.get(), mapping_.size());
+ break;
+ case Config::WRITE_AFTER:
+ {
+ std::vector<uint8_t> buffer(TotalHeaderSize(counts.size()));
+ WriteHeader(&buffer[0], params);
+ util::SeekOrThrow(file_.get(), 0);
+ util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
+ }
+ break;
}
}
-} // namespace detail
+void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) {
+ mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED);
+ vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
+ search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_;
+}
bool RecognizeBinary(const char *file, ModelType &recognized) {
util::scoped_fd fd(util::OpenReadOrThrow(file));
- if (!detail::IsBinaryFormat(fd.get())) return false;
+ if (!IsBinaryFormat(fd.get())) {
+ return false;
+ }
Parameters params;
- detail::ReadHeader(fd.get(), params);
+ ReadHeader(fd.get(), params);
recognized = params.fixed.model_type;
return true;
}
diff --git a/lm/binary_format.hh b/lm/binary_format.hh
index bf699d5..136d6b1 100644
--- a/lm/binary_format.hh
+++ b/lm/binary_format.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BINARY_FORMAT__
-#define LM_BINARY_FORMAT__
+#ifndef LM_BINARY_FORMAT_H
+#define LM_BINARY_FORMAT_H
#include "lm/config.hh"
#include "lm/model_type.hh"
@@ -17,6 +17,8 @@
namespace lm {
namespace ngram {
+extern const char *kModelNames[6];
+
/*Inspect a file to determine if it is a binary lm. If not, return false.
* If so, return true and set recognized to the type. This is the only API in
* this header designed for use by decoder authors.
@@ -42,67 +44,63 @@ struct Parameters {
std::vector<uint64_t> counts;
};
-struct Backing {
- // File behind memory, if any.
- util::scoped_fd file;
- // Vocabulary lookup table. Not to be confused with the vocab words themselves.
- util::scoped_memory vocab;
- // Raw block of memory backing the language model data structures
- util::scoped_memory search;
-};
-
-// Create just enough of a binary file to write vocabulary to it.
-uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
-// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
-uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing);
-
-// Write header to binary file. This is done last to prevent incomplete files
-// from loading.
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing);
+class BinaryFormat {
+ public:
+ explicit BinaryFormat(const Config &config);
+
+ // Reading a binary file:
+ // Takes ownership of fd
+ void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params);
+ // Used to read parts of the file to update the config object before figuring out full size.
+ void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const;
+ // Actually load the binary file and return a pointer to the beginning of the search area.
+ void *LoadBinary(std::size_t size);
+
+ uint64_t VocabStringReadingOffset() const {
+ assert(vocab_string_offset_ != kInvalidOffset);
+ return vocab_string_offset_;
+ }
-namespace detail {
+ // Writing a binary file or initializing in RAM from ARPA:
+ // Size for vocabulary.
+ void *SetupJustVocab(std::size_t memory_size, uint8_t order);
+ // Warning: can change the vocaulary base pointer.
+ void *GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base);
+ // Warning: can change vocabulary and search base addresses.
+ void WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base);
+ // Write the header at the beginning of the file.
+ void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts);
+
+ private:
+ void MapFile(void *&vocab_base, void *&search_base);
+
+ // Copied from configuration.
+ const Config::WriteMethod write_method_;
+ const char *write_mmap_;
+ util::LoadMethod load_method_;
+
+ // File behind memory, if any.
+ util::scoped_fd file_;
+
+ // If there is a file involved, a single mapping.
+ util::scoped_memory mapping_;
+
+ // If the data is only in memory, separately allocate each because the trie
+ // knows vocab's size before it knows search's size (because SRILM might
+ // have pruned).
+ util::scoped_memory memory_vocab_, memory_search_;
+
+ // Memory ranges. Note that these may not be contiguous and may not all
+ // exist.
+ std::size_t header_size_, vocab_size_, vocab_pad_;
+ // aka end of search.
+ uint64_t vocab_string_offset_;
+
+ static const uint64_t kInvalidOffset = (uint64_t)-1;
+};
bool IsBinaryFormat(int fd);
-void ReadHeader(int fd, Parameters &params);
-
-void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters &params);
-
-void SeekPastHeader(int fd, const Parameters &params);
-
-uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing);
-
-void ComplainAboutARPA(const Config &config, ModelType model_type);
-
-} // namespace detail
-
-template <class To> void LoadLM(const char *file, const Config &config, To &to) {
- Backing &backing = to.MutableBacking();
- backing.file.reset(util::OpenReadOrThrow(file));
-
- try {
- if (detail::IsBinaryFormat(backing.file.get())) {
- Parameters params;
- detail::ReadHeader(backing.file.get(), params);
- detail::MatchCheck(To::kModelType, To::kVersion, params);
- // Replace the run-time configured probing_multiplier with the one in the file.
- Config new_config(config);
- new_config.probing_multiplier = params.fixed.probing_multiplier;
- detail::SeekPastHeader(backing.file.get(), params);
- To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config);
- uint64_t memory_size = To::Size(params.counts, new_config);
- uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);
- to.InitializeFromBinary(start, params, new_config, backing.file.get());
- } else {
- detail::ComplainAboutARPA(config, To::kModelType);
- to.InitializeFromARPA(file, config);
- }
- } catch (util::Exception &e) {
- e << " File: " << file;
- throw;
- }
-}
-
} // namespace ngram
} // namespace lm
-#endif // LM_BINARY_FORMAT__
+#endif // LM_BINARY_FORMAT_H
diff --git a/lm/blank.hh b/lm/blank.hh
index 4da8120..94a71ad 100644
--- a/lm/blank.hh
+++ b/lm/blank.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BLANK__
-#define LM_BLANK__
+#ifndef LM_BLANK_H
+#define LM_BLANK_H
#include <limits>
@@ -40,4 +40,4 @@ inline bool HasExtension(const float &backoff) {
} // namespace ngram
} // namespace lm
-#endif // LM_BLANK__
+#endif // LM_BLANK_H
diff --git a/lm/build_binary_main.cc b/lm/build_binary_main.cc
index 425a123..15b421e 100644
--- a/lm/build_binary_main.cc
+++ b/lm/build_binary_main.cc
@@ -52,6 +52,7 @@ void Usage(const char *name, const char *default_mem) {
"-a compresses pointers using an array of offsets. The parameter is the\n"
" maximum number of bits encoded by the array. Memory is minimized subject\n"
" to the maximum, so pick 255 to minimize memory.\n\n"
+"-h print this help message.\n\n"
"Get a memory estimate by passing an ARPA file without an output file name.\n";
exit(1);
}
@@ -104,12 +105,15 @@ int main(int argc, char *argv[]) {
const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G";
+ if (argc == 2 && !strcmp(argv[1], "--help"))
+ Usage(argv[0], default_mem);
+
try {
bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false;
lm::ngram::Config config;
config.building_memory = util::ParseSize(default_mem);
int opt;
- while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) {
+ while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:h")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -161,6 +165,7 @@ int main(int argc, char *argv[]) {
ParseFileList(optarg, config.rest_lower_files);
config.rest_function = Config::REST_LOWER;
break;
+ case 'h': // help
default:
Usage(argv[0], default_mem);
}
diff --git a/lm/builder/adjust_counts.hh b/lm/builder/adjust_counts.hh
index ea8a1e2..00b5d43 100644
--- a/lm/builder/adjust_counts.hh
+++ b/lm/builder/adjust_counts.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_ADJUST_COUNTS__
-#define LM_BUILDER_ADJUST_COUNTS__
+#ifndef LM_BUILDER_ADJUST_COUNTS_H
+#define LM_BUILDER_ADJUST_COUNTS_H
#include "lm/builder/discount.hh"
#include "util/exception.hh"
@@ -43,5 +43,5 @@ class AdjustCounts {
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_ADJUST_COUNTS__
+#endif // LM_BUILDER_ADJUST_COUNTS_H
diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc
index aea93ad..b99edd0 100644
--- a/lm/builder/corpus_count.cc
+++ b/lm/builder/corpus_count.cc
@@ -223,29 +223,47 @@ std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
return VocabHandout::MemUsage(vocab_estimate);
}
-CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
+CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
- dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
+ dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)),
+ disallowed_symbol_action_(disallowed_symbol) {
}
-void CorpusCount::Run(const util::stream::ChainPosition &position) {
- UTIL_TIMER("(%w s) Counted n-grams\n");
+namespace {
+ void ComplainDisallowed(StringPiece word, WarningAction &action) {
+ switch (action) {
+ case SILENT:
+ return;
+ case COMPLAIN:
+ std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl;
+ action = SILENT;
+ return;
+ case THROW_UP:
+ UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace.");
+ }
+ }
+} // namespace
+void CorpusCount::Run(const util::stream::ChainPosition &position) {
VocabHandout vocab(vocab_write_, type_count_);
token_count_ = 0;
type_count_ = 0;
const WordIndex end_sentence = vocab.Lookup("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
- StringPiece delimiters("\0\t\r ", 4);
+ bool delimiters[256];
+ util::BoolCharacter::Build("\0\t\n\r ", delimiters);
try {
while(true) {
StringPiece line(from_.ReadLine());
writer.StartSentence();
- for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) {
+ for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {
WordIndex word = vocab.Lookup(*w);
- UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future.");
+ if (word <= 2) {
+ ComplainDisallowed(*w, disallowed_symbol_action_);
+ continue;
+ }
writer.Append(word);
++count;
}
diff --git a/lm/builder/corpus_count.hh b/lm/builder/corpus_count.hh
index aa0ed8e..da4ff9f 100644
--- a/lm/builder/corpus_count.hh
+++ b/lm/builder/corpus_count.hh
@@ -1,6 +1,7 @@
-#ifndef LM_BUILDER_CORPUS_COUNT__
-#define LM_BUILDER_CORPUS_COUNT__
+#ifndef LM_BUILDER_CORPUS_COUNT_H
+#define LM_BUILDER_CORPUS_COUNT_H
+#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
#include "util/scoped.hh"
@@ -28,7 +29,7 @@ class CorpusCount {
// token_count: out.
// type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
- CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
+ CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol);
void Run(const util::stream::ChainPosition &position);
@@ -40,8 +41,10 @@ class CorpusCount {
std::size_t dedupe_mem_size_;
util::scoped_malloc dedupe_mem_;
+
+ WarningAction disallowed_symbol_action_;
};
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_CORPUS_COUNT__
+#endif // LM_BUILDER_CORPUS_COUNT_H
diff --git a/lm/builder/corpus_count_test.cc b/lm/builder/corpus_count_test.cc
index 6d325ef..26cb634 100644
--- a/lm/builder/corpus_count_test.cc
+++ b/lm/builder/corpus_count_test.cc
@@ -45,7 +45,7 @@ BOOST_AUTO_TEST_CASE(Short) {
NGramStream stream;
uint64_t token_count;
WordIndex type_count = 10;
- CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize(), SILENT);
chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"};
diff --git a/lm/builder/discount.hh b/lm/builder/discount.hh
index 4d0aa4f..e2f4084 100644
--- a/lm/builder/discount.hh
+++ b/lm/builder/discount.hh
@@ -1,5 +1,5 @@
-#ifndef BUILDER_DISCOUNT__
-#define BUILDER_DISCOUNT__
+#ifndef LM_BUILDER_DISCOUNT_H
+#define LM_BUILDER_DISCOUNT_H
#include <algorithm>
@@ -23,4 +23,4 @@ struct Discount {
} // namespace builder
} // namespace lm
-#endif // BUILDER_DISCOUNT__
+#endif // LM_BUILDER_DISCOUNT_H
diff --git a/lm/builder/header_info.hh b/lm/builder/header_info.hh
index ccca145..16f3f60 100644
--- a/lm/builder/header_info.hh
+++ b/lm/builder/header_info.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_HEADER_INFO__
-#define LM_BUILDER_HEADER_INFO__
+#ifndef LM_BUILDER_HEADER_INFO_H
+#define LM_BUILDER_HEADER_INFO_H
#include <string>
#include <stdint.h>
diff --git a/lm/builder/initial_probabilities.hh b/lm/builder/initial_probabilities.hh
index 0729a6a..f2a2975 100644
--- a/lm/builder/initial_probabilities.hh
+++ b/lm/builder/initial_probabilities.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_INITIAL_PROBABILITIES__
-#define LM_BUILDER_INITIAL_PROBABILITIES__
+#ifndef LM_BUILDER_INITIAL_PROBABILITIES_H
+#define LM_BUILDER_INITIAL_PROBABILITIES_H
#include "lm/builder/discount.hh"
#include "util/stream/config.hh"
@@ -31,4 +31,4 @@ void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::v
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_INITIAL_PROBABILITIES__
+#endif // LM_BUILDER_INITIAL_PROBABILITIES_H
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc
index cbf4bb4..4f8ffed 100644
--- a/lm/builder/interpolate.cc
+++ b/lm/builder/interpolate.cc
@@ -5,6 +5,7 @@
#include "lm/builder/multi_stream.hh"
#include "lm/builder/sort.hh"
#include "lm/lm_exception.hh"
+#include "util/fixed_array.hh"
#include "util/murmur_hash.hh"
#include <assert.h>
@@ -74,15 +75,17 @@ class Callback {
void Exit(unsigned, const NGram &) const {}
private:
- FixedArray<util::stream::Stream> backoffs_;
+ util::FixedArray<util::stream::Stream> backoffs_;
std::vector<float> probs_;
const std::vector<uint64_t>& prune_thresholds_;
};
} // namespace
-Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds)
- : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs), prune_thresholds_(prune_thresholds) {}
+Interpolate::Interpolate(uint64_t vocab_size, const ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds)
+ : uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
+ backoffs_(backoffs),
+ prune_thresholds_(prune_thresholds) {}
// perform order-wise interpolation
void Interpolate::Run(const ChainPositions &positions) {
diff --git a/lm/builder/interpolate.hh b/lm/builder/interpolate.hh
index cc372b4..d266191 100644
--- a/lm/builder/interpolate.hh
+++ b/lm/builder/interpolate.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_INTERPOLATE__
-#define LM_BUILDER_INTERPOLATE__
+#ifndef LM_BUILDER_INTERPOLATE_H
+#define LM_BUILDER_INTERPOLATE_H
#include <stdint.h>
@@ -14,8 +14,9 @@ namespace lm { namespace builder {
*/
class Interpolate {
public:
- explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs,
- const std::vector<uint64_t> &prune_thresholds_);
+ // Normally vocab_size is the unigram count-1 (since p(<s>) = 0) but might
+ // be larger when the user specifies a consistent vocabulary size.
+ explicit Interpolate(uint64_t vocab_size, const ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds);
void Run(const ChainPositions &positions);
@@ -26,4 +27,4 @@ class Interpolate {
};
}} // namespaces
-#endif // LM_BUILDER_INTERPOLATE__
+#endif // LM_BUILDER_INTERPOLATE_H
diff --git a/lm/builder/joint_order.hh b/lm/builder/joint_order.hh
index b562014..b9c22a0 100644
--- a/lm/builder/joint_order.hh
+++ b/lm/builder/joint_order.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_JOINT_ORDER__
-#define LM_BUILDER_JOINT_ORDER__
+#ifndef LM_BUILDER_JOINT_ORDER_H
+#define LM_BUILDER_JOINT_ORDER_H
#include "lm/builder/multi_stream.hh"
#include "lm/lm_exception.hh"
@@ -40,4 +40,4 @@ template <class Callback, class Compare> void JointOrder(const ChainPositions &p
}} // namespaces
-#endif // LM_BUILDER_JOINT_ORDER__
+#endif // LM_BUILDER_JOINT_ORDER_H
diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc
index 1e70972..c25fda1 100644
--- a/lm/builder/lmplz_main.cc
+++ b/lm/builder/lmplz_main.cc
@@ -1,4 +1,5 @@
#include "lm/builder/pipeline.hh"
+#include "lm/lm_exception.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/usage.hh"
@@ -79,24 +80,30 @@ int main(int argc, char *argv[]) {
options.add_options()
+ ("help,h", po::bool_switch(), "Show this help message")
("order,o", po::value<std::size_t>(&pipeline.order)
#if BOOST_VERSION >= 104200
->required()
#endif
, "Order of the model")
("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
+ ("skip_symbols", po::bool_switch(), "Treat <s>, </s>, and <unk> as whitespace instead of throwing an exception")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
- ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
- ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
+ ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
+ ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write a file containing the unique vocabulary strings delimited by null bytes")
+ ("vocab_pad", po::value<std::size_t>(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with <unk> to reach this size. Requires --interpolate_unigrams")
("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
("prune,P", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Unigram pruning is not implemented, so the first value must be zero. Default is to not prune, which is equivalent to -prune 0.");
- if (argc == 1) {
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, options), vm);
+
+ if (argc == 1 || vm["help"].as<bool>()) {
std::cerr <<
"Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
"Please cite:\n"
@@ -114,12 +121,17 @@ int main(int argc, char *argv[]) {
"setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n"
"Memory sizes are specified like GNU sort: a number followed by a unit character.\n"
"Valid units are \% for percentage of memory (supported platforms only) and (in\n"
- "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n";
+ "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n";
+ uint64_t mem = util::GuessPhysicalMemory();
+ if (mem) {
+ std::cerr << "This machine has " << mem << " bytes of memory.\n\n";
+ } else {
+ std::cerr << "Unable to determine the amount of memory on this machine.\n\n";
+ }
std::cerr << options << std::endl;
return 1;
}
- po::variables_map vm;
- po::store(po::parse_command_line(argc, argv, options), vm);
+
po::notify(vm);
// required() appeared in Boost 1.42.0.
@@ -130,6 +142,17 @@ int main(int argc, char *argv[]) {
}
#endif
+ if (pipeline.vocab_size_for_unk && !pipeline.initial_probs.interpolate_unigrams) {
+ std::cerr << "--vocab_pad requires --interpolate_unigrams" << std::endl;
+ return 1;
+ }
+
+ if (vm["skip_symbols"].as<bool>()) {
+ pipeline.disallowed_symbol_action = lm::COMPLAIN;
+ } else {
+ pipeline.disallowed_symbol_action = lm::THROW_UP;
+ }
+
// parse pruning thresholds. These depend on order, so it is not done as a notifier.
pipeline.prune_thresholds = ParsePruning(pruning, pipeline.order);
diff --git a/lm/builder/multi_stream.hh b/lm/builder/multi_stream.hh
index 707a98c..1a8eb8b 100644
--- a/lm/builder/multi_stream.hh
+++ b/lm/builder/multi_stream.hh
@@ -1,7 +1,8 @@
-#ifndef LM_BUILDER_MULTI_STREAM__
-#define LM_BUILDER_MULTI_STREAM__
+#ifndef LM_BUILDER_MULTI_STREAM_H
+#define LM_BUILDER_MULTI_STREAM_H
#include "lm/builder/ngram_stream.hh"
+#include "util/fixed_array.hh"
#include "util/scoped.hh"
#include "util/stream/chain.hh"
@@ -13,72 +14,9 @@
namespace lm { namespace builder {
-template <class T> class FixedArray {
- public:
- explicit FixedArray(std::size_t count) {
- Init(count);
- }
-
- FixedArray() : newed_end_(NULL) {}
-
- void Init(std::size_t count) {
- assert(!block_.get());
- block_.reset(malloc(sizeof(T) * count));
- if (!block_.get()) throw std::bad_alloc();
- newed_end_ = begin();
- }
-
- FixedArray(const FixedArray &from) {
- std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
- Init(size);
- for (std::size_t i = 0; i < size; ++i) {
- new(end()) T(from[i]);
- Constructed();
- }
- }
-
- ~FixedArray() { clear(); }
-
- T *begin() { return static_cast<T*>(block_.get()); }
- const T *begin() const { return static_cast<const T*>(block_.get()); }
- // Always call Constructed after successful completion of new.
- T *end() { return newed_end_; }
- const T *end() const { return newed_end_; }
-
- T &back() { return *(end() - 1); }
- const T &back() const { return *(end() - 1); }
-
- std::size_t size() const { return end() - begin(); }
- bool empty() const { return begin() == end(); }
-
- T &operator[](std::size_t i) { return begin()[i]; }
- const T &operator[](std::size_t i) const { return begin()[i]; }
-
- template <class C> void push_back(const C &c) {
- new (end()) T(c);
- Constructed();
- }
-
- void clear() {
- for (T *i = begin(); i != end(); ++i)
- i->~T();
- newed_end_ = begin();
- }
-
- protected:
- void Constructed() {
- ++newed_end_;
- }
-
- private:
- util::scoped_malloc block_;
-
- T *newed_end_;
-};
-
class Chains;
-class ChainPositions : public FixedArray<util::stream::ChainPosition> {
+class ChainPositions : public util::FixedArray<util::stream::ChainPosition> {
public:
ChainPositions() {}
@@ -89,14 +27,14 @@ class ChainPositions : public FixedArray<util::stream::ChainPosition> {
}
};
-class Chains : public FixedArray<util::stream::Chain> {
+class Chains : public util::FixedArray<util::stream::Chain> {
private:
template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun {
typedef Chains type;
};
public:
- explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {}
+ explicit Chains(std::size_t limit) : util::FixedArray<util::stream::Chain>(limit) {}
template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
@@ -129,7 +67,7 @@ class Chains : public FixedArray<util::stream::Chain> {
};
inline void ChainPositions::Init(Chains &chains) {
- FixedArray<util::stream::ChainPosition>::Init(chains.size());
+ util::FixedArray<util::stream::ChainPosition>::Init(chains.size());
for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) {
new (end()) util::stream::ChainPosition(i->Add()); Constructed();
}
@@ -140,13 +78,13 @@ inline Chains &operator>>(Chains &chains, ChainPositions &positions) {
return chains;
}
-class NGramStreams : public FixedArray<NGramStream> {
+class NGramStreams : public util::FixedArray<NGramStream> {
public:
NGramStreams() {}
// This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning).
void InitWithDummy(const ChainPositions &positions) {
- FixedArray<NGramStream>::Init(positions.size() + 1);
+ util::FixedArray<NGramStream>::Init(positions.size() + 1);
new (end()) NGramStream(); Constructed();
for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) {
push_back(*i);
@@ -155,7 +93,7 @@ class NGramStreams : public FixedArray<NGramStream> {
// Limit restricts to positions[0,limit)
void Init(const ChainPositions &positions, std::size_t limit) {
- FixedArray<NGramStream>::Init(limit);
+ util::FixedArray<NGramStream>::Init(limit);
for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) {
push_back(*i);
}
@@ -177,4 +115,4 @@ inline Chains &operator>>(Chains &chains, NGramStreams &streams) {
}
}} // namespaces
-#endif // LM_BUILDER_MULTI_STREAM__
+#endif // LM_BUILDER_MULTI_STREAM_H
diff --git a/lm/builder/ngram.hh b/lm/builder/ngram.hh
index 756eaa6..0472bcb 100644
--- a/lm/builder/ngram.hh
+++ b/lm/builder/ngram.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_NGRAM__
-#define LM_BUILDER_NGRAM__
+#ifndef LM_BUILDER_NGRAM_H
+#define LM_BUILDER_NGRAM_H
#include "lm/weights.hh"
#include "lm/word_index.hh"
@@ -106,4 +106,4 @@ const WordIndex kEOS = 2;
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_NGRAM__
+#endif // LM_BUILDER_NGRAM_H
diff --git a/lm/builder/ngram_stream.hh b/lm/builder/ngram_stream.hh
index 3c99466..d7bf23a 100644
--- a/lm/builder/ngram_stream.hh
+++ b/lm/builder/ngram_stream.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_NGRAM_STREAM__
-#define LM_BUILDER_NGRAM_STREAM__
+#ifndef LM_BUILDER_NGRAM_STREAM_H
+#define LM_BUILDER_NGRAM_STREAM_H
#include "lm/builder/ngram.hh"
#include "util/stream/chain.hh"
@@ -52,4 +52,4 @@ inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &
}
}} // namespaces
-#endif // LM_BUILDER_NGRAM_STREAM__
+#endif // LM_BUILDER_NGRAM_STREAM_H
diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc
index f5548f7..c136753 100644
--- a/lm/builder/pipeline.cc
+++ b/lm/builder/pipeline.cc
@@ -204,7 +204,7 @@ class Master {
Chains chains_;
// Often only unigrams, but sometimes all orders.
- FixedArray<util::stream::FileBuffer> files_;
+ util::FixedArray<util::stream::FileBuffer> files_;
};
void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) {
@@ -225,17 +225,18 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
WordIndex type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
- CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
+ CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
chain >> boost::ref(counter);
util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
chain.Wait(true);
+ std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;
std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;
master.InitForAdjust(sorter, type_count);
}
void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary,
- FixedArray<util::stream::FileBuffer> &gammas, std::vector<uint64_t> &prune_thresholds) {
+ util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds) {
const PipelineConfig &config = master.Config();
Chains second(config.order);
@@ -261,7 +262,7 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector
master.SetupSorts(primary);
}
-void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) {
+void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas) {
std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl;
const PipelineConfig &config = master.Config();
master.MaximumLazyInput(counts, primary);
@@ -279,7 +280,7 @@ void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &maste
gamma_chains.push_back(read_backoffs);
gamma_chains.back() >> gammas[i].Source();
}
- master >> Interpolate(counts[0], ChainPositions(gamma_chains), config.prune_thresholds);
+ master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), ChainPositions(gamma_chains), config.prune_thresholds);
gamma_chains >> util::stream::kRecycle;
master.BufferFinal(counts);
}
@@ -316,7 +317,7 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
master >> AdjustCounts(counts, counts_pruned, discounts, config.prune_thresholds);
{
- FixedArray<util::stream::FileBuffer> gammas;
+ util::FixedArray<util::stream::FileBuffer> gammas;
Sorts<SuffixOrder> primary;
InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds);
InterpolateProbabilities(counts_pruned, master, primary, gammas);
diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh
index a937169..4395622 100644
--- a/lm/builder/pipeline.hh
+++ b/lm/builder/pipeline.hh
@@ -1,8 +1,9 @@
-#ifndef LM_BUILDER_PIPELINE__
-#define LM_BUILDER_PIPELINE__
+#ifndef LM_BUILDER_PIPELINE_H
+#define LM_BUILDER_PIPELINE_H
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/header_info.hh"
+#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
#include "util/stream/config.hh"
#include "util/file_piece.hh"
@@ -34,6 +35,24 @@ struct PipelineConfig {
// corresponding n-gram order
std::vector<uint64_t> prune_thresholds; //mjd
+ /* Computing the perplexity of LMs with different vocabularies is hard. For
+ * example, the lowest perplexity is attained by a unigram model that
+ * predicts p(<unk>) = 1 and has no other vocabulary. Also, linearly
+ * interpolated models will sum to more than 1 because <unk> is duplicated
+ * (SRI just pretends p(<unk>) = 0 for these purposes, which makes it sum to
+ * 1 but comes with its own problems). This option will make the vocabulary
+ * a particular size by replicating <unk> multiple times for purposes of
+ * computing vocabulary size. It has no effect if the actual vocabulary is
+ * larger. This parameter serves the same purpose as IRSTLM's "dub".
+ */
+ uint64_t vocab_size_for_unk;
+
+ /* What to do the first time <s>, </s>, or <unk> appears in the input. If
+ * this is anything but THROW_UP, then the symbol will always be treated as
+ * whitespace.
+ */
+ WarningAction disallowed_symbol_action;
+
const std::string &TempPrefix() const { return sort.temp_prefix; }
std::size_t TotalMemory() const { return sort.total_memory; }
};
@@ -42,4 +61,4 @@ struct PipelineConfig {
void Pipeline(PipelineConfig config, int text_file, int out_arpa);
}} // namespaces
-#endif // LM_BUILDER_PIPELINE__
+#endif // LM_BUILDER_PIPELINE_H
diff --git a/lm/builder/print.hh b/lm/builder/print.hh
index adbbb94..397ca95 100644
--- a/lm/builder/print.hh
+++ b/lm/builder/print.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_PRINT__
-#define LM_BUILDER_PRINT__
+#ifndef LM_BUILDER_PRINT_H
+#define LM_BUILDER_PRINT_H
#include "lm/builder/ngram.hh"
#include "lm/builder/multi_stream.hh"
@@ -100,4 +100,4 @@ class PrintARPA {
};
}} // namespaces
-#endif // LM_BUILDER_PRINT__
+#endif // LM_BUILDER_PRINT_H
diff --git a/lm/builder/sort.hh b/lm/builder/sort.hh
index 9989389..c7f2ff8 100644
--- a/lm/builder/sort.hh
+++ b/lm/builder/sort.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_SORT__
-#define LM_BUILDER_SORT__
+#ifndef LM_BUILDER_SORT_H
+#define LM_BUILDER_SORT_H
#include "lm/builder/multi_stream.hh"
#include "lm/builder/ngram.hh"
@@ -85,10 +85,10 @@ struct AddCombiner {
// The combiner is only used on a single chain, so I didn't bother to allow
// that template.
-template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > {
+template <class Compare> class Sorts : public util::FixedArray<util::stream::Sort<Compare> > {
private:
typedef util::stream::Sort<Compare> S;
- typedef FixedArray<S> P;
+ typedef util::FixedArray<S> P;
public:
void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
@@ -100,4 +100,4 @@ template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Comp
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_SORT__
+#endif // LM_BUILDER_SORT_H
diff --git a/lm/config.hh b/lm/config.hh
index 0de7b7c..dab2812 100644
--- a/lm/config.hh
+++ b/lm/config.hh
@@ -1,5 +1,5 @@
-#ifndef LM_CONFIG__
-#define LM_CONFIG__
+#ifndef LM_CONFIG_H
+#define LM_CONFIG_H
#include "lm/lm_exception.hh"
#include "util/mmap.hh"
@@ -120,4 +120,4 @@ struct Config {
} /* namespace ngram */ } /* namespace lm */
-#endif // LM_CONFIG__
+#endif // LM_CONFIG_H
diff --git a/lm/enumerate_vocab.hh b/lm/enumerate_vocab.hh
index 2726362..f5ce789 100644
--- a/lm/enumerate_vocab.hh
+++ b/lm/enumerate_vocab.hh
@@ -1,5 +1,5 @@
-#ifndef LM_ENUMERATE_VOCAB__
-#define LM_ENUMERATE_VOCAB__
+#ifndef LM_ENUMERATE_VOCAB_H
+#define LM_ENUMERATE_VOCAB_H
#include "lm/word_index.hh"
#include "util/string_piece.hh"
@@ -24,5 +24,5 @@ class EnumerateVocab {
} // namespace lm
-#endif // LM_ENUMERATE_VOCAB__
+#endif // LM_ENUMERATE_VOCAB_H
diff --git a/lm/facade.hh b/lm/facade.hh
index 760e839..8e12b62 100644
--- a/lm/facade.hh
+++ b/lm/facade.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FACADE__
-#define LM_FACADE__
+#ifndef LM_FACADE_H
+#define LM_FACADE_H
#include "lm/virtual_interface.hh"
#include "util/string_piece.hh"
@@ -17,14 +17,14 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ
typedef VocabularyT Vocabulary;
/* Translate from void* to State */
- FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const {
+ FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->FullScore(
*reinterpret_cast<const State*>(in_state),
new_word,
*reinterpret_cast<State*>(out_state));
}
- FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const {
+ FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->FullScoreForgotState(
context_rbegin,
context_rend,
@@ -37,7 +37,7 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ
return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob;
}
- float Score(const void *in_state, const WordIndex new_word, void *out_state) const {
+ float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const {
return static_cast<const Child*>(this)->Score(
*reinterpret_cast<const State*>(in_state),
new_word,
@@ -70,4 +70,4 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ
} // mamespace base
} // namespace lm
-#endif // LM_FACADE__
+#endif // LM_FACADE_H
diff --git a/lm/filter/arpa_io.hh b/lm/filter/arpa_io.hh
index 5b31620..99c97b1 100644
--- a/lm/filter/arpa_io.hh
+++ b/lm/filter/arpa_io.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_ARPA_IO__
-#define LM_FILTER_ARPA_IO__
+#ifndef LM_FILTER_ARPA_IO_H
+#define LM_FILTER_ARPA_IO_H
/* Input and output for ARPA format language model files.
*/
#include "lm/read_arpa.hh"
@@ -14,7 +14,6 @@
#include <string>
#include <vector>
-#include <err.h>
#include <string.h>
#include <stdint.h>
@@ -112,4 +111,4 @@ template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) {
} // namespace lm
-#endif // LM_FILTER_ARPA_IO__
+#endif // LM_FILTER_ARPA_IO_H
diff --git a/lm/filter/count_io.hh b/lm/filter/count_io.hh
index 97c0fa2..de894ba 100644
--- a/lm/filter/count_io.hh
+++ b/lm/filter/count_io.hh
@@ -1,24 +1,22 @@
-#ifndef LM_FILTER_COUNT_IO__
-#define LM_FILTER_COUNT_IO__
+#ifndef LM_FILTER_COUNT_IO_H
+#define LM_FILTER_COUNT_IO_H
#include <fstream>
#include <iostream>
#include <string>
-#include <err.h>
-
+#include "util/fake_ofstream.hh"
+#include "util/file.hh"
#include "util/file_piece.hh"
namespace lm {
class CountOutput : boost::noncopyable {
public:
- explicit CountOutput(const char *name) : file_(name, std::ios::out) {}
+ explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {}
void AddNGram(const StringPiece &line) {
- if (!(file_ << line << '\n')) {
- err(3, "Writing counts file failed");
- }
+ file_ << line << '\n';
}
template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
@@ -30,7 +28,7 @@ class CountOutput : boost::noncopyable {
}
private:
- std::fstream file_;
+ util::FakeOFStream file_;
};
class CountBatch {
@@ -88,4 +86,4 @@ template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) {
} // namespace lm
-#endif // LM_FILTER_COUNT_IO__
+#endif // LM_FILTER_COUNT_IO_H
diff --git a/lm/filter/filter_main.cc b/lm/filter/filter_main.cc
index 1736bc4..82fdc1e 100644
--- a/lm/filter/filter_main.cc
+++ b/lm/filter/filter_main.cc
@@ -6,6 +6,7 @@
#endif
#include "lm/filter/vocab.hh"
#include "lm/filter/wrapper.hh"
+#include "util/exception.hh"
#include "util/file_piece.hh"
#include <boost/ptr_container/ptr_vector.hpp>
@@ -157,92 +158,96 @@ template <class Format> void DispatchFilterModes(const Config &config, std::istr
} // namespace lm
int main(int argc, char *argv[]) {
- if (argc < 4) {
- lm::DisplayHelp(argv[0]);
- return 1;
- }
+ try {
+ if (argc < 4) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
- // I used to have boost::program_options, but some users didn't want to compile boost.
- lm::Config config;
- config.mode = lm::MODE_UNSET;
- for (int i = 1; i < argc - 2; ++i) {
- const char *str = argv[i];
- if (!std::strcmp(str, "copy")) {
- config.mode = lm::MODE_COPY;
- } else if (!std::strcmp(str, "single")) {
- config.mode = lm::MODE_SINGLE;
- } else if (!std::strcmp(str, "multiple")) {
- config.mode = lm::MODE_MULTIPLE;
- } else if (!std::strcmp(str, "union")) {
- config.mode = lm::MODE_UNION;
- } else if (!std::strcmp(str, "phrase")) {
- config.phrase = true;
- } else if (!std::strcmp(str, "context")) {
- config.context = true;
- } else if (!std::strcmp(str, "arpa")) {
- config.format = lm::FORMAT_ARPA;
- } else if (!std::strcmp(str, "raw")) {
- config.format = lm::FORMAT_COUNT;
+ // I used to have boost::program_options, but some users didn't want to compile boost.
+ lm::Config config;
+ config.mode = lm::MODE_UNSET;
+ for (int i = 1; i < argc - 2; ++i) {
+ const char *str = argv[i];
+ if (!std::strcmp(str, "copy")) {
+ config.mode = lm::MODE_COPY;
+ } else if (!std::strcmp(str, "single")) {
+ config.mode = lm::MODE_SINGLE;
+ } else if (!std::strcmp(str, "multiple")) {
+ config.mode = lm::MODE_MULTIPLE;
+ } else if (!std::strcmp(str, "union")) {
+ config.mode = lm::MODE_UNION;
+ } else if (!std::strcmp(str, "phrase")) {
+ config.phrase = true;
+ } else if (!std::strcmp(str, "context")) {
+ config.context = true;
+ } else if (!std::strcmp(str, "arpa")) {
+ config.format = lm::FORMAT_ARPA;
+ } else if (!std::strcmp(str, "raw")) {
+ config.format = lm::FORMAT_COUNT;
#ifndef NTHREAD
- } else if (!std::strncmp(str, "threads:", 8)) {
- config.threads = boost::lexical_cast<size_t>(str + 8);
- if (!config.threads) {
- std::cerr << "Specify at least one thread." << std::endl;
+ } else if (!std::strncmp(str, "threads:", 8)) {
+ config.threads = boost::lexical_cast<size_t>(str + 8);
+ if (!config.threads) {
+ std::cerr << "Specify at least one thread." << std::endl;
+ return 1;
+ }
+ } else if (!std::strncmp(str, "batch_size:", 11)) {
+ config.batch_size = boost::lexical_cast<size_t>(str + 11);
+ if (config.batch_size < 5000) {
+ std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
+ if (!config.batch_size) return 1;
+ }
+#endif
+ } else {
+ lm::DisplayHelp(argv[0]);
return 1;
}
- } else if (!std::strncmp(str, "batch_size:", 11)) {
- config.batch_size = boost::lexical_cast<size_t>(str + 11);
- if (config.batch_size < 5000) {
- std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
- if (!config.batch_size) return 1;
- }
-#endif
- } else {
+ }
+
+ if (config.mode == lm::MODE_UNSET) {
lm::DisplayHelp(argv[0]);
return 1;
}
- }
-
- if (config.mode == lm::MODE_UNSET) {
- lm::DisplayHelp(argv[0]);
- return 1;
- }
- if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
- std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
- return 1;
- }
+ if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
+ std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
+ return 1;
+ }
- bool cmd_is_model = true;
- const char *cmd_input = argv[argc - 2];
- if (!strncmp(cmd_input, "vocab:", 6)) {
- cmd_is_model = false;
- cmd_input += 6;
- } else if (!strncmp(cmd_input, "model:", 6)) {
- cmd_input += 6;
- } else if (strchr(cmd_input, ':')) {
- errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input);
- } else {
- std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
- }
- std::ifstream cmd_file;
- std::istream *vocab;
- if (cmd_is_model) {
- vocab = &std::cin;
- } else {
- cmd_file.open(cmd_input, std::ios::in);
- if (!cmd_file) {
- err(2, "Could not open input file %s", cmd_input);
+ bool cmd_is_model = true;
+ const char *cmd_input = argv[argc - 2];
+ if (!strncmp(cmd_input, "vocab:", 6)) {
+ cmd_is_model = false;
+ cmd_input += 6;
+ } else if (!strncmp(cmd_input, "model:", 6)) {
+ cmd_input += 6;
+ } else if (strchr(cmd_input, ':')) {
+ std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl;
+ return 1;
+ } else {
+ std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
+ }
+ std::ifstream cmd_file;
+ std::istream *vocab;
+ if (cmd_is_model) {
+ vocab = &std::cin;
+ } else {
+ cmd_file.open(cmd_input, std::ios::in);
+ UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input);
+ vocab = &cmd_file;
}
- vocab = &cmd_file;
- }
- util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
+ util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
- if (config.format == lm::FORMAT_ARPA) {
- lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
- } else if (config.format == lm::FORMAT_COUNT) {
- lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ if (config.format == lm::FORMAT_ARPA) {
+ lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
+ } else if (config.format == lm::FORMAT_COUNT) {
+ lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ }
+ return 0;
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
}
- return 0;
}
diff --git a/lm/filter/format.hh b/lm/filter/format.hh
index 7f945b0..5a2e2db 100644
--- a/lm/filter/format.hh
+++ b/lm/filter/format.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_FORMAT_H__
-#define LM_FITLER_FORMAT_H__
+#ifndef LM_FILTER_FORMAT_H
+#define LM_FILTER_FORMAT_H
#include "lm/filter/arpa_io.hh"
#include "lm/filter/count_io.hh"
@@ -247,4 +247,4 @@ class MultipleOutputBuffer {
} // namespace lm
-#endif // LM_FILTER_FORMAT_H__
+#endif // LM_FILTER_FORMAT_H
diff --git a/lm/filter/phrase.hh b/lm/filter/phrase.hh
index e8e8583..e5898c9 100644
--- a/lm/filter/phrase.hh
+++ b/lm/filter/phrase.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_PHRASE_H__
-#define LM_FILTER_PHRASE_H__
+#ifndef LM_FILTER_PHRASE_H
+#define LM_FILTER_PHRASE_H
#include "util/murmur_hash.hh"
#include "util/string_piece.hh"
@@ -165,4 +165,4 @@ class Multiple : public detail::ConditionCommon {
} // namespace phrase
} // namespace lm
-#endif // LM_FILTER_PHRASE_H__
+#endif // LM_FILTER_PHRASE_H
diff --git a/lm/filter/thread.hh b/lm/filter/thread.hh
index e785b26..6a6523f 100644
--- a/lm/filter/thread.hh
+++ b/lm/filter/thread.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_THREAD_H__
-#define LM_FILTER_THREAD_H__
+#ifndef LM_FILTER_THREAD_H
+#define LM_FILTER_THREAD_H
#include "util/thread_pool.hh"
@@ -164,4 +164,4 @@ template <class Filter, class OutputBuffer, class RealOutput> class Controller :
} // namespace lm
-#endif // LM_FILTER_THREAD_H__
+#endif // LM_FILTER_THREAD_H
diff --git a/lm/filter/vocab.cc b/lm/filter/vocab.cc
index 7ee4e84..011ab59 100644
--- a/lm/filter/vocab.cc
+++ b/lm/filter/vocab.cc
@@ -4,7 +4,6 @@
#include <iostream>
#include <ctype.h>
-#include <err.h>
namespace lm {
namespace vocab {
diff --git a/lm/filter/vocab.hh b/lm/filter/vocab.hh
index 7f0fada..2ee6e1f 100644
--- a/lm/filter/vocab.hh
+++ b/lm/filter/vocab.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_VOCAB_H__
-#define LM_FILTER_VOCAB_H__
+#ifndef LM_FILTER_VOCAB_H
+#define LM_FILTER_VOCAB_H
// Vocabulary-based filters for language models.
@@ -130,4 +130,4 @@ class Multiple {
} // namespace vocab
} // namespace lm
-#endif // LM_FILTER_VOCAB_H__
+#endif // LM_FILTER_VOCAB_H
diff --git a/lm/filter/wrapper.hh b/lm/filter/wrapper.hh
index 90b07a0..822c5c2 100644
--- a/lm/filter/wrapper.hh
+++ b/lm/filter/wrapper.hh
@@ -1,5 +1,5 @@
-#ifndef LM_FILTER_WRAPPER_H__
-#define LM_FILTER_WRAPPER_H__
+#ifndef LM_FILTER_WRAPPER_H
+#define LM_FILTER_WRAPPER_H
#include "util/string_piece.hh"
@@ -39,20 +39,18 @@ template <class FilterT> class ContextFilter {
explicit ContextFilter(Filter &backend) : backend_(backend) {}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
- pieces_.clear();
- // TODO: this copy could be avoided by a lookahead iterator.
- std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_));
- backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output);
+ // Find beginning of string or last space.
+ const char *last_space;
+ for (last_space = ngram.data() + ngram.size() - 1; last_space > ngram.data() && *last_space != ' '; --last_space) {}
+ backend_.AddNGram(StringPiece(ngram.data(), last_space - ngram.data()), line, output);
}
void Flush() const {}
private:
- std::vector<StringPiece> pieces_;
-
Filter backend_;
};
} // namespace lm
-#endif // LM_FILTER_WRAPPER_H__
+#endif // LM_FILTER_WRAPPER_H
diff --git a/lm/left.hh b/lm/left.hh
index 85c1ea3..36d6136 100644
--- a/lm/left.hh
+++ b/lm/left.hh
@@ -35,8 +35,8 @@
* phrase, even if hypotheses are generated left-to-right.
*/
-#ifndef LM_LEFT__
-#define LM_LEFT__
+#ifndef LM_LEFT_H
+#define LM_LEFT_H
#include "lm/max_order.hh"
#include "lm/state.hh"
@@ -213,4 +213,4 @@ template <class M> class RuleScore {
} // namespace ngram
} // namespace lm
-#endif // LM_LEFT__
+#endif // LM_LEFT_H
diff --git a/lm/lm_exception.hh b/lm/lm_exception.hh
index f607ced..8bb6108 100644
--- a/lm/lm_exception.hh
+++ b/lm/lm_exception.hh
@@ -1,5 +1,5 @@
-#ifndef LM_LM_EXCEPTION__
-#define LM_LM_EXCEPTION__
+#ifndef LM_LM_EXCEPTION_H
+#define LM_LM_EXCEPTION_H
// Named to avoid conflict with util/exception.hh.
diff --git a/lm/max_order.hh b/lm/max_order.hh
index 3eb97cc..f7344cd 100644
--- a/lm/max_order.hh
+++ b/lm/max_order.hh
@@ -1,9 +1,13 @@
-/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM.
+#ifndef LM_MAX_ORDER_H
+#define LM_MAX_ORDER_H
+/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER_H, THEN CHANGE THE BUILD SYSTEM.
* If not, this is the default maximum order.
* Having this limit means that State can be
* (kMaxOrder - 1) * sizeof(float) bytes instead of
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
*/
#ifndef KENLM_ORDER_MESSAGE
-#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
+#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER_H, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
#endif
+
+#endif // LM_MAX_ORDER_H
diff --git a/lm/model.cc b/lm/model.cc
index a26654a..a5a16bf 100644
--- a/lm/model.cc
+++ b/lm/model.cc
@@ -34,23 +34,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
}
-template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
- LoadLM(file, config, *this);
-
- // g++ prints warnings unless these are fully initialized.
- State begin_sentence = State();
- begin_sentence.length = 1;
- begin_sentence.words[0] = vocab_.BeginSentence();
- typename Search::Node ignored_node;
- bool ignored_independent_left;
- uint64_t ignored_extend_left;
- begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
- State null_context = State();
- null_context.length = 0;
- P::Init(begin_sentence, null_context, vocab_, search_.Order());
+namespace {
+void ComplainAboutARPA(const Config &config, ModelType model_type) {
+ if (config.write_mmap || !config.messages) return;
+ if (config.arpa_complain == Config::ALL) {
+ *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
+ } else if (config.arpa_complain == Config::EXPENSIVE &&
+ (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
+ *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+ }
}
-namespace {
void CheckCounts(const std::vector<uint64_t> &counts) {
UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
if (sizeof(uint64_t) > sizeof(std::size_t)) {
@@ -59,18 +53,45 @@ void CheckCounts(const std::vector<uint64_t> &counts) {
}
}
}
+
} // namespace
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
- CheckCounts(params.counts);
- SetupMemory(start, params.counts, config);
- vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
- search_.LoadedBinary();
+template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) {
+ util::scoped_fd fd(util::OpenReadOrThrow(file));
+ if (IsBinaryFormat(fd.get())) {
+ Parameters parameters;
+ int fd_shallow = fd.release();
+ backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters);
+ CheckCounts(parameters.counts);
+
+ Config new_config(init_config);
+ new_config.probing_multiplier = parameters.fixed.probing_multiplier;
+ Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config);
+ UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+
+ SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config);
+ vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset());
+ } else {
+ ComplainAboutARPA(init_config, kModelType);
+ InitializeFromARPA(fd.release(), file, init_config);
+ }
+
+ // g++ prints warnings unless these are fully initialized.
+ State begin_sentence = State();
+ begin_sentence.length = 1;
+ begin_sentence.words[0] = vocab_.BeginSentence();
+ typename Search::Node ignored_node;
+ bool ignored_independent_left;
+ uint64_t ignored_extend_left;
+ begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
+ State null_context = State();
+ null_context.length = 0;
+ P::Init(begin_sentence, null_context, vocab_, search_.Order());
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
- // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
- util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) {
+ // Backing file is the ARPA.
+ util::FilePiece f(fd, file, config.ProgressMessages());
try {
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
@@ -81,13 +102,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
- vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
+ vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config);
- if (config.write_mmap) {
+ if (config.write_mmap && config.include_vocab) {
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));
+ void *vocab_rebase, *search_rebase;
+ backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase);
+ // Due to writing at the end of file, mmap may have relocated data. So remap.
+ vocab_.Relocate(vocab_rebase);
+ search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config);
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
@@ -99,18 +124,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.UnknownUnigram().backoff = 0.0;
search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
- FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
+ backing_.FinishFile(config, kModelType, kVersion, counts);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
}
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
- util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
- Search::UpdateConfigFromBinary(fd, counts, config);
-}
-
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {
diff --git a/lm/model.hh b/lm/model.hh
index c9c17c4..6925a56 100644
--- a/lm/model.hh
+++ b/lm/model.hh
@@ -1,5 +1,5 @@
-#ifndef LM_MODEL__
-#define LM_MODEL__
+#ifndef LM_MODEL_H
+#define LM_MODEL_H
#include "lm/bhiksha.hh"
#include "lm/binary_format.hh"
@@ -104,10 +104,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
}
private:
- friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
-
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
-
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
// Score bigrams and above. Do not include backoff.
@@ -116,15 +112,11 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
- void InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd);
-
- void InitializeFromARPA(const char *file, const Config &config);
+ void InitializeFromARPA(int fd, const char *file, const Config &config);
float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const;
- Backing &MutableBacking() { return backing_; }
-
- Backing backing_;
+ BinaryFormat backing_;
VocabularyT vocab_;
@@ -161,4 +153,4 @@ base::Model *LoadVirtual(const char *file_name, const Config &config = Config(),
} // namespace ngram
} // namespace lm
-#endif // LM_MODEL__
+#endif // LM_MODEL_H
diff --git a/lm/model_test.cc b/lm/model_test.cc
index eb15909..7005b05 100644
--- a/lm/model_test.cc
+++ b/lm/model_test.cc
@@ -360,10 +360,11 @@ BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
LoadingTest<QuantArrayTrieModel>();
}
-template <class ModelT> void BinaryTest() {
+template <class ModelT> void BinaryTest(Config::WriteMethod write_method) {
Config config;
config.write_mmap = "test.binary";
config.messages = NULL;
+ config.write_method = write_method;
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
@@ -406,6 +407,11 @@ template <class ModelT> void BinaryTest() {
unlink("test_nounk.binary");
}
+template <class ModelT> void BinaryTest() {
+ BinaryTest<ModelT>(Config::WRITE_MMAP);
+ BinaryTest<ModelT>(Config::WRITE_AFTER);
+}
+
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
BinaryTest<ProbingModel>();
}
diff --git a/lm/model_type.hh b/lm/model_type.hh
index 8b35c79..fbe1117 100644
--- a/lm/model_type.hh
+++ b/lm/model_type.hh
@@ -1,5 +1,5 @@
-#ifndef LM_MODEL_TYPE__
-#define LM_MODEL_TYPE__
+#ifndef LM_MODEL_TYPE_H
+#define LM_MODEL_TYPE_H
namespace lm {
namespace ngram {
@@ -20,4 +20,4 @@ const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE - TRIE);
} // namespace ngram
} // namespace lm
-#endif // LM_MODEL_TYPE__
+#endif // LM_MODEL_TYPE_H
diff --git a/lm/ngram_query.hh b/lm/ngram_query.hh
index ec2590f..454e856 100644
--- a/lm/ngram_query.hh
+++ b/lm/ngram_query.hh
@@ -1,8 +1,9 @@
-#ifndef LM_NGRAM_QUERY__
-#define LM_NGRAM_QUERY__
+#ifndef LM_NGRAM_QUERY_H
+#define LM_NGRAM_QUERY_H
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
+#include "util/file_piece.hh"
#include "util/usage.hh"
#include <cstdlib>
@@ -16,42 +17,41 @@
namespace lm {
namespace ngram {
-template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
+template <class Model> void Query(const Model &model, bool sentence_context) {
typename Model::State state, out;
lm::FullScoreReturn ret;
- std::string word;
+ StringPiece word;
+
+ util::FilePiece in(0);
+ std::ostream &out_stream = std::cout;
double corpus_total = 0.0;
+ double corpus_total_oov_only = 0.0;
uint64_t corpus_oov = 0;
uint64_t corpus_tokens = 0;
- while (in_stream) {
+ while (true) {
state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
float total = 0.0;
- bool got = false;
uint64_t oov = 0;
- while (in_stream >> word) {
- got = true;
+
+ while (in.ReadWordSameLine(word)) {
lm::WordIndex vocab = model.GetVocabulary().Index(word);
- if (vocab == 0) ++oov;
ret = model.FullScore(state, vocab, out);
+ if (vocab == model.GetVocabulary().NotFound()) {
+ ++oov;
+ corpus_total_oov_only += ret.prob;
+ }
total += ret.prob;
out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
++corpus_tokens;
state = out;
- char c;
- while (true) {
- c = in_stream.get();
- if (!in_stream) break;
- if (c == '\n') break;
- if (!isspace(c)) {
- in_stream.unget();
- break;
- }
- }
- if (c == '\n') break;
}
- if (!got && !in_stream) break;
+ // If people don't have a newline after their last query, this won't add a </s>.
+ // Sue me.
+ try {
+ UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused.");
+ } catch (const util::EndOfFileException &e) { break; }
if (sentence_context) {
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
@@ -62,18 +62,22 @@ template <class Model> void Query(const Model &model, bool sentence_context, std
corpus_total += total;
corpus_oov += oov;
}
- out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl;
+ out_stream <<
+ "Perplexity including OOVs:\t" << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << "\n"
+ "Perplexity excluding OOVs:\t" << pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast<double>(corpus_tokens - corpus_oov))) << "\n"
+ "OOVs:\t" << corpus_oov << "\n"
+ ;
}
-template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
+template <class M> void Query(const char *file, bool sentence_context) {
Config config;
M model(file, config);
- Query(model, sentence_context, in_stream, out_stream);
+ Query(model, sentence_context);
}
} // namespace ngram
} // namespace lm
-#endif // LM_NGRAM_QUERY__
+#endif // LM_NGRAM_QUERY_H
diff --git a/lm/partial.hh b/lm/partial.hh
index 1dede35..d8adc69 100644
--- a/lm/partial.hh
+++ b/lm/partial.hh
@@ -1,5 +1,5 @@
-#ifndef LM_PARTIAL__
-#define LM_PARTIAL__
+#ifndef LM_PARTIAL_H
+#define LM_PARTIAL_H
#include "lm/return.hh"
#include "lm/state.hh"
@@ -164,4 +164,4 @@ template <class Model> float Subsume(const Model &model, Left &first_left, const
} // namespace ngram
} // namespace lm
-#endif // LM_PARTIAL__
+#endif // LM_PARTIAL_H
diff --git a/lm/quantize.cc b/lm/quantize.cc
index b58c3f3..273ea39 100644
--- a/lm/quantize.cc
+++ b/lm/quantize.cc
@@ -38,13 +38,13 @@ const char kSeparatelyQuantizeVersion = 2;
} // namespace
-void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) {
- char version;
- util::ReadOrThrow(fd, &version, 1);
- util::ReadOrThrow(fd, &config.prob_bits, 1);
- util::ReadOrThrow(fd, &config.backoff_bits, 1);
+void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) {
+ unsigned char buffer[3];
+ file.ReadForConfig(buffer, 3, offset);
+ char version = buffer[0];
+ config.prob_bits = buffer[1];
+ config.backoff_bits = buffer[2];
if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion);
- util::AdvanceOrThrow(fd, -3);
}
void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) {
diff --git a/lm/quantize.hh b/lm/quantize.hh
index 8ce2378..84a3087 100644
--- a/lm/quantize.hh
+++ b/lm/quantize.hh
@@ -1,5 +1,5 @@
-#ifndef LM_QUANTIZE_H__
-#define LM_QUANTIZE_H__
+#ifndef LM_QUANTIZE_H
+#define LM_QUANTIZE_H
#include "lm/blank.hh"
#include "lm/config.hh"
@@ -18,12 +18,13 @@ namespace lm {
namespace ngram {
struct Config;
+class BinaryFormat;
/* Store values directly and don't quantize. */
class DontQuantize {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
- static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {}
static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
@@ -136,7 +137,7 @@ class SeparatelyQuantize {
public:
static const ModelType kModelTypeAdd = kQuantAdd;
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
+ static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
static uint64_t Size(uint8_t order, const Config &config) {
uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
@@ -229,4 +230,4 @@ class SeparatelyQuantize {
} // namespace ngram
} // namespace lm
-#endif // LM_QUANTIZE_H__
+#endif // LM_QUANTIZE_H
diff --git a/lm/query_main.cc b/lm/query_main.cc
index b9db7b0..cd661f7 100644
--- a/lm/query_main.cc
+++ b/lm/query_main.cc
@@ -4,48 +4,62 @@
#include "lm/wrappers/nplm.hh"
#endif
+#include <stdlib.h>
+
+void Usage(const char *name) {
+ std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl;
+ std::cerr << "Usage: " << name << " [-n] lm_file" << std::endl;
+ std::cerr << "Input is wrapped in <s> and </s> unless -n is passed." << std::endl;
+ exit(1);
+}
+
int main(int argc, char *argv[]) {
- if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) {
- std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl;
- std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl;
- std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl;
- return 1;
+ bool sentence_context = true;
+ const char *file = NULL;
+ for (char **arg = argv + 1; arg != argv + argc; ++arg) {
+ if (!strcmp(*arg, "-n")) {
+ sentence_context = false;
+ } else if (!strcmp(*arg, "-h") || !strcmp(*arg, "--help") || file) {
+ Usage(argv[0]);
+ } else {
+ file = *arg;
+ }
}
+ if (!file) Usage(argv[0]);
try {
- bool sentence_context = (argc == 2);
using namespace lm::ngram;
ModelType model_type;
- if (RecognizeBinary(argv[1], model_type)) {
+ if (RecognizeBinary(file, model_type)) {
switch(model_type) {
case PROBING:
- Query<lm::ngram::ProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<lm::ngram::ProbingModel>(file, sentence_context);
break;
case REST_PROBING:
- Query<lm::ngram::RestProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<lm::ngram::RestProbingModel>(file, sentence_context);
break;
case TRIE:
- Query<TrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<TrieModel>(file, sentence_context);
break;
case QUANT_TRIE:
- Query<QuantTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<QuantTrieModel>(file, sentence_context);
break;
case ARRAY_TRIE:
- Query<ArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<ArrayTrieModel>(file, sentence_context);
break;
case QUANT_ARRAY_TRIE:
- Query<QuantArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<QuantArrayTrieModel>(file, sentence_context);
break;
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
abort();
}
#ifdef WITH_NPLM
- } else if (lm::np::Model::Recognize(argv[1])) {
- lm::np::Model model(argv[1]);
- Query(model, sentence_context, std::cin, std::cout);
+ } else if (lm::np::Model::Recognize(file)) {
+ lm::np::Model model(file);
+ Query(model, sentence_context);
#endif
} else {
- Query<ProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ Query<ProbingModel>(file, sentence_context);
}
std::cerr << "Total time including destruction:\n";
util::PrintUsage(std::cerr);
diff --git a/lm/read_arpa.cc b/lm/read_arpa.cc
index 5ccba71..fb8bbfa 100644
--- a/lm/read_arpa.cc
+++ b/lm/read_arpa.cc
@@ -150,7 +150,7 @@ void PositiveProbWarn::Warn(float prob) {
case THROW_UP:
UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error");
case COMPLAIN:
- std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl;
+ std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapped to 0 log probability." << std::endl;
action_ = SILENT;
break;
case SILENT:
diff --git a/lm/read_arpa.hh b/lm/read_arpa.hh
index 234d130..16e1fc1 100644
--- a/lm/read_arpa.hh
+++ b/lm/read_arpa.hh
@@ -1,5 +1,5 @@
-#ifndef LM_READ_ARPA__
-#define LM_READ_ARPA__
+#ifndef LM_READ_ARPA_H
+#define LM_READ_ARPA_H
#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
@@ -87,4 +87,4 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns
} // namespace lm
-#endif // LM_READ_ARPA__
+#endif // LM_READ_ARPA_H
diff --git a/lm/return.hh b/lm/return.hh
index 622320c..982ffd6 100644
--- a/lm/return.hh
+++ b/lm/return.hh
@@ -1,5 +1,5 @@
-#ifndef LM_RETURN__
-#define LM_RETURN__
+#ifndef LM_RETURN_H
+#define LM_RETURN_H
#include <stdint.h>
@@ -39,4 +39,4 @@ struct FullScoreReturn {
};
} // namespace lm
-#endif // LM_RETURN__
+#endif // LM_RETURN_H
diff --git a/lm/search_hashed.cc b/lm/search_hashed.cc
index 62275d2..354a56b 100644
--- a/lm/search_hashed.cc
+++ b/lm/search_hashed.cc
@@ -204,9 +204,10 @@ template <class Build, class Activate, class Store> void ReadNGrams(
namespace detail {
template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
- std::size_t allocated = Unigram::Size(counts[0]);
- unigram_ = Unigram(start, counts[0], allocated);
- start += allocated;
+ unigram_ = Unigram(start, counts[0]);
+ start += Unigram::Size(counts[0]);
+ std::size_t allocated;
+ middle_.clear();
for (unsigned int n = 2; n < counts.size(); ++n) {
allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
middle_.push_back(Middle(start, allocated));
@@ -218,9 +219,21 @@ template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start,
return start;
}
-template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing) {
- // TODO: fix sorted.
- SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config);
+/*template <class Value> void HashedSearch<Value>::Relocate(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
+ unigram_ = Unigram(start, counts[0]);
+ start += Unigram::Size(counts[0]);
+ for (unsigned int n = 2; n < counts.size(); ++n) {
+ middle[n-2].Relocate(start);
+ start += Middle::Size(counts[n - 1], config.probing_multiplier)
+ }
+ longest_.Relocate(start);
+}*/
+
+template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing) {
+ void *vocab_rebase;
+ void *search_base = backing.GrowForSearch(Size(counts, config), vocab.UnkCountChangePadding(), vocab_rebase);
+ vocab.Relocate(vocab_rebase);
+ SetupMemory(reinterpret_cast<uint8_t*>(search_base), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn);
@@ -277,14 +290,6 @@ template <class Value> template <class Build> void HashedSearch<Value>::ApplyBui
ReadEnd(f);
}
-template <class Value> void HashedSearch<Value>::LoadedBinary() {
- unigram_.LoadedBinary();
- for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) {
- i->LoadedBinary();
- }
- longest_.LoadedBinary();
-}
-
template class HashedSearch<BackoffValue>;
template class HashedSearch<RestValue>;
diff --git a/lm/search_hashed.hh b/lm/search_hashed.hh
index 9d067bc..9dc8445 100644
--- a/lm/search_hashed.hh
+++ b/lm/search_hashed.hh
@@ -1,5 +1,5 @@
-#ifndef LM_SEARCH_HASHED__
-#define LM_SEARCH_HASHED__
+#ifndef LM_SEARCH_HASHED_H
+#define LM_SEARCH_HASHED_H
#include "lm/model_type.hh"
#include "lm/config.hh"
@@ -18,7 +18,7 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
-struct Backing;
+class BinaryFormat;
class ProbingVocabulary;
namespace detail {
@@ -72,7 +72,7 @@ template <class Value> class HashedSearch {
static const unsigned int kVersion = 0;
// TODO: move probing_multiplier here with next binary file format update.
- static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
+ static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
uint64_t ret = Unigram::Size(counts[0]);
@@ -84,9 +84,7 @@ template <class Value> class HashedSearch {
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing);
-
- void LoadedBinary();
+ void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing);
unsigned char Order() const {
return middle_.size() + 2;
@@ -148,7 +146,7 @@ template <class Value> class HashedSearch {
public:
Unigram() {}
- Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
+ Unigram(void *start, uint64_t count) :
unigram_(static_cast<typename Value::Weights*>(start))
#ifdef DEBUG
, count_(count)
@@ -168,8 +166,6 @@ template <class Value> class HashedSearch {
typename Value::Weights &Unknown() { return unigram_[0]; }
- void LoadedBinary() {}
-
// For building.
typename Value::Weights *Raw() { return unigram_; }
@@ -193,4 +189,4 @@ template <class Value> class HashedSearch {
} // namespace ngram
} // namespace lm
-#endif // LM_SEARCH_HASHED__
+#endif // LM_SEARCH_HASHED_H
diff --git a/lm/search_trie.cc b/lm/search_trie.cc
index 1b0d9b2..4a88194 100644
--- a/lm/search_trie.cc
+++ b/lm/search_trie.cc
@@ -253,11 +253,6 @@ class FindBlanks {
++counts_.back();
}
- // Unigrams wrote one past.
- void Cleanup() {
- --counts_[0];
- }
-
const std::vector<uint64_t> &Counts() const {
return counts_;
}
@@ -310,8 +305,6 @@ template <class Quant, class Bhiksha> class WriteEntries {
typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob);
}
- void Cleanup() {}
-
private:
RecordReader *contexts_;
const Quant &quant_;
@@ -385,14 +378,14 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con
util::ErsatzProgress progress(unigram_count + 1, progress_out, message);
WordIndex unigram = 0;
std::priority_queue<Gram> grams;
- grams.push(Gram(&unigram, 1));
+ if (unigram_count) grams.push(Gram(&unigram, 1));
for (unsigned char i = 2; i <= total_order; ++i) {
if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i));
}
BlankManager<Doing> blank(total_order, doing);
- while (true) {
+ while (!grams.empty()) {
Gram top = grams.top();
grams.pop();
unsigned char order = top.end - top.begin;
@@ -400,8 +393,7 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con
blank.Visit(&unigram, 1, doing.UnigramProb(unigram));
doing.Unigram(unigram);
progress.Set(unigram);
- if (++unigram == unigram_count + 1) break;
- grams.push(top);
+ if (++unigram < unigram_count) grams.push(top);
} else {
if (order == total_order) {
blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob);
@@ -414,8 +406,6 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con
if (++reader) grams.push(top);
}
}
- assert(grams.empty());
- doing.Cleanup();
}
void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
@@ -469,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c
} // namespace
-template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) {
RecordReader inputs[KENLM_MAX_ORDER - 1];
RecordReader contexts[KENLM_MAX_ORDER - 1];
@@ -498,7 +488,10 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
- out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
+ void *vocab_relocate;
+ void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate);
+ vocab.Relocate(vocab_relocate);
+ out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config);
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
@@ -524,6 +517,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
{
WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
+ // Write the last unigram entry, which is the end pointer for the bigrams.
+ writer.Unigram(counts[0]);
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
@@ -579,15 +574,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
-template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
- unigram_.LoadedBinary();
- for (Middle *i = middle_begin_; i != middle_end_; ++i) {
- i->LoadedBinary();
- }
- longest_.LoadedBinary();
-}
-
-template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) {
std::string temporary_prefix;
if (config.temporary_directory_prefix) {
temporary_prefix = config.temporary_directory_prefix;
diff --git a/lm/search_trie.hh b/lm/search_trie.hh
index 763fd1a..d8838d2 100644
--- a/lm/search_trie.hh
+++ b/lm/search_trie.hh
@@ -1,5 +1,5 @@
-#ifndef LM_SEARCH_TRIE__
-#define LM_SEARCH_TRIE__
+#ifndef LM_SEARCH_TRIE_H
+#define LM_SEARCH_TRIE_H
#include "lm/config.hh"
#include "lm/model_type.hh"
@@ -17,13 +17,13 @@
namespace lm {
namespace ngram {
-struct Backing;
+class BinaryFormat;
class SortedVocabulary;
namespace trie {
template <class Quant, class Bhiksha> class TrieSearch;
class SortedFiles;
-template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
+template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
template <class Quant, class Bhiksha> class TrieSearch {
public:
@@ -39,11 +39,11 @@ template <class Quant, class Bhiksha> class TrieSearch {
static const unsigned int kVersion = 1;
- static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
- Quant::UpdateConfigFromBinary(fd, counts, config);
- util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
+ static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) {
+ Quant::UpdateConfigFromBinary(file, offset, config);
// Currently the unigram pointers are not compresssed, so there will only be a header for order > 2.
- if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config);
+ if (counts.size() > 2)
+ Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config);
}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
@@ -60,9 +60,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- void LoadedBinary();
-
- void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
+ void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing);
unsigned char Order() const {
return middle_end_ - middle_begin_ + 2;
@@ -103,7 +101,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
}
private:
- friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
+ friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
@@ -129,4 +127,4 @@ template <class Quant, class Bhiksha> class TrieSearch {
} // namespace ngram
} // namespace lm
-#endif // LM_SEARCH_TRIE__
+#endif // LM_SEARCH_TRIE_H
diff --git a/lm/sizes.hh b/lm/sizes.hh
index 85abade..eb7e99d 100644
--- a/lm/sizes.hh
+++ b/lm/sizes.hh
@@ -1,5 +1,5 @@
-#ifndef LM_SIZES__
-#define LM_SIZES__
+#ifndef LM_SIZES_H
+#define LM_SIZES_H
#include <vector>
@@ -14,4 +14,4 @@ void ShowSizes(const std::vector<uint64_t> &counts);
void ShowSizes(const char *file, const lm::ngram::Config &config);
}} // namespaces
-#endif // LM_SIZES__
+#endif // LM_SIZES_H
diff --git a/lm/state.hh b/lm/state.hh
index d8e6c13..f6c51d6 100644
--- a/lm/state.hh
+++ b/lm/state.hh
@@ -1,5 +1,5 @@
-#ifndef LM_STATE__
-#define LM_STATE__
+#ifndef LM_STATE_H
+#define LM_STATE_H
#include "lm/max_order.hh"
#include "lm/word_index.hh"
@@ -91,7 +91,7 @@ inline uint64_t hash_value(const Left &left) {
}
struct ChartState {
- bool operator==(const ChartState &other) {
+ bool operator==(const ChartState &other) const {
return (right == other.right) && (left == other.left);
}
@@ -102,7 +102,7 @@ struct ChartState {
}
bool operator<(const ChartState &other) const {
- return Compare(other) == -1;
+ return Compare(other) < 0;
}
void ZeroRemaining() {
@@ -122,4 +122,4 @@ inline uint64_t hash_value(const ChartState &state) {
} // namespace ngram
} // namespace lm
-#endif // LM_STATE__
+#endif // LM_STATE_H
diff --git a/lm/trie.hh b/lm/trie.hh
index 9ea3c54..cd39298 100644
--- a/lm/trie.hh
+++ b/lm/trie.hh
@@ -1,5 +1,5 @@
-#ifndef LM_TRIE__
-#define LM_TRIE__
+#ifndef LM_TRIE_H
+#define LM_TRIE_H
#include "lm/weights.hh"
#include "lm/word_index.hh"
@@ -62,8 +62,6 @@ class Unigram {
return unigram_;
}
- void LoadedBinary() {}
-
UnigramPointer Find(WordIndex word, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
next.begin = val->next;
@@ -108,8 +106,6 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked {
void FinishedLoading(uint64_t next_end, const Config &config);
- void LoadedBinary() { bhiksha_.LoadedBinary(); }
-
util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
@@ -138,18 +134,13 @@ class BitPackedLongest : public BitPacked {
BaseInit(base, max_vocab, quant_bits);
}
- void LoadedBinary() {}
-
util::BitAddress Insert(WordIndex word);
util::BitAddress Find(WordIndex word, const NodeRange &node) const;
-
- private:
- uint8_t quant_bits_;
};
} // namespace trie
} // namespace ngram
} // namespace lm
-#endif // LM_TRIE__
+#endif // LM_TRIE_H
diff --git a/lm/trie_sort.cc b/lm/trie_sort.cc
index dc542bb..126d43a 100644
--- a/lm/trie_sort.cc
+++ b/lm/trie_sort.cc
@@ -50,6 +50,10 @@ class PartialViewProxy {
const void *Data() const { return inner_.Data(); }
void *Data() { return inner_.Data(); }
+ friend void swap(PartialViewProxy first, PartialViewProxy second) {
+ std::swap_ranges(reinterpret_cast<char*>(first.Data()), reinterpret_cast<char*>(first.Data()) + first.attention_size_, reinterpret_cast<char*>(second.Data()));
+ }
+
private:
friend class util::ProxyIterator<PartialViewProxy>;
diff --git a/lm/trie_sort.hh b/lm/trie_sort.hh
index 1afd956..e5406d9 100644
--- a/lm/trie_sort.hh
+++ b/lm/trie_sort.hh
@@ -1,7 +1,7 @@
// Step of trie builder: create sorted files.
-#ifndef LM_TRIE_SORT__
-#define LM_TRIE_SORT__
+#ifndef LM_TRIE_SORT_H
+#define LM_TRIE_SORT_H
#include "lm/max_order.hh"
#include "lm/word_index.hh"
@@ -111,4 +111,4 @@ class SortedFiles {
} // namespace ngram
} // namespace lm
-#endif // LM_TRIE_SORT__
+#endif // LM_TRIE_SORT_H
diff --git a/lm/value.hh b/lm/value.hh
index ba71671..36e8708 100644
--- a/lm/value.hh
+++ b/lm/value.hh
@@ -1,5 +1,5 @@
-#ifndef LM_VALUE__
-#define LM_VALUE__
+#ifndef LM_VALUE_H
+#define LM_VALUE_H
#include "lm/model_type.hh"
#include "lm/value_build.hh"
@@ -154,4 +154,4 @@ struct RestValue {
} // namespace ngram
} // namespace lm
-#endif // LM_VALUE__
+#endif // LM_VALUE_H
diff --git a/lm/value_build.hh b/lm/value_build.hh
index 461e6a5..6fd26ef 100644
--- a/lm/value_build.hh
+++ b/lm/value_build.hh
@@ -1,5 +1,5 @@
-#ifndef LM_VALUE_BUILD__
-#define LM_VALUE_BUILD__
+#ifndef LM_VALUE_BUILD_H
+#define LM_VALUE_BUILD_H
#include "lm/weights.hh"
#include "lm/word_index.hh"
@@ -94,4 +94,4 @@ template <class Model> class LowerRestBuild {
} // namespace ngram
} // namespace lm
-#endif // LM_VALUE_BUILD__
+#endif // LM_VALUE_BUILD_H
diff --git a/lm/virtual_interface.hh b/lm/virtual_interface.hh
index ff4a388..2a2690e 100644
--- a/lm/virtual_interface.hh
+++ b/lm/virtual_interface.hh
@@ -1,5 +1,5 @@
-#ifndef LM_VIRTUAL_INTERFACE__
-#define LM_VIRTUAL_INTERFACE__
+#ifndef LM_VIRTUAL_INTERFACE_H
+#define LM_VIRTUAL_INTERFACE_H
#include "lm/return.hh"
#include "lm/word_index.hh"
@@ -125,13 +125,13 @@ class Model {
void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }
// Requires in_state != out_state
- virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+ virtual float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
// Requires in_state != out_state
- virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+ virtual FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
// Prefer to use FullScore. The context words should be provided in reverse order.
- virtual FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0;
+ virtual FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0;
unsigned char Order() const { return order_; }
@@ -157,4 +157,4 @@ class Model {
} // mamespace base
} // namespace lm
-#endif // LM_VIRTUAL_INTERFACE__
+#endif // LM_VIRTUAL_INTERFACE_H
diff --git a/lm/vocab.cc b/lm/vocab.cc
index fd7f96d..7f0878f 100644
--- a/lm/vocab.cc
+++ b/lm/vocab.cc
@@ -32,7 +32,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
-void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) {
+void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) {
+ util::SeekOrThrow(fd, offset);
// Check that we're at the right place by reading <unk> which is always first.
char check_unk[6];
util::ReadOrThrow(fd, check_unk, 6);
@@ -80,11 +81,6 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
-void WriteWordsWrapper::Write(int fd, uint64_t start) {
- util::SeekOrThrow(fd, start);
- util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
-}
-
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
@@ -100,6 +96,12 @@ void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size
saw_unk_ = false;
}
+void SortedVocabulary::Relocate(void *new_start) {
+ std::size_t delta = end_ - begin_;
+ begin_ = reinterpret_cast<uint64_t*>(new_start) + 1;
+ end_ = begin_ + delta;
+}
+
void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
enumerate_ = to;
if (enumerate_) {
@@ -147,11 +149,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
bound_ = end_ - begin_ + 1;
}
-void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
+void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
- if (have_words) ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_, offset);
}
namespace {
@@ -179,6 +181,11 @@ void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::siz
saw_unk_ = false;
}
+void ProbingVocabulary::Relocate(void *new_start) {
+ header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start);
+ lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)));
+}
+
void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {
enumerate_ = to;
if (enumerate_) {
@@ -206,12 +213,11 @@ void ProbingVocabulary::InternalFinishedLoading() {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
-void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
+void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
- lookup_.LoadedBinary();
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
- if (have_words) ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_, offset);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
diff --git a/lm/vocab.hh b/lm/vocab.hh
index 226ae43..dcd298a 100644
--- a/lm/vocab.hh
+++ b/lm/vocab.hh
@@ -1,5 +1,5 @@
-#ifndef LM_VOCAB__
-#define LM_VOCAB__
+#ifndef LM_VOCAB_H
+#define LM_VOCAB_H
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
@@ -36,7 +36,7 @@ class WriteWordsWrapper : public EnumerateVocab {
void Add(WordIndex index, const StringPiece &str);
- void Write(int fd, uint64_t start);
+ const std::string &Buffer() const { return buffer_; }
private:
EnumerateVocab *inner_;
@@ -71,6 +71,8 @@ class SortedVocabulary : public base::Vocabulary {
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+ void Relocate(void *new_start);
+
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
@@ -83,15 +85,13 @@ class SortedVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
uint64_t *begin_, *end_;
WordIndex bound_;
- WordIndex highest_value_;
-
bool saw_unk_;
EnumerateVocab *enumerate_;
@@ -140,6 +140,8 @@ class ProbingVocabulary : public base::Vocabulary {
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+ void Relocate(void *new_start);
+
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
@@ -152,7 +154,7 @@ class ProbingVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
- void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
+ void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
void InternalFinishedLoading();
@@ -182,4 +184,4 @@ template <class Vocab> void CheckSpecials(const Config &config, const Vocab &voc
} // namespace ngram
} // namespace lm
-#endif // LM_VOCAB__
+#endif // LM_VOCAB_H
diff --git a/lm/weights.hh b/lm/weights.hh
index bd5d803..da1963d 100644
--- a/lm/weights.hh
+++ b/lm/weights.hh
@@ -1,5 +1,5 @@
-#ifndef LM_WEIGHTS__
-#define LM_WEIGHTS__
+#ifndef LM_WEIGHTS_H
+#define LM_WEIGHTS_H
// Weights for n-grams. Probability and possibly a backoff.
@@ -19,4 +19,4 @@ struct RestWeights {
};
} // namespace lm
-#endif // LM_WEIGHTS__
+#endif // LM_WEIGHTS_H
diff --git a/lm/word_index.hh b/lm/word_index.hh
index e09557a..a5a0fda 100644
--- a/lm/word_index.hh
+++ b/lm/word_index.hh
@@ -1,6 +1,6 @@
// Separate header because this is used often.
-#ifndef LM_WORD_INDEX__
-#define LM_WORD_INDEX__
+#ifndef LM_WORD_INDEX_H
+#define LM_WORD_INDEX_H
#include <limits.h>
diff --git a/lm/wrappers/nplm.cc b/lm/wrappers/nplm.cc
index 6a3fa0d..70622bd 100644
--- a/lm/wrappers/nplm.cc
+++ b/lm/wrappers/nplm.cc
@@ -13,7 +13,7 @@ namespace np {
Vocabulary::Vocabulary(const nplm::vocabulary &vocab)
: base::Vocabulary(vocab.lookup_word("<s>"), vocab.lookup_word("</s>"), vocab.lookup_word("<unk>")),
- vocab_(vocab) {}
+ vocab_(vocab), null_word_(vocab.lookup_word("<null>")) {}
Vocabulary::~Vocabulary() {}
@@ -33,8 +33,11 @@ bool Model::Recognize(const std::string &name) {
}
}
-Model::Model(const std::string &file) : base_instance_(new nplm::neuralLM(file)), vocab_(base_instance_->get_vocabulary()) {
+Model::Model(const std::string &file, std::size_t cache)
+ : base_instance_(new nplm::neuralLM(file)), vocab_(base_instance_->get_vocabulary()), cache_size_(cache) {
UTIL_THROW_IF(base_instance_->get_order() > NPLM_MAX_ORDER, util::Exception, "This NPLM has order " << (unsigned int)base_instance_->get_order() << " but the KenLM wrapper was compiled with " << NPLM_MAX_ORDER << ". Change the defintion of NPLM_MAX_ORDER and recompile.");
+ // log10 compatible with backoff models.
+ base_instance_->set_log_base(10.0);
State begin_sentence, null_context;
std::fill(begin_sentence.words, begin_sentence.words + NPLM_MAX_ORDER - 1, base_instance_->lookup_word("<s>"));
null_word_ = base_instance_->lookup_word("<null>");
@@ -50,6 +53,7 @@ FullScoreReturn Model::FullScore(const State &from, const WordIndex new_word, St
if (!lm) {
lm = new nplm::neuralLM(*base_instance_);
backend_.reset(lm);
+ lm->set_cache(cache_size_);
}
// State is in natural word order.
FullScoreReturn ret;
diff --git a/lm/wrappers/nplm.hh b/lm/wrappers/nplm.hh
index 90f1d49..b7dd4a2 100644
--- a/lm/wrappers/nplm.hh
+++ b/lm/wrappers/nplm.hh
@@ -1,5 +1,5 @@
-#ifndef LM_WRAPPER_NPLM__
-#define LM_WRAPPER_NPLM__
+#ifndef LM_WRAPPERS_NPLM_H
+#define LM_WRAPPERS_NPLM_H
#include "lm/facade.hh"
#include "lm/max_order.hh"
@@ -34,8 +34,12 @@ class Vocabulary : public base::Vocabulary {
return Index(std::string(str.data(), str.size()));
}
+ lm::WordIndex NullWord() const { return null_word_; }
+
private:
const nplm::vocabulary &vocab_;
+
+ const lm::WordIndex null_word_;
};
// Sorry for imposing my limitations on your code.
@@ -53,7 +57,7 @@ class Model : public lm::base::ModelFacade<Model, State, Vocabulary> {
// Does this look like an NPLM?
static bool Recognize(const std::string &file);
- explicit Model(const std::string &file);
+ explicit Model(const std::string &file, std::size_t cache_size = 1 << 20);
~Model();
@@ -69,9 +73,11 @@ class Model : public lm::base::ModelFacade<Model, State, Vocabulary> {
Vocabulary vocab_;
lm::WordIndex null_word_;
+
+ const std::size_t cache_size_;
};
} // namespace np
} // namespace lm
-#endif // LM_WRAPPER_NPLM__
+#endif // LM_WRAPPERS_NPLM_H
diff --git a/python/kenlm.cpp b/python/kenlm.cpp
index d401047..c7052fc 100644
--- a/python/kenlm.cpp
+++ b/python/kenlm.cpp
@@ -1,4 +1,4 @@
-/* Generated by Cython 0.19.1 on Wed Jun 5 13:00:19 2013 */
+/* Generated by Cython 0.19.1 on Tue Jan 28 09:32:20 2014 */
#define PY_SSIZE_T_CLEAN
#ifndef CYTHON_USE_PYLONG_INTERNALS
@@ -493,23 +493,8 @@ static const char *__pyx_f[] = {
};
/*--- Type declarations ---*/
-struct __pyx_obj_5kenlm_LanguageModel;
struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores;
-
-/* "kenlm.pyx":10
- * raise TypeError('Cannot convert %s to string' % type(data))
- *
- * cdef class LanguageModel: # <<<<<<<<<<<<<<
- * cdef Model* model
- * cdef public bytes path
- */
-struct __pyx_obj_5kenlm_LanguageModel {
- PyObject_HEAD
- lm::base::Model *model;
- PyObject *path;
- const lm::base::Vocabulary *vocab;
-};
-
+struct __pyx_obj_5kenlm_LanguageModel;
/* "kenlm.pyx":44
* return total
@@ -532,6 +517,21 @@ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores {
Py_ssize_t __pyx_t_1;
};
+
+/* "kenlm.pyx":10
+ * raise TypeError('Cannot convert %s to string' % type(data))
+ *
+ * cdef class LanguageModel: # <<<<<<<<<<<<<<
+ * cdef Model* model
+ * cdef public bytes path
+ */
+struct __pyx_obj_5kenlm_LanguageModel {
+ PyObject_HEAD
+ lm::base::Model *model;
+ PyObject *path;
+ const lm::base::Vocabulary *vocab;
+};
+
#ifndef CYTHON_REFNANNY
#define CYTHON_REFNANNY 0
#endif
@@ -771,8 +771,8 @@ static int __Pyx_InitStrings(__Pyx_StringTabEntry *t); /*proto*/
/* Module declarations from 'kenlm' */
-static PyTypeObject *__pyx_ptype_5kenlm_LanguageModel = 0;
static PyTypeObject *__pyx_ptype_5kenlm___pyx_scope_struct__full_scores = 0;
+static PyTypeObject *__pyx_ptype_5kenlm_LanguageModel = 0;
static PyObject *__pyx_f_5kenlm_as_str(PyObject *); /*proto*/
#define __Pyx_MODULE_NAME "kenlm"
int __pyx_module_is_main_kenlm = 0;
@@ -792,8 +792,8 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_13__reduce__(struct __pyx_obj_5
static PyObject *__pyx_pf_5kenlm_13LanguageModel_4path___get__(struct __pyx_obj_5kenlm_LanguageModel *__pyx_v_self); /* proto */
static int __pyx_pf_5kenlm_13LanguageModel_4path_2__set__(struct __pyx_obj_5kenlm_LanguageModel *__pyx_v_self, PyObject *__pyx_v_value); /* proto */
static int __pyx_pf_5kenlm_13LanguageModel_4path_4__del__(struct __pyx_obj_5kenlm_LanguageModel *__pyx_v_self); /* proto */
-static PyObject *__pyx_tp_new_5kenlm_LanguageModel(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
static PyObject *__pyx_tp_new_5kenlm___pyx_scope_struct__full_scores(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
+static PyObject *__pyx_tp_new_5kenlm_LanguageModel(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/
static char __pyx_k_2[] = "Cannot convert %s to string";
static char __pyx_k_3[] = "\n";
static char __pyx_k_4[] = " ";
@@ -1392,7 +1392,7 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_
* cdef State out_state
* cdef float total = 0 # <<<<<<<<<<<<<<
* for word in words:
- * total += self.model.Score(&state, self.vocab.Index(word), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
*/
__pyx_v_total = 0.0;
@@ -1400,7 +1400,7 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_
* cdef State out_state
* cdef float total = 0
* for word in words: # <<<<<<<<<<<<<<
- * total += self.model.Score(&state, self.vocab.Index(word), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
* state = out_state
*/
if (unlikely(((PyObject *)__pyx_v_words) == Py_None)) {
@@ -1422,18 +1422,18 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_
/* "kenlm.pyx":39
* cdef float total = 0
* for word in words:
- * total += self.model.Score(&state, self.vocab.Index(word), &out_state) # <<<<<<<<<<<<<<
+ * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state) # <<<<<<<<<<<<<<
* state = out_state
- * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state)
*/
__pyx_t_4 = __Pyx_PyObject_AsString(__pyx_v_word); if (unlikely((!__pyx_t_4) && PyErr_Occurred())) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 39; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
- __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->Score((&__pyx_v_state), __pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_v_out_state)));
+ __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->BaseScore((&__pyx_v_state), __pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_v_out_state)));
/* "kenlm.pyx":40
* for word in words:
- * total += self.model.Score(&state, self.vocab.Index(word), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
* state = out_state # <<<<<<<<<<<<<<
- * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state)
* return total
*/
__pyx_v_state = __pyx_v_out_state;
@@ -1441,17 +1441,17 @@ static PyObject *__pyx_pf_5kenlm_13LanguageModel_4score(struct __pyx_obj_5kenlm_
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
/* "kenlm.pyx":41
- * total += self.model.Score(&state, self.vocab.Index(word), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
* state = out_state
- * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state) # <<<<<<<<<<<<<<
+ * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state) # <<<<<<<<<<<<<<
* return total
*
*/
- __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->Score((&__pyx_v_state), __pyx_v_self->vocab->EndSentence(), (&__pyx_v_out_state)));
+ __pyx_v_total = (__pyx_v_total + __pyx_v_self->model->BaseScore((&__pyx_v_state), __pyx_v_self->vocab->EndSentence(), (&__pyx_v_out_state)));
/* "kenlm.pyx":42
* state = out_state
- * total += self.model.Score(&state, self.vocab.EndSentence(), &out_state)
+ * total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state)
* return total # <<<<<<<<<<<<<<
*
* def full_scores(self, sentence):
@@ -1597,7 +1597,7 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec
* cdef FullScoreReturn ret
* cdef float total = 0 # <<<<<<<<<<<<<<
* for word in words:
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
*/
__pyx_cur_scope->__pyx_v_total = 0.0;
@@ -1605,7 +1605,7 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec
* cdef FullScoreReturn ret
* cdef float total = 0
* for word in words: # <<<<<<<<<<<<<<
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.Index(word), &out_state)
*/
if (unlikely(((PyObject *)__pyx_cur_scope->__pyx_v_words) == Py_None)) {
@@ -1628,20 +1628,20 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec
/* "kenlm.pyx":53
* for word in words:
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.Index(word), &out_state) # <<<<<<<<<<<<<<
* yield (ret.prob, ret.ngram_length)
* state = out_state
*/
__pyx_t_4 = __Pyx_PyObject_AsString(__pyx_cur_scope->__pyx_v_word); if (unlikely((!__pyx_t_4) && PyErr_Occurred())) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 53; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
- __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->FullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_cur_scope->__pyx_v_out_state));
+ __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->BaseFullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->Index(__pyx_t_4), (&__pyx_cur_scope->__pyx_v_out_state));
/* "kenlm.pyx":54
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.Index(word), &out_state)
* yield (ret.prob, ret.ngram_length) # <<<<<<<<<<<<<<
* state = out_state
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
*/
__pyx_t_2 = PyFloat_FromDouble(__pyx_cur_scope->__pyx_v_ret.prob); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 54; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_2);
@@ -1676,7 +1676,7 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec
* self.vocab.Index(word), &out_state)
* yield (ret.prob, ret.ngram_length)
* state = out_state # <<<<<<<<<<<<<<
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.EndSentence(), &out_state)
*/
__pyx_cur_scope->__pyx_v_state = __pyx_cur_scope->__pyx_v_out_state;
@@ -1685,15 +1685,15 @@ static PyObject *__pyx_gb_5kenlm_13LanguageModel_8generator(__pyx_GeneratorObjec
/* "kenlm.pyx":57
* state = out_state
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.EndSentence(), &out_state) # <<<<<<<<<<<<<<
* yield (ret.prob, ret.ngram_length)
*
*/
- __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->FullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->EndSentence(), (&__pyx_cur_scope->__pyx_v_out_state));
+ __pyx_cur_scope->__pyx_v_ret = __pyx_cur_scope->__pyx_v_self->model->BaseFullScore((&__pyx_cur_scope->__pyx_v_state), __pyx_cur_scope->__pyx_v_self->vocab->EndSentence(), (&__pyx_cur_scope->__pyx_v_out_state));
/* "kenlm.pyx":58
- * ret = self.model.FullScore(&state,
+ * ret = self.model.BaseFullScore(&state,
* self.vocab.EndSentence(), &out_state)
* yield (ret.prob, ret.ngram_length) # <<<<<<<<<<<<<<
*
@@ -2047,6 +2047,147 @@ static int __pyx_pf_5kenlm_13LanguageModel_4path_4__del__(struct __pyx_obj_5kenl
return __pyx_r;
}
+static struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[8];
+static int __pyx_freecount_5kenlm___pyx_scope_struct__full_scores = 0;
+
+static PyObject *__pyx_tp_new_5kenlm___pyx_scope_struct__full_scores(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) {
+ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p;
+ PyObject *o;
+ if (likely((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores > 0) & (t->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores)))) {
+ o = (PyObject*)__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[--__pyx_freecount_5kenlm___pyx_scope_struct__full_scores];
+ memset(o, 0, sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores));
+ PyObject_INIT(o, t);
+ PyObject_GC_Track(o);
+ } else {
+ o = (*t->tp_alloc)(t, 0);
+ if (unlikely(!o)) return 0;
+ }
+ p = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o);
+ p->__pyx_v_self = 0;
+ p->__pyx_v_sentence = 0;
+ p->__pyx_v_word = 0;
+ p->__pyx_v_words = 0;
+ p->__pyx_t_0 = 0;
+ return o;
+}
+
+static void __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores(PyObject *o) {
+ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
+ PyObject_GC_UnTrack(o);
+ Py_CLEAR(p->__pyx_v_self);
+ Py_CLEAR(p->__pyx_v_sentence);
+ Py_CLEAR(p->__pyx_v_word);
+ Py_CLEAR(p->__pyx_v_words);
+ Py_CLEAR(p->__pyx_t_0);
+ if ((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores < 8) & (Py_TYPE(o)->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores))) {
+ __pyx_freelist_5kenlm___pyx_scope_struct__full_scores[__pyx_freecount_5kenlm___pyx_scope_struct__full_scores++] = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o);
+ } else {
+ (*Py_TYPE(o)->tp_free)(o);
+ }
+}
+
+static int __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores(PyObject *o, visitproc v, void *a) {
+ int e;
+ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
+ if (p->__pyx_v_self) {
+ e = (*v)(((PyObject*)p->__pyx_v_self), a); if (e) return e;
+ }
+ if (p->__pyx_v_sentence) {
+ e = (*v)(p->__pyx_v_sentence, a); if (e) return e;
+ }
+ if (p->__pyx_v_word) {
+ e = (*v)(p->__pyx_v_word, a); if (e) return e;
+ }
+ if (p->__pyx_v_words) {
+ e = (*v)(p->__pyx_v_words, a); if (e) return e;
+ }
+ if (p->__pyx_t_0) {
+ e = (*v)(p->__pyx_t_0, a); if (e) return e;
+ }
+ return 0;
+}
+
+static int __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores(PyObject *o) {
+ struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
+ PyObject* tmp;
+ tmp = ((PyObject*)p->__pyx_v_self);
+ p->__pyx_v_self = ((struct __pyx_obj_5kenlm_LanguageModel *)Py_None); Py_INCREF(Py_None);
+ Py_XDECREF(tmp);
+ tmp = ((PyObject*)p->__pyx_v_sentence);
+ p->__pyx_v_sentence = Py_None; Py_INCREF(Py_None);
+ Py_XDECREF(tmp);
+ tmp = ((PyObject*)p->__pyx_v_word);
+ p->__pyx_v_word = Py_None; Py_INCREF(Py_None);
+ Py_XDECREF(tmp);
+ tmp = ((PyObject*)p->__pyx_v_words);
+ p->__pyx_v_words = ((PyObject*)Py_None); Py_INCREF(Py_None);
+ Py_XDECREF(tmp);
+ tmp = ((PyObject*)p->__pyx_t_0);
+ p->__pyx_t_0 = Py_None; Py_INCREF(Py_None);
+ Py_XDECREF(tmp);
+ return 0;
+}
+
+static PyMethodDef __pyx_methods_5kenlm___pyx_scope_struct__full_scores[] = {
+ {0, 0, 0, 0}
+};
+
+static PyTypeObject __pyx_type_5kenlm___pyx_scope_struct__full_scores = {
+ PyVarObject_HEAD_INIT(0, 0)
+ __Pyx_NAMESTR("kenlm.__pyx_scope_struct__full_scores"), /*tp_name*/
+ sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores), /*tp_basicsize*/
+ 0, /*tp_itemsize*/
+ __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores, /*tp_dealloc*/
+ 0, /*tp_print*/
+ 0, /*tp_getattr*/
+ 0, /*tp_setattr*/
+ #if PY_MAJOR_VERSION < 3
+ 0, /*tp_compare*/
+ #else
+ 0, /*reserved*/
+ #endif
+ 0, /*tp_repr*/
+ 0, /*tp_as_number*/
+ 0, /*tp_as_sequence*/
+ 0, /*tp_as_mapping*/
+ 0, /*tp_hash*/
+ 0, /*tp_call*/
+ 0, /*tp_str*/
+ 0, /*tp_getattro*/
+ 0, /*tp_setattro*/
+ 0, /*tp_as_buffer*/
+ Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_HAVE_GC, /*tp_flags*/
+ 0, /*tp_doc*/
+ __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores, /*tp_traverse*/
+ __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores, /*tp_clear*/
+ 0, /*tp_richcompare*/
+ 0, /*tp_weaklistoffset*/
+ 0, /*tp_iter*/
+ 0, /*tp_iternext*/
+ __pyx_methods_5kenlm___pyx_scope_struct__full_scores, /*tp_methods*/
+ 0, /*tp_members*/
+ 0, /*tp_getset*/
+ 0, /*tp_base*/
+ 0, /*tp_dict*/
+ 0, /*tp_descr_get*/
+ 0, /*tp_descr_set*/
+ 0, /*tp_dictoffset*/
+ 0, /*tp_init*/
+ 0, /*tp_alloc*/
+ __pyx_tp_new_5kenlm___pyx_scope_struct__full_scores, /*tp_new*/
+ 0, /*tp_free*/
+ 0, /*tp_is_gc*/
+ 0, /*tp_bases*/
+ 0, /*tp_mro*/
+ 0, /*tp_cache*/
+ 0, /*tp_subclasses*/
+ 0, /*tp_weaklist*/
+ 0, /*tp_del*/
+ #if PY_VERSION_HEX >= 0x02060000
+ 0, /*tp_version_tag*/
+ #endif
+};
+
static PyObject *__pyx_tp_new_5kenlm_LanguageModel(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) {
struct __pyx_obj_5kenlm_LanguageModel *p;
PyObject *o;
@@ -2190,147 +2331,6 @@ static PyTypeObject __pyx_type_5kenlm_LanguageModel = {
#endif
};
-static struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[8];
-static int __pyx_freecount_5kenlm___pyx_scope_struct__full_scores = 0;
-
-static PyObject *__pyx_tp_new_5kenlm___pyx_scope_struct__full_scores(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) {
- struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p;
- PyObject *o;
- if (likely((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores > 0) & (t->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores)))) {
- o = (PyObject*)__pyx_freelist_5kenlm___pyx_scope_struct__full_scores[--__pyx_freecount_5kenlm___pyx_scope_struct__full_scores];
- memset(o, 0, sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores));
- PyObject_INIT(o, t);
- PyObject_GC_Track(o);
- } else {
- o = (*t->tp_alloc)(t, 0);
- if (unlikely(!o)) return 0;
- }
- p = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o);
- p->__pyx_v_self = 0;
- p->__pyx_v_sentence = 0;
- p->__pyx_v_word = 0;
- p->__pyx_v_words = 0;
- p->__pyx_t_0 = 0;
- return o;
-}
-
-static void __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores(PyObject *o) {
- struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
- PyObject_GC_UnTrack(o);
- Py_CLEAR(p->__pyx_v_self);
- Py_CLEAR(p->__pyx_v_sentence);
- Py_CLEAR(p->__pyx_v_word);
- Py_CLEAR(p->__pyx_v_words);
- Py_CLEAR(p->__pyx_t_0);
- if ((__pyx_freecount_5kenlm___pyx_scope_struct__full_scores < 8) & (Py_TYPE(o)->tp_basicsize == sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores))) {
- __pyx_freelist_5kenlm___pyx_scope_struct__full_scores[__pyx_freecount_5kenlm___pyx_scope_struct__full_scores++] = ((struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o);
- } else {
- (*Py_TYPE(o)->tp_free)(o);
- }
-}
-
-static int __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores(PyObject *o, visitproc v, void *a) {
- int e;
- struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
- if (p->__pyx_v_self) {
- e = (*v)(((PyObject*)p->__pyx_v_self), a); if (e) return e;
- }
- if (p->__pyx_v_sentence) {
- e = (*v)(p->__pyx_v_sentence, a); if (e) return e;
- }
- if (p->__pyx_v_word) {
- e = (*v)(p->__pyx_v_word, a); if (e) return e;
- }
- if (p->__pyx_v_words) {
- e = (*v)(p->__pyx_v_words, a); if (e) return e;
- }
- if (p->__pyx_t_0) {
- e = (*v)(p->__pyx_t_0, a); if (e) return e;
- }
- return 0;
-}
-
-static int __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores(PyObject *o) {
- struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *p = (struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores *)o;
- PyObject* tmp;
- tmp = ((PyObject*)p->__pyx_v_self);
- p->__pyx_v_self = ((struct __pyx_obj_5kenlm_LanguageModel *)Py_None); Py_INCREF(Py_None);
- Py_XDECREF(tmp);
- tmp = ((PyObject*)p->__pyx_v_sentence);
- p->__pyx_v_sentence = Py_None; Py_INCREF(Py_None);
- Py_XDECREF(tmp);
- tmp = ((PyObject*)p->__pyx_v_word);
- p->__pyx_v_word = Py_None; Py_INCREF(Py_None);
- Py_XDECREF(tmp);
- tmp = ((PyObject*)p->__pyx_v_words);
- p->__pyx_v_words = ((PyObject*)Py_None); Py_INCREF(Py_None);
- Py_XDECREF(tmp);
- tmp = ((PyObject*)p->__pyx_t_0);
- p->__pyx_t_0 = Py_None; Py_INCREF(Py_None);
- Py_XDECREF(tmp);
- return 0;
-}
-
-static PyMethodDef __pyx_methods_5kenlm___pyx_scope_struct__full_scores[] = {
- {0, 0, 0, 0}
-};
-
-static PyTypeObject __pyx_type_5kenlm___pyx_scope_struct__full_scores = {
- PyVarObject_HEAD_INIT(0, 0)
- __Pyx_NAMESTR("kenlm.__pyx_scope_struct__full_scores"), /*tp_name*/
- sizeof(struct __pyx_obj_5kenlm___pyx_scope_struct__full_scores), /*tp_basicsize*/
- 0, /*tp_itemsize*/
- __pyx_tp_dealloc_5kenlm___pyx_scope_struct__full_scores, /*tp_dealloc*/
- 0, /*tp_print*/
- 0, /*tp_getattr*/
- 0, /*tp_setattr*/
- #if PY_MAJOR_VERSION < 3
- 0, /*tp_compare*/
- #else
- 0, /*reserved*/
- #endif
- 0, /*tp_repr*/
- 0, /*tp_as_number*/
- 0, /*tp_as_sequence*/
- 0, /*tp_as_mapping*/
- 0, /*tp_hash*/
- 0, /*tp_call*/
- 0, /*tp_str*/
- 0, /*tp_getattro*/
- 0, /*tp_setattro*/
- 0, /*tp_as_buffer*/
- Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_HAVE_GC, /*tp_flags*/
- 0, /*tp_doc*/
- __pyx_tp_traverse_5kenlm___pyx_scope_struct__full_scores, /*tp_traverse*/
- __pyx_tp_clear_5kenlm___pyx_scope_struct__full_scores, /*tp_clear*/
- 0, /*tp_richcompare*/
- 0, /*tp_weaklistoffset*/
- 0, /*tp_iter*/
- 0, /*tp_iternext*/
- __pyx_methods_5kenlm___pyx_scope_struct__full_scores, /*tp_methods*/
- 0, /*tp_members*/
- 0, /*tp_getset*/
- 0, /*tp_base*/
- 0, /*tp_dict*/
- 0, /*tp_descr_get*/
- 0, /*tp_descr_set*/
- 0, /*tp_dictoffset*/
- 0, /*tp_init*/
- 0, /*tp_alloc*/
- __pyx_tp_new_5kenlm___pyx_scope_struct__full_scores, /*tp_new*/
- 0, /*tp_free*/
- 0, /*tp_is_gc*/
- 0, /*tp_bases*/
- 0, /*tp_mro*/
- 0, /*tp_cache*/
- 0, /*tp_subclasses*/
- 0, /*tp_weaklist*/
- 0, /*tp_del*/
- #if PY_VERSION_HEX >= 0x02060000
- 0, /*tp_version_tag*/
- #endif
-};
-
static PyMethodDef __pyx_methods[] = {
{0, 0, 0, 0}
};
@@ -2508,11 +2508,11 @@ PyMODINIT_FUNC PyInit_kenlm(void)
/*--- Variable export code ---*/
/*--- Function export code ---*/
/*--- Type init code ---*/
+ if (PyType_Ready(&__pyx_type_5kenlm___pyx_scope_struct__full_scores) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 44; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
+ __pyx_ptype_5kenlm___pyx_scope_struct__full_scores = &__pyx_type_5kenlm___pyx_scope_struct__full_scores;
if (PyType_Ready(&__pyx_type_5kenlm_LanguageModel) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 10; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
if (__Pyx_SetAttrString(__pyx_m, "LanguageModel", (PyObject *)&__pyx_type_5kenlm_LanguageModel) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 10; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__pyx_ptype_5kenlm_LanguageModel = &__pyx_type_5kenlm_LanguageModel;
- if (PyType_Ready(&__pyx_type_5kenlm___pyx_scope_struct__full_scores) < 0) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 44; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
- __pyx_ptype_5kenlm___pyx_scope_struct__full_scores = &__pyx_type_5kenlm___pyx_scope_struct__full_scores;
/*--- Type import code ---*/
/*--- Variable import code ---*/
/*--- Function import code ---*/
diff --git a/python/kenlm.pxd b/python/kenlm.pxd
index 9a397f3..7d68fc4 100644
--- a/python/kenlm.pxd
+++ b/python/kenlm.pxd
@@ -24,8 +24,8 @@ cdef extern from "lm/virtual_interface.hh" namespace "lm::base":
void NullContextWrite(void *)
unsigned int Order()
const_Vocabulary& BaseVocabulary()
- float Score(void *in_state, WordIndex new_word, void *out_state)
- FullScoreReturn FullScore(void *in_state, WordIndex new_word, void *out_state)
+ float BaseScore(void *in_state, WordIndex new_word, void *out_state)
+ FullScoreReturn BaseFullScore(void *in_state, WordIndex new_word, void *out_state)
cdef extern from "lm/model.hh" namespace "lm::ngram":
cdef Model *LoadVirtual(char *) except +
diff --git a/python/kenlm.pyx b/python/kenlm.pyx
index 7f965d4..eb45169 100644
--- a/python/kenlm.pyx
+++ b/python/kenlm.pyx
@@ -36,9 +36,9 @@ cdef class LanguageModel:
cdef State out_state
cdef float total = 0
for word in words:
- total += self.model.Score(&state, self.vocab.Index(word), &out_state)
+ total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
state = out_state
- total += self.model.Score(&state, self.vocab.EndSentence(), &out_state)
+ total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state)
return total
def full_scores(self, sentence):
@@ -49,11 +49,11 @@ cdef class LanguageModel:
cdef FullScoreReturn ret
cdef float total = 0
for word in words:
- ret = self.model.FullScore(&state,
+ ret = self.model.BaseFullScore(&state,
self.vocab.Index(word), &out_state)
yield (ret.prob, ret.ngram_length)
state = out_state
- ret = self.model.FullScore(&state,
+ ret = self.model.BaseFullScore(&state,
self.vocab.EndSentence(), &out_state)
yield (ret.prob, ret.ngram_length)
diff --git a/util/Jamfile b/util/Jamfile
index 910b305..18b20a3 100644
--- a/util/Jamfile
+++ b/util/Jamfile
@@ -19,15 +19,18 @@ alias read_compressed : read_compressed.o $(compressed_deps) ;
obj read_compressed_test.o : read_compressed_test.cc /top//boost_unit_test_framework : $(compressed_flags) ;
obj file_piece_test.o : file_piece_test.cc /top//boost_unit_test_framework : $(compressed_flags) ;
-fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ;
+fakelib parallel_read : parallel_read.cc : <threading>multi:<source>/top//boost_thread <threading>multi:<define>WITH_THREADS : : <include>.. ;
+
+fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc parallel_read pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ;
+
+exe cat_compressed : cat_compressed_main.cc kenutil ;
+
+alias programs : cat_compressed ;
import testing ;
-unit-test bit_packing_test : bit_packing_test.cc kenutil /top//boost_unit_test_framework ;
run file_piece_test.o kenutil /top//boost_unit_test_framework : : file_piece.cc ;
-unit-test read_compressed_test : read_compressed_test.o kenutil /top//boost_unit_test_framework ;
-unit-test joint_sort_test : joint_sort_test.cc kenutil /top//boost_unit_test_framework ;
-unit-test probing_hash_table_test : probing_hash_table_test.cc kenutil /top//boost_unit_test_framework ;
-unit-test sorted_uniform_test : sorted_uniform_test.cc kenutil /top//boost_unit_test_framework ;
-unit-test tokenize_piece_test : tokenize_piece_test.cc kenutil /top//boost_unit_test_framework ;
-unit-test multi_intersection_test : multi_intersection_test.cc kenutil /top//boost_unit_test_framework ;
+for local t in [ glob *_test.cc : file_piece_test.cc read_compressed_test.cc ] {
+ local name = [ MATCH "(.*)\.cc" : $(t) ] ;
+ unit-test $(name) : $(t) kenutil /top//boost_unit_test_framework /top//boost_system ;
+}
diff --git a/util/bit_packing.hh b/util/bit_packing.hh
index dcbd814..1e34d9a 100644
--- a/util/bit_packing.hh
+++ b/util/bit_packing.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_BIT_PACKING__
-#define UTIL_BIT_PACKING__
+#ifndef UTIL_BIT_PACKING_H
+#define UTIL_BIT_PACKING_H
/* Bit-level packing routines
*
@@ -183,4 +183,4 @@ struct BitAddress {
} // namespace util
-#endif // UTIL_BIT_PACKING__
+#endif // UTIL_BIT_PACKING_H
diff --git a/util/cat_compressed_main.cc b/util/cat_compressed_main.cc
new file mode 100644
index 0000000..2b4d729
--- /dev/null
+++ b/util/cat_compressed_main.cc
@@ -0,0 +1,47 @@
+// Like cat but interprets compressed files.
+#include "util/file.hh"
+#include "util/read_compressed.hh"
+
+#include <string.h>
+#include <iostream>
+
+namespace {
+const std::size_t kBufSize = 16384;
+void Copy(util::ReadCompressed &from, int to) {
+ util::scoped_malloc buffer(util::MallocOrThrow(kBufSize));
+ while (std::size_t amount = from.Read(buffer.get(), kBufSize)) {
+ util::WriteOrThrow(to, buffer.get(), amount);
+ }
+}
+} // namespace
+
+int main(int argc, char *argv[]) {
+ // Lane Schwartz likes -h and --help
+ for (int i = 1; i < argc; ++i) {
+ char *arg = argv[i];
+ if (!strcmp(arg, "--")) break;
+ if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) {
+ std::cerr <<
+ "A cat implementation that interprets compressed files.\n"
+ "Usage: " << argv[0] << " [file1] [file2] ...\n"
+ "If no file is provided, then stdin is read.\n";
+ return 1;
+ }
+ }
+
+ try {
+ if (argc == 1) {
+ util::ReadCompressed in(0);
+ Copy(in, 1);
+ } else {
+ for (int i = 1; i < argc; ++i) {
+ util::ReadCompressed in(util::OpenReadOrThrow(argv[i]));
+ Copy(in, 1);
+ }
+ }
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 2;
+ }
+ return 0;
+}
diff --git a/util/ersatz_progress.hh b/util/ersatz_progress.hh
index b94399a..535dbde 100644
--- a/util/ersatz_progress.hh
+++ b/util/ersatz_progress.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_ERSATZ_PROGRESS__
-#define UTIL_ERSATZ_PROGRESS__
+#ifndef UTIL_ERSATZ_PROGRESS_H
+#define UTIL_ERSATZ_PROGRESS_H
#include <iostream>
#include <string>
@@ -55,4 +55,4 @@ class ErsatzProgress {
} // namespace util
-#endif // UTIL_ERSATZ_PROGRESS__
+#endif // UTIL_ERSATZ_PROGRESS_H
diff --git a/util/exception.cc b/util/exception.cc
index 557c398..083bac2 100644
--- a/util/exception.cc
+++ b/util/exception.cc
@@ -51,6 +51,11 @@ void Exception::SetLocation(const char *file, unsigned int line, const char *fun
}
namespace {
+// At least one of these functions will not be called.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-function"
+#endif
// The XOPEN version.
const char *HandleStrerror(int ret, const char *buf) {
if (!ret) return buf;
@@ -61,6 +66,9 @@ const char *HandleStrerror(int ret, const char *buf) {
const char *HandleStrerror(const char *ret, const char * /*buf*/) {
return ret;
}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
} // namespace
ErrnoException::ErrnoException() throw() : errno_(errno) {
diff --git a/util/exception.hh b/util/exception.hh
index 74046cf..0966740 100644
--- a/util/exception.hh
+++ b/util/exception.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_EXCEPTION__
-#define UTIL_EXCEPTION__
+#ifndef UTIL_EXCEPTION_H
+#define UTIL_EXCEPTION_H
#include <exception>
#include <limits>
@@ -98,6 +98,9 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep
#define UTIL_THROW_IF(Condition, Exception, Modify) \
UTIL_THROW_IF_ARG(Condition, Exception, , Modify)
+#define UTIL_THROW_IF2(Condition, Modify) \
+ UTIL_THROW_IF_ARG(Condition, util::Exception, , Modify)
+
// Exception that records errno and adds it to the message.
class ErrnoException : public Exception {
public:
@@ -111,6 +114,13 @@ class ErrnoException : public Exception {
int errno_;
};
+// file wasn't there, or couldn't be open for some reason
+class FileOpenException : public Exception {
+ public:
+ FileOpenException() throw() {}
+ ~FileOpenException() throw() {}
+};
+
// Utilities for overflow checking.
class OverflowException : public Exception {
public:
@@ -133,4 +143,4 @@ inline std::size_t CheckOverflow(uint64_t value) {
} // namespace util
-#endif // UTIL_EXCEPTION__
+#endif // UTIL_EXCEPTION_H
diff --git a/util/fake_ofstream.hh b/util/fake_ofstream.hh
index bcdebe4..eefb1ed 100644
--- a/util/fake_ofstream.hh
+++ b/util/fake_ofstream.hh
@@ -2,6 +2,9 @@
* Does not support many data types. Currently, it's targeted at writing ARPA
* files quickly.
*/
+#ifndef UTIL_FAKE_OFSTREAM_H
+#define UTIL_FAKE_OFSTREAM_H
+
#include "util/double-conversion/double-conversion.h"
#include "util/double-conversion/utils.h"
#include "util/file.hh"
@@ -17,7 +20,8 @@ class FakeOFStream {
static const std::size_t kOutBuf = 1048576;
// Does not take ownership of out.
- explicit FakeOFStream(int out)
+ // Allows default constructor, but must call SetFD.
+ explicit FakeOFStream(int out = -1)
: buf_(util::MallocOrThrow(kOutBuf)),
builder_(static_cast<char*>(buf_.get()), kOutBuf),
// Mostly the default but with inf instead. And no flags.
@@ -28,6 +32,11 @@ class FakeOFStream {
if (buf_.get()) Flush();
}
+ void SetFD(int to) {
+ if (builder_.position()) Flush();
+ fd_ = to;
+ }
+
FakeOFStream &operator<<(float value) {
// Odd, but this is the largest number found in the comments.
EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
@@ -92,3 +101,5 @@ class FakeOFStream {
};
} // namespace
+
+#endif
diff --git a/util/file.cc b/util/file.cc
index bef04cb..0d9adf2 100644
--- a/util/file.cc
+++ b/util/file.cc
@@ -17,7 +17,11 @@
#include <fcntl.h>
#include <stdint.h>
-#if defined(_WIN32) || defined(_WIN64)
+#if defined __MINGW32__
+#include <windows.h>
+#include <unistd.h>
+#warning "The file functions on MinGW have not been tested for file sizes above 2^31 - 1. Please read https://stackoverflow.com/questions/12539488/determine-64-bit-file-size-in-c-on-mingw-32-bit and fix"
+#elif defined(_WIN32) || defined(_WIN64)
#include <windows.h>
#include <io.h>
#include <algorithm>
@@ -76,7 +80,13 @@ int CreateOrThrow(const char *name) {
}
uint64_t SizeFile(int fd) {
-#if defined(_WIN32) || defined(_WIN64)
+#if defined __MINGW32__
+ struct stat sb;
+ // Does this handle 64-bit?
+ int ret = fstat(fd, &sb);
+ if (ret == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize;
+ return sb.st_size;
+#elif defined(_WIN32) || defined(_WIN64)
__int64 ret = _filelengthi64(fd);
return (ret == -1) ? kBadSize : ret;
#else // Not windows.
@@ -100,7 +110,10 @@ uint64_t SizeOrThrow(int fd) {
}
void ResizeOrThrow(int fd, uint64_t to) {
-#if defined(_WIN32) || defined(_WIN64)
+#if defined __MINGW32__
+ // Does this handle 64-bit?
+ int ret = ftruncate
+#elif defined(_WIN32) || defined(_WIN64)
errno_t ret = _chsize_s
#elif defined(OS_ANDROID)
int ret = ftruncate64
@@ -115,8 +128,10 @@ namespace {
std::size_t GuardLarge(std::size_t size) {
// The following operating systems have broken read/write/pread/pwrite that
// only supports up to 2^31.
-#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID)
- return std::min(static_cast<std::size_t>(static_cast<unsigned>(-1)), size);
+ // OS X man pages claim to support 64-bit, but Kareem M. Darwish had problems
+ // building with larger files, so APPLE is also here.
+#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID) || defined(__MINGW32__)
+ return size < INT_MAX ? size : INT_MAX;
#else
return size;
#endif
@@ -179,16 +194,15 @@ void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) {
#else
ssize_t ret;
errno = 0;
- do {
- ret =
+ ret =
#ifdef OS_ANDROID
- pread64
+ pread64
#else
- pread
+ pread
#endif
- (fd, to, GuardLarge(size), off);
- } while (ret == -1 && errno == EINTR);
+ (fd, to, GuardLarge(size), off);
if (ret <= 0) {
+ if (ret == -1 && errno == EINTR) continue;
UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from " << NameFromFD(fd));
UTIL_THROW_ARG(FDException, (fd), "while reading " << size << " bytes at offset " << off);
}
@@ -251,7 +265,10 @@ typedef CheckOffT<sizeof(off_t)>::True IgnoredType;
// Can't we all just get along?
void InternalSeek(int fd, int64_t off, int whence) {
if (
-#if defined(_WIN32) || defined(_WIN64)
+#if defined __MINGW32__
+ // Does this handle 64-bit?
+ (off_t)-1 == lseek(fd, off, whence)
+#elif defined(_WIN32) || defined(_WIN64)
(__int64)-1 == _lseeki64(fd, off, whence)
#elif defined(OS_ANDROID)
(off64_t)-1 == lseek64(fd, off, whence)
@@ -427,8 +444,8 @@ void NormalizeTempPrefix(std::string &base) {
) base += '/';
}
-int MakeTemp(const std::string &base) {
- std::string name(base);
+int MakeTemp(const StringPiece &base) {
+ std::string name(base.data(), base.size());
name += "XXXXXX";
name.push_back(0);
int ret;
@@ -436,7 +453,7 @@ int MakeTemp(const std::string &base) {
return ret;
}
-std::FILE *FMakeTemp(const std::string &base) {
+std::FILE *FMakeTemp(const StringPiece &base) {
util::scoped_fd file(MakeTemp(base));
return FDOpenOrThrow(file);
}
@@ -462,14 +479,18 @@ bool TryName(int fd, std::string &out) {
if (-1 == lstat(name.c_str(), &sb))
return false;
out.resize(sb.st_size + 1);
- ssize_t ret = readlink(name.c_str(), &out[0], sb.st_size + 1);
- if (-1 == ret)
- return false;
- if (ret > sb.st_size) {
- // Increased in size?!
- return false;
+ // lstat gave us a size, but I've seen it grow, possibly due to symlinks on top of symlinks.
+ while (true) {
+ ssize_t ret = readlink(name.c_str(), &out[0], out.size());
+ if (-1 == ret)
+ return false;
+ if ((size_t)ret < out.size()) {
+ out.resize(ret);
+ break;
+ }
+ // Exponential growth.
+ out.resize(out.size() * 2);
}
- out.resize(ret);
// Don't use the non-file names.
if (!out.empty() && out[0] != '/')
return false;
diff --git a/util/file.hh b/util/file.hh
index be88431..170a7c7 100644
--- a/util/file.hh
+++ b/util/file.hh
@@ -1,7 +1,8 @@
-#ifndef UTIL_FILE__
-#define UTIL_FILE__
+#ifndef UTIL_FILE_H
+#define UTIL_FILE_H
#include "util/exception.hh"
+#include "util/string_piece.hh"
#include <cstddef>
#include <cstdio>
@@ -125,8 +126,8 @@ std::FILE *FDOpenReadOrThrow(scoped_fd &file);
// Temporary files
// Append a / if base is a directory.
void NormalizeTempPrefix(std::string &base);
-int MakeTemp(const std::string &prefix);
-std::FILE *FMakeTemp(const std::string &prefix);
+int MakeTemp(const StringPiece &prefix);
+std::FILE *FMakeTemp(const StringPiece &prefix);
// dup an fd.
int DupOrThrow(int fd);
@@ -139,4 +140,4 @@ std::string NameFromFD(int fd);
} // namespace util
-#endif // UTIL_FILE__
+#endif // UTIL_FILE_H
diff --git a/util/file_piece.cc b/util/file_piece.cc
index 9c7e00c..4aaa250 100644
--- a/util/file_piece.cc
+++ b/util/file_piece.cc
@@ -84,6 +84,13 @@ StringPiece FilePiece::ReadLine(char delim) {
}
}
+bool FilePiece::ReadLineOrEOF(StringPiece &to, char delim) {
+ try {
+ to = ReadLine(delim);
+ } catch (const util::EndOfFileException &e) { return false; }
+ return true;
+}
+
float FilePiece::ReadFloat() {
return ReadNumber<float>();
}
diff --git a/util/file_piece.hh b/util/file_piece.hh
index ed3dc5a..5495ddc 100644
--- a/util/file_piece.hh
+++ b/util/file_piece.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_FILE_PIECE__
-#define UTIL_FILE_PIECE__
+#ifndef UTIL_FILE_PIECE_H
+#define UTIL_FILE_PIECE_H
#include "util/ersatz_progress.hh"
#include "util/exception.hh"
@@ -56,10 +56,33 @@ class FilePiece {
return Consume(FindDelimiterOrEOF(delim));
}
+ // Read word until the line or file ends.
+ bool ReadWordSameLine(StringPiece &to, const bool *delim = kSpaces) {
+ assert(delim[static_cast<unsigned char>('\n')]);
+ // Skip non-enter spaces.
+ for (; ; ++position_) {
+ if (position_ == position_end_) {
+ try {
+ Shift();
+ } catch (const util::EndOfFileException &e) { return false; }
+ // And break out at end of file.
+ if (position_ == position_end_) return false;
+ }
+ if (!delim[static_cast<unsigned char>(*position_)]) break;
+ if (*position_ == '\n') return false;
+ }
+ // We can't be at the end of file because there's at least one character open.
+ to = Consume(FindDelimiterOrEOF(delim));
+ return true;
+ }
+
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
// It is similar to getline in that way.
StringPiece ReadLine(char delim = '\n');
+ // Doesn't throw EndOfFileException, just returns false.
+ bool ReadLineOrEOF(StringPiece &to, char delim = '\n');
+
float ReadFloat();
double ReadDouble();
long int ReadLong();
@@ -132,4 +155,4 @@ class FilePiece {
} // namespace util
-#endif // UTIL_FILE_PIECE__
+#endif // UTIL_FILE_PIECE_H
diff --git a/util/file_piece_test.cc b/util/file_piece_test.cc
index 7336007..4361877 100644
--- a/util/file_piece_test.cc
+++ b/util/file_piece_test.cc
@@ -1,4 +1,4 @@
-// Tests might fail if you have creative characters in your path. Sue me.
+// Tests might fail if you have creative characters in your path. Sue me.
#include "util/file_piece.hh"
#include "util/file.hh"
@@ -55,7 +55,7 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) {
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
/* Apple isn't happy with the popen, fileno, dup. And I don't want to
- * reimplement popen. This is an issue with the test.
+ * reimplement popen. This is an issue with the test.
*/
/* read() implementation */
BOOST_AUTO_TEST_CASE(StreamReadLine) {
@@ -67,7 +67,7 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) {
FILE *catter = popen(popen_args.c_str(), "r");
BOOST_REQUIRE(catter);
-
+
FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
@@ -107,8 +107,8 @@ BOOST_AUTO_TEST_CASE(PlainZipReadLine) {
}
// gzip stream. Apple doesn't like popen, fileno, dup. This is an issue with
-// the test.
-#ifndef __APPLE__
+// the test.
+#if !defined __APPLE__ && !defined __MINGW32__
BOOST_AUTO_TEST_CASE(StreamZipReadLine) {
std::fstream ref(FileLocation().c_str(), std::ios::in);
@@ -117,7 +117,7 @@ BOOST_AUTO_TEST_CASE(StreamZipReadLine) {
FILE * catter = popen(command.c_str(), "r");
BOOST_REQUIRE(catter);
-
+
FilePiece test(dup(fileno(catter)), "file_piece.cc.gz", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
diff --git a/util/fixed_array.hh b/util/fixed_array.hh
new file mode 100644
index 0000000..bae13de
--- /dev/null
+++ b/util/fixed_array.hh
@@ -0,0 +1,94 @@
+#ifndef UTIL_FIXED_ARRAY_H
+#define UTIL_FIXED_ARRAY_H
+
+// Ever want an array of things by they don't have a default constructor or are
+// non-copyable? FixedArray allows constructing one at a time.
+#include "util/scoped.hh"
+
+#include <cstddef>
+
+#include <assert.h>
+
+namespace util {
+
+template <class T> class FixedArray {
+ public:
+ // Initialize with a given size bound but do not construct the objects.
+ explicit FixedArray(std::size_t limit) {
+ Init(limit);
+ }
+
+ FixedArray()
+ : newed_end_(NULL)
+#ifndef NDEBUG
+ , allocated_end_(NULL)
+#endif
+ {}
+
+ void Init(std::size_t count) {
+ assert(!block_.get());
+ block_.reset(malloc(sizeof(T) * count));
+ if (!block_.get()) throw std::bad_alloc();
+ newed_end_ = begin();
+#ifndef NDEBUG
+ allocated_end_ = begin() + count;
+#endif
+ }
+
+ FixedArray(const FixedArray &from) {
+ std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
+ Init(size);
+ for (std::size_t i = 0; i < size; ++i) {
+ push_back(from[i]);
+ }
+ }
+
+ ~FixedArray() { clear(); }
+
+ T *begin() { return static_cast<T*>(block_.get()); }
+ const T *begin() const { return static_cast<const T*>(block_.get()); }
+ // Always call Constructed after successful completion of new.
+ T *end() { return newed_end_; }
+ const T *end() const { return newed_end_; }
+
+ T &back() { return *(end() - 1); }
+ const T &back() const { return *(end() - 1); }
+
+ std::size_t size() const { return end() - begin(); }
+ bool empty() const { return begin() == end(); }
+
+ T &operator[](std::size_t i) { return begin()[i]; }
+ const T &operator[](std::size_t i) const { return begin()[i]; }
+
+ template <class C> void push_back(const C &c) {
+ new (end()) T(c);
+ Constructed();
+ }
+
+ void clear() {
+ for (T *i = begin(); i != end(); ++i)
+ i->~T();
+ newed_end_ = begin();
+ }
+
+ protected:
+ void Constructed() {
+ ++newed_end_;
+#ifndef NDEBUG
+ assert(newed_end_ <= allocated_end_);
+#endif
+ }
+
+ private:
+ util::scoped_malloc block_;
+
+ T *newed_end_;
+
+#ifndef NDEBUG
+ T *allocated_end_;
+#endif
+};
+
+} // namespace util
+
+#endif // UTIL_FIXED_ARRAY_H
diff --git a/util/getopt.hh b/util/getopt.hh
index 6ad9773..50eab56 100644
--- a/util/getopt.hh
+++ b/util/getopt.hh
@@ -11,8 +11,8 @@ Code given out at the 1985 UNIFORUM conference in Dallas.
#endif
#ifndef __GNUC__
-#ifndef _WINGETOPT_H_
-#define _WINGETOPT_H_
+#ifndef UTIL_GETOPT_H
+#define UTIL_GETOPT_H
#ifdef __cplusplus
extern "C" {
@@ -28,6 +28,6 @@ extern int getopt(int argc, char **argv, char *opts);
}
#endif
-#endif /* _GETOPT_H_ */
+#endif /* UTIL_GETOPT_H */
#endif /* __GNUC__ */
diff --git a/util/have.hh b/util/have.hh
index 6e18529..dc3f633 100644
--- a/util/have.hh
+++ b/util/have.hh
@@ -1,6 +1,6 @@
/* Optional packages. You might want to integrate this with your build system e.g. config.h from ./configure. */
-#ifndef UTIL_HAVE__
-#define UTIL_HAVE__
+#ifndef UTIL_HAVE_H
+#define UTIL_HAVE_H
#ifdef HAVE_CONFIG_H
#include "config.h"
@@ -10,4 +10,4 @@
//#define HAVE_ICU
#endif
-#endif // UTIL_HAVE__
+#endif // UTIL_HAVE_H
diff --git a/util/joint_sort.hh b/util/joint_sort.hh
index 1b43ddc..de4b554 100644
--- a/util/joint_sort.hh
+++ b/util/joint_sort.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_JOINT_SORT__
-#define UTIL_JOINT_SORT__
+#ifndef UTIL_JOINT_SORT_H
+#define UTIL_JOINT_SORT_H
/* A terrifying amount of C++ to coax std::sort into soring one range while
* also permuting another range the same way.
@@ -9,7 +9,6 @@
#include <algorithm>
#include <functional>
-#include <iostream>
namespace util {
@@ -35,9 +34,16 @@ template <class KeyIter, class ValueIter> class JointIter {
return *this;
}
- void swap(const JointIter &other) {
- std::swap(key_, other.key_);
- std::swap(value_, other.value_);
+ friend void swap(JointIter &first, JointIter &second) {
+ using std::swap;
+ swap(first.key_, second.key_);
+ swap(first.value_, second.value_);
+ }
+
+ void DeepSwap(JointIter &other) {
+ using std::swap;
+ swap(*key_, *other.key_);
+ swap(*value_, *other.value_);
}
private:
@@ -83,9 +89,8 @@ template <class KeyIter, class ValueIter> class JointProxy {
return *(inner_.key_);
}
- void swap(JointProxy<KeyIter, ValueIter> &other) {
- std::swap(*inner_.key_, *other.inner_.key_);
- std::swap(*inner_.value_, *other.inner_.value_);
+ friend void swap(JointProxy<KeyIter, ValueIter> first, JointProxy<KeyIter, ValueIter> second) {
+ first.Inner().DeepSwap(second.Inner());
}
private:
@@ -138,14 +143,4 @@ template <class KeyIter, class ValueIter> void JointSort(const KeyIter &key_begi
} // namespace util
-namespace std {
-template <class KeyIter, class ValueIter> void swap(util::detail::JointIter<KeyIter, ValueIter> &left, util::detail::JointIter<KeyIter, ValueIter> &right) {
- left.swap(right);
-}
-
-template <class KeyIter, class ValueIter> void swap(util::detail::JointProxy<KeyIter, ValueIter> &left, util::detail::JointProxy<KeyIter, ValueIter> &right) {
- left.swap(right);
-}
-} // namespace std
-
-#endif // UTIL_JOINT_SORT__
+#endif // UTIL_JOINT_SORT_H
diff --git a/util/joint_sort_test.cc b/util/joint_sort_test.cc
index 4dc8591..b24c602 100644
--- a/util/joint_sort_test.cc
+++ b/util/joint_sort_test.cc
@@ -47,4 +47,16 @@ BOOST_AUTO_TEST_CASE(char_int) {
BOOST_CHECK_EQUAL(327, values[3]);
}
+BOOST_AUTO_TEST_CASE(swap_proxy) {
+ char keys[2] = {0, 1};
+ int values[2] = {2, 3};
+ detail::JointProxy<char *, int *> first(keys, values);
+ detail::JointProxy<char *, int *> second(keys + 1, values + 1);
+ swap(first, second);
+ BOOST_CHECK_EQUAL(1, keys[0]);
+ BOOST_CHECK_EQUAL(0, keys[1]);
+ BOOST_CHECK_EQUAL(3, values[0]);
+ BOOST_CHECK_EQUAL(2, values[1]);
+}
+
}} // namespace anonymous util
diff --git a/util/mmap.cc b/util/mmap.cc
index cee6a97..a3c8a02 100644
--- a/util/mmap.cc
+++ b/util/mmap.cc
@@ -6,6 +6,7 @@
#include "util/exception.hh"
#include "util/file.hh"
+#include "util/parallel_read.hh"
#include "util/scoped.hh"
#include <iostream>
@@ -40,7 +41,7 @@ void SyncOrThrow(void *start, size_t length) {
#if defined(_WIN32) || defined(_WIN64)
UTIL_THROW_IF(!::FlushViewOfFile(start, length), ErrnoException, "Failed to sync mmap");
#else
- UTIL_THROW_IF(msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap");
+ UTIL_THROW_IF(length && msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap");
#endif
}
@@ -154,6 +155,10 @@ void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scope
SeekOrThrow(fd, offset);
ReadOrThrow(fd, out.get(), size);
break;
+ case PARALLEL_READ:
+ out.reset(MallocOrThrow(size), size, scoped_memory::MALLOC_ALLOCATED);
+ ParallelRead(fd, out.get(), size, offset);
+ break;
}
}
@@ -189,4 +194,66 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) {
}
}
+Rolling::Rolling(const Rolling &copy_from, uint64_t increase) {
+ *this = copy_from;
+ IncreaseBase(increase);
+}
+
+Rolling &Rolling::operator=(const Rolling &copy_from) {
+ fd_ = copy_from.fd_;
+ file_begin_ = copy_from.file_begin_;
+ file_end_ = copy_from.file_end_;
+ for_write_ = copy_from.for_write_;
+ block_ = copy_from.block_;
+ read_bound_ = copy_from.read_bound_;
+
+ current_begin_ = 0;
+ if (copy_from.IsPassthrough()) {
+ current_end_ = copy_from.current_end_;
+ ptr_ = copy_from.ptr_;
+ } else {
+ // Force call on next mmap.
+ current_end_ = 0;
+ ptr_ = NULL;
+ }
+ return *this;
+}
+
+Rolling::Rolling(int fd, bool for_write, std::size_t block, std::size_t read_bound, uint64_t offset, uint64_t amount) {
+ current_begin_ = 0;
+ current_end_ = 0;
+ fd_ = fd;
+ file_begin_ = offset;
+ file_end_ = offset + amount;
+ for_write_ = for_write;
+ block_ = block;
+ read_bound_ = read_bound;
+}
+
+void *Rolling::ExtractNonRolling(scoped_memory &out, uint64_t index, std::size_t size) {
+ out.reset();
+ if (IsPassthrough()) return static_cast<uint8_t*>(get()) + index;
+ uint64_t offset = index + file_begin_;
+ // Round down to multiple of page size.
+ uint64_t cruft = offset % static_cast<uint64_t>(SizePage());
+ std::size_t map_size = static_cast<std::size_t>(size + cruft);
+ out.reset(MapOrThrow(map_size, for_write_, kFileFlags, true, fd_, offset - cruft), map_size, scoped_memory::MMAP_ALLOCATED);
+ return static_cast<uint8_t*>(out.get()) + static_cast<std::size_t>(cruft);
+}
+
+void Rolling::Roll(uint64_t index) {
+ assert(!IsPassthrough());
+ std::size_t amount;
+ if (file_end_ - (index + file_begin_) > static_cast<uint64_t>(block_)) {
+ amount = block_;
+ current_end_ = index + amount - read_bound_;
+ } else {
+ amount = file_end_ - (index + file_begin_);
+ current_end_ = index + amount;
+ }
+ ptr_ = static_cast<uint8_t*>(ExtractNonRolling(mem_, index, amount)) - index;
+
+ current_begin_ = index;
+}
+
} // namespace util
diff --git a/util/mmap.hh b/util/mmap.hh
index b218c4d..9b1e120 100644
--- a/util/mmap.hh
+++ b/util/mmap.hh
@@ -1,8 +1,9 @@
-#ifndef UTIL_MMAP__
-#define UTIL_MMAP__
+#ifndef UTIL_MMAP_H
+#define UTIL_MMAP_H
// Utilities for mmaped files.
#include <cstddef>
+#include <limits>
#include <stdint.h>
#include <sys/types.h>
@@ -52,6 +53,9 @@ class scoped_memory {
public:
typedef enum {MMAP_ALLOCATED, ARRAY_ALLOCATED, MALLOC_ALLOCATED, NONE_ALLOCATED} Alloc;
+ scoped_memory(void *data, std::size_t size, Alloc source)
+ : data_(data), size_(size), source_(source) {}
+
scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {}
~scoped_memory() { reset(); }
@@ -72,7 +76,6 @@ class scoped_memory {
void call_realloc(std::size_t to);
private:
-
void *data_;
std::size_t size_;
@@ -90,7 +93,9 @@ typedef enum {
// Populate on Linux. malloc and read on non-Linux.
POPULATE_OR_READ,
// malloc and read.
- READ
+ READ,
+ // malloc and read in parallel (recommended for Lustre)
+ PARALLEL_READ,
} LoadMethod;
extern const int kFileFlags;
@@ -109,6 +114,79 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file);
// msync wrapper
void SyncOrThrow(void *start, size_t length);
+// Forward rolling memory map with no overlap.
+class Rolling {
+ public:
+ Rolling() {}
+
+ explicit Rolling(void *data) { Init(data); }
+
+ Rolling(const Rolling &copy_from, uint64_t increase = 0);
+ Rolling &operator=(const Rolling &copy_from);
+
+ // For an actual rolling mmap.
+ explicit Rolling(int fd, bool for_write, std::size_t block, std::size_t read_bound, uint64_t offset, uint64_t amount);
+
+ // For a static mapping
+ void Init(void *data) {
+ ptr_ = data;
+ current_end_ = std::numeric_limits<uint64_t>::max();
+ current_begin_ = 0;
+ // Mark as a pass-through.
+ fd_ = -1;
+ }
+
+ void IncreaseBase(uint64_t by) {
+ file_begin_ += by;
+ ptr_ = static_cast<uint8_t*>(ptr_) + by;
+ if (!IsPassthrough()) current_end_ = 0;
+ }
+
+ void DecreaseBase(uint64_t by) {
+ file_begin_ -= by;
+ ptr_ = static_cast<uint8_t*>(ptr_) - by;
+ if (!IsPassthrough()) current_end_ = 0;
+ }
+
+ void *ExtractNonRolling(scoped_memory &out, uint64_t index, std::size_t size);
+
+ // Returns base pointer
+ void *get() const { return ptr_; }
+
+ // Returns base pointer.
+ void *CheckedBase(uint64_t index) {
+ if (index >= current_end_ || index < current_begin_) {
+ Roll(index);
+ }
+ return ptr_;
+ }
+
+ // Returns indexed pointer.
+ void *CheckedIndex(uint64_t index) {
+ return static_cast<uint8_t*>(CheckedBase(index)) + index;
+ }
+
+ private:
+ void Roll(uint64_t index);
+
+ // True if this is just a thin wrapper on a pointer.
+ bool IsPassthrough() const { return fd_ == -1; }
+
+ void *ptr_;
+ uint64_t current_begin_;
+ uint64_t current_end_;
+
+ scoped_memory mem_;
+
+ int fd_;
+ uint64_t file_begin_;
+ uint64_t file_end_;
+
+ bool for_write_;
+ std::size_t block_;
+ std::size_t read_bound_;
+};
+
} // namespace util
-#endif // UTIL_MMAP__
+#endif // UTIL_MMAP_H
diff --git a/util/multi_intersection.hh b/util/multi_intersection.hh
index 8334d39..2955acc 100644
--- a/util/multi_intersection.hh
+++ b/util/multi_intersection.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_MULTI_INTERSECTION__
-#define UTIL_MULTI_INTERSECTION__
+#ifndef UTIL_MULTI_INTERSECTION_H
+#define UTIL_MULTI_INTERSECTION_H
#include <boost/optional.hpp>
#include <boost/range/iterator_range.hpp>
@@ -66,7 +66,7 @@ template <class Iterator, class Output, class Less> void AllIntersection(std::ve
std::sort(sets.begin(), sets.end(), detail::RangeLessBySize<boost::iterator_range<Iterator> >());
boost::optional<Value> ret;
- for (boost::optional<Value> ret; ret = detail::FirstIntersectionSorted(sets, less); sets.front().advance_begin(1)) {
+ for (boost::optional<Value> ret; (ret = detail::FirstIntersectionSorted(sets, less)); sets.front().advance_begin(1)) {
out(*ret);
}
}
@@ -77,4 +77,4 @@ template <class Iterator, class Output> void AllIntersection(std::vector<boost::
} // namespace util
-#endif // UTIL_MULTI_INTERSECTION__
+#endif // UTIL_MULTI_INTERSECTION_H
diff --git a/util/murmur_hash.cc b/util/murmur_hash.cc
index 4f51931..189668c 100644
--- a/util/murmur_hash.cc
+++ b/util/murmur_hash.cc
@@ -153,12 +153,19 @@ uint64_t MurmurHash64B ( const void * key, std::size_t len, uint64_t seed )
// Trick to test for 64-bit architecture at compile time.
namespace {
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-function"
+#endif
template <unsigned L> inline uint64_t MurmurHashNativeBackend(const void * key, std::size_t len, uint64_t seed) {
return MurmurHash64A(key, len, seed);
}
template <> inline uint64_t MurmurHashNativeBackend<4>(const void * key, std::size_t len, uint64_t seed) {
return MurmurHash64B(key, len, seed);
}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
} // namespace
uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed) {
diff --git a/util/murmur_hash.hh b/util/murmur_hash.hh
index ae7e88d..f17157c 100644
--- a/util/murmur_hash.hh
+++ b/util/murmur_hash.hh
@@ -1,14 +1,18 @@
-#ifndef UTIL_MURMUR_HASH__
-#define UTIL_MURMUR_HASH__
+#ifndef UTIL_MURMUR_HASH_H
+#define UTIL_MURMUR_HASH_H
#include <cstddef>
#include <stdint.h>
namespace util {
+// 64-bit machine version
uint64_t MurmurHash64A(const void * key, std::size_t len, uint64_t seed = 0);
+// 32-bit machine version (not the same function as above)
uint64_t MurmurHash64B(const void * key, std::size_t len, uint64_t seed = 0);
+// Use the version for this arch. Because the values differ across
+// architectures, really only use it for in-memory structures.
uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed = 0);
} // namespace util
-#endif // UTIL_MURMUR_HASH__
+#endif // UTIL_MURMUR_HASH_H
diff --git a/util/parallel_read.cc b/util/parallel_read.cc
new file mode 100644
index 0000000..10972d7
--- /dev/null
+++ b/util/parallel_read.cc
@@ -0,0 +1,69 @@
+#include "util/parallel_read.hh"
+
+#include "util/file.hh"
+
+#ifdef WITH_THREADS
+#include "util/thread_pool.hh"
+
+namespace util {
+namespace {
+
+class Reader {
+ public:
+ explicit Reader(int fd) : fd_(fd) {}
+
+ struct Request {
+ void *to;
+ std::size_t size;
+ uint64_t offset;
+
+ bool operator==(const Request &other) const {
+ return (to == other.to) && (size == other.size) && (offset == other.offset);
+ }
+ };
+
+ void operator()(const Request &request) {
+ util::PReadOrThrow(fd_, request.to, request.size, request.offset);
+ }
+
+ private:
+ int fd_;
+};
+
+} // namespace
+
+void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) {
+ Reader::Request poison;
+ poison.to = NULL;
+ poison.size = 0;
+ poison.offset = 0;
+ unsigned threads = boost::thread::hardware_concurrency();
+ if (!threads) threads = 2;
+ ThreadPool<Reader> pool(2 /* don't need much of a queue */, threads, fd, poison);
+ const std::size_t kBatch = 1ULL << 25; // 32 MB
+ Reader::Request request;
+ request.to = to;
+ request.size = kBatch;
+ request.offset = offset;
+ for (; amount > kBatch; amount -= kBatch) {
+ pool.Produce(request);
+ request.to = reinterpret_cast<uint8_t*>(request.to) + kBatch;
+ request.offset += kBatch;
+ }
+ request.size = amount;
+ if (request.size) {
+ pool.Produce(request);
+ }
+}
+
+} // namespace util
+
+#else // WITH_THREADS
+
+namespace util {
+void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) {
+ util::PReadOrThrow(fd, to, amount, offset);
+}
+} // namespace util
+
+#endif
diff --git a/util/parallel_read.hh b/util/parallel_read.hh
new file mode 100644
index 0000000..1e96e79
--- /dev/null
+++ b/util/parallel_read.hh
@@ -0,0 +1,16 @@
+#ifndef UTIL_PARALLEL_READ__
+#define UTIL_PARALLEL_READ__
+
+/* Read pieces of a file in parallel. This has a very specific use case:
+ * reading files from Lustre is CPU bound so multiple threads actually
+ * increases throughput. Speed matters when an LM takes a terabyte.
+ */
+
+#include <cstddef>
+#include <stdint.h>
+
+namespace util {
+void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset);
+} // namespace util
+
+#endif // UTIL_PARALLEL_READ__
diff --git a/util/pcqueue.hh b/util/pcqueue.hh
index 3df8749..312a66f 100644
--- a/util/pcqueue.hh
+++ b/util/pcqueue.hh
@@ -1,5 +1,7 @@
-#ifndef UTIL_PCQUEUE__
-#define UTIL_PCQUEUE__
+#ifndef UTIL_PCQUEUE_H
+#define UTIL_PCQUEUE_H
+
+#include "util/exception.hh"
#include <boost/interprocess/sync/interprocess_semaphore.hpp>
#include <boost/scoped_array.hpp>
@@ -8,20 +10,68 @@
#include <errno.h>
+#ifdef __APPLE__
+#include <mach/semaphore.h>
+#include <mach/task.h>
+#include <mach/mach_traps.h>
+#include <mach/mach.h>
+#endif // __APPLE__
+
namespace util {
-inline void WaitSemaphore (boost::interprocess::interprocess_semaphore &on) {
+/* OS X Maverick and Boost interprocess were doing "Function not implemented."
+ * So this is my own wrapper around the mach kernel APIs.
+ */
+#ifdef __APPLE__
+
+#define MACH_CALL(call) UTIL_THROW_IF(KERN_SUCCESS != (call), Exception, "Mach call failure")
+
+class Semaphore {
+ public:
+ explicit Semaphore(int value) : task_(mach_task_self()) {
+ MACH_CALL(semaphore_create(task_, &back_, SYNC_POLICY_FIFO, value));
+ }
+
+ ~Semaphore() {
+ MACH_CALL(semaphore_destroy(task_, back_));
+ }
+
+ void wait() {
+ MACH_CALL(semaphore_wait(back_));
+ }
+
+ void post() {
+ MACH_CALL(semaphore_signal(back_));
+ }
+
+ private:
+ semaphore_t back_;
+ task_t task_;
+};
+
+inline void WaitSemaphore(Semaphore &semaphore) {
+ semaphore.wait();
+}
+
+#else
+typedef boost::interprocess::interprocess_semaphore Semaphore;
+
+inline void WaitSemaphore (Semaphore &on) {
while (1) {
try {
on.wait();
break;
}
catch (boost::interprocess::interprocess_exception &e) {
- if (e.get_native_error() != EINTR) throw;
+ if (e.get_native_error() != EINTR) {
+ throw;
+ }
}
}
}
+#endif // __APPLE__
+
/* Producer consumer queue safe for multiple producers and multiple consumers.
* T must be default constructable and have operator=.
* The value is copied twice for Consume(T &out) or three times for Consume(),
@@ -82,9 +132,9 @@ template <class T> class PCQueue : boost::noncopyable {
private:
// Number of empty spaces in storage_.
- boost::interprocess::interprocess_semaphore empty_;
+ Semaphore empty_;
// Number of occupied spaces in storage_.
- boost::interprocess::interprocess_semaphore used_;
+ Semaphore used_;
boost::scoped_array<T> storage_;
@@ -102,4 +152,4 @@ template <class T> class PCQueue : boost::noncopyable {
} // namespace util
-#endif // UTIL_PCQUEUE__
+#endif // UTIL_PCQUEUE_H
diff --git a/util/pcqueue_test.cc b/util/pcqueue_test.cc
new file mode 100644
index 0000000..22ed2c6
--- /dev/null
+++ b/util/pcqueue_test.cc
@@ -0,0 +1,20 @@
+#include "util/pcqueue.hh"
+
+#define BOOST_TEST_MODULE PCQueueTest
+#include <boost/test/unit_test.hpp>
+
+namespace util {
+namespace {
+
+BOOST_AUTO_TEST_CASE(SingleThread) {
+ PCQueue<int> queue(10);
+ for (int i = 0; i < 10; ++i) {
+ queue.Produce(i);
+ }
+ for (int i = 0; i < 10; ++i) {
+ BOOST_CHECK_EQUAL(i, queue.Consume());
+ }
+}
+
+}
+} // namespace util
diff --git a/util/pool.hh b/util/pool.hh
index 72f8a0c..89e793d 100644
--- a/util/pool.hh
+++ b/util/pool.hh
@@ -1,8 +1,8 @@
// Very simple pool. It can only allocate memory. And all of the memory it
// allocates must be freed at the same time.
-#ifndef UTIL_POOL__
-#define UTIL_POOL__
+#ifndef UTIL_POOL_H
+#define UTIL_POOL_H
#include <vector>
@@ -42,4 +42,4 @@ class Pool {
} // namespace util
-#endif // UTIL_POOL__
+#endif // UTIL_POOL_H
diff --git a/util/probing_hash_table.hh b/util/probing_hash_table.hh
index 51a2944..c1fe917 100644
--- a/util/probing_hash_table.hh
+++ b/util/probing_hash_table.hh
@@ -1,7 +1,8 @@
-#ifndef UTIL_PROBING_HASH_TABLE__
-#define UTIL_PROBING_HASH_TABLE__
+#ifndef UTIL_PROBING_HASH_TABLE_H
+#define UTIL_PROBING_HASH_TABLE_H
#include "util/exception.hh"
+#include "util/scoped.hh"
#include <algorithm>
#include <cstddef>
@@ -25,6 +26,8 @@ struct IdentityHash {
template <class T> T operator()(T arg) const { return arg; }
};
+template <class EntryT, class HashT, class EqualT> class AutoProbing;
+
/* Non-standard hash table
* Buckets must be set at the beginning and must be greater than maximum number
* of elements, else it throws ProbingSizeException.
@@ -33,7 +36,6 @@ struct IdentityHash {
* Uses linear probing to find value.
* Only insert and lookup operations.
*/
-
template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class ProbingHashTable {
public:
typedef EntryT Entry;
@@ -43,7 +45,6 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
typedef HashT Hash;
typedef EqualT Equal;
- public:
static uint64_t Size(uint64_t entries, float multiplier) {
uint64_t buckets = std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries)));
return buckets * sizeof(Entry);
@@ -69,6 +70,11 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#endif
{}
+ void Relocate(void *new_base) {
+ begin_ = reinterpret_cast<MutableIterator>(new_base);
+ end_ = begin_ + buckets_;
+ }
+
template <class T> MutableIterator Insert(const T &t) {
#ifdef DEBUG
assert(initialized_);
@@ -82,7 +88,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#ifdef DEBUG
assert(initialized_);
#endif
- for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) {
+ for (MutableIterator i = Ideal(t);;) {
Key got(i->GetKey());
if (equal_(got, t.GetKey())) { out = i; return true; }
if (equal_(got, invalid_)) {
@@ -97,8 +103,6 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
void FinishedInserting() {}
- void LoadedBinary() {}
-
// Don't change anything related to GetKey,
template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
#ifdef DEBUG
@@ -224,6 +228,8 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
}
private:
+ friend class AutoProbing<Entry, Hash, Equal>;
+
template <class T> MutableIterator Ideal(const T &t) {
return begin_ + (hash_(t.GetKey()) % buckets_);
}
@@ -247,6 +253,75 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#endif
};
+// Resizable linear probing hash table. This owns the memory.
+template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class AutoProbing {
+ private:
+ typedef ProbingHashTable<EntryT, HashT, EqualT> Backend;
+ public:
+ typedef EntryT Entry;
+ typedef typename Entry::Key Key;
+ typedef const Entry *ConstIterator;
+ typedef Entry *MutableIterator;
+ typedef HashT Hash;
+ typedef EqualT Equal;
+
+ AutoProbing(std::size_t initial_size = 10, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) :
+ allocated_(Backend::Size(initial_size, 1.5)), mem_(util::MallocOrThrow(allocated_)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) {
+ threshold_ = initial_size * 1.2;
+ Clear();
+ }
+
+ // Assumes that the key is unique. Multiple insertions won't cause a failure, just inconsistent lookup.
+ template <class T> MutableIterator Insert(const T &t) {
+ DoubleIfNeeded();
+ return backend_.UncheckedInsert(t);
+ }
+
+ template <class T> bool FindOrInsert(const T &t, MutableIterator &out) {
+ DoubleIfNeeded();
+ return backend_.FindOrInsert(t, out);
+ }
+
+ template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
+ return backend_.UnsafeMutableFind(key, out);
+ }
+
+ template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) {
+ return backend_.UnsafeMutableMustFind(key);
+ }
+
+ template <class Key> bool Find(const Key key, ConstIterator &out) const {
+ return backend_.Find(key, out);
+ }
+
+ template <class Key> ConstIterator MustFind(const Key key) const {
+ return backend_.MustFind(key);
+ }
+
+ std::size_t Size() const {
+ return backend_.SizeNoSerialization();
+ }
+
+ void Clear() {
+ backend_.Clear();
+ }
+
+ private:
+ void DoubleIfNeeded() {
+ if (Size() < threshold_)
+ return;
+ mem_.call_realloc(backend_.DoubleTo());
+ allocated_ = backend_.DoubleTo();
+ backend_.Double(mem_.get());
+ threshold_ *= 2;
+ }
+
+ std::size_t allocated_;
+ util::scoped_malloc mem_;
+ Backend backend_;
+ std::size_t threshold_;
+};
+
} // namespace util
-#endif // UTIL_PROBING_HASH_TABLE__
+#endif // UTIL_PROBING_HASH_TABLE_H
diff --git a/util/proxy_iterator.hh b/util/proxy_iterator.hh
index 0ee1716..8aa697b 100644
--- a/util/proxy_iterator.hh
+++ b/util/proxy_iterator.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_PROXY_ITERATOR__
-#define UTIL_PROXY_ITERATOR__
+#ifndef UTIL_PROXY_ITERATOR_H
+#define UTIL_PROXY_ITERATOR_H
#include <cstddef>
#include <iterator>
@@ -38,8 +38,8 @@ template <class Proxy> class ProxyIterator {
typedef std::random_access_iterator_tag iterator_category;
typedef typename Proxy::value_type value_type;
typedef std::ptrdiff_t difference_type;
- typedef Proxy & reference;
- typedef Proxy * pointer;
+ typedef Proxy reference;
+ typedef ProxyIterator<Proxy> * pointer;
ProxyIterator() {}
@@ -47,10 +47,10 @@ template <class Proxy> class ProxyIterator {
template <class AlternateProxy> ProxyIterator(const ProxyIterator<AlternateProxy> &in) : p_(*in) {}
explicit ProxyIterator(const Proxy &p) : p_(p) {}
- // p_'s swap does value swapping, but here we want iterator swapping
+/* // p_'s swap does value swapping, but here we want iterator swapping
friend inline void swap(ProxyIterator<Proxy> &first, ProxyIterator<Proxy> &second) {
swap(first.I(), second.I());
- }
+ }*/
// p_'s operator= does value copying, but here we want iterator copying.
S &operator=(const S &other) {
@@ -77,8 +77,8 @@ template <class Proxy> class ProxyIterator {
std::ptrdiff_t operator-(const S &other) const { return I() - other.I(); }
- Proxy &operator*() { return p_; }
- const Proxy &operator*() const { return p_; }
+ Proxy operator*() { return p_; }
+ const Proxy operator*() const { return p_; }
Proxy *operator->() { return &p_; }
const Proxy *operator->() const { return &p_; }
Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); }
@@ -98,4 +98,4 @@ template <class Proxy> ProxyIterator<Proxy> operator+(std::ptrdiff_t amount, con
} // namespace util
-#endif // UTIL_PROXY_ITERATOR__
+#endif // UTIL_PROXY_ITERATOR_H
diff --git a/util/read_compressed.cc b/util/read_compressed.cc
index b62a6e8..71ef0e2 100644
--- a/util/read_compressed.cc
+++ b/util/read_compressed.cc
@@ -49,6 +49,8 @@ class ReadBase {
thunk.internal_.reset(with);
}
+ ReadBase *Current(ReadCompressed &thunk) { return thunk.internal_.get(); }
+
static uint64_t &ReadCount(ReadCompressed &thunk) {
return thunk.raw_amount_;
}
@@ -56,6 +58,8 @@ class ReadBase {
namespace {
+ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, std::size_t already_size, bool require_compressed);
+
// Completed file that other classes can thunk to.
class Complete : public ReadBase {
public:
@@ -80,7 +84,7 @@ class Uncompressed : public ReadBase {
class UncompressedWithHeader : public ReadBase {
public:
- UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) {
+ UncompressedWithHeader(int fd, const void *already_data, std::size_t already_size) : fd_(fd) {
assert(already_size);
buf_.reset(malloc(already_size));
if (!buf_.get()) throw std::bad_alloc();
@@ -91,6 +95,7 @@ class UncompressedWithHeader : public ReadBase {
std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
assert(buf_.get());
+ assert(remain_ != end_);
std::size_t sending = std::min<std::size_t>(amount, end_ - remain_);
memcpy(to, remain_, sending);
remain_ += sending;
@@ -108,23 +113,51 @@ class UncompressedWithHeader : public ReadBase {
scoped_fd fd_;
};
-#ifdef HAVE_ZLIB
-class GZip : public ReadBase {
+static const std::size_t kInputBuffer = 16384;
+
+template <class Compression> class StreamCompressed : public ReadBase {
+ public:
+ StreamCompressed(int fd, const void *already_data, std::size_t already_size)
+ : file_(fd),
+ in_buffer_(MallocOrThrow(kInputBuffer)),
+ back_(memcpy(in_buffer_.get(), already_data, already_size), already_size) {}
+
+ std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
+ if (amount == 0) return 0;
+ back_.SetOutput(to, amount);
+ do {
+ if (!back_.Stream().avail_in) ReadInput(thunk);
+ if (!back_.Process()) {
+ // reached end, at least for the compressed portion.
+ std::size_t ret = static_cast<const uint8_t *>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to);
+ ReplaceThis(ReadFactory(file_.release(), ReadCount(thunk), back_.Stream().next_in, back_.Stream().avail_in, true), thunk);
+ if (ret) return ret;
+ // We did not read anything this round, so clients might think EOF. Transfer responsibility to the next reader.
+ return Current(thunk)->Read(to, amount, thunk);
+ }
+ } while (back_.Stream().next_out == to);
+ return static_cast<const uint8_t*>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to);
+ }
+
private:
- static const std::size_t kInputBuffer = 16384;
+ void ReadInput(ReadCompressed &thunk) {
+ assert(!back_.Stream().avail_in);
+ std::size_t got = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
+ back_.SetInput(in_buffer_.get(), got);
+ ReadCount(thunk) += got;
+ }
+
+ scoped_fd file_;
+ scoped_malloc in_buffer_;
+
+ Compression back_;
+};
+
+#ifdef HAVE_ZLIB
+class GZip {
public:
- GZip(int fd, void *already_data, std::size_t already_size)
- : file_(fd), in_buffer_(malloc(kInputBuffer)) {
- if (!in_buffer_.get()) throw std::bad_alloc();
- assert(already_size < kInputBuffer);
- if (already_size) {
- memcpy(in_buffer_.get(), already_data, already_size);
- stream_.next_in = static_cast<Bytef *>(in_buffer_.get());
- stream_.avail_in = already_size;
- stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size);
- } else {
- stream_.avail_in = 0;
- }
+ GZip(const void *base, std::size_t amount) {
+ SetInput(base, amount);
stream_.zalloc = Z_NULL;
stream_.zfree = Z_NULL;
stream_.opaque = Z_NULL;
@@ -141,227 +174,154 @@ class GZip : public ReadBase {
}
}
- std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
- if (amount == 0) return 0;
+ void SetOutput(void *to, std::size_t amount) {
stream_.next_out = static_cast<Bytef*>(to);
stream_.avail_out = std::min<std::size_t>(std::numeric_limits<uInt>::max(), amount);
- do {
- if (!stream_.avail_in) ReadInput(thunk);
- int result = inflate(&stream_, 0);
- switch (result) {
- case Z_OK:
- break;
- case Z_STREAM_END:
- {
- std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
- ReplaceThis(new Complete(), thunk);
- return ret;
- }
- case Z_ERRNO:
- UTIL_THROW(ErrnoException, "zlib error");
- default:
- UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result);
- }
- } while (stream_.next_out == to);
- return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
}
- private:
- void ReadInput(ReadCompressed &thunk) {
- assert(!stream_.avail_in);
- stream_.next_in = static_cast<Bytef *>(in_buffer_.get());
- stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
- ReadCount(thunk) += stream_.avail_in;
+ void SetInput(const void *base, std::size_t amount) {
+ assert(amount < static_cast<std::size_t>(std::numeric_limits<uInt>::max()));
+ stream_.next_in = const_cast<Bytef*>(static_cast<const Bytef*>(base));
+ stream_.avail_in = amount;
}
- scoped_fd file_;
- scoped_malloc in_buffer_;
+ const z_stream &Stream() const { return stream_; }
+
+ bool Process() {
+ int result = inflate(&stream_, 0);
+ switch (result) {
+ case Z_OK:
+ return true;
+ case Z_STREAM_END:
+ return false;
+ case Z_ERRNO:
+ UTIL_THROW(ErrnoException, "zlib error");
+ default:
+ UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result);
+ }
+ }
+
+ private:
z_stream stream_;
};
#endif // HAVE_ZLIB
-const uint8_t kBZMagic[3] = {'B', 'Z', 'h'};
-
#ifdef HAVE_BZLIB
-class BZip : public ReadBase {
+class BZip {
public:
- BZip(int fd, void *already_data, std::size_t already_size) {
- scoped_fd hold(fd);
- closer_.reset(FDOpenReadOrThrow(hold));
- file_ = NULL;
- Open(already_data, already_size);
+ BZip(const void *base, std::size_t amount) {
+ memset(&stream_, 0, sizeof(stream_));
+ SetInput(base, amount);
+ HandleError(BZ2_bzDecompressInit(&stream_, 0, 0));
}
- BZip(FILE *file, void *already_data, std::size_t already_size) {
- closer_.reset(file);
- file_ = NULL;
- Open(already_data, already_size);
+ ~BZip() {
+ try {
+ HandleError(BZ2_bzDecompressEnd(&stream_));
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ abort();
+ }
}
- ~BZip() {
- Close(file_);
+ bool Process() {
+ int ret = BZ2_bzDecompress(&stream_);
+ if (ret == BZ_STREAM_END) return false;
+ HandleError(ret);
+ return true;
}
- std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
- assert(file_);
- int bzerror = BZ_OK;
- int ret = BZ2_bzRead(&bzerror, file_, to, std::min<std::size_t>(static_cast<std::size_t>(INT_MAX), amount));
- long pos = ftell(closer_.get());
- if (pos != -1) ReadCount(thunk) = pos;
- switch (bzerror) {
- case BZ_STREAM_END:
- /* bzip2 files can be concatenated by e.g. pbzip2. Annoyingly, the
- * library doesn't handle this internally. This gets the trailing
- * data, grows it up to magic as needed, validates the magic, and
- * reopens.
- */
- {
- bzerror = BZ_OK;
- void *trailing_data;
- int trailing_size;
- BZ2_bzReadGetUnused(&bzerror, file_, &trailing_data, &trailing_size);
- UTIL_THROW_IF(bzerror != BZ_OK, BZException, "bzip2 error in BZ2_bzReadGetUnused " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror);
- std::string trailing(static_cast<const char*>(trailing_data), trailing_size);
- Close(file_);
-
- if (trailing_size < (int)sizeof(kBZMagic)) {
- trailing.resize(sizeof(kBZMagic));
- if (1 != fread(&trailing[trailing_size], sizeof(kBZMagic) - trailing_size, 1, closer_.get())) {
- UTIL_THROW_IF(trailing_size, BZException, "File has trailing cruft");
- // Legitimate end of file.
- ReplaceThis(new Complete(), thunk);
- return ret;
- }
- }
- UTIL_THROW_IF(memcmp(trailing.data(), kBZMagic, sizeof(kBZMagic)), BZException, "Trailing cruft is not another bzip2 stream");
- Open(&trailing[0], trailing.size());
- }
- return ret;
- case BZ_OK:
- return ret;
- default:
- UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror);
- }
+ void SetOutput(void *base, std::size_t amount) {
+ stream_.next_out = static_cast<char*>(base);
+ stream_.avail_out = std::min<std::size_t>(std::numeric_limits<unsigned int>::max(), amount);
+ }
+
+ void SetInput(const void *base, std::size_t amount) {
+ stream_.next_in = const_cast<char*>(static_cast<const char*>(base));
+ stream_.avail_in = amount;
}
+ const bz_stream &Stream() const { return stream_; }
+
private:
- void Open(void *already_data, std::size_t already_size) {
- assert(!file_);
- int bzerror = BZ_OK;
- file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size);
- switch (bzerror) {
+ void HandleError(int value) {
+ switch(value) {
case BZ_OK:
return;
case BZ_CONFIG_ERROR:
- UTIL_THROW(BZException, "Looks like bzip2 was miscompiled.");
+ UTIL_THROW(BZException, "bzip2 seems to be miscompiled.");
case BZ_PARAM_ERROR:
- UTIL_THROW(BZException, "Parameter error");
- case BZ_IO_ERROR:
- UTIL_THROW(BZException, "IO error reading file");
+ UTIL_THROW(BZException, "bzip2 Parameter error");
+ case BZ_DATA_ERROR:
+ UTIL_THROW(BZException, "bzip2 detected a corrupt file");
+ case BZ_DATA_ERROR_MAGIC:
+ UTIL_THROW(BZException, "bzip2 detected bad magic bytes. Perhaps this was not a bzip2 file after all?");
case BZ_MEM_ERROR:
throw std::bad_alloc();
default:
- UTIL_THROW(BZException, "Unknown bzip2 error code " << bzerror);
- }
- assert(file_);
- }
-
- static void Close(BZFILE *&file) {
- if (file == NULL) return;
- int bzerror = BZ_OK;
- BZ2_bzReadClose(&bzerror, file);
- if (bzerror != BZ_OK) {
- std::cerr << "bz2 readclose error number " << bzerror << std::endl;
- abort();
+ UTIL_THROW(BZException, "Unknown bzip2 error code " << value);
}
- file = NULL;
}
- scoped_FILE closer_;
- BZFILE *file_;
+ bz_stream stream_;
};
#endif // HAVE_BZLIB
#ifdef HAVE_XZLIB
-class XZip : public ReadBase {
- private:
- static const std::size_t kInputBuffer = 16384;
+class XZip {
public:
- XZip(int fd, void *already_data, std::size_t already_size)
- : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) {
- if (!in_buffer_.get()) throw std::bad_alloc();
- assert(already_size < kInputBuffer);
- if (already_size) {
- memcpy(in_buffer_.get(), already_data, already_size);
- stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get());
- stream_.avail_in = already_size;
- stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size);
- } else {
- stream_.avail_in = 0;
- }
- stream_.allocator = NULL;
- lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED);
- switch (ret) {
- case LZMA_OK:
- break;
- case LZMA_MEM_ERROR:
- UTIL_THROW(ErrnoException, "xz open error");
- default:
- UTIL_THROW(XZException, "xz error code " << ret);
- }
+ XZip(const void *base, std::size_t amount)
+ : stream_(), action_(LZMA_RUN) {
+ memset(&stream_, 0, sizeof(stream_));
+ SetInput(base, amount);
+ HandleError(lzma_stream_decoder(&stream_, UINT64_MAX, 0));
}
~XZip() {
lzma_end(&stream_);
}
- std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
- if (amount == 0) return 0;
- stream_.next_out = static_cast<uint8_t*>(to);
+ void SetOutput(void *base, std::size_t amount) {
+ stream_.next_out = static_cast<uint8_t*>(base);
stream_.avail_out = amount;
- do {
- if (!stream_.avail_in) ReadInput(thunk);
- lzma_ret status = lzma_code(&stream_, action_);
- switch (status) {
- case LZMA_OK:
- break;
- case LZMA_STREAM_END:
- UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet.");
- {
- std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
- ReplaceThis(new Complete(), thunk);
- return ret;
- }
- case LZMA_MEM_ERROR:
- throw std::bad_alloc();
- case LZMA_FORMAT_ERROR:
- UTIL_THROW(XZException, "xzlib says file format not recognized");
- case LZMA_OPTIONS_ERROR:
- UTIL_THROW(XZException, "xzlib says unsupported compression options");
- case LZMA_DATA_ERROR:
- UTIL_THROW(XZException, "xzlib says this file is corrupt");
- case LZMA_BUF_ERROR:
- UTIL_THROW(XZException, "xzlib says unexpected end of input");
- default:
- UTIL_THROW(XZException, "unrecognized xzlib error " << status);
- }
- } while (stream_.next_out == to);
- return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
+ }
+
+ void SetInput(const void *base, std::size_t amount) {
+ stream_.next_in = static_cast<const uint8_t*>(base);
+ stream_.avail_in = amount;
+ if (!amount) action_ = LZMA_FINISH;
+ }
+
+ const lzma_stream &Stream() const { return stream_; }
+
+ bool Process() {
+ lzma_ret status = lzma_code(&stream_, action_);
+ if (status == LZMA_STREAM_END) return false;
+ HandleError(status);
+ return true;
}
private:
- void ReadInput(ReadCompressed &thunk) {
- assert(!stream_.avail_in);
- stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get());
- stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
- if (!stream_.avail_in) action_ = LZMA_FINISH;
- ReadCount(thunk) += stream_.avail_in;
+ void HandleError(lzma_ret value) {
+ switch (value) {
+ case LZMA_OK:
+ return;
+ case LZMA_MEM_ERROR:
+ throw std::bad_alloc();
+ case LZMA_FORMAT_ERROR:
+ UTIL_THROW(XZException, "xzlib says file format not recognized");
+ case LZMA_OPTIONS_ERROR:
+ UTIL_THROW(XZException, "xzlib says unsupported compression options");
+ case LZMA_DATA_ERROR:
+ UTIL_THROW(XZException, "xzlib says this file is corrupt");
+ case LZMA_BUF_ERROR:
+ UTIL_THROW(XZException, "xzlib says unexpected end of input");
+ default:
+ UTIL_THROW(XZException, "unrecognized xzlib error " << value);
+ }
}
- scoped_fd file_;
- scoped_malloc in_buffer_;
lzma_stream stream_;
-
lzma_action action_;
};
#endif // HAVE_XZLIB
@@ -384,66 +344,68 @@ class IStreamReader : public ReadBase {
};
enum MagicResult {
- UNKNOWN, GZIP, BZIP, XZIP
+ UTIL_UNKNOWN, UTIL_GZIP, UTIL_BZIP, UTIL_XZIP
};
-MagicResult DetectMagic(const void *from_void) {
+MagicResult DetectMagic(const void *from_void, std::size_t length) {
const uint8_t *header = static_cast<const uint8_t*>(from_void);
- if (header[0] == 0x1f && header[1] == 0x8b) {
- return GZIP;
+ if (length >= 2 && header[0] == 0x1f && header[1] == 0x8b) {
+ return UTIL_GZIP;
}
- if (!memcmp(header, kBZMagic, sizeof(kBZMagic))) {
- return BZIP;
+ const uint8_t kBZMagic[3] = {'B', 'Z', 'h'};
+ if (length >= sizeof(kBZMagic) && !memcmp(header, kBZMagic, sizeof(kBZMagic))) {
+ return UTIL_BZIP;
}
const uint8_t kXZMagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 };
- if (!memcmp(header, kXZMagic, sizeof(kXZMagic))) {
- return XZIP;
+ if (length >= sizeof(kXZMagic) && !memcmp(header, kXZMagic, sizeof(kXZMagic))) {
+ return UTIL_XZIP;
}
- return UNKNOWN;
+ return UTIL_UNKNOWN;
}
-ReadBase *ReadFactory(int fd, uint64_t &raw_amount) {
+ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, const std::size_t already_size, bool require_compressed) {
scoped_fd hold(fd);
- unsigned char header[ReadCompressed::kMagicSize];
- raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize);
- if (!raw_amount)
- return new Uncompressed(hold.release());
- if (raw_amount != ReadCompressed::kMagicSize)
- return new UncompressedWithHeader(hold.release(), header, raw_amount);
- switch (DetectMagic(header)) {
- case GZIP:
+ std::string header(reinterpret_cast<const char*>(already_data), already_size);
+ if (header.size() < ReadCompressed::kMagicSize) {
+ std::size_t original = header.size();
+ header.resize(ReadCompressed::kMagicSize);
+ std::size_t got = ReadOrEOF(fd, &header[original], ReadCompressed::kMagicSize - original);
+ raw_amount += got;
+ header.resize(original + got);
+ }
+ if (header.empty()) {
+ hold.release();
+ return new Complete();
+ }
+ switch (DetectMagic(&header[0], header.size())) {
+ case UTIL_GZIP:
#ifdef HAVE_ZLIB
- return new GZip(hold.release(), header, ReadCompressed::kMagicSize);
+ return new StreamCompressed<GZip>(hold.release(), header.data(), header.size());
#else
UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in.");
#endif
- case BZIP:
+ case UTIL_BZIP:
#ifdef HAVE_BZLIB
- return new BZip(hold.release(), header, ReadCompressed::kMagicSize);
+ return new StreamCompressed<BZip>(hold.release(), &header[0], header.size());
#else
- UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in.");
+ UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZh), but bzip support was not compiled in.");
#endif
- case XZIP:
+ case UTIL_XZIP:
#ifdef HAVE_XZLIB
- return new XZip(hold.release(), header, ReadCompressed::kMagicSize);
+ return new StreamCompressed<XZip>(hold.release(), header.data(), header.size());
#else
UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in.");
#endif
- case UNKNOWN:
- break;
- }
- try {
- SeekOrThrow(fd, 0);
- } catch (const util::ErrnoException &e) {
- return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize);
+ default:
+ UTIL_THROW_IF(require_compressed, CompressedException, "Uncompressed data detected after a compresssed file. This could be supported but usually indicates an error.");
+ return new UncompressedWithHeader(hold.release(), header.data(), header.size());
}
- return new Uncompressed(hold.release());
}
} // namespace
bool ReadCompressed::DetectCompressedMagic(const void *from_void) {
- return DetectMagic(from_void) != UNKNOWN;
+ return DetectMagic(from_void, kMagicSize) != UTIL_UNKNOWN;
}
ReadCompressed::ReadCompressed(int fd) {
@@ -459,8 +421,9 @@ ReadCompressed::ReadCompressed() {}
ReadCompressed::~ReadCompressed() {}
void ReadCompressed::Reset(int fd) {
+ raw_amount_ = 0;
internal_.reset();
- internal_.reset(ReadFactory(fd, raw_amount_));
+ internal_.reset(ReadFactory(fd, raw_amount_, NULL, 0, false));
}
void ReadCompressed::Reset(std::istream &in) {
diff --git a/util/read_compressed.hh b/util/read_compressed.hh
index 8b54c9e..763e6bb 100644
--- a/util/read_compressed.hh
+++ b/util/read_compressed.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_READ_COMPRESSED__
-#define UTIL_READ_COMPRESSED__
+#ifndef UTIL_READ_COMPRESSED_H
+#define UTIL_READ_COMPRESSED_H
#include "util/exception.hh"
#include "util/scoped.hh"
@@ -78,4 +78,4 @@ class ReadCompressed {
} // namespace util
-#endif // UTIL_READ_COMPRESSED__
+#endif // UTIL_READ_COMPRESSED_H
diff --git a/util/read_compressed_test.cc b/util/read_compressed_test.cc
index 9cb4a4b..301e8f4 100644
--- a/util/read_compressed_test.cc
+++ b/util/read_compressed_test.cc
@@ -12,6 +12,23 @@
#include <stdlib.h>
+#if defined __MINGW32__
+#include <time.h>
+#include <fcntl.h>
+
+#if !defined mkstemp
+// TODO insecure
+int mkstemp(char * stemplate)
+{
+ char *filename = mktemp(stemplate);
+ if (filename == NULL)
+ return -1;
+ return open(filename, O_RDWR | O_CREAT, 0600);
+}
+#endif
+
+#endif // defined
+
namespace util {
namespace {
@@ -96,6 +113,11 @@ BOOST_AUTO_TEST_CASE(ReadXZ) {
}
#endif
+#ifdef HAVE_ZLIB
+BOOST_AUTO_TEST_CASE(AppendGZ) {
+}
+#endif
+
BOOST_AUTO_TEST_CASE(IStream) {
std::string name(WriteRandom());
std::fstream stream(name.c_str(), std::ios::in);
diff --git a/util/scoped.hh b/util/scoped.hh
index b642d06..ae70b6b 100644
--- a/util/scoped.hh
+++ b/util/scoped.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_SCOPED__
-#define UTIL_SCOPED__
+#ifndef UTIL_SCOPED_H
+#define UTIL_SCOPED_H
/* Other scoped objects in the style of scoped_ptr. */
#include "util/exception.hh"
@@ -101,4 +101,4 @@ template <class T> class scoped_ptr {
} // namespace util
-#endif // UTIL_SCOPED__
+#endif // UTIL_SCOPED_H
diff --git a/util/sized_iterator.hh b/util/sized_iterator.hh
index dce8f22..75f6886 100644
--- a/util/sized_iterator.hh
+++ b/util/sized_iterator.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_SIZED_ITERATOR__
-#define UTIL_SIZED_ITERATOR__
+#ifndef UTIL_SIZED_ITERATOR_H
+#define UTIL_SIZED_ITERATOR_H
#include "util/proxy_iterator.hh"
@@ -36,7 +36,7 @@ class SizedInnerIterator {
void *Data() { return ptr_; }
std::size_t EntrySize() const { return size_; }
- friend inline void swap(SizedInnerIterator &first, SizedInnerIterator &second) {
+ friend void swap(SizedInnerIterator &first, SizedInnerIterator &second) {
std::swap(first.ptr_, second.ptr_);
std::swap(first.size_, second.size_);
}
@@ -69,17 +69,7 @@ class SizedProxy {
const void *Data() const { return inner_.Data(); }
void *Data() { return inner_.Data(); }
- /**
- // TODO: this (deep) swap was recently added. why? if any std heap sort etc
- // algs are using swap, that's going to be worse performance than using
- // =. i'm not sure why we *want* a deep swap. if C++11 compilers are
- // choosing between move constructor and swap, then we'd better implement a
- // (deep) move constructor. it may also be that this is moot since i made
- // ProxyIterator a reference and added a shallow ProxyIterator swap? (I
- // need Ken or someone competent to judge whether that's correct also. -
- // let me know at graehl@gmail.com
- */
- friend void swap(SizedProxy &first, SizedProxy &second) {
+ friend void swap(SizedProxy first, SizedProxy second) {
std::swap_ranges(
static_cast<char*>(first.inner_.Data()),
static_cast<char*>(first.inner_.Data()) + first.inner_.EntrySize(),
@@ -127,4 +117,4 @@ template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public
};
} // namespace util
-#endif // UTIL_SIZED_ITERATOR__
+#endif // UTIL_SIZED_ITERATOR_H
diff --git a/util/sized_iterator_test.cc b/util/sized_iterator_test.cc
new file mode 100644
index 0000000..c36bcb2
--- /dev/null
+++ b/util/sized_iterator_test.cc
@@ -0,0 +1,16 @@
+#include "util/sized_iterator.hh"
+
+#define BOOST_TEST_MODULE SizedIteratorTest
+#include <boost/test/unit_test.hpp>
+
+namespace util { namespace {
+
+BOOST_AUTO_TEST_CASE(swap_works) {
+ char str[2] = { 0, 1 };
+ SizedProxy first(str, 1), second(str + 1, 1);
+ swap(first, second);
+ BOOST_CHECK_EQUAL(1, str[0]);
+ BOOST_CHECK_EQUAL(0, str[1]);
+}
+
+}} // namespace anonymous util
diff --git a/util/sorted_uniform.hh b/util/sorted_uniform.hh
index 7700d9e..a7dba5e 100644
--- a/util/sorted_uniform.hh
+++ b/util/sorted_uniform.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_SORTED_UNIFORM__
-#define UTIL_SORTED_UNIFORM__
+#ifndef UTIL_SORTED_UNIFORM_H
+#define UTIL_SORTED_UNIFORM_H
#include <algorithm>
#include <cstddef>
@@ -124,4 +124,4 @@ template <class Iterator, class Accessor> Iterator BinaryBelow(
} // namespace util
-#endif // UTIL_SORTED_UNIFORM__
+#endif // UTIL_SORTED_UNIFORM_H
diff --git a/util/stream/block.hh b/util/stream/block.hh
index 11aa991..b9f1d49 100644
--- a/util/stream/block.hh
+++ b/util/stream/block.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_BLOCK__
-#define UTIL_STREAM_BLOCK__
+#ifndef UTIL_STREAM_BLOCK_H
+#define UTIL_STREAM_BLOCK_H
#include <cstddef>
#include <stdint.h>
@@ -40,4 +40,4 @@ class Block {
} // namespace stream
} // namespace util
-#endif // UTIL_STREAM_BLOCK__
+#endif // UTIL_STREAM_BLOCK_H
diff --git a/util/stream/chain.hh b/util/stream/chain.hh
index 0cc83a8..bda653a 100644
--- a/util/stream/chain.hh
+++ b/util/stream/chain.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_CHAIN__
-#define UTIL_STREAM_CHAIN__
+#ifndef UTIL_STREAM_CHAIN_H
+#define UTIL_STREAM_CHAIN_H
#include "util/stream/block.hh"
#include "util/stream/config.hh"
@@ -195,4 +195,4 @@ inline Chain &operator>>(Chain &chain, Link &link) {
} // namespace stream
} // namespace util
-#endif // UTIL_STREAM_CHAIN__
+#endif // UTIL_STREAM_CHAIN_H
diff --git a/util/stream/config.hh b/util/stream/config.hh
index 1eeb3a8..052a05b 100644
--- a/util/stream/config.hh
+++ b/util/stream/config.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_CONFIG__
-#define UTIL_STREAM_CONFIG__
+#ifndef UTIL_STREAM_CONFIG_H
+#define UTIL_STREAM_CONFIG_H
#include <cstddef>
#include <string>
@@ -29,4 +29,4 @@ struct SortConfig {
};
}} // namespaces
-#endif // UTIL_STREAM_CONFIG__
+#endif // UTIL_STREAM_CONFIG_H
diff --git a/util/stream/io.hh b/util/stream/io.hh
index 934b6b3..65918a9 100644
--- a/util/stream/io.hh
+++ b/util/stream/io.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_IO__
-#define UTIL_STREAM_IO__
+#ifndef UTIL_STREAM_IO_H
+#define UTIL_STREAM_IO_H
#include "util/exception.hh"
#include "util/file.hh"
@@ -73,4 +73,4 @@ class FileBuffer {
} // namespace stream
} // namespace util
-#endif // UTIL_STREAM_IO__
+#endif // UTIL_STREAM_IO_H
diff --git a/util/stream/line_input.hh b/util/stream/line_input.hh
index 86db1dd..a870a66 100644
--- a/util/stream/line_input.hh
+++ b/util/stream/line_input.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_LINE_INPUT__
-#define UTIL_STREAM_LINE_INPUT__
+#ifndef UTIL_STREAM_LINE_INPUT_H
+#define UTIL_STREAM_LINE_INPUT_H
namespace util {namespace stream {
class ChainPosition;
@@ -19,4 +19,4 @@ class LineInput {
};
}} // namespaces
-#endif // UTIL_STREAM_LINE_INPUT__
+#endif // UTIL_STREAM_LINE_INPUT_H
diff --git a/util/stream/multi_progress.hh b/util/stream/multi_progress.hh
index c4dd45a..82e698a 100644
--- a/util/stream/multi_progress.hh
+++ b/util/stream/multi_progress.hh
@@ -1,6 +1,6 @@
/* Progress bar suitable for chains of workers */
-#ifndef UTIL_MULTI_PROGRESS__
-#define UTIL_MULTI_PROGRESS__
+#ifndef UTIL_STREAM_MULTI_PROGRESS_H
+#define UTIL_STREAM_MULTI_PROGRESS_H
#include <boost/thread/mutex.hpp>
@@ -87,4 +87,4 @@ class WorkerProgress {
}} // namespaces
-#endif // UTIL_MULTI_PROGRESS__
+#endif // UTIL_STREAM_MULTI_PROGRESS_H
diff --git a/util/stream/sort.hh b/util/stream/sort.hh
index 16aa6a0..e18c6ae 100644
--- a/util/stream/sort.hh
+++ b/util/stream/sort.hh
@@ -15,8 +15,8 @@
* sort. Use a hash table for that.
*/
-#ifndef UTIL_STREAM_SORT__
-#define UTIL_STREAM_SORT__
+#ifndef UTIL_STREAM_SORT_H
+#define UTIL_STREAM_SORT_H
#include "util/stream/chain.hh"
#include "util/stream/config.hh"
@@ -545,4 +545,4 @@ template <class Compare, class Combine> uint64_t BlockingSort(Chain &chain, cons
} // namespace stream
} // namespace util
-#endif // UTIL_STREAM_SORT__
+#endif // UTIL_STREAM_SORT_H
diff --git a/util/stream/stream.hh b/util/stream/stream.hh
index 6ff45b8..5cd1bdd 100644
--- a/util/stream/stream.hh
+++ b/util/stream/stream.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_STREAM__
-#define UTIL_STREAM_STREAM__
+#ifndef UTIL_STREAM_STREAM_H
+#define UTIL_STREAM_STREAM_H
#include "util/stream/chain.hh"
@@ -71,4 +71,4 @@ inline Chain &operator>>(Chain &chain, Stream &stream) {
} // namespace stream
} // namespace util
-#endif // UTIL_STREAM_STREAM__
+#endif // UTIL_STREAM_STREAM_H
diff --git a/util/stream/timer.hh b/util/stream/timer.hh
index 7e1a588..06488a1 100644
--- a/util/stream/timer.hh
+++ b/util/stream/timer.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STREAM_TIMER__
-#define UTIL_STREAM_TIMER__
+#ifndef UTIL_STREAM_TIMER_H
+#define UTIL_STREAM_TIMER_H
// Sorry Jon, this was adding library dependencies in Moses and people complained.
@@ -13,4 +13,4 @@
#define UTIL_TIMER(str)
//#endif
-#endif // UTIL_STREAM_TIMER__
+#endif // UTIL_STREAM_TIMER_H
diff --git a/util/string_piece.hh b/util/string_piece.hh
index 84431db..114e254 100644
--- a/util/string_piece.hh
+++ b/util/string_piece.hh
@@ -45,8 +45,8 @@
// conversions from "const char*" to "string" and back again.
//
-#ifndef BASE_STRING_PIECE_H__
-#define BASE_STRING_PIECE_H__
+#ifndef UTIL_STRING_PIECE_H
+#define UTIL_STRING_PIECE_H
#include "util/have.hh"
@@ -267,4 +267,4 @@ U_NAMESPACE_END
using U_NAMESPACE_QUALIFIER StringPiece;
#endif
-#endif // BASE_STRING_PIECE_H__
+#endif // UTIL_STRING_PIECE_H
diff --git a/util/string_piece_hash.hh b/util/string_piece_hash.hh
index f206b1d..5c8c525 100644
--- a/util/string_piece_hash.hh
+++ b/util/string_piece_hash.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_STRING_PIECE_HASH__
-#define UTIL_STRING_PIECE_HASH__
+#ifndef UTIL_STRING_PIECE_HASH_H
+#define UTIL_STRING_PIECE_HASH_H
#include "util/string_piece.hh"
@@ -40,4 +40,4 @@ template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece
#endif
}
-#endif // UTIL_STRING_PIECE_HASH__
+#endif // UTIL_STRING_PIECE_HASH_H
diff --git a/util/thread_pool.hh b/util/thread_pool.hh
index 84e257e..d1a883a 100644
--- a/util/thread_pool.hh
+++ b/util/thread_pool.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_THREAD_POOL__
-#define UTIL_THREAD_POOL__
+#ifndef UTIL_THREAD_POOL_H
+#define UTIL_THREAD_POOL_H
#include "util/pcqueue.hh"
@@ -18,8 +18,8 @@ template <class HandlerT> class Worker : boost::noncopyable {
typedef HandlerT Handler;
typedef typename Handler::Request Request;
- template <class Construct> Worker(PCQueue<Request> &in, Construct &construct, Request &poison)
- : in_(in), handler_(construct), thread_(boost::ref(*this)), poison_(poison) {}
+ template <class Construct> Worker(PCQueue<Request> &in, Construct &construct, const Request &poison)
+ : in_(in), handler_(construct), poison_(poison), thread_(boost::ref(*this)) {}
// Only call from thread.
void operator()() {
@@ -30,7 +30,7 @@ template <class HandlerT> class Worker : boost::noncopyable {
try {
(*handler_)(request);
}
- catch(std::exception &e) {
+ catch(const std::exception &e) {
std::cerr << "Handler threw " << e.what() << std::endl;
abort();
}
@@ -49,10 +49,10 @@ template <class HandlerT> class Worker : boost::noncopyable {
PCQueue<Request> &in_;
boost::optional<Handler> handler_;
+
+ const Request poison_;
boost::thread thread_;
-
- Request poison_;
};
template <class HandlerT> class ThreadPool : boost::noncopyable {
@@ -92,4 +92,4 @@ template <class HandlerT> class ThreadPool : boost::noncopyable {
} // namespace util
-#endif // UTIL_THREAD_POOL__
+#endif // UTIL_THREAD_POOL_H
diff --git a/util/tokenize_piece.hh b/util/tokenize_piece.hh
index a588c3f..908c8da 100644
--- a/util/tokenize_piece.hh
+++ b/util/tokenize_piece.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_TOKENIZE_PIECE__
-#define UTIL_TOKENIZE_PIECE__
+#ifndef UTIL_TOKENIZE_PIECE_H
+#define UTIL_TOKENIZE_PIECE_H
#include "util/exception.hh"
#include "util/string_piece.hh"
@@ -7,7 +7,8 @@
#include <boost/iterator/iterator_facade.hpp>
#include <algorithm>
-#include <iostream>
+
+#include <string.h>
namespace util {
@@ -58,6 +59,30 @@ class AnyCharacter {
StringPiece chars_;
};
+class BoolCharacter {
+ public:
+ BoolCharacter() {}
+
+ explicit BoolCharacter(const bool *delimiter) { delimiter_ = delimiter; }
+
+ StringPiece Find(const StringPiece &in) const {
+ for (const char *i = in.data(); i != in.data() + in.size(); ++i) {
+ if (delimiter_[static_cast<unsigned char>(*i)]) return StringPiece(i, 1);
+ }
+ return StringPiece(in.data() + in.size(), 0);
+ }
+
+ template <unsigned Length> static void Build(const char (&characters)[Length], bool (&out)[256]) {
+ memset(out, 0, sizeof(out));
+ for (const char *i = characters; i != characters + Length; ++i) {
+ out[static_cast<unsigned char>(*i)] = true;
+ }
+ }
+
+ private:
+ const bool *delimiter_;
+};
+
class AnyCharacterLast {
public:
AnyCharacterLast() {}
@@ -123,4 +148,4 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it
} // namespace util
-#endif // UTIL_TOKENIZE_PIECE__
+#endif // UTIL_TOKENIZE_PIECE_H
diff --git a/util/usage.cc b/util/usage.cc
index 25fc097..5912f90 100644
--- a/util/usage.cc
+++ b/util/usage.cc
@@ -10,61 +10,119 @@
#include <string.h>
#include <ctype.h>
-#if !defined(_WIN32) && !defined(_WIN64)
+#include <time.h>
+#if defined(_WIN32) || defined(_WIN64)
+// This code lifted from physmem.c in gnulib. See the copyright statement
+// below.
+# define WIN32_LEAN_AND_MEAN
+# include <windows.h>
+/* MEMORYSTATUSEX is missing from older windows headers, so define
+ a local replacement. */
+typedef struct
+{
+ DWORD dwLength;
+ DWORD dwMemoryLoad;
+ DWORDLONG ullTotalPhys;
+ DWORDLONG ullAvailPhys;
+ DWORDLONG ullTotalPageFile;
+ DWORDLONG ullAvailPageFile;
+ DWORDLONG ullTotalVirtual;
+ DWORDLONG ullAvailVirtual;
+ DWORDLONG ullAvailExtendedVirtual;
+} lMEMORYSTATUSEX;
+// Is this really supposed to be defined like this?
+typedef int WINBOOL;
+typedef WINBOOL (WINAPI *PFN_MS_EX) (lMEMORYSTATUSEX*);
+#else
#include <sys/resource.h>
#include <sys/time.h>
-#include <time.h>
#include <unistd.h>
#endif
-namespace util {
+#if defined(__MACH__) || defined(__FreeBSD__) || defined(__APPLE__)
+#include <sys/types.h>
+#include <sys/sysctl.h>
+#endif
-#if !defined(_WIN32) && !defined(_WIN64)
+namespace util {
namespace {
-// On Mac OS X, clock_gettime is not implemented.
-// CLOCK_MONOTONIC is not defined either.
-#ifdef __MACH__
-#define CLOCK_MONOTONIC 0
-
-int clock_gettime(int clk_id, struct timespec *tp) {
+#if defined(__MACH__)
+typedef struct timeval Wall;
+Wall GetWall() {
struct timeval tv;
gettimeofday(&tv, NULL);
- tp->tv_sec = tv.tv_sec;
- tp->tv_nsec = tv.tv_usec * 1000;
- return 0;
+ return tv;
}
-#endif // __MACH__
-
-float FloatSec(const struct timeval &tv) {
- return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000.0);
+#elif defined(_WIN32) || defined(_WIN64)
+typedef time_t Wall;
+Wall GetWall() {
+ return time(NULL);
}
-float FloatSec(const struct timespec &tv) {
- return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_nsec) / 1000000000.0);
+#else
+typedef struct timespec Wall;
+Wall GetWall() {
+ Wall ret;
+ clock_gettime(CLOCK_MONOTONIC, &ret);
+ return ret;
}
+#endif
-const char *SkipSpaces(const char *at) {
- for (; *at == ' ' || *at == '\t'; ++at) {}
- return at;
+// Some of these functions are only used on some platforms.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunused-function"
+#endif
+// These all assume first > second
+double Subtract(time_t first, time_t second) {
+ return difftime(first, second);
+}
+double DoubleSec(time_t tv) {
+ return static_cast<double>(tv);
+}
+#if !defined(_WIN32) && !defined(_WIN64)
+double Subtract(const struct timeval &first, const struct timeval &second) {
+ return static_cast<double>(first.tv_sec - second.tv_sec) + static_cast<double>(first.tv_usec - second.tv_usec) / 1000000.0;
}
+double Subtract(const struct timespec &first, const struct timespec &second) {
+ return static_cast<double>(first.tv_sec - second.tv_sec) + static_cast<double>(first.tv_nsec - second.tv_nsec) / 1000000000.0;
+}
+double DoubleSec(const struct timeval &tv) {
+ return static_cast<double>(tv.tv_sec) + (static_cast<double>(tv.tv_usec) / 1000000.0);
+}
+double DoubleSec(const struct timespec &tv) {
+ return static_cast<double>(tv.tv_sec) + (static_cast<double>(tv.tv_nsec) / 1000000000.0);
+}
+#endif
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
class RecordStart {
public:
RecordStart() {
- clock_gettime(CLOCK_MONOTONIC, &started_);
+ started_ = GetWall();
}
- const struct timespec &Started() const {
+ const Wall &Started() const {
return started_;
}
private:
- struct timespec started_;
+ Wall started_;
};
const RecordStart kRecordStart;
+
+const char *SkipSpaces(const char *at) {
+ for (; *at == ' ' || *at == '\t'; ++at) {}
+ return at;
+}
} // namespace
-#endif
+
+double WallTime() {
+ return Subtract(GetWall(), kRecordStart.Started());
+}
void PrintUsage(std::ostream &out) {
#if !defined(_WIN32) && !defined(_WIN64)
@@ -88,27 +146,84 @@ void PrintUsage(std::ostream &out) {
return;
}
out << "RSSMax:" << usage.ru_maxrss << " kB" << '\t';
- out << "user:" << FloatSec(usage.ru_utime) << "\tsys:" << FloatSec(usage.ru_stime) << '\t';
- out << "CPU:" << (FloatSec(usage.ru_utime) + FloatSec(usage.ru_stime));
-
- struct timespec current;
- clock_gettime(CLOCK_MONOTONIC, &current);
- out << "\treal:" << (FloatSec(current) - FloatSec(kRecordStart.Started())) << '\n';
+ out << "user:" << DoubleSec(usage.ru_utime) << "\tsys:" << DoubleSec(usage.ru_stime) << '\t';
+ out << "CPU:" << (DoubleSec(usage.ru_utime) + DoubleSec(usage.ru_stime));
+ out << '\t';
#endif
+
+ out << "real:" << WallTime() << '\n';
}
+/* Adapted from physmem.c in gnulib 831b84c59ef413c57a36b67344467d66a8a2ba70 */
+/* Calculate the size of physical memory.
+
+ Copyright (C) 2000-2001, 2003, 2005-2006, 2009-2013 Free Software
+ Foundation, Inc.
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Lesser General Public License as published by
+ the Free Software Foundation; either version 2.1 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>. */
+
+/* Written by Paul Eggert. */
uint64_t GuessPhysicalMemory() {
+#if defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE)
+ {
+ long pages = sysconf(_SC_PHYS_PAGES);
+ long page_size = sysconf(_SC_PAGESIZE);
+ if (pages != -1 && page_size != -1)
+ return static_cast<uint64_t>(pages) * static_cast<uint64_t>(page_size);
+ }
+#endif
+#ifdef HW_PHYSMEM
+ { /* This works on *bsd and darwin. */
+ unsigned int physmem;
+ size_t len = sizeof physmem;
+ static int mib[2] = { CTL_HW, HW_PHYSMEM };
+
+ if (sysctl (mib, sizeof(mib) / sizeof(mib[0]), &physmem, &len, NULL, 0) == 0
+ && len == sizeof (physmem))
+ return static_cast<uint64_t>(physmem);
+ }
+#endif
+
#if defined(_WIN32) || defined(_WIN64)
- return 0;
-#elif defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE)
- long pages = sysconf(_SC_PHYS_PAGES);
- if (pages == -1) return 0;
- long page_size = sysconf(_SC_PAGESIZE);
- if (page_size == -1) return 0;
- return static_cast<uint64_t>(pages) * static_cast<uint64_t>(page_size);
-#else
- return 0;
+ { /* this works on windows */
+ PFN_MS_EX pfnex;
+ HMODULE h = GetModuleHandle ("kernel32.dll");
+
+ if (!h)
+ return 0;
+
+ /* Use GlobalMemoryStatusEx if available. */
+ if ((pfnex = (PFN_MS_EX) GetProcAddress (h, "GlobalMemoryStatusEx")))
+ {
+ lMEMORYSTATUSEX lms_ex;
+ lms_ex.dwLength = sizeof lms_ex;
+ if (!pfnex (&lms_ex))
+ return 0;
+ return lms_ex.ullTotalPhys;
+ }
+
+ /* Fall back to GlobalMemoryStatus which is always available.
+ but returns wrong results for physical memory > 4GB. */
+ else
+ {
+ MEMORYSTATUS ms;
+ GlobalMemoryStatus (&ms);
+ return ms.dwTotalPhys;
+ }
+ }
#endif
+ return 0;
}
namespace {
diff --git a/util/usage.hh b/util/usage.hh
index e19eda7..e578b0a 100644
--- a/util/usage.hh
+++ b/util/usage.hh
@@ -1,5 +1,5 @@
-#ifndef UTIL_USAGE__
-#define UTIL_USAGE__
+#ifndef UTIL_USAGE_H
+#define UTIL_USAGE_H
#include <cstddef>
#include <iosfwd>
#include <string>
@@ -7,6 +7,9 @@
#include <stdint.h>
namespace util {
+// Time in seconds since process started. Zero on unsupported platforms.
+double WallTime();
+
void PrintUsage(std::ostream &to);
// Determine how much physical memory there is. Return 0 on failure.
@@ -15,4 +18,4 @@ uint64_t GuessPhysicalMemory();
// Parse a size like unix sort. Sadly, this means the default multiplier is K.
uint64_t ParseSize(const std::string &arg);
} // namespace util
-#endif // UTIL_USAGE__
+#endif // UTIL_USAGE_H