Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2015-01-22 19:42:46 +0300
committerKenneth Heafield <github@kheafield.com>2015-01-22 19:42:46 +0300
commit769c19d10ca8e05f7e983dc4553f3c4c80968fc9 (patch)
treed3ba24b9a4e66a68eea3d6df62e879d0664f4f79 /lm
parent9235534269e8047942fecd6d5c3f2244485d4acd (diff)
KenLM a6d57501dcac95a31719a8628f6cbd288f6741e2 including Marcin's fixed pruning
Diffstat (limited to 'lm')
-rw-r--r--lm/build_binary_main.cc1
-rw-r--r--lm/builder/adjust_counts.cc131
-rw-r--r--lm/builder/adjust_counts.hh5
-rw-r--r--lm/builder/corpus_count.cc28
-rw-r--r--lm/builder/corpus_count.hh5
-rw-r--r--lm/builder/header_info.hh12
-rw-r--r--lm/builder/initial_probabilities.cc52
-rw-r--r--lm/builder/initial_probabilities.hh3
-rw-r--r--lm/builder/interpolate.cc17
-rw-r--r--lm/builder/interpolate.hh3
-rw-r--r--lm/builder/joint_order.hh26
-rw-r--r--lm/builder/lmplz_main.cc22
-rw-r--r--lm/builder/output.cc14
-rw-r--r--lm/builder/output.hh89
-rw-r--r--lm/builder/pipeline.cc41
-rw-r--r--lm/builder/pipeline.hh9
-rw-r--r--lm/builder/print.cc41
-rw-r--r--lm/builder/print.hh52
-rw-r--r--lm/config.cc2
-rw-r--r--lm/config.hh4
-rw-r--r--lm/max_order.hh4
-rw-r--r--lm/search_trie.cc2
22 files changed, 421 insertions, 142 deletions
diff --git a/lm/build_binary_main.cc b/lm/build_binary_main.cc
index 15b421e9f..2af2222e3 100644
--- a/lm/build_binary_main.cc
+++ b/lm/build_binary_main.cc
@@ -137,6 +137,7 @@ int main(int argc, char *argv[]) {
case 't': // legacy
case 'T':
config.temporary_directory_prefix = optarg;
+ util::NormalizeTempPrefix(config.temporary_directory_prefix);
break;
case 'm': // legacy
config.building_memory = ParseUInt(optarg) * 1048576;
diff --git a/lm/builder/adjust_counts.cc b/lm/builder/adjust_counts.cc
index 803c557d0..03ccbb934 100644
--- a/lm/builder/adjust_counts.cc
+++ b/lm/builder/adjust_counts.cc
@@ -4,6 +4,7 @@
#include <algorithm>
#include <iostream>
+#include <limits>
namespace lm { namespace builder {
@@ -108,9 +109,10 @@ class StatCollector {
// order but we don't care because the data is going to be sorted again.
class CollapseStream {
public:
- CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold) :
+ CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold, const std::vector<bool>& prune_words) :
current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
prune_threshold_(prune_threshold),
+ prune_words_(prune_words),
block_(position) {
StartBlock();
}
@@ -132,6 +134,15 @@ class CollapseStream {
current_.Mark();
}
+ if(!prune_words_.empty()) {
+ for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
+ if(prune_words_[*i]) {
+ current_.Mark();
+ break;
+ }
+ }
+ }
+
}
current_.NextInMemory();
@@ -146,6 +157,15 @@ class CollapseStream {
if(current_.Count() <= prune_threshold_) {
current_.Mark();
}
+
+ if(!prune_words_.empty()) {
+ for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
+ if(prune_words_[*i]) {
+ current_.Mark();
+ break;
+ }
+ }
+ }
return *this;
}
@@ -164,6 +184,15 @@ class CollapseStream {
if(current_.Count() <= prune_threshold_) {
current_.Mark();
}
+
+ if(!prune_words_.empty()) {
+ for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
+ if(prune_words_[*i]) {
+ current_.Mark();
+ break;
+ }
+ }
+ }
}
@@ -179,6 +208,7 @@ class CollapseStream {
// Goes backwards in the block
uint8_t *copy_from_;
uint64_t prune_threshold_;
+ const std::vector<bool>& prune_words_;
util::stream::Link block_;
};
@@ -192,8 +222,19 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
if (order == 1) {
// Only unigrams. Just collect stats.
- for (NGramStream full(positions[0]); full; ++full)
- stats.AddFull(full->Count());
+ for (NGramStream full(positions[0]); full; ++full) {
+
+ // Do not prune <s> </s> <unk>
+ if(*full->begin() > 2) {
+ if(full->Count() <= prune_thresholds_[0])
+ full->Mark();
+
+ if(!prune_words_.empty() && prune_words_[*full->begin()])
+ full->Mark();
+ }
+
+ stats.AddFull(full->UnmarkedCount(), full->IsMarked());
+ }
stats.CalculateDiscounts(discount_config_);
return;
@@ -202,56 +243,67 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
NGramStreams streams;
streams.Init(positions, positions.size() - 1);
- CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back());
+ CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back(), prune_words_);
// Initialization: <unk> has count 0 and so does <s>.
NGramStream *lower_valid = streams.begin();
+ const NGramStream *const streams_begin = streams.begin();
streams[0]->Count() = 0;
*streams[0]->begin() = kUNK;
stats.Add(0, 0);
(++streams[0])->Count() = 0;
*streams[0]->begin() = kBOS;
- // not in stats because it will get put in later.
+ // <s> is not in stats yet because it will get put in later.
- std::vector<uint64_t> lower_counts(positions.size(), 0);
+ // This keeps track of actual counts for lower orders. It is not output
+ // (only adjusted counts are), but used to determine pruning.
+ std::vector<uint64_t> actual_counts(positions.size(), 0);
+ // Something of a hack: don't prune <s>.
+ actual_counts[0] = std::numeric_limits<uint64_t>::max();
- // iterate over full (the stream of the highest order ngrams)
- for (; full; ++full) {
+ // Iterate over full (the stream of the highest order ngrams)
+ for (; full; ++full) {
const WordIndex *different = FindDifference(*full, **lower_valid);
std::size_t same = full->end() - 1 - different;
- // Increment the adjusted count.
- if (same) ++streams[same - 1]->Count();
- // Output all the valid ones that changed.
+ // STEP 1: Output all the n-grams that changed.
for (; lower_valid >= &streams[same]; --lower_valid) {
-
- // mjd: review this!
- uint64_t order = (*lower_valid)->Order();
- uint64_t realCount = lower_counts[order - 1];
- if(order > 1 && prune_thresholds_[order - 1] && realCount <= prune_thresholds_[order - 1])
+ uint64_t order_minus_1 = lower_valid - streams_begin;
+ if(actual_counts[order_minus_1] <= prune_thresholds_[order_minus_1])
(*lower_valid)->Mark();
- stats.Add(lower_valid - streams.begin(), (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked());
+ if(!prune_words_.empty()) {
+ for(WordIndex* i = (*lower_valid)->begin(); i != (*lower_valid)->end(); i++) {
+ if(prune_words_[*i]) {
+ (*lower_valid)->Mark();
+ break;
+ }
+ }
+ }
+
+ stats.Add(order_minus_1, (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked());
++*lower_valid;
}
-
- // Count the true occurrences of lower-order n-grams
- for (std::size_t i = 0; i < lower_counts.size(); ++i) {
- if (i >= same) {
- lower_counts[i] = 0;
- }
- lower_counts[i] += full->UnmarkedCount();
+
+ // STEP 2: Update n-grams that still match.
+ // n-grams that match get count from the full entry.
+ for (std::size_t i = 0; i < same; ++i) {
+ actual_counts[i] += full->UnmarkedCount();
}
+ // Increment the number of unique extensions for the longest match.
+ if (same) ++streams[same - 1]->Count();
+ // STEP 3: Initialize new n-grams.
// This is here because bos is also const WordIndex *, so copy gets
// consistent argument types.
const WordIndex *full_end = full->end();
// Initialize and mark as valid up to bos.
const WordIndex *bos;
for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
- ++lower_valid;
- std::copy(bos, full_end, (*lower_valid)->begin());
- (*lower_valid)->Count() = 1;
+ NGramStream &to = *++lower_valid;
+ std::copy(bos, full_end, to->begin());
+ to->Count() = 1;
+ actual_counts[lower_valid - streams_begin] = full->UnmarkedCount();
}
// Now bos indicates where <s> is or is the 0th word of full.
if (bos != full->begin()) {
@@ -259,19 +311,32 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
NGramStream &to = *++lower_valid;
std::copy(bos, full_end, to->begin());
- // mjd: what is this doing?
- to->Count() = full->UnmarkedCount();
+ // Anything that begins with <s> has full non adjusted count.
+ to->Count() = full->UnmarkedCount();
+ actual_counts[lower_valid - streams_begin] = full->UnmarkedCount();
} else {
- stats.AddFull(full->UnmarkedCount(), full->IsMarked());
+ stats.AddFull(full->UnmarkedCount(), full->IsMarked());
}
assert(lower_valid >= &streams[0]);
}
- // Output everything valid.
+ // The above loop outputs n-grams when it observes changes. This outputs
+ // the last n-grams.
for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
- if((*s)->Count() <= prune_thresholds_[(*s)->Order() - 1])
+ uint64_t lower_count = actual_counts[(*s)->Order() - 1];
+ if(lower_count <= prune_thresholds_[(*s)->Order() - 1])
(*s)->Mark();
- stats.Add(s - streams.begin(), (*s)->UnmarkedCount(), (*s)->IsMarked());
+
+ if(!prune_words_.empty()) {
+ for(WordIndex* i = (*s)->begin(); i != (*s)->end(); i++) {
+ if(prune_words_[*i]) {
+ (*s)->Mark();
+ break;
+ }
+ }
+ }
+
+ stats.Add(s - streams.begin(), lower_count, (*s)->IsMarked());
++*s;
}
// Poison everyone! Except the N-grams which were already poisoned by the input.
diff --git a/lm/builder/adjust_counts.hh b/lm/builder/adjust_counts.hh
index a5435c282..b169950e9 100644
--- a/lm/builder/adjust_counts.hh
+++ b/lm/builder/adjust_counts.hh
@@ -46,9 +46,11 @@ class AdjustCounts {
const std::vector<uint64_t> &prune_thresholds,
std::vector<uint64_t> &counts,
std::vector<uint64_t> &counts_pruned,
+ const std::vector<bool> &prune_words,
const DiscountConfig &discount_config,
std::vector<Discount> &discounts)
- : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), discount_config_(discount_config), discounts_(discounts)
+ : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned),
+ prune_words_(prune_words), discount_config_(discount_config), discounts_(discounts)
{}
void Run(const util::stream::ChainPositions &positions);
@@ -57,6 +59,7 @@ class AdjustCounts {
const std::vector<uint64_t> &prune_thresholds_;
std::vector<uint64_t> &counts_;
std::vector<uint64_t> &counts_pruned_;
+ const std::vector<bool> &prune_words_;
DiscountConfig discount_config_;
std::vector<Discount> &discounts_;
diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc
index 590e79fad..7f3dafa27 100644
--- a/lm/builder/corpus_count.cc
+++ b/lm/builder/corpus_count.cc
@@ -174,8 +174,9 @@ std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
return ngram::GrowableVocab<ngram::WriteUniqueWords>::MemUsage(vocab_estimate);
}
-CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol)
+CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::vector<bool> &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
+ prune_words_(prune_words), prune_vocab_filename_(prune_vocab_filename),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)),
disallowed_symbol_action_(disallowed_symbol) {
@@ -223,6 +224,31 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
} catch (const util::EndOfFileException &e) {}
token_count_ = count;
type_count_ = vocab.Size();
+
+ // Create list of unigrams that are supposed to be pruned
+ if (!prune_vocab_filename_.empty()) {
+ try {
+ util::FilePiece prune_vocab_file(prune_vocab_filename_.c_str());
+
+ prune_words_.resize(vocab.Size(), true);
+ try {
+ while (true) {
+ StringPiece line(prune_vocab_file.ReadLine());
+ for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w)
+ prune_words_[vocab.Index(*w)] = false;
+ }
+ } catch (const util::EndOfFileException &e) {}
+
+ // Never prune <unk>, <s>, </s>
+ prune_words_[kUNK] = false;
+ prune_words_[kBOS] = false;
+ prune_words_[kEOS] = false;
+
+ } catch (const util::Exception &e) {
+ std::cerr << e.what() << std::endl;
+ abort();
+ }
+ }
}
} // namespace builder
diff --git a/lm/builder/corpus_count.hh b/lm/builder/corpus_count.hh
index da4ff9fc6..d3121ca45 100644
--- a/lm/builder/corpus_count.hh
+++ b/lm/builder/corpus_count.hh
@@ -8,6 +8,7 @@
#include <cstddef>
#include <string>
#include <stdint.h>
+#include <vector>
namespace util {
class FilePiece;
@@ -29,7 +30,7 @@ class CorpusCount {
// token_count: out.
// type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
- CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol);
+ CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::vector<bool> &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol);
void Run(const util::stream::ChainPosition &position);
@@ -38,6 +39,8 @@ class CorpusCount {
int vocab_write_;
uint64_t &token_count_;
WordIndex &type_count_;
+ std::vector<bool>& prune_words_;
+ const std::string& prune_vocab_filename_;
std::size_t dedupe_mem_size_;
util::scoped_malloc dedupe_mem_;
diff --git a/lm/builder/header_info.hh b/lm/builder/header_info.hh
index 16f3f6090..146195233 100644
--- a/lm/builder/header_info.hh
+++ b/lm/builder/header_info.hh
@@ -2,16 +2,20 @@
#define LM_BUILDER_HEADER_INFO_H
#include <string>
+#include <vector>
#include <stdint.h>
// Some configuration info that is used to add
// comments to the beginning of an ARPA file
struct HeaderInfo {
- const std::string input_file;
- const uint64_t token_count;
+ std::string input_file;
+ uint64_t token_count;
+ std::vector<uint64_t> counts_pruned;
- HeaderInfo(const std::string& input_file_in, uint64_t token_count_in)
- : input_file(input_file_in), token_count(token_count_in) {}
+ HeaderInfo() {}
+
+ HeaderInfo(const std::string& input_file_in, uint64_t token_count_in, const std::vector<uint64_t> &counts_pruned_in)
+ : input_file(input_file_in), token_count(token_count_in), counts_pruned(counts_pruned_in) {}
// TODO: Add smoothing type
// TODO: More info if multiple models were interpolated
diff --git a/lm/builder/initial_probabilities.cc b/lm/builder/initial_probabilities.cc
index 5d19a8973..b1dd96f31 100644
--- a/lm/builder/initial_probabilities.cc
+++ b/lm/builder/initial_probabilities.cc
@@ -51,15 +51,13 @@ class PruneNGramStream {
PruneNGramStream &operator++() {
assert(block_);
- if (current_.Order() > 1) {
- if(currentCount_ > 0) {
- if(dest_.Base() < current_.Base()) {
- memcpy(dest_.Base(), current_.Base(), current_.TotalSize());
- }
- dest_.NextInMemory();
+ if(current_.Order() == 1 && *current_.begin() <= 2)
+ dest_.NextInMemory();
+ else if(currentCount_ > 0) {
+ if(dest_.Base() < current_.Base()) {
+ memcpy(dest_.Base(), current_.Base(), current_.TotalSize());
}
- } else {
- dest_.NextInMemory();
+ dest_.NextInMemory();
}
current_.NextInMemory();
@@ -78,7 +76,7 @@ class PruneNGramStream {
return *this;
}
-
+
private:
void StartBlock() {
for (; ; ++block_) {
@@ -215,14 +213,33 @@ class MergeRight {
PruneNGramStream grams(primary);
// Without interpolation, the interpolation weight goes to <unk>.
- if (grams->Order() == 1 && !interpolate_unigrams_) {
+ if (grams->Order() == 1) {
BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get()));
+ // Special case for <unk>
assert(*grams->begin() == kUNK);
- grams->Value().uninterp.prob = sums.gamma;
+ float gamma_assign;
+ if (interpolate_unigrams_) {
+ // Default: treat <unk> like a zeroton.
+ gamma_assign = sums.gamma;
+ grams->Value().uninterp.prob = 0.0;
+ } else {
+ // SRI: give all the interpolation mass to <unk>
+ gamma_assign = 0.0;
+ grams->Value().uninterp.prob = sums.gamma;
+ }
+ grams->Value().uninterp.gamma = gamma_assign;
+ ++grams;
+
+ // Special case for <s>: probability 1.0. This allows <s> to be
+ // explicitly scores as part of the sentence without impacting
+ // probability and computes q correctly as b(<s>).
+ assert(*grams->begin() == kBOS);
+ grams->Value().uninterp.prob = 1.0;
grams->Value().uninterp.gamma = 0.0;
+
while (++grams) {
grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator;
- grams->Value().uninterp.gamma = 0.0;
+ grams->Value().uninterp.gamma = gamma_assign;
}
++summed;
return;
@@ -256,10 +273,11 @@ void InitialProbabilities(
util::stream::Chains &primary,
util::stream::Chains &second_in,
util::stream::Chains &gamma_out,
- const std::vector<uint64_t> &prune_thresholds) {
+ const std::vector<uint64_t> &prune_thresholds,
+ bool prune_vocab) {
for (size_t i = 0; i < primary.size(); ++i) {
util::stream::ChainConfig gamma_config = config.adder_out;
- if(prune_thresholds[i] > 0)
+ if(prune_vocab || prune_thresholds[i] > 0)
gamma_config.entry_size = sizeof(HashBufferEntry);
else
gamma_config.entry_size = sizeof(BufferEntry);
@@ -267,12 +285,12 @@ void InitialProbabilities(
util::stream::ChainPosition second(second_in[i].Add());
second_in[i] >> util::stream::kRecycle;
gamma_out.push_back(gamma_config);
- gamma_out[i] >> AddRight(discounts[i], second, prune_thresholds[i] > 0);
+ gamma_out[i] >> AddRight(discounts[i], second, prune_vocab || prune_thresholds[i] > 0);
primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]);
-
+
// Don't bother with the OnlyGamma thread for something to discard.
- if (i) gamma_out[i] >> OnlyGamma(prune_thresholds[i] > 0);
+ if (i) gamma_out[i] >> OnlyGamma(prune_vocab || prune_thresholds[i] > 0);
}
}
diff --git a/lm/builder/initial_probabilities.hh b/lm/builder/initial_probabilities.hh
index c1010e082..57e09cd51 100644
--- a/lm/builder/initial_probabilities.hh
+++ b/lm/builder/initial_probabilities.hh
@@ -33,7 +33,8 @@ void InitialProbabilities(
util::stream::Chains &primary,
util::stream::Chains &second_in,
util::stream::Chains &gamma_out,
- const std::vector<uint64_t> &prune_thresholds);
+ const std::vector<uint64_t> &prune_thresholds,
+ bool prune_vocab);
} // namespace builder
} // namespace lm
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc
index a7947a422..7de7852b9 100644
--- a/lm/builder/interpolate.cc
+++ b/lm/builder/interpolate.cc
@@ -65,9 +65,10 @@ class OutputProbBackoff {
template <class Output> class Callback {
public:
- Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds)
+ Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab)
: backoffs_(backoffs.size()), probs_(backoffs.size() + 2),
prune_thresholds_(prune_thresholds),
+ prune_vocab_(prune_vocab),
output_(backoffs.size() + 1 /* order */) {
probs_[0] = uniform_prob;
for (std::size_t i = 0; i < backoffs.size(); ++i) {
@@ -77,7 +78,7 @@ template <class Output> class Callback {
~Callback() {
for (std::size_t i = 0; i < backoffs_.size(); ++i) {
- if(prune_thresholds_[i + 1] > 0)
+ if(prune_vocab_ || prune_thresholds_[i + 1] > 0)
while(backoffs_[i])
++backoffs_[i];
@@ -94,8 +95,8 @@ template <class Output> class Callback {
probs_[order_minus_1 + 1] = pay.complete.prob;
float out_backoff;
- if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
- if(prune_thresholds_[order_minus_1 + 1] > 0) {
+ if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS && backoffs_[order_minus_1]) {
+ if(prune_vocab_ || prune_thresholds_[order_minus_1 + 1] > 0) {
//Compute hash value for current context
uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));
@@ -129,15 +130,17 @@ template <class Output> class Callback {
std::vector<float> probs_;
const std::vector<uint64_t>& prune_thresholds_;
+ bool prune_vocab_;
Output output_;
};
} // namespace
-Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool output_q)
+Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool prune_vocab, bool output_q)
: uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
backoffs_(backoffs),
prune_thresholds_(prune_thresholds),
+ prune_vocab_(prune_vocab),
output_q_(output_q) {}
// perform order-wise interpolation
@@ -145,11 +148,11 @@ void Interpolate::Run(const util::stream::ChainPositions &positions) {
assert(positions.size() == backoffs_.size() + 1);
if (output_q_) {
typedef Callback<OutputQ> C;
- C callback(uniform_prob_, backoffs_, prune_thresholds_);
+ C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_);
JointOrder<C, SuffixOrder>(positions, callback);
} else {
typedef Callback<OutputProbBackoff> C;
- C callback(uniform_prob_, backoffs_, prune_thresholds_);
+ C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_);
JointOrder<C, SuffixOrder>(positions, callback);
}
}
diff --git a/lm/builder/interpolate.hh b/lm/builder/interpolate.hh
index 0acece926..adfd9198f 100644
--- a/lm/builder/interpolate.hh
+++ b/lm/builder/interpolate.hh
@@ -18,7 +18,7 @@ class Interpolate {
public:
// Normally vocab_size is the unigram count-1 (since p(<s>) = 0) but might
// be larger when the user specifies a consistent vocabulary size.
- explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool output_q_);
+ explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, bool output_q_);
void Run(const util::stream::ChainPositions &positions);
@@ -26,6 +26,7 @@ class Interpolate {
float uniform_prob_;
util::stream::ChainPositions backoffs_;
const std::vector<uint64_t> prune_thresholds_;
+ bool prune_vocab_;
bool output_q_;
};
diff --git a/lm/builder/joint_order.hh b/lm/builder/joint_order.hh
index 7235d4f7b..9ed89097a 100644
--- a/lm/builder/joint_order.hh
+++ b/lm/builder/joint_order.hh
@@ -4,6 +4,11 @@
#include "lm/builder/ngram_stream.hh"
#include "lm/lm_exception.hh"
+#ifdef DEBUG
+#include "util/fixed_array.hh"
+#include <iostream>
+#endif
+
#include <string.h>
namespace lm { namespace builder {
@@ -17,21 +22,40 @@ template <class Callback, class Compare> void JointOrder(const util::stream::Cha
unsigned int order;
for (order = 0; order < positions.size() && streams[order]; ++order) {}
assert(order); // should always have <unk>.
+
+ // Debugging only: call comparison function to sanity check order.
+#ifdef DEBUG
+ util::FixedArray<Compare> less_compare(order);
+ for (unsigned i = 0; i < order; ++i)
+ less_compare.push_back(i + 1);
+#endif // DEBUG
+
unsigned int current = 0;
while (true) {
- // Does the context match the lower one?
+ // Does the context match the lower one?
if (!memcmp(streams[static_cast<int>(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) {
callback.Enter(current, *streams[current]);
// Transition to looking for extensions.
if (++current < order) continue;
}
+#ifdef DEBUG
+ // match_check[current - 1] matches current-grams
+ // The lower-order stream (which skips fewer current-grams) should always be <= the higher order-stream (which can skip current-grams).
+ else if (!less_compare[current - 1](streams[static_cast<int>(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset)) {
+ std::cerr << "Stream out of order detected" << std::endl;
+ abort();
+ }
+#endif // DEBUG
// No extension left.
while(true) {
assert(current > 0);
--current;
callback.Exit(current, *streams[current]);
+
if (++streams[current]) break;
+
UTIL_THROW_IF(order != current + 1, FormatLoadException, "Detected n-gram without matching suffix");
+
order = current;
if (!order) return;
}
diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc
index 265dd2164..d3bd99d23 100644
--- a/lm/builder/lmplz_main.cc
+++ b/lm/builder/lmplz_main.cc
@@ -1,4 +1,6 @@
+#include "lm/builder/output.hh"
#include "lm/builder/pipeline.hh"
+#include "lm/builder/print.hh"
#include "lm/lm_exception.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
@@ -51,8 +53,7 @@ std::vector<uint64_t> ParsePruning(const std::vector<std::string> &param, std::s
// throw if each n-gram order has not threshold specified
UTIL_THROW_IF(prune_thresholds.size() > order, util::Exception, "You specified pruning thresholds for orders 1 through " << prune_thresholds.size() << " but the model only has order " << order);
// threshold for unigram can only be 0 (no pruning)
- UTIL_THROW_IF(prune_thresholds[0] != 0, util::Exception, "Unigram pruning is not implemented, so the first pruning threshold must be 0.");
-
+
// check if threshold are not in decreasing order
uint64_t lower_threshold = 0;
for (std::vector<uint64_t>::iterator it = prune_thresholds.begin(); it != prune_thresholds.end(); ++it) {
@@ -93,6 +94,7 @@ int main(int argc, char *argv[]) {
discount_fallback_default.push_back("0.5");
discount_fallback_default.push_back("1");
discount_fallback_default.push_back("1.5");
+ bool verbose_header;
options.add_options()
("help,h", po::bool_switch(), "Show this help message")
@@ -111,11 +113,12 @@ int main(int argc, char *argv[]) {
("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write a file containing the unique vocabulary strings delimited by null bytes")
("vocab_pad", po::value<uint64_t>(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with <unk> to reach this size. Requires --interpolate_unigrams")
- ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
+ ("verbose_header", po::bool_switch(&verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.")
- ("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Unigram pruning is not implemented, so the first value must be zero. Default is to not prune, which is equivalent to --prune 0.")
+ ("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Default is to not prune, which is equivalent to --prune 0.")
+ ("limit_vocab_file", po::value<std::string>(&pipeline.prune_vocab_file)->default_value(""), "Read allowed vocabulary separated by whitespace. N-grams that contain vocabulary items not in this list will be pruned. Can be combined with --prune arg")
("discount_fallback", po::value<std::vector<std::string> >(&discount_fallback)->multitoken()->implicit_value(discount_fallback_default, "0.5 1 1.5"), "The closed-form estimate for Kneser-Ney discounts does not work without singletons or doubletons. It can also fail if these values are out of range. This option falls back to user-specified discounts when the closed-form estimate fails. Note that this option is generally a bad idea: you should deduplicate your corpus instead. However, class-based models need custom discounts because they lack singleton unigrams. Provide up to three discounts (for adjusted counts 1, 2, and 3+), which will be applied to all orders where the closed-form estimates fail.");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, options), vm);
@@ -181,6 +184,13 @@ int main(int argc, char *argv[]) {
// parse pruning thresholds. These depend on order, so it is not done as a notifier.
pipeline.prune_thresholds = ParsePruning(pruning, pipeline.order);
+
+ if (!vm["limit_vocab_file"].as<std::string>().empty()) {
+ pipeline.prune_vocab = true;
+ }
+ else {
+ pipeline.prune_vocab = false;
+ }
util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
@@ -202,7 +212,9 @@ int main(int argc, char *argv[]) {
// Read from stdin
try {
- lm::builder::Pipeline(pipeline, in.release(), out.release());
+ lm::builder::Output output;
+ output.Add(new lm::builder::PrintARPA(out.release(), verbose_header));
+ lm::builder::Pipeline(pipeline, in.release(), output);
} catch (const util::MallocException &e) {
std::cerr << e.what() << std::endl;
std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl;
diff --git a/lm/builder/output.cc b/lm/builder/output.cc
new file mode 100644
index 000000000..0fc0197c4
--- /dev/null
+++ b/lm/builder/output.cc
@@ -0,0 +1,14 @@
+#include "lm/builder/output.hh"
+#include "util/stream/multi_stream.hh"
+
+#include <boost/ref.hpp>
+
+namespace lm { namespace builder {
+
+OutputHook::~OutputHook() {}
+
+void OutputHook::Apply(util::stream::Chains &chains) {
+ chains >> boost::ref(*this);
+}
+
+}} // namespaces
diff --git a/lm/builder/output.hh b/lm/builder/output.hh
new file mode 100644
index 000000000..0ef769ae2
--- /dev/null
+++ b/lm/builder/output.hh
@@ -0,0 +1,89 @@
+#ifndef LM_BUILDER_OUTPUT_H
+#define LM_BUILDER_OUTPUT_H
+
+#include "lm/builder/header_info.hh"
+#include "util/file.hh"
+
+#include <boost/ptr_container/ptr_vector.hpp>
+#include <boost/utility.hpp>
+
+#include <map>
+
+namespace util { namespace stream { class Chains; class ChainPositions; } }
+
+/* Outputs from lmplz: ARPA< sharded files, etc */
+namespace lm { namespace builder {
+
+// These are different types of hooks. Values should be consecutive to enable a vector lookup.
+enum HookType {
+ COUNT_HOOK, // Raw N-gram counts, highest order only.
+ PROB_PARALLEL_HOOK, // Probability and backoff (or just q). Output must process the orders in parallel or there will be a deadlock.
+ PROB_SEQUENTIAL_HOOK, // Probability and backoff (or just q). Output can process orders any way it likes. This requires writing the data to disk then reading. Useful for ARPA files, which put unigrams first etc.
+ NUMBER_OF_HOOKS // Keep this last so we know how many values there are.
+};
+
+class Output;
+
+class OutputHook {
+ public:
+ explicit OutputHook(HookType hook_type) : type_(hook_type), master_(NULL) {}
+
+ virtual ~OutputHook();
+
+ virtual void Apply(util::stream::Chains &chains);
+
+ virtual void Run(const util::stream::ChainPositions &positions) = 0;
+
+ protected:
+ const HeaderInfo &GetHeader() const;
+ int GetVocabFD() const;
+
+ private:
+ friend class Output;
+ const HookType type_;
+ const Output *master_;
+};
+
+class Output : boost::noncopyable {
+ public:
+ Output() {}
+
+ // Takes ownership.
+ void Add(OutputHook *hook) {
+ hook->master_ = this;
+ outputs_[hook->type_].push_back(hook);
+ }
+
+ bool Have(HookType hook_type) const {
+ return !outputs_[hook_type].empty();
+ }
+
+ void SetVocabFD(int to) { vocab_fd_ = to; }
+ int GetVocabFD() const { return vocab_fd_; }
+
+ void SetHeader(const HeaderInfo &header) { header_ = header; }
+ const HeaderInfo &GetHeader() const { return header_; }
+
+ void Apply(HookType hook_type, util::stream::Chains &chains) {
+ for (boost::ptr_vector<OutputHook>::iterator entry = outputs_[hook_type].begin(); entry != outputs_[hook_type].end(); ++entry) {
+ entry->Apply(chains);
+ }
+ }
+
+ private:
+ boost::ptr_vector<OutputHook> outputs_[NUMBER_OF_HOOKS];
+ int vocab_fd_;
+ HeaderInfo header_;
+};
+
+inline const HeaderInfo &OutputHook::GetHeader() const {
+ return master_->GetHeader();
+}
+
+inline int OutputHook::GetVocabFD() const {
+ return master_->GetVocabFD();
+}
+
+}} // namespaces
+
+#endif // LM_BUILDER_OUTPUT_H
diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc
index 21064ab3a..fced0e3bd 100644
--- a/lm/builder/pipeline.cc
+++ b/lm/builder/pipeline.cc
@@ -5,7 +5,7 @@
#include "lm/builder/hash_gamma.hh"
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/interpolate.hh"
-#include "lm/builder/print.hh"
+#include "lm/builder/output.hh"
#include "lm/builder/sort.hh"
#include "lm/sizes.hh"
@@ -16,6 +16,7 @@
#include <algorithm>
#include <iostream>
+#include <fstream>
#include <vector>
namespace lm { namespace builder {
@@ -36,7 +37,7 @@ void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<uint
class Master {
public:
- explicit Master(const PipelineConfig &config)
+ explicit Master(PipelineConfig &config)
: config_(config), chains_(config.order), files_(config.order) {
config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block);
}
@@ -200,14 +201,14 @@ class Master {
std::cerr << std::endl;
}
- PipelineConfig config_;
+ PipelineConfig &config_;
util::stream::Chains chains_;
// Often only unigrams, but sometimes all orders.
util::FixedArray<util::stream::FileBuffer> files_;
};
-void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) {
+void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name, std::vector<bool> &prune_words) {
const PipelineConfig &config = master.Config();
std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
@@ -225,7 +226,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
WordIndex type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
- CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
+ CorpusCount counter(text, vocab_file, token_count, type_count, prune_words, config.prune_vocab_file, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
chain >> boost::ref(counter);
util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
@@ -236,7 +237,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
}
void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary,
- util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds) {
+ util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab) {
const PipelineConfig &config = master.Config();
util::stream::Chains second(config.order);
@@ -250,7 +251,7 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector
}
util::stream::Chains gamma_chains(config.order);
- InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds);
+ InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds, prune_vocab);
// Don't care about gamma for 0.
gamma_chains[0] >> util::stream::kRecycle;
gammas.Init(config.order - 1);
@@ -271,8 +272,7 @@ void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &maste
for (std::size_t i = 0; i < config.order - 1; ++i) {
util::stream::ChainConfig read_backoffs(config.read_backoffs);
- // Add 1 because here we are skipping unigrams
- if(config.prune_thresholds[i + 1] > 0)
+ if(config.prune_vocab || config.prune_thresholds[i + 1] > 0)
read_backoffs.entry_size = sizeof(HashGamma);
else
read_backoffs.entry_size = sizeof(float);
@@ -280,14 +280,14 @@ void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &maste
gamma_chains.push_back(read_backoffs);
gamma_chains.back() >> gammas[i].Source();
}
- master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.output_q);
+ master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.prune_vocab, config.output_q);
gamma_chains >> util::stream::kRecycle;
master.BufferFinal(counts);
}
} // namespace
-void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
+void Pipeline(PipelineConfig &config, int text_file, Output &output) {
// Some fail-fast sanity checks.
if (config.sort.buffer_size * 4 > config.TotalMemory()) {
config.sort.buffer_size = config.TotalMemory() / 4;
@@ -310,27 +310,30 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
util::scoped_fd vocab_file(config.vocab_file.empty() ?
util::MakeTemp(config.TempPrefix()) :
util::CreateOrThrow(config.vocab_file.c_str()));
+ output.SetVocabFD(vocab_file.get());
uint64_t token_count;
std::string text_file_name;
- CountText(text_file, vocab_file.get(), master, token_count, text_file_name);
-
+
+ std::vector<bool> prune_words;
+ CountText(text_file, vocab_file.get(), master, token_count, text_file_name, prune_words);
+
std::vector<uint64_t> counts;
std::vector<uint64_t> counts_pruned;
std::vector<Discount> discounts;
- master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, config.discount, discounts);
+ master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, prune_words, config.discount, discounts);
{
util::FixedArray<util::stream::FileBuffer> gammas;
Sorts<SuffixOrder> primary;
- InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds);
+ InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds, config.prune_vocab);
InterpolateProbabilities(counts_pruned, master, primary, gammas);
}
std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
- VocabReconstitute vocab(vocab_file.get());
- UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?");
- HeaderInfo header_info(text_file_name, token_count);
- master >> PrintARPA(vocab, counts_pruned, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle;
+
+ output.SetHeader(HeaderInfo(text_file_name, token_count, counts_pruned));
+ output.Apply(PROB_SEQUENTIAL_HOOK, master.MutableChains());
+ master >> util::stream::kRecycle;
master.MutableChains().Wait(true);
} catch (const util::Exception &e) {
std::cerr << e.what() << std::endl;
diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh
index 09e1a4d52..8f4d82103 100644
--- a/lm/builder/pipeline.hh
+++ b/lm/builder/pipeline.hh
@@ -14,6 +14,8 @@
namespace lm { namespace builder {
+class Output;
+
struct PipelineConfig {
std::size_t order;
std::string vocab_file;
@@ -21,9 +23,6 @@ struct PipelineConfig {
InitialProbabilitiesConfig initial_probs;
util::stream::ChainConfig read_backoffs;
- // Include a header in the ARPA with some statistics?
- bool verbose_header;
-
// Estimated vocabulary size. Used for sizing CorpusCount memory and
// initial probing hash table sizing, also in CorpusCount.
lm::WordIndex vocab_estimate;
@@ -37,6 +36,8 @@ struct PipelineConfig {
// n-gram count thresholds for pruning. 0 values means no pruning for
// corresponding n-gram order
std::vector<uint64_t> prune_thresholds; //mjd
+ bool prune_vocab;
+ std::string prune_vocab_file;
// What to do with discount failures.
DiscountConfig discount;
@@ -67,7 +68,7 @@ struct PipelineConfig {
};
// Takes ownership of text_file and out_arpa.
-void Pipeline(PipelineConfig config, int text_file, int out_arpa);
+void Pipeline(PipelineConfig &config, int text_file, Output &output);
}} // namespaces
#endif // LM_BUILDER_PIPELINE_H
diff --git a/lm/builder/print.cc b/lm/builder/print.cc
index aee6e1341..bb9483333 100644
--- a/lm/builder/print.cc
+++ b/lm/builder/print.cc
@@ -24,35 +24,34 @@ VocabReconstitute::VocabReconstitute(int fd) {
map_.push_back(i);
}
-PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd)
- : vocab_(vocab), out_fd_(out_fd) {
- std::stringstream stream;
+void PrintARPA::Run(const util::stream::ChainPositions &positions) {
+ VocabReconstitute vocab(GetVocabFD());
- if (header_info) {
- stream << "# Input file: " << header_info->input_file << '\n';
- stream << "# Token count: " << header_info->token_count << '\n';
- stream << "# Smoothing: Modified Kneser-Ney" << '\n';
- }
- stream << "\\data\\\n";
- for (size_t i = 0; i < counts.size(); ++i) {
- stream << "ngram " << (i+1) << '=' << counts[i] << '\n';
+ // Write header. TODO: integers in FakeOFStream.
+ {
+ std::stringstream stream;
+ if (verbose_header_) {
+ stream << "# Input file: " << GetHeader().input_file << '\n';
+ stream << "# Token count: " << GetHeader().token_count << '\n';
+ stream << "# Smoothing: Modified Kneser-Ney" << '\n';
+ }
+ stream << "\\data\\\n";
+ for (size_t i = 0; i < positions.size(); ++i) {
+ stream << "ngram " << (i+1) << '=' << GetHeader().counts_pruned[i] << '\n';
+ }
+ stream << '\n';
+ std::string as_string(stream.str());
+ util::WriteOrThrow(out_fd_.get(), as_string.data(), as_string.size());
}
- stream << '\n';
- std::string as_string(stream.str());
- util::WriteOrThrow(out_fd, as_string.data(), as_string.size());
-}
-void PrintARPA::Run(const util::stream::ChainPositions &positions) {
- util::scoped_fd closer(out_fd_);
- UTIL_TIMER("(%w s) Wrote ARPA file\n");
- util::FakeOFStream out(out_fd_);
+ util::FakeOFStream out(out_fd_.get());
for (unsigned order = 1; order <= positions.size(); ++order) {
out << "\\" << order << "-grams:" << '\n';
for (NGramStream stream(positions[order - 1]); stream; ++stream) {
// Correcting for numerical precision issues. Take that IRST.
- out << stream->Value().complete.prob << '\t' << vocab_.Lookup(*stream->begin());
+ out << stream->Value().complete.prob << '\t' << vocab.Lookup(*stream->begin());
for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
- out << ' ' << vocab_.Lookup(*i);
+ out << ' ' << vocab.Lookup(*i);
}
if (order != positions.size())
out << '\t' << stream->Value().complete.backoff;
diff --git a/lm/builder/print.hh b/lm/builder/print.hh
index 9856cea85..ba57f060a 100644
--- a/lm/builder/print.hh
+++ b/lm/builder/print.hh
@@ -3,7 +3,8 @@
#include "lm/builder/ngram.hh"
#include "lm/builder/ngram_stream.hh"
-#include "lm/builder/header_info.hh"
+#include "lm/builder/output.hh"
+#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/mmap.hh"
#include "util/string_piece.hh"
@@ -43,60 +44,71 @@ class VocabReconstitute {
};
// Not defined, only specialized.
-template <class T> void PrintPayload(std::ostream &to, const Payload &payload);
-template <> inline void PrintPayload<uint64_t>(std::ostream &to, const Payload &payload) {
- to << payload.count;
+template <class T> void PrintPayload(util::FakeOFStream &to, const Payload &payload);
+template <> inline void PrintPayload<uint64_t>(util::FakeOFStream &to, const Payload &payload) {
+ // TODO slow
+ to << boost::lexical_cast<std::string>(payload.count);
}
-template <> inline void PrintPayload<Uninterpolated>(std::ostream &to, const Payload &payload) {
+template <> inline void PrintPayload<Uninterpolated>(util::FakeOFStream &to, const Payload &payload) {
to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma);
}
-template <> inline void PrintPayload<ProbBackoff>(std::ostream &to, const Payload &payload) {
+template <> inline void PrintPayload<ProbBackoff>(util::FakeOFStream &to, const Payload &payload) {
to << payload.complete.prob << ' ' << payload.complete.backoff;
}
// template parameter is the type stored.
template <class V> class Print {
public:
- explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {}
+ static void DumpSeparateFiles(const VocabReconstitute &vocab, const std::string &file_base, util::stream::Chains &chains) {
+ for (unsigned int i = 0; i < chains.size(); ++i) {
+ std::string file(file_base + boost::lexical_cast<std::string>(i));
+ chains[i] >> Print(vocab, util::CreateOrThrow(file.c_str()));
+ }
+ }
+
+ explicit Print(const VocabReconstitute &vocab, int fd) : vocab_(vocab), to_(fd) {}
void Run(const util::stream::ChainPositions &chains) {
+ util::scoped_fd fd(to_);
+ util::FakeOFStream out(to_);
NGramStreams streams(chains);
for (NGramStream *s = streams.begin(); s != streams.end(); ++s) {
- DumpStream(*s);
+ DumpStream(*s, out);
}
}
void Run(const util::stream::ChainPosition &position) {
+ util::scoped_fd fd(to_);
+ util::FakeOFStream out(to_);
NGramStream stream(position);
- DumpStream(stream);
+ DumpStream(stream, out);
}
private:
- void DumpStream(NGramStream &stream) {
+ void DumpStream(NGramStream &stream, util::FakeOFStream &to) {
for (; stream; ++stream) {
- PrintPayload<V>(to_, stream->Value());
+ PrintPayload<V>(to, stream->Value());
for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) {
- to_ << ' ' << vocab_.Lookup(*w) << '=' << *w;
+ to << ' ' << vocab_.Lookup(*w) << '=' << *w;
}
- to_ << '\n';
+ to << '\n';
}
}
const VocabReconstitute &vocab_;
- std::ostream &to_;
+ int to_;
};
-class PrintARPA {
+class PrintARPA : public OutputHook {
public:
- // header_info may be NULL to disable the header.
- // Takes ownership of out_fd upon Run().
- explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd);
+ explicit PrintARPA(int fd, bool verbose_header)
+ : OutputHook(PROB_SEQUENTIAL_HOOK), out_fd_(fd), verbose_header_(verbose_header) {}
void Run(const util::stream::ChainPositions &positions);
private:
- const VocabReconstitute &vocab_;
- int out_fd_;
+ util::scoped_fd out_fd_;
+ bool verbose_header_;
};
}} // namespaces
diff --git a/lm/config.cc b/lm/config.cc
index 9520c41c8..6c695edfb 100644
--- a/lm/config.cc
+++ b/lm/config.cc
@@ -15,7 +15,7 @@ Config::Config() :
unknown_missing_logprob(-100.0),
probing_multiplier(1.5),
building_memory(1073741824ULL), // 1 GB
- temporary_directory_prefix(NULL),
+ temporary_directory_prefix(""),
arpa_complain(ALL),
write_mmap(NULL),
write_method(WRITE_AFTER),
diff --git a/lm/config.hh b/lm/config.hh
index dab281238..a4238cd9a 100644
--- a/lm/config.hh
+++ b/lm/config.hh
@@ -66,9 +66,9 @@ struct Config {
// Template for temporary directory appropriate for passing to mkdtemp.
// The characters XXXXXX are appended before passing to mkdtemp. Only
- // applies to trie. If NULL, defaults to write_mmap. If that's NULL,
+ // applies to trie. If empty, defaults to write_mmap. If that's NULL,
// defaults to input file name.
- const char *temporary_directory_prefix;
+ std::string temporary_directory_prefix;
// Level of complaining to do when loading from ARPA instead of binary format.
enum ARPALoadComplain {ALL, EXPENSIVE, NONE};
diff --git a/lm/max_order.hh b/lm/max_order.hh
index f7344cde2..5f181f3fc 100644
--- a/lm/max_order.hh
+++ b/lm/max_order.hh
@@ -1,13 +1,13 @@
#ifndef LM_MAX_ORDER_H
#define LM_MAX_ORDER_H
-/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER_H, THEN CHANGE THE BUILD SYSTEM.
+/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM.
* If not, this is the default maximum order.
* Having this limit means that State can be
* (kMaxOrder - 1) * sizeof(float) bytes instead of
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
*/
#ifndef KENLM_ORDER_MESSAGE
-#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER_H, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
+#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
#endif
#endif // LM_MAX_ORDER_H
diff --git a/lm/search_trie.cc b/lm/search_trie.cc
index 7fc70f4eb..5b0f55fc8 100644
--- a/lm/search_trie.cc
+++ b/lm/search_trie.cc
@@ -577,7 +577,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) {
std::string temporary_prefix;
- if (config.temporary_directory_prefix) {
+ if (!config.temporary_directory_prefix.empty()) {
temporary_prefix = config.temporary_directory_prefix;
} else if (config.write_mmap) {
temporary_prefix = config.write_mmap;