Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2013-01-24 16:07:46 +0400
committerKenneth Heafield <github@kheafield.com>2013-01-24 16:07:46 +0400
commit03b077364a39b367a125092418278e9f4240c35f (patch)
tree4b410417dfec95f5bb1bf3f8ad1c3aad870bb979 /lm
parent22bf1c77e9866f5010708b288a33957a24627481 (diff)
KenLM 31a6644 resizable probing hash table, build fixes
Diffstat (limited to 'lm')
-rw-r--r--lm/builder/corpus_count.cc77
-rw-r--r--lm/builder/corpus_count.hh5
-rw-r--r--lm/builder/corpus_count_test.cc2
-rw-r--r--lm/builder/initial_probabilities.cc4
-rw-r--r--lm/builder/lmplz_main.cc2
-rw-r--r--lm/builder/pipeline.cc7
-rw-r--r--lm/builder/pipeline.hh7
7 files changed, 73 insertions, 31 deletions
diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc
index abea4ed06..3714dddad 100644
--- a/lm/builder/corpus_count.cc
+++ b/lm/builder/corpus_count.cc
@@ -23,9 +23,32 @@ namespace lm {
namespace builder {
namespace {
+#pragma pack(push)
+#pragma pack(4)
+struct VocabEntry {
+ typedef uint64_t Key;
+
+ uint64_t GetKey() const { return key; }
+ void SetKey(uint64_t to) { key = to; }
+
+ uint64_t key;
+ lm::WordIndex value;
+};
+#pragma pack(pop)
+
+const float kProbingMultiplier = 1.5;
+
class VocabHandout {
public:
- explicit VocabHandout(int fd) {
+ static std::size_t MemUsage(WordIndex initial_guess) {
+ if (initial_guess < 2) initial_guess = 2;
+ return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier));
+ }
+
+ explicit VocabHandout(int fd, WordIndex initial_guess) :
+ table_backing_(util::CallocOrThrow(MemUsage(initial_guess))),
+ table_(table_backing_.get(), MemUsage(initial_guess)),
+ double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)) {
util::scoped_fd duped(util::DupOrThrow(fd));
word_list_.reset(util::FDOpenOrThrow(duped));
@@ -35,25 +58,38 @@ class VocabHandout {
}
WordIndex Lookup(const StringPiece &word) {
- uint64_t hashed = util::MurmurHashNative(word.data(), word.size());
- std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size())));
- if (ret.second) {
- char null_delimit = 0;
- util::WriteOrThrow(word_list_.get(), word.data(), word.size());
- util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
- UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ VocabEntry entry;
+ entry.key = util::MurmurHashNative(word.data(), word.size());
+ entry.value = table_.SizeNoSerialization();
+
+ Table::MutableIterator it;
+ if (table_.FindOrInsert(entry, it))
+ return it->value;
+ char null_delimit = 0;
+ util::WriteOrThrow(word_list_.get(), word.data(), word.size());
+ util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
+ UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ if (Size() >= double_cutoff_) {
+ table_backing_.call_realloc(table_.DoubleTo());
+ table_.Double(table_backing_.get());
+ double_cutoff_ *= 2;
}
- return ret.first->second;
+ return entry.value;
}
WordIndex Size() const {
- return seen_.size();
+ return table_.SizeNoSerialization();
}
private:
- typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen;
+ // TODO: factor out a resizable probing hash table.
+ // TODO: use mremap on linux to get all zeros on resizes.
+ util::scoped_malloc table_backing_;
+
+ typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table;
+ Table table_;
- Seen seen_;
+ std::size_t double_cutoff_;
util::scoped_FILE word_list_;
};
@@ -85,6 +121,7 @@ class DedupeEquals : public std::binary_function<const WordIndex *, const WordIn
struct DedupeEntry {
typedef WordIndex *Key;
Key GetKey() const { return key; }
+ void SetKey(WordIndex *to) { key = to; }
Key key;
static DedupeEntry Construct(WordIndex *at) {
DedupeEntry ret;
@@ -95,8 +132,6 @@ struct DedupeEntry {
typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
-const float kProbingMultiplier = 1.5;
-
class Writer {
public:
Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
@@ -105,7 +140,7 @@ class Writer {
dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
buffer_(new WordIndex[order - 1]),
block_size_(position.GetChain().BlockSize()) {
- dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ dedupe_.Clear();
assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
if (order == 1) {
// Add special words. AdjustCounts is responsible if order != 1.
@@ -149,7 +184,7 @@ class Writer {
}
// Block end. Need to store the context in a temporary buffer.
std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
- dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ dedupe_.Clear();
block_->SetValidSize(block_size_);
gram_.ReBase((++block_)->Get());
std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
@@ -187,18 +222,22 @@ float CorpusCount::DedupeMultiplier(std::size_t order) {
return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
}
+std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
+ return VocabHandout::MemUsage(vocab_estimate);
+}
+
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
- token_count_ = 0;
- type_count_ = 0;
}
void CorpusCount::Run(const util::stream::ChainPosition &position) {
UTIL_TIMER("(%w s) Counted n-grams\n");
- VocabHandout vocab(vocab_write_);
+ VocabHandout vocab(vocab_write_, type_count_);
+ token_count_ = 0;
+ type_count_ = 0;
const WordIndex end_sentence = vocab.Lookup("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
diff --git a/lm/builder/corpus_count.hh b/lm/builder/corpus_count.hh
index e255bad13..aa0ed8ede 100644
--- a/lm/builder/corpus_count.hh
+++ b/lm/builder/corpus_count.hh
@@ -23,6 +23,11 @@ class CorpusCount {
// Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size
static float DedupeMultiplier(std::size_t order);
+ // How much memory vocabulary will use based on estimated size of the vocab.
+ static std::size_t VocabUsage(std::size_t vocab_estimate);
+
+ // token_count: out.
+ // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
void Run(const util::stream::ChainPosition &position);
diff --git a/lm/builder/corpus_count_test.cc b/lm/builder/corpus_count_test.cc
index 8d53ca9d1..6d325ef52 100644
--- a/lm/builder/corpus_count_test.cc
+++ b/lm/builder/corpus_count_test.cc
@@ -44,7 +44,7 @@ BOOST_AUTO_TEST_CASE(Short) {
util::stream::Chain chain(config);
NGramStream stream;
uint64_t token_count;
- WordIndex type_count;
+ WordIndex type_count = 10;
CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());
chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
diff --git a/lm/builder/initial_probabilities.cc b/lm/builder/initial_probabilities.cc
index 1e905c3e0..58b42a20c 100644
--- a/lm/builder/initial_probabilities.cc
+++ b/lm/builder/initial_probabilities.cc
@@ -24,7 +24,6 @@ struct BufferEntry {
class OnlyGamma {
public:
void Run(const util::stream::ChainPosition &position) {
- uint64_t count = 0;
for (util::stream::Link block_it(position); block_it; ++block_it) {
float *out = static_cast<float*>(block_it->Get());
const float *in = out;
@@ -33,10 +32,7 @@ class OnlyGamma {
*out = *in;
}
block_it->SetValidSize(block_it->ValidSize() / 2);
- count += block_it->ValidSize() / sizeof(float);
}
- std::cerr << std::endl;
- std::cerr << "Backoff count is " << count << std::endl;
}
};
diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc
index 8b4953ba2..1e086dcce 100644
--- a/lm/builder/lmplz_main.cc
+++ b/lm/builder/lmplz_main.cc
@@ -42,9 +42,9 @@ int main(int argc, char *argv[]) {
("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
- ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step")
("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
+ ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.");
diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc
index 14a1f7218..b89ea6ba5 100644
--- a/lm/builder/pipeline.cc
+++ b/lm/builder/pipeline.cc
@@ -207,17 +207,18 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
const PipelineConfig &config = master.Config();
std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
- UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory());
+ const std::size_t vocab_usage = CorpusCount::VocabUsage(config.vocab_estimate);
+ UTIL_THROW_IF(config.TotalMemory() < vocab_usage, util::Exception, "Vocab hash size estimate " << vocab_usage << " exceeds total memory " << config.TotalMemory());
std::size_t memory_for_chain =
// This much memory to work with after vocab hash table.
- static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) /
+ static_cast<float>(config.TotalMemory() - vocab_usage) /
// Solve for block size including the dedupe multiplier for one block.
(static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
// Chain likes memory expressed in terms of total memory.
static_cast<float>(config.block_count);
util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain));
- WordIndex type_count;
+ WordIndex type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh
index f1d6c5f61..fc3314bf1 100644
--- a/lm/builder/pipeline.hh
+++ b/lm/builder/pipeline.hh
@@ -3,6 +3,7 @@
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/header_info.hh"
+#include "lm/word_index.hh"
#include "util/stream/config.hh"
#include "util/file_piece.hh"
@@ -19,9 +20,9 @@ struct PipelineConfig {
util::stream::ChainConfig read_backoffs;
bool verbose_header;
- // Amount of memory to assume that the vocabulary hash table will use. This
- // is subtracted from total memory for CorpusCount.
- std::size_t assume_vocab_hash_size;
+ // Estimated vocabulary size. Used for sizing CorpusCount memory and
+ // initial probing hash table sizing, also in CorpusCount.
+ lm::WordIndex vocab_estimate;
// Minimum block size to tolerate.
std::size_t minimum_block;