Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2014-10-08 16:42:53 +0400
committerKenneth Heafield <github@kheafield.com>2014-10-08 16:42:53 +0400
commit36da8d1e0c17cbdecb7613398464bb9cabb56d20 (patch)
tree640641c55ac102f2fb5178fc816e260c1c0965b7 /lm
parente9f0ae951e59b5853ef677b1c627b4e7bc5b48b7 (diff)
KenLM 370f97fa549f02e162a3a0f17bf3ad6cce2c3813
Diffstat (limited to 'lm')
-rw-r--r--lm/builder/adjust_counts.cc50
-rw-r--r--lm/builder/adjust_counts.hh28
-rw-r--r--lm/builder/adjust_counts_test.cc5
-rw-r--r--lm/builder/initial_probabilities.cc7
-rw-r--r--lm/builder/interpolate.cc94
-rw-r--r--lm/builder/interpolate.hh3
-rw-r--r--lm/builder/lmplz_main.cc36
-rw-r--r--lm/builder/pipeline.cc4
-rw-r--r--lm/builder/pipeline.hh9
-rw-r--r--lm/builder/print.cc2
10 files changed, 188 insertions, 50 deletions
diff --git a/lm/builder/adjust_counts.cc b/lm/builder/adjust_counts.cc
index 080b438a4..803c557d0 100644
--- a/lm/builder/adjust_counts.cc
+++ b/lm/builder/adjust_counts.cc
@@ -29,28 +29,44 @@ class StatCollector {
~StatCollector() {}
- void CalculateDiscounts() {
+ void CalculateDiscounts(const DiscountConfig &config) {
counts_.resize(orders_.size());
counts_pruned_.resize(orders_.size());
- discounts_.resize(orders_.size());
for (std::size_t i = 0; i < orders_.size(); ++i) {
const OrderStat &s = orders_[i];
counts_[i] = s.count;
counts_pruned_[i] = s.count_pruned;
+ }
- for (unsigned j = 1; j < 4; ++j) {
- // TODO: Specialize error message for j == 3, meaning 3+
- UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
- << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
- << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
- }
-
- // See equation (26) in Chen and Goodman.
- discounts_[i].amount[0] = 0.0;
- float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
- for (unsigned j = 1; j < 4; ++j) {
- discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
- UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
+ discounts_ = config.overwrite;
+ discounts_.resize(orders_.size());
+ for (std::size_t i = config.overwrite.size(); i < orders_.size(); ++i) {
+ const OrderStat &s = orders_[i];
+ try {
+ for (unsigned j = 1; j < 4; ++j) {
+ // TODO: Specialize error message for j == 3, meaning 3+
+ UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
+ << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
+ << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
+ }
+
+ // See equation (26) in Chen and Goodman.
+ discounts_[i].amount[0] = 0.0;
+ float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
+ for (unsigned j = 1; j < 4; ++j) {
+ discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
+ UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
+ }
+ } catch (const BadDiscountException &e) {
+ switch (config.bad_action) {
+ case THROW_UP:
+ throw;
+ case COMPLAIN:
+ std::cerr << e.what() << " Substituting fallback discounts D1=" << config.fallback.amount[1] << " D2=" << config.fallback.amount[2] << " D3+=" << config.fallback.amount[3] << std::endl;
+ case SILENT:
+ break;
+ }
+ discounts_[i] = config.fallback;
}
}
}
@@ -179,7 +195,7 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
for (NGramStream full(positions[0]); full; ++full)
stats.AddFull(full->Count());
- stats.CalculateDiscounts();
+ stats.CalculateDiscounts(discount_config_);
return;
}
@@ -262,7 +278,7 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
s->Poison();
- stats.CalculateDiscounts();
+ stats.CalculateDiscounts(discount_config_);
// NOTE: See special early-return case for unigrams near the top of this function
}
diff --git a/lm/builder/adjust_counts.hh b/lm/builder/adjust_counts.hh
index 60198e8f8..a5435c282 100644
--- a/lm/builder/adjust_counts.hh
+++ b/lm/builder/adjust_counts.hh
@@ -2,6 +2,7 @@
#define LM_BUILDER_ADJUST_COUNTS_H
#include "lm/builder/discount.hh"
+#include "lm/lm_exception.hh"
#include "util/exception.hh"
#include <vector>
@@ -19,6 +20,16 @@ class BadDiscountException : public util::Exception {
~BadDiscountException() throw();
};
+struct DiscountConfig {
+ // Overrides discounts for orders [1,discount_override.size()].
+ std::vector<Discount> overwrite;
+ // If discounting fails for an order, copy them from here.
+ Discount fallback;
+ // What to do when discounts are out of range or would trigger divison by
+ // zero. It it does something other than THROW_UP, use fallback_discount.
+ WarningAction bad_action;
+};
+
/* Compute adjusted counts.
* Input: unique suffix sorted N-grams (and just the N-grams) with raw counts.
* Output: [1,N]-grams with adjusted counts.
@@ -27,17 +38,28 @@ class BadDiscountException : public util::Exception {
*/
class AdjustCounts {
public:
- AdjustCounts(std::vector<uint64_t> &counts, std::vector<uint64_t> &counts_pruned, std::vector<Discount> &discounts, std::vector<uint64_t> &prune_thresholds)
- : counts_(counts), counts_pruned_(counts_pruned), discounts_(discounts), prune_thresholds_(prune_thresholds)
+ // counts: output
+ // counts_pruned: output
+ // discounts: mostly output. If the input already has entries, they will be kept.
+ // prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned.
+ AdjustCounts(
+ const std::vector<uint64_t> &prune_thresholds,
+ std::vector<uint64_t> &counts,
+ std::vector<uint64_t> &counts_pruned,
+ const DiscountConfig &discount_config,
+ std::vector<Discount> &discounts)
+ : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), discount_config_(discount_config), discounts_(discounts)
{}
void Run(const util::stream::ChainPositions &positions);
private:
+ const std::vector<uint64_t> &prune_thresholds_;
std::vector<uint64_t> &counts_;
std::vector<uint64_t> &counts_pruned_;
+
+ DiscountConfig discount_config_;
std::vector<Discount> &discounts_;
- std::vector<uint64_t> &prune_thresholds_;
};
} // namespace builder
diff --git a/lm/builder/adjust_counts_test.cc b/lm/builder/adjust_counts_test.cc
index 9d8ef65b6..073c5dfeb 100644
--- a/lm/builder/adjust_counts_test.cc
+++ b/lm/builder/adjust_counts_test.cc
@@ -75,7 +75,10 @@ BOOST_AUTO_TEST_CASE(Simple) {
chains >> util::stream::kRecycle;
std::vector<uint64_t> counts_pruned(4);
std::vector<uint64_t> prune_thresholds(4);
- BOOST_CHECK_THROW(AdjustCounts(counts, counts_pruned, discount, prune_thresholds).Run(for_adjust), BadDiscountException);
+ DiscountConfig discount_config;
+ discount_config.fallback = Discount();
+ discount_config.bad_action = THROW_UP;
+ BOOST_CHECK_THROW(AdjustCounts(prune_thresholds, counts, counts_pruned, discount_config, discount).Run(for_adjust), BadDiscountException);
}
BOOST_REQUIRE_EQUAL(4UL, counts.size());
BOOST_CHECK_EQUAL(4UL, counts[0]);
diff --git a/lm/builder/initial_probabilities.cc b/lm/builder/initial_probabilities.cc
index f6ee334c7..5d19a8973 100644
--- a/lm/builder/initial_probabilities.cc
+++ b/lm/builder/initial_probabilities.cc
@@ -69,9 +69,12 @@ class PruneNGramStream {
block_->SetValidSize(dest_.Base() - block_base);
++block_;
StartBlock();
+ if (block_) {
+ currentCount_ = current_.CutoffCount();
+ }
+ } else {
+ currentCount_ = current_.CutoffCount();
}
-
- currentCount_ = current_.CutoffCount();
return *this;
}
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc
index 3e1225d9e..a7947a422 100644
--- a/lm/builder/interpolate.cc
+++ b/lm/builder/interpolate.cc
@@ -9,14 +9,66 @@
#include "util/murmur_hash.hh"
#include <assert.h>
+#include <math.h>
namespace lm { namespace builder {
namespace {
-class Callback {
+/* Calculate q, the collapsed probability and backoff, as defined in
+ * @inproceedings{Heafield-rest,
+ * author = {Kenneth Heafield and Philipp Koehn and Alon Lavie},
+ * title = {Language Model Rest Costs and Space-Efficient Storage},
+ * year = {2012},
+ * month = {July},
+ * booktitle = {Proceedings of the Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning},
+ * address = {Jeju Island, Korea},
+ * pages = {1169--1178},
+ * url = {http://kheafield.com/professional/edinburgh/rest\_paper.pdf},
+ * }
+ * This is particularly convenient to calculate during interpolation because
+ * the needed backoff terms are already accessed at the same time.
+ */
+class OutputQ {
+ public:
+ explicit OutputQ(std::size_t order) : q_delta_(order) {}
+
+ void Gram(unsigned order_minus_1, float full_backoff, ProbBackoff &out) {
+ float &q_del = q_delta_[order_minus_1];
+ if (order_minus_1) {
+ // Divide by context's backoff (which comes in as out.backoff)
+ q_del = q_delta_[order_minus_1 - 1] / out.backoff * full_backoff;
+ } else {
+ q_del = full_backoff;
+ }
+ out.prob = log10f(out.prob * q_del);
+ // TODO: stop wastefully outputting this!
+ out.backoff = 0.0;
+ }
+
+ private:
+ // Product of backoffs in the numerator divided by backoffs in the
+ // denominator. Does not include
+ std::vector<float> q_delta_;
+};
+
+/* Default: output probability and backoff */
+class OutputProbBackoff {
+ public:
+ explicit OutputProbBackoff(std::size_t /*order*/) {}
+
+ void Gram(unsigned /*order_minus_1*/, float full_backoff, ProbBackoff &out) const {
+ // Correcting for numerical precision issues. Take that IRST.
+ out.prob = std::min(0.0f, log10f(out.prob));
+ out.backoff = log10f(full_backoff);
+ }
+};
+
+template <class Output> class Callback {
public:
Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds)
- : backoffs_(backoffs.size()), probs_(backoffs.size() + 2), prune_thresholds_(prune_thresholds) {
+ : backoffs_(backoffs.size()), probs_(backoffs.size() + 2),
+ prune_thresholds_(prune_thresholds),
+ output_(backoffs.size() + 1 /* order */) {
probs_[0] = uniform_prob;
for (std::size_t i = 0; i < backoffs.size(); ++i) {
backoffs_.push_back(backoffs[i]);
@@ -40,15 +92,9 @@ class Callback {
Payload &pay = gram.Value();
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
probs_[order_minus_1 + 1] = pay.complete.prob;
- pay.complete.prob = log10(pay.complete.prob);
-
- if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
- // This skips over ngrams if backoffs have been exhausted.
- if(!backoffs_[order_minus_1]) {
- pay.complete.backoff = 0.0;
- return;
- }
+ float out_backoff;
+ if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
if(prune_thresholds_[order_minus_1 + 1] > 0) {
//Compute hash value for current context
uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));
@@ -58,20 +104,22 @@ class Callback {
hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
if(current_hash == hashed_backoff->hash_value) {
- pay.complete.backoff = log10(hashed_backoff->gamma);
+ out_backoff = hashed_backoff->gamma;
++backoffs_[order_minus_1];
} else {
// Has been pruned away so it is not a context anymore
- pay.complete.backoff = 0.0;
+ out_backoff = 1.0;
}
} else {
- pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
+ out_backoff = *static_cast<const float*>(backoffs_[order_minus_1].Get());
++backoffs_[order_minus_1];
}
} else {
// Not a context.
- pay.complete.backoff = 0.0;
+ out_backoff = 1.0;
}
+
+ output_.Gram(order_minus_1, out_backoff, pay.complete);
}
void Exit(unsigned, const NGram &) const {}
@@ -81,19 +129,29 @@ class Callback {
std::vector<float> probs_;
const std::vector<uint64_t>& prune_thresholds_;
+
+ Output output_;
};
} // namespace
-Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds)
+Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool output_q)
: uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
backoffs_(backoffs),
- prune_thresholds_(prune_thresholds) {}
+ prune_thresholds_(prune_thresholds),
+ output_q_(output_q) {}
// perform order-wise interpolation
void Interpolate::Run(const util::stream::ChainPositions &positions) {
assert(positions.size() == backoffs_.size() + 1);
- Callback callback(uniform_prob_, backoffs_, prune_thresholds_);
- JointOrder<Callback, SuffixOrder>(positions, callback);
+ if (output_q_) {
+ typedef Callback<OutputQ> C;
+ C callback(uniform_prob_, backoffs_, prune_thresholds_);
+ JointOrder<C, SuffixOrder>(positions, callback);
+ } else {
+ typedef Callback<OutputProbBackoff> C;
+ C callback(uniform_prob_, backoffs_, prune_thresholds_);
+ JointOrder<C, SuffixOrder>(positions, callback);
+ }
}
}} // namespaces
diff --git a/lm/builder/interpolate.hh b/lm/builder/interpolate.hh
index 55a55428f..0acece926 100644
--- a/lm/builder/interpolate.hh
+++ b/lm/builder/interpolate.hh
@@ -18,7 +18,7 @@ class Interpolate {
public:
// Normally vocab_size is the unigram count-1 (since p(<s>) = 0) but might
// be larger when the user specifies a consistent vocabulary size.
- explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds);
+ explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool output_q_);
void Run(const util::stream::ChainPositions &positions);
@@ -26,6 +26,7 @@ class Interpolate {
float uniform_prob_;
util::stream::ChainPositions backoffs_;
const std::vector<uint64_t> prune_thresholds_;
+ bool output_q_;
};
}} // namespaces
diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc
index e1ae2d417..265dd2164 100644
--- a/lm/builder/lmplz_main.cc
+++ b/lm/builder/lmplz_main.cc
@@ -33,7 +33,6 @@ std::vector<uint64_t> ParsePruning(const std::vector<std::string> &param, std::s
// convert to vector of integers
std::vector<uint64_t> prune_thresholds;
prune_thresholds.reserve(order);
- std::cerr << "Pruning ";
for (std::vector<std::string>::const_iterator it(param.begin()); it != param.end(); ++it) {
try {
prune_thresholds.push_back(boost::lexical_cast<uint64_t>(*it));
@@ -66,6 +65,18 @@ std::vector<uint64_t> ParsePruning(const std::vector<std::string> &param, std::s
return prune_thresholds;
}
+lm::builder::Discount ParseDiscountFallback(const std::vector<std::string> &param) {
+ lm::builder::Discount ret;
+ UTIL_THROW_IF(param.size() > 3, util::Exception, "Specify at most three fallback discounts: 1, 2, and 3+");
+ UTIL_THROW_IF(param.empty(), util::Exception, "Fallback discounting enabled, but no discount specified");
+ ret.amount[0] = 0.0;
+ for (unsigned i = 0; i < 3; ++i) {
+ float discount = boost::lexical_cast<float>(param[i < param.size() ? i : (param.size() - 1)]);
+ UTIL_THROW_IF(discount < 0.0 || discount > static_cast<float>(i+1), util::Exception, "The discount for count " << (i+1) << " was parsed as " << discount << " which is not in the range [0, " << (i+1) << "].");
+ ret.amount[i + 1] = discount;
+ }
+ return ret;
+}
} // namespace
@@ -77,7 +88,11 @@ int main(int argc, char *argv[]) {
std::string text, arpa;
std::vector<std::string> pruning;
-
+ std::vector<std::string> discount_fallback;
+ std::vector<std::string> discount_fallback_default;
+ discount_fallback_default.push_back("0.5");
+ discount_fallback_default.push_back("1");
+ discount_fallback_default.push_back("1.5");
options.add_options()
("help,h", po::bool_switch(), "Show this help message")
@@ -86,7 +101,7 @@ int main(int argc, char *argv[]) {
->required()
#endif
, "Order of the model")
- ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
+ ("interpolate_unigrams", po::value<bool>(&pipeline.initial_probs.interpolate_unigrams)->default_value(true)->implicit_value(true), "Interpolate the unigrams (default) as opposed to giving lots of mass to <unk> like SRI. If you want SRI's behavior with a large <unk> and the old lmplz default, use --interpolate_unigrams 0.")
("skip_symbols", po::bool_switch(), "Treat <s>, </s>, and <unk> as whitespace instead of throwing an exception")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
@@ -99,7 +114,9 @@ int main(int argc, char *argv[]) {
("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
- ("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Unigram pruning is not implemented, so the first value must be zero. Default is to not prune, which is equivalent to --prune 0.");
+ ("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.")
+ ("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Unigram pruning is not implemented, so the first value must be zero. Default is to not prune, which is equivalent to --prune 0.")
+ ("discount_fallback", po::value<std::vector<std::string> >(&discount_fallback)->multitoken()->implicit_value(discount_fallback_default, "0.5 1 1.5"), "The closed-form estimate for Kneser-Ney discounts does not work without singletons or doubletons. It can also fail if these values are out of range. This option falls back to user-specified discounts when the closed-form estimate fails. Note that this option is generally a bad idea: you should deduplicate your corpus instead. However, class-based models need custom discounts because they lack singleton unigrams. Provide up to three discounts (for adjusted counts 1, 2, and 3+), which will be applied to all orders where the closed-form estimates fail.");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, options), vm);
@@ -143,7 +160,7 @@ int main(int argc, char *argv[]) {
#endif
if (pipeline.vocab_size_for_unk && !pipeline.initial_probs.interpolate_unigrams) {
- std::cerr << "--vocab_pad requires --interpolate_unigrams" << std::endl;
+ std::cerr << "--vocab_pad requires --interpolate_unigrams be on" << std::endl;
return 1;
}
@@ -153,6 +170,15 @@ int main(int argc, char *argv[]) {
pipeline.disallowed_symbol_action = lm::THROW_UP;
}
+ if (vm.count("discount_fallback")) {
+ pipeline.discount.fallback = ParseDiscountFallback(discount_fallback);
+ pipeline.discount.bad_action = lm::COMPLAIN;
+ } else {
+ // Unused, just here to prevent the compiler from complaining about uninitialized.
+ pipeline.discount.fallback = lm::builder::Discount();
+ pipeline.discount.bad_action = lm::THROW_UP;
+ }
+
// parse pruning thresholds. These depend on order, so it is not done as a notifier.
pipeline.prune_thresholds = ParsePruning(pruning, pipeline.order);
diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc
index e91870808..21064ab3a 100644
--- a/lm/builder/pipeline.cc
+++ b/lm/builder/pipeline.cc
@@ -280,7 +280,7 @@ void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &maste
gamma_chains.push_back(read_backoffs);
gamma_chains.back() >> gammas[i].Source();
}
- master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds);
+ master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.output_q);
gamma_chains >> util::stream::kRecycle;
master.BufferFinal(counts);
}
@@ -317,7 +317,7 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
std::vector<uint64_t> counts;
std::vector<uint64_t> counts_pruned;
std::vector<Discount> discounts;
- master >> AdjustCounts(counts, counts_pruned, discounts, config.prune_thresholds);
+ master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, config.discount, discounts);
{
util::FixedArray<util::stream::FileBuffer> gammas;
diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh
index 4395622ed..09e1a4d52 100644
--- a/lm/builder/pipeline.hh
+++ b/lm/builder/pipeline.hh
@@ -1,6 +1,7 @@
#ifndef LM_BUILDER_PIPELINE_H
#define LM_BUILDER_PIPELINE_H
+#include "lm/builder/adjust_counts.hh"
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/header_info.hh"
#include "lm/lm_exception.hh"
@@ -19,6 +20,8 @@ struct PipelineConfig {
util::stream::SortConfig sort;
InitialProbabilitiesConfig initial_probs;
util::stream::ChainConfig read_backoffs;
+
+ // Include a header in the ARPA with some statistics?
bool verbose_header;
// Estimated vocabulary size. Used for sizing CorpusCount memory and
@@ -34,6 +37,12 @@ struct PipelineConfig {
// n-gram count thresholds for pruning. 0 values means no pruning for
// corresponding n-gram order
std::vector<uint64_t> prune_thresholds; //mjd
+
+ // What to do with discount failures.
+ DiscountConfig discount;
+
+ // Compute collapsed q values instead of probability and backoff
+ bool output_q;
/* Computing the perplexity of LMs with different vocabularies is hard. For
* example, the lowest perplexity is attained by a unigram model that
diff --git a/lm/builder/print.cc b/lm/builder/print.cc
index 75f15f0a6..aee6e1341 100644
--- a/lm/builder/print.cc
+++ b/lm/builder/print.cc
@@ -50,7 +50,7 @@ void PrintARPA::Run(const util::stream::ChainPositions &positions) {
out << "\\" << order << "-grams:" << '\n';
for (NGramStream stream(positions[order - 1]); stream; ++stream) {
// Correcting for numerical precision issues. Take that IRST.
- out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin());
+ out << stream->Value().complete.prob << '\t' << vocab_.Lookup(*stream->begin());
for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
out << ' ' << vocab_.Lookup(*i);
}