Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2015-05-19 22:27:30 +0300
committerKenneth Heafield <github@kheafield.com>2015-05-19 22:27:30 +0300
commita70d37e46fce323f7a9720e3a621f35d19e4ac9f (patch)
treed45651f2d44ba7c373f2a198611dca3b196e36d9 /lm
parent90309aebfa0184ac611725520443b34c3331794b (diff)
KenLM 7408730be415db9b650560a8b2bd3e4e3af49ec9.
unistd.hh is dead.
Diffstat (limited to 'lm')
-rw-r--r--lm/binary_format.cc4
-rw-r--r--lm/builder/Jamfile4
-rw-r--r--lm/builder/adjust_counts.cc87
-rw-r--r--lm/builder/adjust_counts_test.cc27
-rw-r--r--lm/builder/combine_counts.hh31
-rw-r--r--lm/builder/corpus_count.cc37
-rw-r--r--lm/builder/corpus_count_test.cc13
-rw-r--r--lm/builder/initial_probabilities.cc73
-rw-r--r--lm/builder/initial_probabilities.hh6
-rw-r--r--lm/builder/interpolate.cc27
-rw-r--r--lm/builder/interpolate.hh5
-rw-r--r--lm/builder/joint_order.hh7
-rw-r--r--lm/builder/lmplz_main.cc19
-rw-r--r--lm/builder/output.cc33
-rw-r--r--lm/builder/output.hh26
-rw-r--r--lm/builder/payload.hh48
-rw-r--r--lm/builder/pipeline.cc213
-rw-r--r--lm/builder/pipeline.hh3
-rw-r--r--lm/builder/print.cc33
-rw-r--r--lm/builder/print.hh27
-rw-r--r--lm/builder/special.hh27
-rw-r--r--lm/common/Jamfile2
-rw-r--r--lm/common/compare.hh (renamed from lm/builder/sort.hh)76
-rw-r--r--lm/common/model_buffer.cc82
-rw-r--r--lm/common/model_buffer.hh45
-rw-r--r--lm/common/ngram.hh (renamed from lm/builder/ngram.hh)82
-rw-r--r--lm/common/ngram_stream.hh (renamed from lm/builder/ngram_stream.hh)30
-rw-r--r--lm/common/renumber.cc17
-rw-r--r--lm/common/renumber.hh30
-rw-r--r--lm/kenlm_benchmark_main.cc128
-rw-r--r--lm/ngram_query.hh64
-rw-r--r--lm/query_main.cc47
-rw-r--r--lm/value.hh1
-rw-r--r--lm/vocab.cc104
-rw-r--r--lm/vocab.hh40
-rw-r--r--lm/word_index.hh1
36 files changed, 1020 insertions, 479 deletions
diff --git a/lm/binary_format.cc b/lm/binary_format.cc
index 4ad893d44..2b34a778a 100644
--- a/lm/binary_format.cc
+++ b/lm/binary_format.cc
@@ -170,6 +170,7 @@ void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
if (!write_mmap_) {
header_size_ = 0;
util::MapAnonymous(memory_size, memory_vocab_);
+ util::AdviseHugePages(memory_vocab_.get(), memory_size);
return reinterpret_cast<uint8_t*>(memory_vocab_.get());
}
header_size_ = TotalHeaderSize(order);
@@ -189,6 +190,7 @@ void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
break;
}
strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_);
+ util::AdviseHugePages(vocab_base, total);
return reinterpret_cast<uint8_t*>(vocab_base) + header_size_;
}
@@ -201,6 +203,7 @@ void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad
util::MapAnonymous(memory_size, memory_search_);
assert(header_size_ == 0 || write_mmap_);
vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ util::AdviseHugePages(memory_search_.get(), memory_size);
return reinterpret_cast<uint8_t*>(memory_search_.get());
}
@@ -214,6 +217,7 @@ void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad
util::ResizeOrThrow(file_.get(), new_size);
void *ret;
MapFile(vocab_base, ret);
+ util::AdviseHugePages(ret, new_size);
return ret;
}
diff --git a/lm/builder/Jamfile b/lm/builder/Jamfile
index 1e0e18b5f..329a8e076 100644
--- a/lm/builder/Jamfile
+++ b/lm/builder/Jamfile
@@ -1,5 +1,5 @@
-fakelib builder : [ glob *.cc : *test.cc *main.cc ]
- ../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm
+fakelib builder : [ glob *.cc : *test.cc *main.cc ]
+ ../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm ../common//common
: : : <library>/top//boost_thread $(timer-link) ;
exe lmplz : lmplz_main.cc builder /top//boost_program_options ;
diff --git a/lm/builder/adjust_counts.cc b/lm/builder/adjust_counts.cc
index bcaa71998..3ac3e8d20 100644
--- a/lm/builder/adjust_counts.cc
+++ b/lm/builder/adjust_counts.cc
@@ -1,5 +1,6 @@
#include "lm/builder/adjust_counts.hh"
-#include "lm/builder/ngram_stream.hh"
+#include "lm/common/ngram_stream.hh"
+#include "lm/builder/payload.hh"
#include "util/stream/timer.hh"
#include <algorithm>
@@ -13,7 +14,7 @@ BadDiscountException::~BadDiscountException() throw() {}
namespace {
// Return last word in full that is different.
-const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) {
+const WordIndex* FindDifference(const NGram<BuildingPayload> &full, const NGram<BuildingPayload> &lower_last) {
const WordIndex *cur_word = full.end() - 1;
const WordIndex *pre_word = lower_last.end() - 1;
// Find last difference.
@@ -111,15 +112,15 @@ class StatCollector {
class CollapseStream {
public:
CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold, const std::vector<bool>& prune_words) :
- current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
+ current_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
prune_threshold_(prune_threshold),
prune_words_(prune_words),
block_(position) {
StartBlock();
}
- const NGram &operator*() const { return current_; }
- const NGram *operator->() const { return &current_; }
+ const NGram<BuildingPayload> &operator*() const { return current_; }
+ const NGram<BuildingPayload> *operator->() const { return &current_; }
operator bool() const { return block_; }
@@ -131,14 +132,14 @@ class CollapseStream {
UpdateCopyFrom();
// Mark highest order n-grams for later pruning
- if(current_.Count() <= prune_threshold_) {
- current_.Mark();
+ if(current_.Value().count <= prune_threshold_) {
+ current_.Value().Mark();
}
if(!prune_words_.empty()) {
for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
if(prune_words_[*i]) {
- current_.Mark();
+ current_.Value().Mark();
break;
}
}
@@ -155,14 +156,14 @@ class CollapseStream {
}
// Mark highest order n-grams for later pruning
- if(current_.Count() <= prune_threshold_) {
- current_.Mark();
+ if(current_.Value().count <= prune_threshold_) {
+ current_.Value().Mark();
}
if(!prune_words_.empty()) {
for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
if(prune_words_[*i]) {
- current_.Mark();
+ current_.Value().Mark();
break;
}
}
@@ -182,14 +183,14 @@ class CollapseStream {
UpdateCopyFrom();
// Mark highest order n-grams for later pruning
- if(current_.Count() <= prune_threshold_) {
- current_.Mark();
+ if(current_.Value().count <= prune_threshold_) {
+ current_.Value().Mark();
}
if(!prune_words_.empty()) {
for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
if(prune_words_[*i]) {
- current_.Mark();
+ current_.Value().Mark();
break;
}
}
@@ -200,11 +201,11 @@ class CollapseStream {
// Find last without bos.
void UpdateCopyFrom() {
for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
- if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break;
+ if (NGram<BuildingPayload>(copy_from_, current_.Order()).begin()[1] != kBOS) break;
}
}
- NGram current_;
+ NGram<BuildingPayload> current_;
// Goes backwards in the block
uint8_t *copy_from_;
@@ -223,36 +224,36 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
if (order == 1) {
// Only unigrams. Just collect stats.
- for (NGramStream full(positions[0]); full; ++full) {
+ for (NGramStream<BuildingPayload> full(positions[0]); full; ++full) {
// Do not prune <s> </s> <unk>
if(*full->begin() > 2) {
- if(full->Count() <= prune_thresholds_[0])
- full->Mark();
+ if(full->Value().count <= prune_thresholds_[0])
+ full->Value().Mark();
if(!prune_words_.empty() && prune_words_[*full->begin()])
- full->Mark();
+ full->Value().Mark();
}
- stats.AddFull(full->UnmarkedCount(), full->IsMarked());
+ stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked());
}
stats.CalculateDiscounts(discount_config_);
return;
}
- NGramStreams streams;
+ NGramStreams<BuildingPayload> streams;
streams.Init(positions, positions.size() - 1);
CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back(), prune_words_);
// Initialization: <unk> has count 0 and so does <s>.
- NGramStream *lower_valid = streams.begin();
- const NGramStream *const streams_begin = streams.begin();
- streams[0]->Count() = 0;
+ NGramStream<BuildingPayload> *lower_valid = streams.begin();
+ const NGramStream<BuildingPayload> *const streams_begin = streams.begin();
+ streams[0]->Value().count = 0;
*streams[0]->begin() = kUNK;
stats.Add(0, 0);
- (++streams[0])->Count() = 0;
+ (++streams[0])->Value().count = 0;
*streams[0]->begin() = kBOS;
// <s> is not in stats yet because it will get put in later.
@@ -271,28 +272,28 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
for (; lower_valid >= &streams[same]; --lower_valid) {
uint64_t order_minus_1 = lower_valid - streams_begin;
if(actual_counts[order_minus_1] <= prune_thresholds_[order_minus_1])
- (*lower_valid)->Mark();
+ (*lower_valid)->Value().Mark();
if(!prune_words_.empty()) {
for(WordIndex* i = (*lower_valid)->begin(); i != (*lower_valid)->end(); i++) {
if(prune_words_[*i]) {
- (*lower_valid)->Mark();
+ (*lower_valid)->Value().Mark();
break;
}
}
}
- stats.Add(order_minus_1, (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked());
+ stats.Add(order_minus_1, (*lower_valid)->Value().UnmarkedCount(), (*lower_valid)->Value().IsMarked());
++*lower_valid;
}
// STEP 2: Update n-grams that still match.
// n-grams that match get count from the full entry.
for (std::size_t i = 0; i < same; ++i) {
- actual_counts[i] += full->UnmarkedCount();
+ actual_counts[i] += full->Value().UnmarkedCount();
}
// Increment the number of unique extensions for the longest match.
- if (same) ++streams[same - 1]->Count();
+ if (same) ++streams[same - 1]->Value().count;
// STEP 3: Initialize new n-grams.
// This is here because bos is also const WordIndex *, so copy gets
@@ -301,47 +302,47 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
// Initialize and mark as valid up to bos.
const WordIndex *bos;
for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
- NGramStream &to = *++lower_valid;
+ NGramStream<BuildingPayload> &to = *++lower_valid;
std::copy(bos, full_end, to->begin());
- to->Count() = 1;
- actual_counts[lower_valid - streams_begin] = full->UnmarkedCount();
+ to->Value().count = 1;
+ actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount();
}
// Now bos indicates where <s> is or is the 0th word of full.
if (bos != full->begin()) {
// There is an <s> beyond the 0th word.
- NGramStream &to = *++lower_valid;
+ NGramStream<BuildingPayload> &to = *++lower_valid;
std::copy(bos, full_end, to->begin());
// Anything that begins with <s> has full non adjusted count.
- to->Count() = full->UnmarkedCount();
- actual_counts[lower_valid - streams_begin] = full->UnmarkedCount();
+ to->Value().count = full->Value().UnmarkedCount();
+ actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount();
} else {
- stats.AddFull(full->UnmarkedCount(), full->IsMarked());
+ stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked());
}
assert(lower_valid >= &streams[0]);
}
// The above loop outputs n-grams when it observes changes. This outputs
// the last n-grams.
- for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
+ for (NGramStream<BuildingPayload> *s = streams.begin(); s <= lower_valid; ++s) {
uint64_t lower_count = actual_counts[(*s)->Order() - 1];
if(lower_count <= prune_thresholds_[(*s)->Order() - 1])
- (*s)->Mark();
+ (*s)->Value().Mark();
if(!prune_words_.empty()) {
for(WordIndex* i = (*s)->begin(); i != (*s)->end(); i++) {
if(prune_words_[*i]) {
- (*s)->Mark();
+ (*s)->Value().Mark();
break;
}
}
}
- stats.Add(s - streams.begin(), lower_count, (*s)->IsMarked());
+ stats.Add(s - streams.begin(), lower_count, (*s)->Value().IsMarked());
++*s;
}
// Poison everyone! Except the N-grams which were already poisoned by the input.
- for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
+ for (NGramStream<BuildingPayload> *s = streams.begin(); s != streams.end(); ++s)
s->Poison();
stats.CalculateDiscounts(discount_config_);
diff --git a/lm/builder/adjust_counts_test.cc b/lm/builder/adjust_counts_test.cc
index 2a9d78ae0..fff551f7c 100644
--- a/lm/builder/adjust_counts_test.cc
+++ b/lm/builder/adjust_counts_test.cc
@@ -1,6 +1,7 @@
#include "lm/builder/adjust_counts.hh"
-#include "lm/builder/ngram_stream.hh"
+#include "lm/common/ngram_stream.hh"
+#include "lm/builder/payload.hh"
#include "util/scoped.hh"
#include <boost/thread/thread.hpp>
@@ -37,7 +38,7 @@ struct Gram4 {
class WriteInput {
public:
void Run(const util::stream::ChainPosition &position) {
- NGramStream input(position);
+ NGramStream<BuildingPayload> input(position);
Gram4 grams[] = {
{{0,0,0,0},10},
{{0,0,3,0},3},
@@ -47,7 +48,7 @@ class WriteInput {
};
for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) {
memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4);
- input->Count() = grams[i].count;
+ input->Value().count = grams[i].count;
}
input.Poison();
}
@@ -63,7 +64,7 @@ BOOST_AUTO_TEST_CASE(Simple) {
config.block_count = 1;
util::stream::Chains chains(4);
for (unsigned i = 0; i < 4; ++i) {
- config.entry_size = NGram::TotalSize(i + 1);
+ config.entry_size = NGram<BuildingPayload>::TotalSize(i + 1);
chains.push_back(config);
}
@@ -86,25 +87,25 @@ BOOST_AUTO_TEST_CASE(Simple) {
/* BOOST_CHECK_EQUAL(4UL, counts[1]);
BOOST_CHECK_EQUAL(3UL, counts[2]);
BOOST_CHECK_EQUAL(3UL, counts[3]);*/
- BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size());
- NGram uni(outputs[0].Get(), 1);
+ BOOST_REQUIRE_EQUAL(NGram<BuildingPayload>::TotalSize(1) * 4, outputs[0].Size());
+ NGram<BuildingPayload> uni(outputs[0].Get(), 1);
BOOST_CHECK_EQUAL(kUNK, *uni.begin());
- BOOST_CHECK_EQUAL(0ULL, uni.Count());
+ BOOST_CHECK_EQUAL(0ULL, uni.Value().count);
uni.NextInMemory();
BOOST_CHECK_EQUAL(kBOS, *uni.begin());
- BOOST_CHECK_EQUAL(0ULL, uni.Count());
+ BOOST_CHECK_EQUAL(0ULL, uni.Value().count);
uni.NextInMemory();
BOOST_CHECK_EQUAL(0UL, *uni.begin());
- BOOST_CHECK_EQUAL(2ULL, uni.Count());
+ BOOST_CHECK_EQUAL(2ULL, uni.Value().count);
uni.NextInMemory();
- BOOST_CHECK_EQUAL(2ULL, uni.Count());
+ BOOST_CHECK_EQUAL(2ULL, uni.Value().count);
BOOST_CHECK_EQUAL(2UL, *uni.begin());
- BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size());
- NGram bi(outputs[1].Get(), 2);
+ BOOST_REQUIRE_EQUAL(NGram<BuildingPayload>::TotalSize(2) * 4, outputs[1].Size());
+ NGram<BuildingPayload> bi(outputs[1].Get(), 2);
BOOST_CHECK_EQUAL(0UL, *bi.begin());
BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1));
- BOOST_CHECK_EQUAL(1ULL, bi.Count());
+ BOOST_CHECK_EQUAL(1ULL, bi.Value().count);
bi.NextInMemory();
}
diff --git a/lm/builder/combine_counts.hh b/lm/builder/combine_counts.hh
new file mode 100644
index 000000000..2eda51704
--- /dev/null
+++ b/lm/builder/combine_counts.hh
@@ -0,0 +1,31 @@
+#ifndef LM_BUILDER_COMBINE_COUNTS_H
+#define LM_BUILDER_COMBINE_COUNTS_H
+
+#include "lm/builder/payload.hh"
+#include "lm/common/ngram.hh"
+#include "lm/common/compare.hh"
+#include "lm/word_index.hh"
+#include "util/stream/sort.hh"
+
+#include <functional>
+#include <string>
+
+namespace lm {
+namespace builder {
+
+// Sum counts for the same n-gram.
+struct CombineCounts {
+ bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
+ NGram<BuildingPayload> first(first_void, compare.Order());
+ // There isn't a const version of NGram.
+ NGram<BuildingPayload> second(const_cast<void*>(second_void), compare.Order());
+ if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
+ first.Value().count += second.Value().count;
+ return true;
+ }
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_COMBINE_COUNTS_H
diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc
index 889eeb7a9..9f23b28a8 100644
--- a/lm/builder/corpus_count.cc
+++ b/lm/builder/corpus_count.cc
@@ -1,6 +1,7 @@
#include "lm/builder/corpus_count.hh"
-#include "lm/builder/ngram.hh"
+#include "lm/builder/payload.hh"
+#include "lm/common/ngram.hh"
#include "lm/lm_exception.hh"
#include "lm/vocab.hh"
#include "lm/word_index.hh"
@@ -25,19 +26,6 @@ namespace lm {
namespace builder {
namespace {
-#pragma pack(push)
-#pragma pack(4)
-struct VocabEntry {
- typedef uint64_t Key;
-
- uint64_t GetKey() const { return key; }
- void SetKey(uint64_t to) { key = to; }
-
- uint64_t key;
- lm::WordIndex value;
-};
-#pragma pack(pop)
-
class DedupeHash : public std::unary_function<const WordIndex *, bool> {
public:
explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
@@ -115,17 +103,17 @@ class Writer {
bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at);
if (found) {
// Already present.
- NGram already(at->key, gram_.Order());
- ++(already.Count());
+ NGram<BuildingPayload> already(at->key, gram_.Order());
+ ++(already.Value().count);
// Shift left by one.
memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
return;
}
// Complete the write.
- gram_.Count() = 1;
+ gram_.Value().count = 1;
// Prepare the next n-gram.
if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
- NGram last(gram_);
+ NGram<BuildingPayload> last(gram_);
gram_.NextInMemory();
std::copy(last.begin() + 1, last.end(), gram_.begin());
return;
@@ -141,7 +129,7 @@ class Writer {
private:
void AddUnigramWord(WordIndex index) {
*gram_.begin() = index;
- gram_.Count() = 0;
+ gram_.Value().count = 0;
gram_.NextInMemory();
if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) {
block_->SetValidSize(block_size_);
@@ -151,7 +139,7 @@ class Writer {
util::stream::Link block_;
- NGram gram_;
+ NGram<BuildingPayload> gram_;
// This is the memory behind the invalid value in dedupe_.
std::vector<WordIndex> dedupe_invalid_;
@@ -167,7 +155,7 @@ class Writer {
} // namespace
float CorpusCount::DedupeMultiplier(std::size_t order) {
- return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
+ return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram<BuildingPayload>::TotalSize(order));
}
std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
@@ -202,7 +190,7 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
token_count_ = 0;
type_count_ = 0;
const WordIndex end_sentence = vocab.FindOrInsert("</s>");
- Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
+ Writer writer(NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
bool delimiters[256];
util::BoolCharacter::Build("\0\t\n\r ", delimiters);
@@ -233,9 +221,8 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
prune_words_.resize(vocab.Size(), true);
try {
while (true) {
- StringPiece line(prune_vocab_file.ReadLine());
- for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w)
- prune_words_[vocab.Index(*w)] = false;
+ StringPiece word(prune_vocab_file.ReadDelimited(delimiters));
+ prune_words_[vocab.Index(word)] = false;
}
} catch (const util::EndOfFileException &e) {}
diff --git a/lm/builder/corpus_count_test.cc b/lm/builder/corpus_count_test.cc
index 18301656f..82f859690 100644
--- a/lm/builder/corpus_count_test.cc
+++ b/lm/builder/corpus_count_test.cc
@@ -1,7 +1,8 @@
#include "lm/builder/corpus_count.hh"
-#include "lm/builder/ngram.hh"
-#include "lm/builder/ngram_stream.hh"
+#include "lm/builder/payload.hh"
+#include "lm/common/ngram_stream.hh"
+#include "lm/common/ngram.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
@@ -14,13 +15,13 @@
namespace lm { namespace builder { namespace {
-#define Check(str, count) { \
+#define Check(str, cnt) { \
BOOST_REQUIRE(stream); \
w = stream->begin(); \
for (util::TokenIter<util::AnyCharacter, true> t(str, " "); t; ++t, ++w) { \
BOOST_CHECK_EQUAL(*t, v[*w]); \
} \
- BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \
+ BOOST_CHECK_EQUAL((uint64_t)cnt, stream->Value().count); \
++stream; \
}
@@ -35,14 +36,14 @@ BOOST_AUTO_TEST_CASE(Short) {
util::FilePiece input_piece(input_file.release(), "temp file");
util::stream::ChainConfig config;
- config.entry_size = NGram::TotalSize(3);
+ config.entry_size = NGram<BuildingPayload>::TotalSize(3);
config.total_memory = config.entry_size * 20;
config.block_count = 2;
util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab"));
util::stream::Chain chain(config);
- NGramStream stream;
+ NGramStream<BuildingPayload> stream;
uint64_t token_count;
WordIndex type_count = 10;
std::vector<bool> prune_words;
diff --git a/lm/builder/initial_probabilities.cc b/lm/builder/initial_probabilities.cc
index 80063eb2e..ef8a8ecfd 100644
--- a/lm/builder/initial_probabilities.cc
+++ b/lm/builder/initial_probabilities.cc
@@ -1,9 +1,10 @@
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/discount.hh"
-#include "lm/builder/ngram_stream.hh"
-#include "lm/builder/sort.hh"
+#include "lm/builder/special.hh"
#include "lm/builder/hash_gamma.hh"
+#include "lm/builder/payload.hh"
+#include "lm/common/ngram_stream.hh"
#include "util/murmur_hash.hh"
#include "util/file.hh"
#include "util/stream/chain.hh"
@@ -32,17 +33,18 @@ struct HashBufferEntry : public BufferEntry {
// threshold.
class PruneNGramStream {
public:
- PruneNGramStream(const util::stream::ChainPosition &position) :
- current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
- dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
+ PruneNGramStream(const util::stream::ChainPosition &position, const SpecialVocab &specials) :
+ current_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
+ dest_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
currentCount_(0),
- block_(position)
+ block_(position),
+ specials_(specials)
{
StartBlock();
}
- NGram &operator*() { return current_; }
- NGram *operator->() { return &current_; }
+ NGram<BuildingPayload> &operator*() { return current_; }
+ NGram<BuildingPayload> *operator->() { return &current_; }
operator bool() const {
return block_;
@@ -50,8 +52,7 @@ class PruneNGramStream {
PruneNGramStream &operator++() {
assert(block_);
-
- if(current_.Order() == 1 && *current_.begin() <= 2)
+ if(UTIL_UNLIKELY(current_.Order() == 1 && specials_.IsSpecial(*current_.begin())))
dest_.NextInMemory();
else if(currentCount_ > 0) {
if(dest_.Base() < current_.Base()) {
@@ -68,10 +69,10 @@ class PruneNGramStream {
++block_;
StartBlock();
if (block_) {
- currentCount_ = current_.CutoffCount();
+ currentCount_ = current_.Value().CutoffCount();
}
} else {
- currentCount_ = current_.CutoffCount();
+ currentCount_ = current_.Value().CutoffCount();
}
return *this;
@@ -84,23 +85,25 @@ class PruneNGramStream {
if (block_->ValidSize()) break;
}
current_.ReBase(block_->Get());
- currentCount_ = current_.CutoffCount();
+ currentCount_ = current_.Value().CutoffCount();
dest_.ReBase(block_->Get());
}
- NGram current_; // input iterator
- NGram dest_; // output iterator
+ NGram<BuildingPayload> current_; // input iterator
+ NGram<BuildingPayload> dest_; // output iterator
uint64_t currentCount_;
util::stream::Link block_;
+
+ const SpecialVocab specials_;
};
// Extract an array of HashedGamma from an array of BufferEntry.
class OnlyGamma {
public:
- OnlyGamma(bool pruning) : pruning_(pruning) {}
+ explicit OnlyGamma(bool pruning) : pruning_(pruning) {}
void Run(const util::stream::ChainPosition &position) {
for (util::stream::Link block_it(position); block_it; ++block_it) {
@@ -143,7 +146,7 @@ class AddRight {
: discount_(discount), input_(input), pruning_(pruning) {}
void Run(const util::stream::ChainPosition &output) {
- NGramStream in(input_);
+ NGramStream<BuildingPayload> in(input_);
util::stream::Stream out(output);
std::vector<WordIndex> previous(in->Order() - 1);
@@ -159,17 +162,17 @@ class AddRight {
uint64_t counts[4];
memset(counts, 0, sizeof(counts));
do {
- denominator += in->UnmarkedCount();
+ denominator += in->Value().UnmarkedCount();
// Collect unused probability mass from pruning.
// Becomes 0 for unpruned ngrams.
- normalizer += in->UnmarkedCount() - in->CutoffCount();
+ normalizer += in->Value().UnmarkedCount() - in->Value().CutoffCount();
// Chen&Goodman do not mention counting based on cutoffs, but
// backoff becomes larger than 1 otherwise, so probably needs
// to count cutoffs. Counts normally without pruning.
- if(in->CutoffCount() > 0)
- ++counts[std::min(in->CutoffCount(), static_cast<uint64_t>(3))];
+ if(in->Value().CutoffCount() > 0)
+ ++counts[std::min(in->Value().CutoffCount(), static_cast<uint64_t>(3))];
} while (++in && !memcmp(previous_raw, in->begin(), size));
@@ -202,15 +205,15 @@ class AddRight {
class MergeRight {
public:
- MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount)
- : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {}
+ MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount, const SpecialVocab &specials)
+ : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount), specials_(specials) {}
// calculate the initial probability of each n-gram (before order-interpolation)
// Run() gets invoked once for each order
void Run(const util::stream::ChainPosition &primary) {
util::stream::Stream summed(from_adder_);
- PruneNGramStream grams(primary);
+ PruneNGramStream grams(primary, specials_);
// Without interpolation, the interpolation weight goes to <unk>.
if (grams->Order() == 1) {
@@ -228,17 +231,21 @@ class MergeRight {
grams->Value().uninterp.prob = sums.gamma;
}
grams->Value().uninterp.gamma = gamma_assign;
- ++grams;
+
+ for (++grams; *grams->begin() != specials_.BOS(); ++grams) {
+ grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator;
+ grams->Value().uninterp.gamma = gamma_assign;
+ }
// Special case for <s>: probability 1.0. This allows <s> to be
- // explicitly scores as part of the sentence without impacting
+ // explicitly scored as part of the sentence without impacting
// probability and computes q correctly as b(<s>).
- assert(*grams->begin() == kBOS);
+ assert(*grams->begin() == specials_.BOS());
grams->Value().uninterp.prob = 1.0;
grams->Value().uninterp.gamma = 0.0;
while (++grams) {
- grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator;
+ grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator;
grams->Value().uninterp.gamma = gamma_assign;
}
++summed;
@@ -252,8 +259,8 @@ class MergeRight {
const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
do {
- Payload &pay = grams->Value();
- pay.uninterp.prob = discount_.Apply(grams->UnmarkedCount()) / sums.denominator;
+ BuildingPayload &pay = grams->Value();
+ pay.uninterp.prob = discount_.Apply(grams->Value().UnmarkedCount()) / sums.denominator;
pay.uninterp.gamma = sums.gamma;
} while (++grams && !memcmp(&previous[0], grams->begin(), size));
}
@@ -263,6 +270,7 @@ class MergeRight {
bool interpolate_unigrams_;
util::stream::ChainPosition from_adder_;
Discount discount_;
+ const SpecialVocab specials_;
};
} // namespace
@@ -274,7 +282,8 @@ void InitialProbabilities(
util::stream::Chains &second_in,
util::stream::Chains &gamma_out,
const std::vector<uint64_t> &prune_thresholds,
- bool prune_vocab) {
+ bool prune_vocab,
+ const SpecialVocab &specials) {
for (size_t i = 0; i < primary.size(); ++i) {
util::stream::ChainConfig gamma_config = config.adder_out;
if(prune_vocab || prune_thresholds[i] > 0)
@@ -287,7 +296,7 @@ void InitialProbabilities(
gamma_out.push_back(gamma_config);
gamma_out[i] >> AddRight(discounts[i], second, prune_vocab || prune_thresholds[i] > 0);
- primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]);
+ primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i], specials);
// Don't bother with the OnlyGamma thread for something to discard.
if (i) gamma_out[i] >> OnlyGamma(prune_vocab || prune_thresholds[i] > 0);
diff --git a/lm/builder/initial_probabilities.hh b/lm/builder/initial_probabilities.hh
index a8ecf4dc2..dddbbb913 100644
--- a/lm/builder/initial_probabilities.hh
+++ b/lm/builder/initial_probabilities.hh
@@ -2,6 +2,7 @@
#define LM_BUILDER_INITIAL_PROBABILITIES_H
#include "lm/builder/discount.hh"
+#include "lm/word_index.hh"
#include "util/stream/config.hh"
#include <vector>
@@ -11,6 +12,8 @@ namespace util { namespace stream { class Chains; } }
namespace lm {
namespace builder {
+class SpecialVocab;
+
struct InitialProbabilitiesConfig {
// These should be small buffers to keep the adder from getting too far ahead
util::stream::ChainConfig adder_in;
@@ -34,7 +37,8 @@ void InitialProbabilities(
util::stream::Chains &second_in,
util::stream::Chains &gamma_out,
const std::vector<uint64_t> &prune_thresholds,
- bool prune_vocab);
+ bool prune_vocab,
+ const SpecialVocab &vocab);
} // namespace builder
} // namespace lm
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc
index 5b04cb3ff..84672e068 100644
--- a/lm/builder/interpolate.cc
+++ b/lm/builder/interpolate.cc
@@ -2,8 +2,8 @@
#include "lm/builder/hash_gamma.hh"
#include "lm/builder/joint_order.hh"
-#include "lm/builder/ngram_stream.hh"
-#include "lm/builder/sort.hh"
+#include "lm/common/ngram_stream.hh"
+#include "lm/common/compare.hh"
#include "lm/lm_exception.hh"
#include "util/fixed_array.hh"
#include "util/murmur_hash.hh"
@@ -65,11 +65,12 @@ class OutputProbBackoff {
template <class Output> class Callback {
public:
- Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab)
+ Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, const SpecialVocab &specials)
: backoffs_(backoffs.size()), probs_(backoffs.size() + 2),
prune_thresholds_(prune_thresholds),
prune_vocab_(prune_vocab),
- output_(backoffs.size() + 1 /* order */) {
+ output_(backoffs.size() + 1 /* order */),
+ specials_(specials) {
probs_[0] = uniform_prob;
for (std::size_t i = 0; i < backoffs.size(); ++i) {
backoffs_.push_back(backoffs[i]);
@@ -89,13 +90,13 @@ template <class Output> class Callback {
}
}
- void Enter(unsigned order_minus_1, NGram &gram) {
- Payload &pay = gram.Value();
+ void Enter(unsigned order_minus_1, NGram<BuildingPayload> &gram) {
+ BuildingPayload &pay = gram.Value();
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
probs_[order_minus_1 + 1] = pay.complete.prob;
float out_backoff;
- if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS && backoffs_[order_minus_1]) {
+ if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != specials_.UNK() && *(gram.end() - 1) != specials_.EOS() && backoffs_[order_minus_1]) {
if(prune_vocab_ || prune_thresholds_[order_minus_1 + 1] > 0) {
//Compute hash value for current context
uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));
@@ -123,7 +124,7 @@ template <class Output> class Callback {
output_.Gram(order_minus_1, out_backoff, pay.complete);
}
- void Exit(unsigned, const NGram &) const {}
+ void Exit(unsigned, const NGram<BuildingPayload> &) const {}
private:
util::FixedArray<util::stream::Stream> backoffs_;
@@ -133,26 +134,28 @@ template <class Output> class Callback {
bool prune_vocab_;
Output output_;
+ const SpecialVocab specials_;
};
} // namespace
-Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool prune_vocab, bool output_q)
+Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool prune_vocab, bool output_q, const SpecialVocab &specials)
: uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
backoffs_(backoffs),
prune_thresholds_(prune_thresholds),
prune_vocab_(prune_vocab),
- output_q_(output_q) {}
+ output_q_(output_q),
+ specials_(specials) {}
// perform order-wise interpolation
void Interpolate::Run(const util::stream::ChainPositions &positions) {
assert(positions.size() == backoffs_.size() + 1);
if (output_q_) {
typedef Callback<OutputQ> C;
- C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_);
+ C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_, specials_);
JointOrder<C, SuffixOrder>(positions, callback);
} else {
typedef Callback<OutputProbBackoff> C;
- C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_);
+ C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_, specials_);
JointOrder<C, SuffixOrder>(positions, callback);
}
}
diff --git a/lm/builder/interpolate.hh b/lm/builder/interpolate.hh
index 207a16dfd..dcee75adb 100644
--- a/lm/builder/interpolate.hh
+++ b/lm/builder/interpolate.hh
@@ -1,6 +1,8 @@
#ifndef LM_BUILDER_INTERPOLATE_H
#define LM_BUILDER_INTERPOLATE_H
+#include "lm/builder/special.hh"
+#include "lm/word_index.hh"
#include "util/stream/multi_stream.hh"
#include <vector>
@@ -18,7 +20,7 @@ class Interpolate {
public:
// Normally vocab_size is the unigram count-1 (since p(<s>) = 0) but might
// be larger when the user specifies a consistent vocabulary size.
- explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, bool output_q_);
+ explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, bool output_q, const SpecialVocab &specials);
void Run(const util::stream::ChainPositions &positions);
@@ -28,6 +30,7 @@ class Interpolate {
const std::vector<uint64_t> prune_thresholds_;
bool prune_vocab_;
bool output_q_;
+ const SpecialVocab specials_;
};
}} // namespaces
diff --git a/lm/builder/joint_order.hh b/lm/builder/joint_order.hh
index b05ef67fd..5f62a4578 100644
--- a/lm/builder/joint_order.hh
+++ b/lm/builder/joint_order.hh
@@ -1,7 +1,8 @@
#ifndef LM_BUILDER_JOINT_ORDER_H
#define LM_BUILDER_JOINT_ORDER_H
-#include "lm/builder/ngram_stream.hh"
+#include "lm/common/ngram_stream.hh"
+#include "lm/builder/payload.hh"
#include "lm/lm_exception.hh"
#ifdef DEBUG
@@ -15,9 +16,9 @@ namespace lm { namespace builder {
template <class Callback, class Compare> void JointOrder(const util::stream::ChainPositions &positions, Callback &callback) {
// Allow matching to reference streams[-1].
- NGramStreams streams_with_dummy;
+ NGramStreams<BuildingPayload> streams_with_dummy;
streams_with_dummy.InitWithDummy(positions);
- NGramStream *streams = streams_with_dummy.begin() + 1;
+ NGramStream<BuildingPayload> *streams = streams_with_dummy.begin() + 1;
unsigned int order;
for (order = 0; order < positions.size() && streams[order]; ++order) {}
diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc
index 5c9d86deb..c27490665 100644
--- a/lm/builder/lmplz_main.cc
+++ b/lm/builder/lmplz_main.cc
@@ -87,7 +87,7 @@ int main(int argc, char *argv[]) {
po::options_description options("Language model building options");
lm::builder::PipelineConfig pipeline;
- std::string text, arpa;
+ std::string text, intermediate, arpa;
std::vector<std::string> pruning;
std::vector<std::string> discount_fallback;
std::vector<std::string> discount_fallback_default;
@@ -116,6 +116,8 @@ int main(int argc, char *argv[]) {
("verbose_header", po::bool_switch(&verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
+ ("intermediate", po::value<std::string>(&intermediate), "Write ngrams to an intermediate file. Turns off ARPA output (which can be reactivated by --arpa file). Forces --renumber on. Implicitly makes --vocab_file be the provided name + .vocab.")
+ ("renumber", po::bool_switch(&pipeline.renumber_vocabulary), "Rrenumber the vocabulary identifiers so that they are monotone with the hash of each string. This is consistent with the ordering used by the trie data structure.")
("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.")
("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Default is to not prune, which is equivalent to --prune 0.")
("limit_vocab_file", po::value<std::string>(&pipeline.prune_vocab_file)->default_value(""), "Read allowed vocabulary separated by whitespace. N-grams that contain vocabulary items not in this list will be pruned. Can be combined with --prune arg")
@@ -212,8 +214,19 @@ int main(int argc, char *argv[]) {
}
try {
- lm::builder::Output output;
- output.Add(new lm::builder::PrintARPA(out.release(), verbose_header));
+ bool writing_intermediate = vm.count("intermediate");
+ if (writing_intermediate) {
+ pipeline.renumber_vocabulary = true;
+ if (!pipeline.vocab_file.empty()) {
+ std::cerr << "--intermediate and --vocab_file are incompatible because --intermediate already makes a vocab file." << std::endl;
+ return 1;
+ }
+ pipeline.vocab_file = intermediate + ".vocab";
+ }
+ lm::builder::Output output(writing_intermediate ? intermediate : pipeline.sort.temp_prefix, writing_intermediate);
+ if (!writing_intermediate || vm.count("arpa")) {
+ output.Add(new lm::builder::PrintARPA(out.release(), verbose_header));
+ }
lm::builder::Pipeline(pipeline, in.release(), output);
} catch (const util::MallocException &e) {
std::cerr << e.what() << std::endl;
diff --git a/lm/builder/output.cc b/lm/builder/output.cc
index 0fc0197c4..76478ad06 100644
--- a/lm/builder/output.cc
+++ b/lm/builder/output.cc
@@ -1,14 +1,41 @@
#include "lm/builder/output.hh"
+
+#include "lm/common/model_buffer.hh"
#include "util/stream/multi_stream.hh"
-#include <boost/ref.hpp>
+#include <iostream>
namespace lm { namespace builder {
OutputHook::~OutputHook() {}
-void OutputHook::Apply(util::stream::Chains &chains) {
- chains >> boost::ref(*this);
+Output::Output(StringPiece file_base, bool keep_buffer)
+ : file_base_(file_base.data(), file_base.size()), keep_buffer_(keep_buffer) {}
+
+void Output::SinkProbs(util::stream::Chains &chains, bool output_q) {
+ Apply(PROB_PARALLEL_HOOK, chains);
+ if (!keep_buffer_ && !Have(PROB_SEQUENTIAL_HOOK)) {
+ chains >> util::stream::kRecycle;
+ chains.Wait(true);
+ return;
+ }
+ lm::common::ModelBuffer buf(file_base_, keep_buffer_, output_q);
+ buf.Sink(chains);
+ chains >> util::stream::kRecycle;
+ chains.Wait(false);
+ if (Have(PROB_SEQUENTIAL_HOOK)) {
+ std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
+ buf.Source(chains);
+ Apply(PROB_SEQUENTIAL_HOOK, chains);
+ chains >> util::stream::kRecycle;
+ chains.Wait(true);
+ }
+}
+
+void Output::Apply(HookType hook_type, util::stream::Chains &chains) {
+ for (boost::ptr_vector<OutputHook>::iterator entry = outputs_[hook_type].begin(); entry != outputs_[hook_type].end(); ++entry) {
+ entry->Sink(chains);
+ }
}
}} // namespaces
diff --git a/lm/builder/output.hh b/lm/builder/output.hh
index 0ef769ae2..c1e0d1469 100644
--- a/lm/builder/output.hh
+++ b/lm/builder/output.hh
@@ -7,16 +7,14 @@
#include <boost/ptr_container/ptr_vector.hpp>
#include <boost/utility.hpp>
-#include <map>
-
namespace util { namespace stream { class Chains; class ChainPositions; } }
-/* Outputs from lmplz: ARPA< sharded files, etc */
+/* Outputs from lmplz: ARPA, sharded files, etc */
namespace lm { namespace builder {
// These are different types of hooks. Values should be consecutive to enable a vector lookup.
enum HookType {
- COUNT_HOOK, // Raw N-gram counts, highest order only.
+ // TODO: counts.
PROB_PARALLEL_HOOK, // Probability and backoff (or just q). Output must process the orders in parallel or there will be a deadlock.
PROB_SEQUENTIAL_HOOK, // Probability and backoff (or just q). Output can process orders any way it likes. This requires writing the data to disk then reading. Useful for ARPA files, which put unigrams first etc.
NUMBER_OF_HOOKS // Keep this last so we know how many values there are.
@@ -30,9 +28,7 @@ class OutputHook {
virtual ~OutputHook();
- virtual void Apply(util::stream::Chains &chains);
-
- virtual void Run(const util::stream::ChainPositions &positions) = 0;
+ virtual void Sink(util::stream::Chains &chains) = 0;
protected:
const HeaderInfo &GetHeader() const;
@@ -46,7 +42,7 @@ class OutputHook {
class Output : boost::noncopyable {
public:
- Output() {}
+ Output(StringPiece file_base, bool keep_buffer);
// Takes ownership.
void Add(OutputHook *hook) {
@@ -64,16 +60,20 @@ class Output : boost::noncopyable {
void SetHeader(const HeaderInfo &header) { header_ = header; }
const HeaderInfo &GetHeader() const { return header_; }
- void Apply(HookType hook_type, util::stream::Chains &chains) {
- for (boost::ptr_vector<OutputHook>::iterator entry = outputs_[hook_type].begin(); entry != outputs_[hook_type].end(); ++entry) {
- entry->Apply(chains);
- }
- }
+ // This is called by the pipeline.
+ void SinkProbs(util::stream::Chains &chains, bool output_q);
+
+ unsigned int Steps() const { return Have(PROB_SEQUENTIAL_HOOK); }
private:
+ void Apply(HookType hook_type, util::stream::Chains &chains);
+
boost::ptr_vector<OutputHook> outputs_[NUMBER_OF_HOOKS];
int vocab_fd_;
HeaderInfo header_;
+
+ std::string file_base_;
+ bool keep_buffer_;
};
inline const HeaderInfo &OutputHook::GetHeader() const {
diff --git a/lm/builder/payload.hh b/lm/builder/payload.hh
new file mode 100644
index 000000000..ba12725a4
--- /dev/null
+++ b/lm/builder/payload.hh
@@ -0,0 +1,48 @@
+#ifndef LM_BUILDER_PAYLOAD_H
+#define LM_BUILDER_PAYLOAD_H
+
+#include "lm/weights.hh"
+#include "lm/word_index.hh"
+#include <stdint.h>
+
+namespace lm { namespace builder {
+
+struct Uninterpolated {
+ float prob; // Uninterpolated probability.
+ float gamma; // Interpolation weight for lower order.
+};
+
+union BuildingPayload {
+ uint64_t count;
+ Uninterpolated uninterp;
+ ProbBackoff complete;
+
+ /*mjd**********************************************************************/
+ bool IsMarked() const {
+ return count >> (sizeof(count) * 8 - 1);
+ }
+
+ void Mark() {
+ count |= (1ul << (sizeof(count) * 8 - 1));
+ }
+
+ void Unmark() {
+ count &= ~(1ul << (sizeof(count) * 8 - 1));
+ }
+
+ uint64_t UnmarkedCount() const {
+ return count & ~(1ul << (sizeof(count) * 8 - 1));
+ }
+
+ uint64_t CutoffCount() const {
+ return IsMarked() ? 0 : UnmarkedCount();
+ }
+ /*mjd**********************************************************************/
+};
+
+const WordIndex kBOS = 1;
+const WordIndex kEOS = 2;
+
+}} // namespaces
+
+#endif // LM_BUILDER_PAYLOAD_H
diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc
index 1ca2e26f5..d588beedf 100644
--- a/lm/builder/pipeline.cc
+++ b/lm/builder/pipeline.cc
@@ -1,14 +1,17 @@
#include "lm/builder/pipeline.hh"
#include "lm/builder/adjust_counts.hh"
+#include "lm/builder/combine_counts.hh"
#include "lm/builder/corpus_count.hh"
#include "lm/builder/hash_gamma.hh"
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/interpolate.hh"
#include "lm/builder/output.hh"
-#include "lm/builder/sort.hh"
+#include "lm/common/compare.hh"
+#include "lm/common/renumber.hh"
#include "lm/sizes.hh"
+#include "lm/vocab.hh"
#include "util/exception.hh"
#include "util/file.hh"
@@ -21,7 +24,10 @@
namespace lm { namespace builder {
+using util::stream::Sorts;
+
namespace {
+
void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts) {
std::cerr << "Statistics:\n";
for (size_t i = 0; i < counts.size(); ++i) {
@@ -37,9 +43,9 @@ void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<uint
class Master {
public:
- explicit Master(PipelineConfig &config)
- : config_(config), chains_(config.order), files_(config.order) {
- config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block);
+ explicit Master(PipelineConfig &config, unsigned output_steps)
+ : config_(config), chains_(config.order), unigrams_(util::MakeTemp(config_.TempPrefix())), steps_(output_steps + 4) {
+ config_.minimum_block = std::max(NGram<BuildingPayload>::TotalSize(config_.order), config_.minimum_block);
}
const PipelineConfig &Config() const { return config_; }
@@ -52,40 +58,42 @@ class Master {
}
// This takes the (partially) sorted ngrams and sets up for adjusted counts.
- void InitForAdjust(util::stream::Sort<SuffixOrder, AddCombiner> &ngrams, WordIndex types) {
+ void InitForAdjust(util::stream::Sort<SuffixOrder, CombineCounts> &ngrams, WordIndex types, std::size_t subtract_for_numbering) {
const std::size_t each_order_min = config_.minimum_block * config_.block_count;
// We know how many unigrams there are. Don't allocate more than needed to them.
const std::size_t min_chains = (config_.order - 1) * each_order_min +
- std::min(types * NGram::TotalSize(1), each_order_min);
+ std::min(types * NGram<BuildingPayload>::TotalSize(1), each_order_min);
+ // Prevent overflow in subtracting.
+ const std::size_t total = std::max<std::size_t>(config_.TotalMemory(), min_chains + subtract_for_numbering + config_.minimum_block);
// Do merge sort with calculated laziness.
- const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy()));
+ const std::size_t merge_using = ngrams.Merge(std::min(total - min_chains - subtract_for_numbering, ngrams.DefaultLazy()));
std::vector<uint64_t> count_bounds(1, types);
- CreateChains(config_.TotalMemory() - merge_using, count_bounds);
+ CreateChains(total - merge_using - subtract_for_numbering, count_bounds);
ngrams.Output(chains_.back(), merge_using);
-
- // Setup unigram file.
- files_.push_back(util::MakeTemp(config_.TempPrefix()));
}
// For initial probabilities, but this is generic.
void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, util::stream::Chains &second, util::stream::ChainConfig second_config) {
+ bool unigrams_are_sorted = !config_.renumber_vocabulary;
// Do merge first before allocating chain memory.
- for (std::size_t i = 1; i < config_.order; ++i) {
- sorts[i - 1].Merge(0);
+ for (std::size_t i = 0; i < config_.order - unigrams_are_sorted; ++i) {
+ sorts[i].Merge(0);
}
// There's no lazy merge, so just divide memory amongst the chains.
CreateChains(config_.TotalMemory(), counts);
chains_.back().ActivateProgress();
- chains_[0] >> files_[0].Source();
- second_config.entry_size = NGram::TotalSize(1);
- second.push_back(second_config);
- second.back() >> files_[0].Source();
- for (std::size_t i = 1; i < config_.order; ++i) {
- util::scoped_fd fd(sorts[i - 1].StealCompleted());
+ if (unigrams_are_sorted) {
+ chains_[0] >> unigrams_.Source();
+ second_config.entry_size = NGram<BuildingPayload>::TotalSize(1);
+ second.push_back(second_config);
+ second.back() >> unigrams_.Source();
+ }
+ for (std::size_t i = unigrams_are_sorted; i < config_.order; ++i) {
+ util::scoped_fd fd(sorts[i - unigrams_are_sorted].StealCompleted());
chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get()));
chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true);
- second_config.entry_size = NGram::TotalSize(i + 1);
+ second_config.entry_size = NGram<BuildingPayload>::TotalSize(i + 1);
second.push_back(second_config);
second.back() >> util::stream::PRead(fd.release(), true);
}
@@ -96,7 +104,7 @@ class Master {
// Determine the minimum we can use for all the chains.
std::size_t min_chains = 0;
for (std::size_t i = 0; i < config_.order; ++i) {
- min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block));
+ min_chains += std::min(counts[i] * NGram<BuildingPayload>::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block));
}
std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains);
std::vector<std::size_t> laziness;
@@ -110,36 +118,24 @@ class Master {
CreateChains(for_merge + min_chains, counts);
chains_.back().ActivateProgress();
- chains_[0] >> files_[0].Source();
+ chains_[0] >> unigrams_.Source();
for (std::size_t i = 1; i < config_.order; ++i) {
sorts[i - 1].Output(chains_[i], laziness[i - 1]);
}
}
- void BufferFinal(const std::vector<uint64_t> &counts) {
- chains_[0] >> files_[0].Sink();
- for (std::size_t i = 1; i < config_.order; ++i) {
- files_.push_back(util::MakeTemp(config_.TempPrefix()));
- chains_[i] >> files_[i].Sink();
- }
- chains_.Wait(true);
- // Use less memory. Because we can.
- CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts);
- for (std::size_t i = 0; i < config_.order; ++i) {
- chains_[i] >> files_[i].Source();
- }
- }
-
- template <class Compare> void SetupSorts(Sorts<Compare> &sorts) {
- sorts.Init(config_.order - 1);
+ template <class Compare> void SetupSorts(Sorts<Compare> &sorts, bool exclude_unigrams) {
+ sorts.Init(config_.order - exclude_unigrams);
// Unigrams don't get sorted because their order is always the same.
- chains_[0] >> files_[0].Sink();
- for (std::size_t i = 1; i < config_.order; ++i) {
+ if (exclude_unigrams) chains_[0] >> unigrams_.Sink();
+ for (std::size_t i = exclude_unigrams; i < config_.order; ++i) {
sorts.push_back(chains_[i], config_.sort, Compare(i + 1));
}
chains_.Wait(true);
}
+ unsigned int Steps() const { return steps_; }
+
private:
// Create chains, allocating memory to them. Totally heuristic. Count
// bounds are upper bounds on the counts or not present.
@@ -150,7 +146,7 @@ class Master {
for (std::size_t i = 0; i < count_bounds.size(); ++i) {
assignments.push_back(static_cast<std::size_t>(std::min(
static_cast<uint64_t>(remaining_mem),
- count_bounds[i] * static_cast<uint64_t>(NGram::TotalSize(i + 1)))));
+ count_bounds[i] * static_cast<uint64_t>(NGram<BuildingPayload>::TotalSize(i + 1)))));
}
assignments.resize(config_.order, remaining_mem);
@@ -160,7 +156,7 @@ class Master {
// Indices of orders that have yet to be assigned.
std::vector<std::size_t> unassigned;
for (std::size_t i = 0; i < config_.order; ++i) {
- portions.push_back(static_cast<float>((i+1) * NGram::TotalSize(i+1)));
+ portions.push_back(static_cast<float>((i+1) * NGram<BuildingPayload>::TotalSize(i+1)));
unassigned.push_back(i);
}
/*If somebody doesn't eat their full dinner, give it to the rest of the
@@ -196,7 +192,7 @@ class Master {
std::cerr << "Chain sizes:";
for (std::size_t i = 0; i < config_.order; ++i) {
std::cerr << ' ' << (i+1) << ":" << assignments[i];
- chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i]));
+ chains_.push_back(util::stream::ChainConfig(NGram<BuildingPayload>::TotalSize(i + 1), block_count[i], assignments[i]));
}
std::cerr << std::endl;
}
@@ -204,13 +200,15 @@ class Master {
PipelineConfig &config_;
util::stream::Chains chains_;
- // Often only unigrams, but sometimes all orders.
- util::FixedArray<util::stream::FileBuffer> files_;
+
+ util::stream::FileBuffer unigrams_;
+
+ const unsigned int steps_;
};
-void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name, std::vector<bool> &prune_words) {
+util::stream::Sort<SuffixOrder, CombineCounts> *CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, WordIndex &type_count, std::string &text_file_name, std::vector<bool> &prune_words) {
const PipelineConfig &config = master.Config();
- std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
+ std::cerr << "=== 1/" << master.Steps() << " Counting and sorting n-grams ===" << std::endl;
const std::size_t vocab_usage = CorpusCount::VocabUsage(config.vocab_estimate);
UTIL_THROW_IF(config.TotalMemory() < vocab_usage, util::Exception, "Vocab hash size estimate " << vocab_usage << " exceeds total memory " << config.TotalMemory());
@@ -221,37 +219,34 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
(static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
// Chain likes memory expressed in terms of total memory.
static_cast<float>(config.block_count);
- util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain));
+ util::stream::Chain chain(util::stream::ChainConfig(NGram<BuildingPayload>::TotalSize(config.order), config.block_count, memory_for_chain));
- WordIndex type_count = config.vocab_estimate;
+ type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
CorpusCount counter(text, vocab_file, token_count, type_count, prune_words, config.prune_vocab_file, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
chain >> boost::ref(counter);
- util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
+ util::scoped_ptr<util::stream::Sort<SuffixOrder, CombineCounts> > sorter(new util::stream::Sort<SuffixOrder, CombineCounts>(chain, config.sort, SuffixOrder(config.order), CombineCounts()));
chain.Wait(true);
- std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;
- std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;
- master.InitForAdjust(sorter, type_count);
+ return sorter.release();
}
-void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary,
- util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab) {
+void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, const SpecialVocab &specials) {
const PipelineConfig &config = master.Config();
util::stream::Chains second(config.order);
{
Sorts<ContextOrder> sorts;
- master.SetupSorts(sorts);
+ master.SetupSorts(sorts, !config.renumber_vocabulary);
PrintStatistics(counts, counts_pruned, discounts);
lm::ngram::ShowSizes(counts_pruned);
- std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl;
+ std::cerr << "=== 3/" << master.Steps() << " Calculating and sorting initial probabilities ===" << std::endl;
master.SortAndReadTwice(counts_pruned, sorts, second, config.initial_probs.adder_in);
}
util::stream::Chains gamma_chains(config.order);
- InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds, prune_vocab);
+ InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds, prune_vocab, specials);
// Don't care about gamma for 0.
gamma_chains[0] >> util::stream::kRecycle;
gammas.Init(config.order - 1);
@@ -260,11 +255,11 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector
gamma_chains[i] >> gammas[i - 1].Sink();
}
// Has to be done here due to gamma_chains scope.
- master.SetupSorts(primary);
+ master.SetupSorts(primary, true);
}
-void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas) {
- std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl;
+void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas, Output &output, const SpecialVocab &specials) {
+ std::cerr << "=== 4/" << master.Steps() << " Calculating and writing order-interpolated probabilities ===" << std::endl;
const PipelineConfig &config = master.Config();
master.MaximumLazyInput(counts, primary);
@@ -278,13 +273,62 @@ void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &maste
read_backoffs.entry_size = sizeof(float);
gamma_chains.push_back(read_backoffs);
- gamma_chains.back() >> gammas[i].Source();
+ gamma_chains.back() >> gammas[i].Source(true);
}
- master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.prune_vocab, config.output_q);
+ master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.prune_vocab, config.output_q, specials);
gamma_chains >> util::stream::kRecycle;
- master.BufferFinal(counts);
+ output.SinkProbs(master.MutableChains(), config.output_q);
}
+class VocabNumbering {
+ public:
+ VocabNumbering(StringPiece vocab_file, StringPiece temp_prefix, bool renumber)
+ : vocab_file_(vocab_file.data(), vocab_file.size()),
+ temp_prefix_(temp_prefix.data(), temp_prefix.size()),
+ renumber_(renumber),
+ specials_(kBOS, kEOS) {
+ InitFile(renumber || vocab_file.empty());
+ }
+
+ int File() const { return null_delimited_.get(); }
+
+ // Compute the vocabulary mapping and return the memory used.
+ std::size_t ComputeMapping(WordIndex type_count) {
+ if (!renumber_) return 0;
+ util::scoped_fd previous(null_delimited_.release());
+ InitFile(vocab_file_.empty());
+ ngram::SortedVocabulary::ComputeRenumbering(type_count, previous.get(), null_delimited_.get(), vocab_mapping_);
+ return sizeof(WordIndex) * vocab_mapping_.size();
+ }
+
+ void ApplyRenumber(util::stream::Chains &chains) {
+ if (!renumber_) return;
+ for (std::size_t i = 0; i < chains.size(); ++i) {
+ chains[i] >> Renumber(&*vocab_mapping_.begin(), i + 1);
+ }
+ specials_ = SpecialVocab(vocab_mapping_[specials_.BOS()], vocab_mapping_[specials_.EOS()]);
+ }
+
+ const SpecialVocab &Specials() const { return specials_; }
+
+ private:
+ void InitFile(bool temp) {
+ null_delimited_.reset(temp ?
+ util::MakeTemp(temp_prefix_) :
+ util::CreateOrThrow(vocab_file_.c_str()));
+ }
+
+ std::string vocab_file_, temp_prefix_;
+
+ util::scoped_fd null_delimited_;
+
+ bool renumber_;
+
+ std::vector<WordIndex> vocab_mapping_;
+
+ SpecialVocab specials_;
+};
+
} // namespace
void Pipeline(PipelineConfig &config, int text_file, Output &output) {
@@ -293,48 +337,49 @@ void Pipeline(PipelineConfig &config, int text_file, Output &output) {
config.sort.buffer_size = config.TotalMemory() / 4;
std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl;
}
- if (config.minimum_block < NGram::TotalSize(config.order)) {
- config.minimum_block = NGram::TotalSize(config.order);
+ if (config.minimum_block < NGram<BuildingPayload>::TotalSize(config.order)) {
+ config.minimum_block = NGram<BuildingPayload>::TotalSize(config.order);
std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl;
}
UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << ".");
UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception,
"Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size.");
- UTIL_TIMER("(%w s) Total wall time elapsed\n");
-
- Master master(config);
+ Master master(config, output.Steps());
// master's destructor will wait for chains. But they might be deadlocked if
// this thread dies because e.g. it ran out of memory.
try {
- util::scoped_fd vocab_file(config.vocab_file.empty() ?
- util::MakeTemp(config.TempPrefix()) :
- util::CreateOrThrow(config.vocab_file.c_str()));
- output.SetVocabFD(vocab_file.get());
+ VocabNumbering numbering(config.vocab_file, config.TempPrefix(), config.renumber_vocabulary);
uint64_t token_count;
+ WordIndex type_count;
std::string text_file_name;
-
std::vector<bool> prune_words;
- CountText(text_file, vocab_file.get(), master, token_count, text_file_name, prune_words);
+ util::scoped_ptr<util::stream::Sort<SuffixOrder, CombineCounts> > sorted_counts(
+ CountText(text_file, numbering.File(), master, token_count, type_count, text_file_name, prune_words));
+ std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;
+
+ // Create vocab mapping, which uses temporary memory, while nothing else is happening.
+ std::size_t subtract_for_numbering = numbering.ComputeMapping(type_count);
+ output.SetVocabFD(numbering.File());
+
+ std::cerr << "=== 2/" << master.Steps() << " Calculating and sorting adjusted counts ===" << std::endl;
+ master.InitForAdjust(*sorted_counts, type_count, subtract_for_numbering);
+ sorted_counts.reset();
std::vector<uint64_t> counts;
std::vector<uint64_t> counts_pruned;
std::vector<Discount> discounts;
master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, prune_words, config.discount, discounts);
+ numbering.ApplyRenumber(master.MutableChains());
{
util::FixedArray<util::stream::FileBuffer> gammas;
Sorts<SuffixOrder> primary;
- InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds, config.prune_vocab);
- InterpolateProbabilities(counts_pruned, master, primary, gammas);
+ InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds, config.prune_vocab, numbering.Specials());
+ output.SetHeader(HeaderInfo(text_file_name, token_count, counts_pruned));
+ // Also does output.
+ InterpolateProbabilities(counts_pruned, master, primary, gammas, output, numbering.Specials());
}
-
- std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
-
- output.SetHeader(HeaderInfo(text_file_name, token_count, counts_pruned));
- output.Apply(PROB_SEQUENTIAL_HOOK, master.MutableChains());
- master >> util::stream::kRecycle;
- master.MutableChains().Wait(true);
} catch (const util::Exception &e) {
std::cerr << e.what() << std::endl;
abort();
diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh
index 1987daff1..695ecf7bd 100644
--- a/lm/builder/pipeline.hh
+++ b/lm/builder/pipeline.hh
@@ -39,6 +39,9 @@ struct PipelineConfig {
bool prune_vocab;
std::string prune_vocab_file;
+ /* Renumber the vocabulary the way the trie likes it? */
+ bool renumber_vocabulary;
+
// What to do with discount failures.
DiscountConfig discount;
diff --git a/lm/builder/print.cc b/lm/builder/print.cc
index 56a3134d8..178e54a21 100644
--- a/lm/builder/print.cc
+++ b/lm/builder/print.cc
@@ -23,30 +23,29 @@ VocabReconstitute::VocabReconstitute(int fd) {
map_.push_back(i);
}
+void PrintARPA::Sink(util::stream::Chains &chains) {
+ chains >> boost::ref(*this);
+}
+
void PrintARPA::Run(const util::stream::ChainPositions &positions) {
VocabReconstitute vocab(GetVocabFD());
+ util::FakeOFStream out(out_fd_.get());
- // Write header. TODO: integers in FakeOFStream.
- {
- std::stringstream stream;
- if (verbose_header_) {
- stream << "# Input file: " << GetHeader().input_file << '\n';
- stream << "# Token count: " << GetHeader().token_count << '\n';
- stream << "# Smoothing: Modified Kneser-Ney" << '\n';
- }
- stream << "\\data\\\n";
- for (size_t i = 0; i < positions.size(); ++i) {
- stream << "ngram " << (i+1) << '=' << GetHeader().counts_pruned[i] << '\n';
- }
- stream << '\n';
- std::string as_string(stream.str());
- util::WriteOrThrow(out_fd_.get(), as_string.data(), as_string.size());
+ // Write header.
+ if (verbose_header_) {
+ out << "# Input file: " << GetHeader().input_file << '\n';
+ out << "# Token count: " << GetHeader().token_count << '\n';
+ out << "# Smoothing: Modified Kneser-Ney" << '\n';
+ }
+ out << "\\data\\\n";
+ for (size_t i = 0; i < positions.size(); ++i) {
+ out << "ngram " << (i+1) << '=' << GetHeader().counts_pruned[i] << '\n';
}
+ out << '\n';
- util::FakeOFStream out(out_fd_.get());
for (unsigned order = 1; order <= positions.size(); ++order) {
out << "\\" << order << "-grams:" << '\n';
- for (NGramStream stream(positions[order - 1]); stream; ++stream) {
+ for (NGramStream<BuildingPayload> stream(positions[order - 1]); stream; ++stream) {
// Correcting for numerical precision issues. Take that IRST.
out << stream->Value().complete.prob << '\t' << vocab.Lookup(*stream->begin());
for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
diff --git a/lm/builder/print.hh b/lm/builder/print.hh
index 093a35697..5f293de85 100644
--- a/lm/builder/print.hh
+++ b/lm/builder/print.hh
@@ -1,14 +1,17 @@
#ifndef LM_BUILDER_PRINT_H
#define LM_BUILDER_PRINT_H
-#include "lm/builder/ngram.hh"
-#include "lm/builder/ngram_stream.hh"
+#include "lm/common/ngram_stream.hh"
#include "lm/builder/output.hh"
+#include "lm/builder/payload.hh"
+#include "lm/common/ngram.hh"
#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/mmap.hh"
#include "util/string_piece.hh"
+#include <boost/lexical_cast.hpp>
+
#include <ostream>
#include <cassert>
@@ -43,15 +46,15 @@ class VocabReconstitute {
};
// Not defined, only specialized.
-template <class T> void PrintPayload(util::FakeOFStream &to, const Payload &payload);
-template <> inline void PrintPayload<uint64_t>(util::FakeOFStream &to, const Payload &payload) {
+template <class T> void PrintPayload(util::FakeOFStream &to, const BuildingPayload &payload);
+template <> inline void PrintPayload<uint64_t>(util::FakeOFStream &to, const BuildingPayload &payload) {
// TODO slow
- to << boost::lexical_cast<std::string>(payload.count);
+ to << payload.count;
}
-template <> inline void PrintPayload<Uninterpolated>(util::FakeOFStream &to, const Payload &payload) {
+template <> inline void PrintPayload<Uninterpolated>(util::FakeOFStream &to, const BuildingPayload &payload) {
to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma);
}
-template <> inline void PrintPayload<ProbBackoff>(util::FakeOFStream &to, const Payload &payload) {
+template <> inline void PrintPayload<ProbBackoff>(util::FakeOFStream &to, const BuildingPayload &payload) {
to << payload.complete.prob << ' ' << payload.complete.backoff;
}
@@ -70,8 +73,8 @@ template <class V> class Print {
void Run(const util::stream::ChainPositions &chains) {
util::scoped_fd fd(to_);
util::FakeOFStream out(to_);
- NGramStreams streams(chains);
- for (NGramStream *s = streams.begin(); s != streams.end(); ++s) {
+ NGramStreams<BuildingPayload> streams(chains);
+ for (NGramStream<BuildingPayload> *s = streams.begin(); s != streams.end(); ++s) {
DumpStream(*s, out);
}
}
@@ -79,12 +82,12 @@ template <class V> class Print {
void Run(const util::stream::ChainPosition &position) {
util::scoped_fd fd(to_);
util::FakeOFStream out(to_);
- NGramStream stream(position);
+ NGramStream<BuildingPayload> stream(position);
DumpStream(stream, out);
}
private:
- void DumpStream(NGramStream &stream, util::FakeOFStream &to) {
+ void DumpStream(NGramStream<BuildingPayload> &stream, util::FakeOFStream &to) {
for (; stream; ++stream) {
PrintPayload<V>(to, stream->Value());
for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) {
@@ -103,6 +106,8 @@ class PrintARPA : public OutputHook {
explicit PrintARPA(int fd, bool verbose_header)
: OutputHook(PROB_SEQUENTIAL_HOOK), out_fd_(fd), verbose_header_(verbose_header) {}
+ void Sink(util::stream::Chains &chains);
+
void Run(const util::stream::ChainPositions &positions);
private:
diff --git a/lm/builder/special.hh b/lm/builder/special.hh
new file mode 100644
index 000000000..c70865ce1
--- /dev/null
+++ b/lm/builder/special.hh
@@ -0,0 +1,27 @@
+#ifndef LM_BUILDER_SPECIAL_H
+#define LM_BUILDER_SPECIAL_H
+
+#include "lm/word_index.hh"
+
+namespace lm { namespace builder {
+
+class SpecialVocab {
+ public:
+ SpecialVocab(WordIndex bos, WordIndex eos) : bos_(bos), eos_(eos) {}
+
+ bool IsSpecial(WordIndex word) const {
+ return word == kUNK || word == bos_ || word == eos_;
+ }
+
+ WordIndex UNK() const { return kUNK; }
+ WordIndex BOS() const { return bos_; }
+ WordIndex EOS() const { return eos_; }
+
+ private:
+ WordIndex bos_;
+ WordIndex eos_;
+};
+
+}} // namespaces
+
+#endif // LM_BUILDER_SPECIAL_H
diff --git a/lm/common/Jamfile b/lm/common/Jamfile
new file mode 100644
index 000000000..1c9c37210
--- /dev/null
+++ b/lm/common/Jamfile
@@ -0,0 +1,2 @@
+fakelib common : [ glob *.cc : *test.cc *main.cc ]
+ ../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm ;
diff --git a/lm/builder/sort.hh b/lm/common/compare.hh
index ed20b4b79..1c7cd2499 100644
--- a/lm/builder/sort.hh
+++ b/lm/common/compare.hh
@@ -1,18 +1,12 @@
-#ifndef LM_BUILDER_SORT_H
-#define LM_BUILDER_SORT_H
+#ifndef LM_COMMON_COMPARE_H
+#define LM_COMMON_COMPARE_H
-#include "lm/builder/ngram_stream.hh"
-#include "lm/builder/ngram.hh"
#include "lm/word_index.hh"
-#include "util/stream/sort.hh"
-
-#include "util/stream/timer.hh"
#include <functional>
#include <string>
namespace lm {
-namespace builder {
/**
* Abstract parent class for defining custom n-gram comparators.
@@ -175,70 +169,6 @@ class PrefixOrder : public Comparator<PrefixOrder> {
static const unsigned kMatchOffset = 0;
};
-// Sum counts for the same n-gram.
-struct AddCombiner {
- bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
- NGram first(first_void, compare.Order());
- // There isn't a const version of NGram.
- NGram second(const_cast<void*>(second_void), compare.Order());
- if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
- first.Count() += second.Count();
- return true;
- }
-};
-
-// The combiner is only used on a single chain, so I didn't bother to allow
-// that template.
-/**
- * Represents an @ref util::FixedArray "array" capable of storing @ref util::stream::Sort "Sort" objects.
- *
- * In the anticipated use case, an instance of this class will maintain one @ref util::stream::Sort "Sort" object
- * for each n-gram order (ranging from 1 up to the maximum n-gram order being processed).
- * Use in this manner would enable the n-grams each n-gram order to be sorted, in parallel.
- *
- * @tparam Compare An @ref Comparator "ngram comparator" to use during sorting.
- */
-template <class Compare> class Sorts : public util::FixedArray<util::stream::Sort<Compare> > {
- private:
- typedef util::stream::Sort<Compare> S;
- typedef util::FixedArray<S> P;
-
- public:
-
- /**
- * Constructs, but does not initialize.
- *
- * @ref util::FixedArray::Init() "Init" must be called before use.
- *
- * @see util::FixedArray::Init()
- */
- Sorts() {}
-
- /**
- * Constructs an @ref util::FixedArray "array" capable of storing a fixed number of @ref util::stream::Sort "Sort" objects.
- *
- * @param number The maximum number of @ref util::stream::Sort "sorters" that can be held by this @ref util::FixedArray "array"
- * @see util::FixedArray::FixedArray()
- */
- explicit Sorts(std::size_t number) : util::FixedArray<util::stream::Sort<Compare> >(number) {}
-
- /**
- * Constructs a new @ref util::stream::Sort "Sort" object which is stored in this @ref util::FixedArray "array".
- *
- * The new @ref util::stream::Sort "Sort" object is constructed using the provided @ref util::stream::SortConfig "SortConfig" and @ref Comparator "ngram comparator";
- * once constructed, a new worker @ref util::stream::Thread "thread" (owned by the @ref util::stream::Chain "chain") will sort the n-gram data stored
- * in the @ref util::stream::Block "blocks" of the provided @ref util::stream::Chain "chain".
- *
- * @see util::stream::Sort::Sort()
- * @see util::stream::Chain::operator>>()
- */
- void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
- new (P::end()) S(chain, config, compare); // use "placement new" syntax to initalize S in an already-allocated memory location
- P::Constructed();
- }
-};
-
-} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_SORT_H
+#endif // LM_COMMON_COMPARE_H
diff --git a/lm/common/model_buffer.cc b/lm/common/model_buffer.cc
new file mode 100644
index 000000000..d4635da51
--- /dev/null
+++ b/lm/common/model_buffer.cc
@@ -0,0 +1,82 @@
+#include "lm/common/model_buffer.hh"
+#include "util/exception.hh"
+#include "util/fake_ofstream.hh"
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/stream/io.hh"
+#include "util/stream/multi_stream.hh"
+
+#include <boost/lexical_cast.hpp>
+
+namespace lm { namespace common {
+
+namespace {
+const char kMetadataHeader[] = "KenLM intermediate binary file";
+} // namespace
+
+ModelBuffer::ModelBuffer(const std::string &file_base, bool keep_buffer, bool output_q)
+ : file_base_(file_base), keep_buffer_(keep_buffer), output_q_(output_q) {}
+
+ModelBuffer::ModelBuffer(const std::string &file_base)
+ : file_base_(file_base), keep_buffer_(false) {
+ const std::string full_name = file_base_ + ".kenlm_intermediate";
+ util::FilePiece in(full_name.c_str());
+ StringPiece token = in.ReadLine();
+ UTIL_THROW_IF2(token != kMetadataHeader, "File " << full_name << " begins with \"" << token << "\" not " << kMetadataHeader);
+
+ token = in.ReadDelimited();
+ UTIL_THROW_IF2(token != "Order", "Expected Order, got \"" << token << "\" in " << full_name);
+ unsigned long order = in.ReadULong();
+
+ token = in.ReadDelimited();
+ UTIL_THROW_IF2(token != "Payload", "Expected Payload, got \"" << token << "\" in " << full_name);
+ token = in.ReadDelimited();
+ if (token == "q") {
+ output_q_ = true;
+ } else if (token == "pb") {
+ output_q_ = false;
+ } else {
+ UTIL_THROW(util::Exception, "Unknown payload " << token);
+ }
+
+ files_.Init(order);
+ for (unsigned long i = 0; i < order; ++i) {
+ files_.push_back(util::OpenReadOrThrow((file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()));
+ }
+}
+
+// virtual destructor
+ModelBuffer::~ModelBuffer() {}
+
+void ModelBuffer::Sink(util::stream::Chains &chains) {
+ // Open files.
+ files_.Init(chains.size());
+ for (std::size_t i = 0; i < chains.size(); ++i) {
+ if (keep_buffer_) {
+ files_.push_back(util::CreateOrThrow(
+ (file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()
+ ));
+ } else {
+ files_.push_back(util::MakeTemp(file_base_));
+ }
+ chains[i] >> util::stream::Write(files_.back().get());
+ }
+ if (keep_buffer_) {
+ util::scoped_fd metadata(util::CreateOrThrow((file_base_ + ".kenlm_intermediate").c_str()));
+ util::FakeOFStream meta(metadata.get(), 200);
+ meta << kMetadataHeader << "\nOrder " << chains.size() << "\nPayload " << (output_q_ ? "q" : "pb") << '\n';
+ }
+}
+
+void ModelBuffer::Source(util::stream::Chains &chains) {
+ assert(chains.size() == files_.size());
+ for (unsigned int i = 0; i < files_.size(); ++i) {
+ chains[i] >> util::stream::PRead(files_[i].get());
+ }
+}
+
+std::size_t ModelBuffer::Order() const {
+ return files_.size();
+}
+
+}} // namespaces
diff --git a/lm/common/model_buffer.hh b/lm/common/model_buffer.hh
new file mode 100644
index 000000000..6a5c7bf49
--- /dev/null
+++ b/lm/common/model_buffer.hh
@@ -0,0 +1,45 @@
+#ifndef LM_BUILDER_MODEL_BUFFER_H
+#define LM_BUILDER_MODEL_BUFFER_H
+
+/* Format with separate files in suffix order. Each file contains
+ * n-grams of the same order.
+ */
+
+#include "util/file.hh"
+#include "util/fixed_array.hh"
+
+#include <string>
+
+namespace util { namespace stream { class Chains; } }
+
+namespace lm { namespace common {
+
+class ModelBuffer {
+ public:
+ // Construct for writing.
+ ModelBuffer(const std::string &file_base, bool keep_buffer, bool output_q);
+
+ // Load from file.
+ explicit ModelBuffer(const std::string &file_base);
+
+ // explicit for virtual destructor.
+ ~ModelBuffer();
+
+ void Sink(util::stream::Chains &chains);
+
+ void Source(util::stream::Chains &chains);
+
+ // The order of the n-gram model that is associated with the model buffer.
+ std::size_t Order() const;
+
+ private:
+ const std::string file_base_;
+ const bool keep_buffer_;
+ bool output_q_;
+
+ util::FixedArray<util::scoped_fd> files_;
+};
+
+}} // namespaces
+
+#endif // LM_BUILDER_MODEL_BUFFER_H
diff --git a/lm/builder/ngram.hh b/lm/common/ngram.hh
index d0033206c..813017640 100644
--- a/lm/builder/ngram.hh
+++ b/lm/common/ngram.hh
@@ -1,5 +1,5 @@
-#ifndef LM_BUILDER_NGRAM_H
-#define LM_BUILDER_NGRAM_H
+#ifndef LM_COMMON_NGRAM_H
+#define LM_COMMON_NGRAM_H
#include "lm/weights.hh"
#include "lm/word_index.hh"
@@ -10,22 +10,10 @@
#include <cstring>
namespace lm {
-namespace builder {
-struct Uninterpolated {
- float prob; // Uninterpolated probability.
- float gamma; // Interpolation weight for lower order.
-};
-
-union Payload {
- uint64_t count;
- Uninterpolated uninterp;
- ProbBackoff complete;
-};
-
-class NGram {
+class NGramHeader {
public:
- NGram(void *begin, std::size_t order)
+ NGramHeader(void *begin, std::size_t order)
: begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {}
const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); }
@@ -37,24 +25,29 @@ class NGram {
end_ = begin_ + difference;
}
- // Would do operator++ but that can get confusing for a stream.
- void NextInMemory() {
- ReBase(&Value() + 1);
- }
-
+ // These are for the vocab index.
// Lower-case in deference to STL.
const WordIndex *begin() const { return begin_; }
WordIndex *begin() { return begin_; }
const WordIndex *end() const { return end_; }
WordIndex *end() { return end_; }
- const Payload &Value() const { return *reinterpret_cast<const Payload *>(end_); }
- Payload &Value() { return *reinterpret_cast<Payload *>(end_); }
+ std::size_t Order() const { return end_ - begin_; }
- uint64_t &Count() { return Value().count; }
- uint64_t Count() const { return Value().count; }
+ private:
+ WordIndex *begin_, *end_;
+};
- std::size_t Order() const { return end_ - begin_; }
+template <class PayloadT> class NGram : public NGramHeader {
+ public:
+ typedef PayloadT Payload;
+
+ NGram(void *begin, std::size_t order) : NGramHeader(begin, order) {}
+
+ // Would do operator++ but that can get confusing for a stream.
+ void NextInMemory() {
+ ReBase(&Value() + 1);
+ }
static std::size_t TotalSize(std::size_t order) {
return order * sizeof(WordIndex) + sizeof(Payload);
@@ -63,46 +56,17 @@ class NGram {
// Compiler should optimize this.
return TotalSize(Order());
}
+
static std::size_t OrderFromSize(std::size_t size) {
std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex);
assert(size == TotalSize(ret));
return ret;
}
- // manipulate msb to signal that ngram can be pruned
- /*mjd**********************************************************************/
-
- bool IsMarked() const {
- return Value().count >> (sizeof(Value().count) * 8 - 1);
- }
-
- void Mark() {
- Value().count |= (1ul << (sizeof(Value().count) * 8 - 1));
- }
-
- void Unmark() {
- Value().count &= ~(1ul << (sizeof(Value().count) * 8 - 1));
- }
-
- uint64_t UnmarkedCount() const {
- return Value().count & ~(1ul << (sizeof(Value().count) * 8 - 1));
- }
-
- uint64_t CutoffCount() const {
- return IsMarked() ? 0 : UnmarkedCount();
- }
-
- /*mjd**********************************************************************/
-
- private:
- WordIndex *begin_, *end_;
+ const Payload &Value() const { return *reinterpret_cast<const Payload *>(end()); }
+ Payload &Value() { return *reinterpret_cast<Payload *>(end()); }
};
-const WordIndex kUNK = 0;
-const WordIndex kBOS = 1;
-const WordIndex kEOS = 2;
-
-} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_NGRAM_H
+#endif // LM_COMMON_NGRAM_H
diff --git a/lm/builder/ngram_stream.hh b/lm/common/ngram_stream.hh
index ab42734c4..53c4ffcb8 100644
--- a/lm/builder/ngram_stream.hh
+++ b/lm/common/ngram_stream.hh
@@ -1,16 +1,16 @@
#ifndef LM_BUILDER_NGRAM_STREAM_H
#define LM_BUILDER_NGRAM_STREAM_H
-#include "lm/builder/ngram.hh"
+#include "lm/common/ngram.hh"
#include "util/stream/chain.hh"
#include "util/stream/multi_stream.hh"
#include "util/stream/stream.hh"
#include <cstddef>
-namespace lm { namespace builder {
+namespace lm {
-class NGramStream {
+template <class Payload> class NGramStream {
public:
NGramStream() : gram_(NULL, 0) {}
@@ -20,14 +20,14 @@ class NGramStream {
void Init(const util::stream::ChainPosition &position) {
stream_.Init(position);
- gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize()));
+ gram_ = NGram<Payload>(stream_.Get(), NGram<Payload>::OrderFromSize(position.GetChain().EntrySize()));
}
- NGram &operator*() { return gram_; }
- const NGram &operator*() const { return gram_; }
+ NGram<Payload> &operator*() { return gram_; }
+ const NGram<Payload> &operator*() const { return gram_; }
- NGram *operator->() { return &gram_; }
- const NGram *operator->() const { return &gram_; }
+ NGram<Payload> *operator->() { return &gram_; }
+ const NGram<Payload> *operator->() const { return &gram_; }
void *Get() { return stream_.Get(); }
const void *Get() const { return stream_.Get(); }
@@ -43,16 +43,22 @@ class NGramStream {
}
private:
- NGram gram_;
+ NGram<Payload> gram_;
util::stream::Stream stream_;
};
-inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) {
+template <class Payload> inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream<Payload> &str) {
str.Init(chain.Add());
return chain;
}
-typedef util::stream::GenericStreams<NGramStream> NGramStreams;
+template <class Payload> class NGramStreams : public util::stream::GenericStreams<NGramStream<Payload> > {
+ private:
+ typedef util::stream::GenericStreams<NGramStream<Payload> > P;
+ public:
+ NGramStreams() : P() {}
+ NGramStreams(const util::stream::ChainPositions &positions) : P(positions) {}
+};
-}} // namespaces
+} // namespace
#endif // LM_BUILDER_NGRAM_STREAM_H
diff --git a/lm/common/renumber.cc b/lm/common/renumber.cc
new file mode 100644
index 000000000..0632a149b
--- /dev/null
+++ b/lm/common/renumber.cc
@@ -0,0 +1,17 @@
+#include "lm/common/renumber.hh"
+#include "lm/common/ngram.hh"
+
+#include "util/stream/stream.hh"
+
+namespace lm {
+
+void Renumber::Run(const util::stream::ChainPosition &position) {
+ for (util::stream::Stream stream(position); stream; ++stream) {
+ NGramHeader gram(stream.Get(), order_);
+ for (WordIndex *w = gram.begin(); w != gram.end(); ++w) {
+ *w = new_numbers_[*w];
+ }
+ }
+}
+
+} // namespace lm
diff --git a/lm/common/renumber.hh b/lm/common/renumber.hh
new file mode 100644
index 000000000..ca25c4dc6
--- /dev/null
+++ b/lm/common/renumber.hh
@@ -0,0 +1,30 @@
+/* Map vocab ids. This is useful to merge independently collected counts or
+ * change the vocab ids to the order used by the trie.
+ */
+#ifndef LM_COMMON_RENUMBER_H
+#define LM_COMMON_RENUMBER_H
+
+#include "lm/word_index.hh"
+
+#include <cstddef>
+
+namespace util { namespace stream { class ChainPosition; }}
+
+namespace lm {
+
+class Renumber {
+ public:
+ // Assumes the array is large enough to map all words and stays alive while
+ // the thread is active.
+ Renumber(const WordIndex *new_numbers, std::size_t order)
+ : new_numbers_(new_numbers), order_(order) {}
+
+ void Run(const util::stream::ChainPosition &position);
+
+ private:
+ const WordIndex *new_numbers_;
+ std::size_t order_;
+};
+
+} // namespace lm
+#endif // LM_COMMON_RENUMBER_H
diff --git a/lm/kenlm_benchmark_main.cc b/lm/kenlm_benchmark_main.cc
new file mode 100644
index 000000000..d8b659139
--- /dev/null
+++ b/lm/kenlm_benchmark_main.cc
@@ -0,0 +1,128 @@
+#include "lm/model.hh"
+#include "util/fake_ofstream.hh"
+#include "util/file.hh"
+#include "util/file_piece.hh"
+#include "util/usage.hh"
+
+#include <stdint.h>
+
+namespace {
+
+template <class Model, class Width> void ConvertToBytes(const Model &model, int fd_in) {
+ util::FilePiece in(fd_in);
+ util::FakeOFStream out(1);
+ Width width;
+ StringPiece word;
+ const Width end_sentence = (Width)model.GetVocabulary().EndSentence();
+ while (true) {
+ while (in.ReadWordSameLine(word)) {
+ width = (Width)model.GetVocabulary().Index(word);
+ out.write(&width, sizeof(Width));
+ }
+ if (!in.ReadLineOrEOF(word)) break;
+ out.write(&end_sentence, sizeof(Width));
+ }
+}
+
+template <class Model, class Width> void QueryFromBytes(const Model &model, int fd_in) {
+ lm::ngram::State state[3];
+ const lm::ngram::State *const begin_state = &model.BeginSentenceState();
+ const lm::ngram::State *next_state = begin_state;
+ Width kEOS = model.GetVocabulary().EndSentence();
+ Width buf[4096];
+ float sum = 0.0;
+ while (true) {
+ std::size_t got = util::ReadOrEOF(fd_in, buf, sizeof(buf));
+ if (!got) break;
+ UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width));
+ got /= sizeof(Width);
+ // Do even stuff first.
+ const Width *even_end = buf + (got & ~1);
+ // Alternating states
+ const Width *i;
+ for (i = buf; i != even_end;) {
+ sum += model.FullScore(*next_state, *i, state[1]).prob;
+ next_state = (*i++ == kEOS) ? begin_state : &state[1];
+ sum += model.FullScore(*next_state, *i, state[0]).prob;
+ next_state = (*i++ == kEOS) ? begin_state : &state[0];
+ }
+ // Odd corner case.
+ if (got & 1) {
+ sum += model.FullScore(*next_state, *i, state[2]).prob;
+ next_state = (*i++ == kEOS) ? begin_state : &state[2];
+ }
+ }
+ std::cout << "Sum is " << sum << std::endl;
+}
+
+template <class Model, class Width> void DispatchFunction(const Model &model, bool query) {
+ if (query) {
+ QueryFromBytes<Model, Width>(model, 0);
+ } else {
+ ConvertToBytes<Model, Width>(model, 0);
+ }
+}
+
+template <class Model> void DispatchWidth(const char *file, bool query) {
+ Model model(file);
+ lm::WordIndex bound = model.GetVocabulary().Bound();
+ if (bound <= 256) {
+ DispatchFunction<Model, uint8_t>(model, query);
+ } else if (bound <= 65536) {
+ DispatchFunction<Model, uint16_t>(model, query);
+ } else if (bound <= (1ULL << 32)) {
+ DispatchFunction<Model, uint32_t>(model, query);
+ } else {
+ DispatchFunction<Model, uint64_t>(model, query);
+ }
+}
+
+void Dispatch(const char *file, bool query) {
+ using namespace lm::ngram;
+ lm::ngram::ModelType model_type;
+ if (lm::ngram::RecognizeBinary(file, model_type)) {
+ switch(model_type) {
+ case PROBING:
+ DispatchWidth<lm::ngram::ProbingModel>(file, query);
+ break;
+ case REST_PROBING:
+ DispatchWidth<lm::ngram::RestProbingModel>(file, query);
+ break;
+ case TRIE:
+ DispatchWidth<lm::ngram::TrieModel>(file, query);
+ break;
+ case QUANT_TRIE:
+ DispatchWidth<lm::ngram::QuantTrieModel>(file, query);
+ break;
+ case ARRAY_TRIE:
+ DispatchWidth<lm::ngram::ArrayTrieModel>(file, query);
+ break;
+ case QUANT_ARRAY_TRIE:
+ DispatchWidth<lm::ngram::QuantArrayTrieModel>(file, query);
+ break;
+ default:
+ UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type);
+ }
+ } else {
+ UTIL_THROW(util::Exception, "Binarize before running benchmarks.");
+ }
+}
+
+} // namespace
+
+int main(int argc, char *argv[]) {
+ if (argc != 3 || (strcmp(argv[1], "vocab") && strcmp(argv[1], "query"))) {
+ std::cerr
+ << "Benchmark program for KenLM. Intended usage:\n"
+ << "#Convert text to vocabulary ids offline. These ids are tied to a model.\n"
+ << argv[0] << " vocab $model <$text >$text.vocab\n"
+ << "#Ensure files are in RAM.\n"
+ << "cat $text.vocab $model >/dev/null\n"
+ << "#Timed query against the model, including loading.\n"
+ << "time " << argv[0] << " query $model <$text.vocab\n";
+ return 1;
+ }
+ Dispatch(argv[2], !strcmp(argv[1], "query"));
+ util::PrintUsage(std::cerr);
+ return 0;
+}
diff --git a/lm/ngram_query.hh b/lm/ngram_query.hh
index 937fe2421..b19c5aa4f 100644
--- a/lm/ngram_query.hh
+++ b/lm/ngram_query.hh
@@ -3,45 +3,53 @@
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
+#include "util/fake_ofstream.hh"
#include "util/file_piece.hh"
#include "util/usage.hh"
#include <cstdlib>
-#include <iostream>
-#include <ostream>
-#include <istream>
#include <string>
#include <cmath>
namespace lm {
namespace ngram {
-struct BasicPrint {
- void Word(StringPiece, WordIndex, const FullScoreReturn &) const {}
- void Line(uint64_t oov, float total) const {
- std::cout << "Total: " << total << " OOV: " << oov << '\n';
- }
- void Summary(double, double, uint64_t, uint64_t) {}
+class QueryPrinter {
+ public:
+ QueryPrinter(int fd, bool print_word, bool print_line, bool print_summary, bool flush)
+ : out_(fd), print_word_(print_word), print_line_(print_line), print_summary_(print_summary), flush_(flush) {}
-};
+ void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) {
+ if (!print_word_) return;
+ out_ << surface << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
+ if (flush_) out_.flush();
+ }
-struct FullPrint : public BasicPrint {
- void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) const {
- std::cout << surface << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
- }
+ void Line(uint64_t oov, float total) {
+ if (!print_line_) return;
+ out_ << "Total: " << total << " OOV: " << oov << '\n';
+ if (flush_) out_.flush();
+ }
- void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) {
- std::cout <<
- "Perplexity including OOVs:\t" << ppl_including_oov << "\n"
- "Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n"
- "OOVs:\t" << corpus_oov << "\n"
- "Tokens:\t" << corpus_tokens << '\n'
- ;
- }
+ void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) {
+ if (!print_summary_) return;
+ out_ <<
+ "Perplexity including OOVs:\t" << ppl_including_oov << "\n"
+ "Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n"
+ "OOVs:\t" << corpus_oov << "\n"
+ "Tokens:\t" << corpus_tokens << '\n';
+ out_.flush();
+ }
+
+ private:
+ util::FakeOFStream out_;
+ bool print_word_;
+ bool print_line_;
+ bool print_summary_;
+ bool flush_;
};
-template <class Model, class Printer> void Query(const Model &model, bool sentence_context) {
- Printer printer;
+template <class Model, class Printer> void Query(const Model &model, bool sentence_context, Printer &printer) {
typename Model::State state, out;
lm::FullScoreReturn ret;
StringPiece word;
@@ -92,13 +100,9 @@ template <class Model, class Printer> void Query(const Model &model, bool senten
corpus_tokens);
}
-template <class Model> void Query(const char *file, const Config &config, bool sentence_context, bool show_words) {
+template <class Model> void Query(const char *file, const Config &config, bool sentence_context, QueryPrinter &printer) {
Model model(file, config);
- if (show_words) {
- Query<Model, FullPrint>(model, sentence_context);
- } else {
- Query<Model, BasicPrint>(model, sentence_context);
- }
+ Query<Model, QueryPrinter>(model, sentence_context, printer);
}
} // namespace ngram
diff --git a/lm/query_main.cc b/lm/query_main.cc
index 3013ff21e..0bd28f7a9 100644
--- a/lm/query_main.cc
+++ b/lm/query_main.cc
@@ -10,9 +10,10 @@
void Usage(const char *name) {
std::cerr <<
"KenLM was compiled with maximum order " << KENLM_MAX_ORDER << ".\n"
- "Usage: " << name << " [-n] [-s] lm_file\n"
+ "Usage: " << name << " [-b] [-n] [-w] [-s] lm_file\n"
+ "-b: Do not buffer output.\n"
"-n: Do not wrap the input in <s> and </s>.\n"
- "-s: Sentence totals only.\n"
+ "-v summary|sentence|word: Level of verbosity\n"
"-l lazy|populate|read|parallel: Load lazily, with populate, or malloc+read\n"
"The default loading method is populate on Linux and read on others.\n";
exit(1);
@@ -24,16 +25,28 @@ int main(int argc, char *argv[]) {
lm::ngram::Config config;
bool sentence_context = true;
- bool show_words = true;
+ unsigned int verbosity = 2;
+ bool flush = false;
int opt;
- while ((opt = getopt(argc, argv, "hnsl:")) != -1) {
+ while ((opt = getopt(argc, argv, "bnv:l:")) != -1) {
switch (opt) {
+ case 'b':
+ flush = true;
+ break;
case 'n':
sentence_context = false;
break;
- case 's':
- show_words = false;
+ case 'v':
+ if (!strcmp(optarg, "word") || !strcmp(optarg, "2")) {
+ verbosity = 2;
+ } else if (!strcmp(optarg, "sentence") || !strcmp(optarg, "1")) {
+ verbosity = 1;
+ } else if (!strcmp(optarg, "summary") || !strcmp(optarg, "0")) {
+ verbosity = 0;
+ } else {
+ Usage(argv[0]);
+ }
break;
case 'l':
if (!strcmp(optarg, "lazy")) {
@@ -55,6 +68,7 @@ int main(int argc, char *argv[]) {
}
if (optind + 1 != argc)
Usage(argv[0]);
+ lm::ngram::QueryPrinter printer(1, verbosity >= 2, verbosity >= 1, true, flush);
const char *file = argv[optind];
try {
using namespace lm::ngram;
@@ -62,22 +76,22 @@ int main(int argc, char *argv[]) {
if (RecognizeBinary(file, model_type)) {
switch(model_type) {
case PROBING:
- Query<lm::ngram::ProbingModel>(file, config, sentence_context, show_words);
+ Query<lm::ngram::ProbingModel>(file, config, sentence_context, printer);
break;
case REST_PROBING:
- Query<lm::ngram::RestProbingModel>(file, config, sentence_context, show_words);
+ Query<lm::ngram::RestProbingModel>(file, config, sentence_context, printer);
break;
case TRIE:
- Query<TrieModel>(file, config, sentence_context, show_words);
+ Query<TrieModel>(file, config, sentence_context, printer);
break;
case QUANT_TRIE:
- Query<QuantTrieModel>(file, config, sentence_context, show_words);
+ Query<QuantTrieModel>(file, config, sentence_context, printer);
break;
case ARRAY_TRIE:
- Query<ArrayTrieModel>(file, config, sentence_context, show_words);
+ Query<ArrayTrieModel>(file, config, sentence_context, printer);
break;
case QUANT_ARRAY_TRIE:
- Query<QuantArrayTrieModel>(file, config, sentence_context, show_words);
+ Query<QuantArrayTrieModel>(file, config, sentence_context, printer);
break;
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
@@ -86,14 +100,11 @@ int main(int argc, char *argv[]) {
#ifdef WITH_NPLM
} else if (lm::np::Model::Recognize(file)) {
lm::np::Model model(file);
- if (show_words) {
- Query<lm::np::Model, lm::ngram::FullPrint>(model, sentence_context);
- } else {
- Query<lm::np::Model, lm::ngram::BasicPrint>(model, sentence_context);
- }
+ Query<lm::np::Model, lm::ngram::QueryPrinter>(model, sentence_context, printer);
+ Query<lm::np::Model, lm::ngram::QueryPrinter>(model, sentence_context, printer);
#endif
} else {
- Query<ProbingModel>(file, config, sentence_context, show_words);
+ Query<ProbingModel>(file, config, sentence_context, printer);
}
util::PrintUsage(std::cerr);
} catch (const std::exception &e) {
diff --git a/lm/value.hh b/lm/value.hh
index d017d59fc..d2425cc13 100644
--- a/lm/value.hh
+++ b/lm/value.hh
@@ -1,6 +1,7 @@
#ifndef LM_VALUE_H
#define LM_VALUE_H
+#include "lm/config.hh"
#include "lm/model_type.hh"
#include "lm/value_build.hh"
#include "lm/weights.hh"
diff --git a/lm/vocab.cc b/lm/vocab.cc
index f6d834323..5696e60b3 100644
--- a/lm/vocab.cc
+++ b/lm/vocab.cc
@@ -6,13 +6,14 @@
#include "lm/config.hh"
#include "lm/weights.hh"
#include "util/exception.hh"
+#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/joint_sort.hh"
#include "util/murmur_hash.hh"
#include "util/probing_hash_table.hh"
-#include <string>
#include <cstring>
+#include <string>
namespace lm {
namespace ngram {
@@ -31,6 +32,7 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
+// TODO: replace with FilePiece.
void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) {
util::SeekOrThrow(fd, offset);
// Check that we're at the right place by reading <unk> which is always first.
@@ -69,10 +71,17 @@ void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint
UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file.");
}
+// Constructor ordering madness.
+int SeekAndReturn(int fd, uint64_t start) {
+ util::SeekOrThrow(fd, start);
+ return fd;
+}
} // namespace
+ImmediateWriteWordsWrapper::ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start)
+ : inner_(inner), stream_(SeekAndReturn(fd, start)) {}
+
WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {}
-WriteWordsWrapper::~WriteWordsWrapper() {}
void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
if (inner_) inner_->Add(index, str);
@@ -80,6 +89,14 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
+void WriteWordsWrapper::Write(int fd, uint64_t start) {
+ util::SeekOrThrow(fd, start);
+ util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
+ // Free memory from the string.
+ std::string for_swap;
+ std::swap(buffer_, for_swap);
+}
+
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
@@ -126,10 +143,78 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
return end_ - begin_;
}
-void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
+void SortedVocabulary::FinishedLoading(ProbBackoff *reorder) {
+ GenericFinished(reorder);
+}
+
+namespace {
+#pragma pack(push)
+#pragma pack(4)
+struct RenumberEntry {
+ uint64_t hash;
+ const char *str;
+ WordIndex old;
+ bool operator<(const RenumberEntry &other) const {
+ return hash < other.hash;
+ }
+};
+#pragma pack(pop)
+} // namespace
+
+void SortedVocabulary::ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping) {
+ mapping.clear();
+ uint64_t file_size = util::SizeOrThrow(from_words);
+ util::scoped_memory strings;
+ util::MapRead(util::POPULATE_OR_READ, from_words, 0, file_size, strings);
+ const char *const start = static_cast<const char*>(strings.get());
+ UTIL_THROW_IF(memcmp(start, "<unk>", 6), FormatLoadException, "Vocab file does not begin with <unk> followed by null");
+ std::vector<RenumberEntry> entries;
+ entries.reserve(types - 1);
+ RenumberEntry entry;
+ entry.old = 1;
+ for (entry.str = start + 6 /* skip <unk>\0 */; entry.str < start + file_size; ++entry.old) {
+ StringPiece str(entry.str, strlen(entry.str));
+ entry.hash = detail::HashForVocab(str);
+ entries.push_back(entry);
+ entry.str += str.size() + 1;
+ }
+ UTIL_THROW_IF2(entries.size() != types - 1, "Wrong number of vocab ids. Got " << (entries.size() + 1) << " expected " << types);
+ std::sort(entries.begin(), entries.end());
+ // Write out new vocab file.
+ {
+ util::FakeOFStream out(to_words);
+ out << "<unk>" << '\0';
+ for (std::vector<RenumberEntry>::const_iterator i = entries.begin(); i != entries.end(); ++i) {
+ out << i->str << '\0';
+ }
+ }
+ strings.reset();
+
+ mapping.resize(types);
+ mapping[0] = 0; // <unk>
+ for (std::vector<RenumberEntry>::const_iterator i = entries.begin(); i != entries.end(); ++i) {
+ mapping[i->old] = i + 1 - entries.begin();
+ }
+}
+
+void SortedVocabulary::Populated() {
+ saw_unk_ = true;
+ SetSpecial(Index("<s>"), Index("</s>"), 0);
+ bound_ = end_ - begin_ + 1;
+ *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
+}
+
+void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
+ end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
+ SetSpecial(Index("<s>"), Index("</s>"), 0);
+ bound_ = end_ - begin_ + 1;
+ if (have_words) ReadWords(fd, to, bound_, offset);
+}
+
+template <class T> void SortedVocabulary::GenericFinished(T *reorder) {
if (enumerate_) {
if (!strings_to_enumerate_.empty()) {
- util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
+ util::PairedIterator<T*, StringPiece*> values(reorder + 1, &*strings_to_enumerate_.begin());
util::JointSort(begin_, end_, values);
}
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
@@ -139,7 +224,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
strings_to_enumerate_.clear();
string_backing_.FreeAll();
} else {
- util::JointSort(begin_, end_, reorder_vocab + 1);
+ util::JointSort(begin_, end_, reorder + 1);
}
SetSpecial(Index("<s>"), Index("</s>"), 0);
// Save size. Excludes UNK.
@@ -148,13 +233,6 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
bound_ = end_ - begin_ + 1;
}
-void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
- end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
- SetSpecial(Index("<s>"), Index("</s>"), 0);
- bound_ = end_ - begin_ + 1;
- if (have_words) ReadWords(fd, to, bound_, offset);
-}
-
namespace {
const unsigned int kProbingVocabularyVersion = 0;
} // namespace
@@ -209,7 +287,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
}
}
-void ProbingVocabulary::FinishedLoading() {
+void ProbingVocabulary::InternalFinishedLoading() {
lookup_.FinishedInserting();
header_->bound = bound_;
header_->version = kProbingVocabularyVersion;
diff --git a/lm/vocab.hh b/lm/vocab.hh
index 2659b9ba8..b42566f23 100644
--- a/lm/vocab.hh
+++ b/lm/vocab.hh
@@ -30,15 +30,32 @@ inline uint64_t HashForVocab(const StringPiece &str) {
struct ProbingVocabularyHeader;
} // namespace detail
+// Writes words immediately to a file instead of buffering, because we know
+// where in the file to put them.
+class ImmediateWriteWordsWrapper : public EnumerateVocab {
+ public:
+ ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start);
+
+ void Add(WordIndex index, const StringPiece &str) {
+ stream_ << str << '\0';
+ if (inner_) inner_->Add(index, str);
+ }
+
+ private:
+ EnumerateVocab *inner_;
+
+ util::FakeOFStream stream_;
+};
+
+// When the binary size isn't known yet.
class WriteWordsWrapper : public EnumerateVocab {
public:
WriteWordsWrapper(EnumerateVocab *inner);
- ~WriteWordsWrapper();
-
void Add(WordIndex index, const StringPiece &str);
const std::string &Buffer() const { return buffer_; }
+ void Write(int fd, uint64_t start);
private:
EnumerateVocab *inner_;
@@ -67,6 +84,12 @@ class SortedVocabulary : public base::Vocabulary {
// Size for purposes of file writing
static uint64_t Size(uint64_t entries, const Config &config);
+ /* Read null-delimited words from file from_words, renumber according to
+ * hash order, write null-delimited words to to_words, and create a mapping
+ * from old id to new id. The 0th vocab word must be <unk>.
+ */
+ static void ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping);
+
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
WordIndex Bound() const { return bound_; }
@@ -77,8 +100,8 @@ class SortedVocabulary : public base::Vocabulary {
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
+ // Insert and FinishedLoading go together.
WordIndex Insert(const StringPiece &str);
-
// Reorders reorder_vocab so that the IDs are sorted.
void FinishedLoading(ProbBackoff *reorder_vocab);
@@ -89,7 +112,13 @@ class SortedVocabulary : public base::Vocabulary {
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
+ uint64_t *&EndHack() { return end_; }
+
+ void Populated();
+
private:
+ template <class T> void GenericFinished(T *reorder);
+
uint64_t *begin_, *end_;
WordIndex bound_;
@@ -153,9 +182,8 @@ class ProbingVocabulary : public base::Vocabulary {
WordIndex Insert(const StringPiece &str);
template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) {
- FinishedLoading();
+ InternalFinishedLoading();
}
- void FinishedLoading();
std::size_t UnkCountChangePadding() const { return 0; }
@@ -164,6 +192,8 @@ class ProbingVocabulary : public base::Vocabulary {
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
+ void InternalFinishedLoading();
+
typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup;
Lookup lookup_;
diff --git a/lm/word_index.hh b/lm/word_index.hh
index ad59a7c2f..59b24d7d2 100644
--- a/lm/word_index.hh
+++ b/lm/word_index.hh
@@ -7,6 +7,7 @@
namespace lm {
typedef unsigned int WordIndex;
const WordIndex kMaxWordIndex = UINT_MAX;
+const WordIndex kUNK = 0;
} // namespace lm
typedef lm::WordIndex LMWordIndex;