diff options
author | Kenneth Heafield <github@kheafield.com> | 2014-09-18 16:38:29 +0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2014-09-18 16:38:29 +0400 |
commit | f3a61c7883d51c12312218a3ee7093c97a2bce94 (patch) | |
tree | f1742383fd8ef0e7e62b29699ed36b056680e4d6 | |
parent | 2216424eadd12b5095be4302fc4ce0c777273c21 (diff) | |
parent | dad19cb55a6905ba309f87383b90b51c47e9ef9c (diff) |
Merge branch 'master' of https://github.com/kpu/kenlm
-rw-r--r-- | lm/builder/interpolate.cc | 81 | ||||
-rw-r--r-- | lm/builder/pipeline.hh | 2 |
2 files changed, 67 insertions, 16 deletions
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc index 3e1225d..cc89802 100644 --- a/lm/builder/interpolate.cc +++ b/lm/builder/interpolate.cc @@ -13,10 +13,60 @@ namespace lm { namespace builder { namespace { -class Callback { +/* Calculate q, the collapsed probability and backoff, as defined in + * @inproceedings{Heafield-rest, + * author = {Kenneth Heafield and Philipp Koehn and Alon Lavie}, + * title = {Language Model Rest Costs and Space-Efficient Storage}, + * year = {2012}, + * month = {July}, + * booktitle = {Proceedings of the Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning}, + * address = {Jeju Island, Korea}, + * pages = {1169--1178}, + * url = {http://kheafield.com/professional/edinburgh/rest\_paper.pdf}, + * } + * This is particularly convenient to calculate during interpolation because + * the needed backoff terms are already accessed at the same time. + */ +class OutputQ { + public: + explicit OutputQ(std::size_t order) : q_delta_(order) {} + + void Gram(unsigned order_minus_1, float full_backoff, ProbBackoff &out) { + float &q_del = q_delta_[order_minus_1]; + if (order_minus_1) { + // Divide by context's backoff (which comes in as out.backoff) + q_del = q_delta_[order_minus_1 - 1] / out.backoff * full_backoff; + } else { + q_del = full_backoff; + } + out.prob = log10(out.prob * q_del); + // TODO: stop wastefully outputting this! + out.backoff = 0.0; + } + + private: + // Product of backoffs in the numerator divided by backoffs in the + // denominator. Does not include + std::vector<float> q_delta_; +}; + +/* Default: output probability and backoff */ +class OutputProbBackoff { + public: + explicit OutputProbBackoff(std::size_t /*order*/) {} + + void Gram(unsigned /*order_minus_1*/, float full_backoff, ProbBackoff &out) const { + out.prob = log10(out.prob); + out.backoff = log10(full_backoff); + } +}; + +template <class Output> class Callback { public: Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds) - : backoffs_(backoffs.size()), probs_(backoffs.size() + 2), prune_thresholds_(prune_thresholds) { + : backoffs_(backoffs.size()), probs_(backoffs.size() + 2), + prune_thresholds_(prune_thresholds), + output_(backoffs.size() + 1 /* order */) { probs_[0] = uniform_prob; for (std::size_t i = 0; i < backoffs.size(); ++i) { backoffs_.push_back(backoffs[i]); @@ -40,15 +90,9 @@ class Callback { Payload &pay = gram.Value(); pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; probs_[order_minus_1 + 1] = pay.complete.prob; - pay.complete.prob = log10(pay.complete.prob); - - if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { - // This skips over ngrams if backoffs have been exhausted. - if(!backoffs_[order_minus_1]) { - pay.complete.backoff = 0.0; - return; - } + float out_backoff; + if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { if(prune_thresholds_[order_minus_1 + 1] > 0) { //Compute hash value for current context uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex)); @@ -58,20 +102,22 @@ class Callback { hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get()); if(current_hash == hashed_backoff->hash_value) { - pay.complete.backoff = log10(hashed_backoff->gamma); + out_backoff = hashed_backoff->gamma; ++backoffs_[order_minus_1]; } else { // Has been pruned away so it is not a context anymore - pay.complete.backoff = 0.0; + out_backoff = 1.0; } } else { - pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get())); + out_backoff = *static_cast<const float*>(backoffs_[order_minus_1].Get()); ++backoffs_[order_minus_1]; } } else { // Not a context. - pay.complete.backoff = 0.0; + out_backoff = 1.0; } + + output_.Gram(order_minus_1, out_backoff, pay.complete); } void Exit(unsigned, const NGram &) const {} @@ -81,6 +127,8 @@ class Callback { std::vector<float> probs_; const std::vector<uint64_t>& prune_thresholds_; + + Output output_; }; } // namespace @@ -92,8 +140,9 @@ Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions // perform order-wise interpolation void Interpolate::Run(const util::stream::ChainPositions &positions) { assert(positions.size() == backoffs_.size() + 1); - Callback callback(uniform_prob_, backoffs_, prune_thresholds_); - JointOrder<Callback, SuffixOrder>(positions, callback); + typedef Callback<OutputProbBackoff> C; + C callback(uniform_prob_, backoffs_, prune_thresholds_); + JointOrder<C, SuffixOrder>(positions, callback); } }} // namespaces diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh index 3d1cd7f..961eb51 100644 --- a/lm/builder/pipeline.hh +++ b/lm/builder/pipeline.hh @@ -20,6 +20,8 @@ struct PipelineConfig { util::stream::SortConfig sort; InitialProbabilitiesConfig initial_probs; util::stream::ChainConfig read_backoffs; + + // Include a header in the ARPA with some statistics? bool verbose_header; // Estimated vocabulary size. Used for sizing CorpusCount memory and |