Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/kpu/kenlm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2014-09-14 15:47:17 +0400
committerKenneth Heafield <github@kheafield.com>2014-09-14 15:47:17 +0400
commit8cf2bfe53c9847f6e47f3997fb32ccf187d61337 (patch)
tree4c12f1cd0ca3923f7796a8c63548dd34180f249e
parent0e891289fb49b6d4d72bbe0374dfec10a89e4c67 (diff)
Isolate prob+backoff vs q computation
-rw-r--r--lm/builder/interpolate.cc81
1 files changed, 65 insertions, 16 deletions
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc
index 3e1225d..cc89802 100644
--- a/lm/builder/interpolate.cc
+++ b/lm/builder/interpolate.cc
@@ -13,10 +13,60 @@
namespace lm { namespace builder {
namespace {
-class Callback {
+/* Calculate q, the collapsed probability and backoff, as defined in
+ * @inproceedings{Heafield-rest,
+ * author = {Kenneth Heafield and Philipp Koehn and Alon Lavie},
+ * title = {Language Model Rest Costs and Space-Efficient Storage},
+ * year = {2012},
+ * month = {July},
+ * booktitle = {Proceedings of the Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning},
+ * address = {Jeju Island, Korea},
+ * pages = {1169--1178},
+ * url = {http://kheafield.com/professional/edinburgh/rest\_paper.pdf},
+ * }
+ * This is particularly convenient to calculate during interpolation because
+ * the needed backoff terms are already accessed at the same time.
+ */
+class OutputQ {
+ public:
+ explicit OutputQ(std::size_t order) : q_delta_(order) {}
+
+ void Gram(unsigned order_minus_1, float full_backoff, ProbBackoff &out) {
+ float &q_del = q_delta_[order_minus_1];
+ if (order_minus_1) {
+ // Divide by context's backoff (which comes in as out.backoff)
+ q_del = q_delta_[order_minus_1 - 1] / out.backoff * full_backoff;
+ } else {
+ q_del = full_backoff;
+ }
+ out.prob = log10(out.prob * q_del);
+ // TODO: stop wastefully outputting this!
+ out.backoff = 0.0;
+ }
+
+ private:
+ // Product of backoffs in the numerator divided by backoffs in the
+ // denominator. Does not include
+ std::vector<float> q_delta_;
+};
+
+/* Default: output probability and backoff */
+class OutputProbBackoff {
+ public:
+ explicit OutputProbBackoff(std::size_t /*order*/) {}
+
+ void Gram(unsigned /*order_minus_1*/, float full_backoff, ProbBackoff &out) const {
+ out.prob = log10(out.prob);
+ out.backoff = log10(full_backoff);
+ }
+};
+
+template <class Output> class Callback {
public:
Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds)
- : backoffs_(backoffs.size()), probs_(backoffs.size() + 2), prune_thresholds_(prune_thresholds) {
+ : backoffs_(backoffs.size()), probs_(backoffs.size() + 2),
+ prune_thresholds_(prune_thresholds),
+ output_(backoffs.size() + 1 /* order */) {
probs_[0] = uniform_prob;
for (std::size_t i = 0; i < backoffs.size(); ++i) {
backoffs_.push_back(backoffs[i]);
@@ -40,15 +90,9 @@ class Callback {
Payload &pay = gram.Value();
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
probs_[order_minus_1 + 1] = pay.complete.prob;
- pay.complete.prob = log10(pay.complete.prob);
-
- if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
- // This skips over ngrams if backoffs have been exhausted.
- if(!backoffs_[order_minus_1]) {
- pay.complete.backoff = 0.0;
- return;
- }
+ float out_backoff;
+ if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
if(prune_thresholds_[order_minus_1 + 1] > 0) {
//Compute hash value for current context
uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));
@@ -58,20 +102,22 @@ class Callback {
hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
if(current_hash == hashed_backoff->hash_value) {
- pay.complete.backoff = log10(hashed_backoff->gamma);
+ out_backoff = hashed_backoff->gamma;
++backoffs_[order_minus_1];
} else {
// Has been pruned away so it is not a context anymore
- pay.complete.backoff = 0.0;
+ out_backoff = 1.0;
}
} else {
- pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
+ out_backoff = *static_cast<const float*>(backoffs_[order_minus_1].Get());
++backoffs_[order_minus_1];
}
} else {
// Not a context.
- pay.complete.backoff = 0.0;
+ out_backoff = 1.0;
}
+
+ output_.Gram(order_minus_1, out_backoff, pay.complete);
}
void Exit(unsigned, const NGram &) const {}
@@ -81,6 +127,8 @@ class Callback {
std::vector<float> probs_;
const std::vector<uint64_t>& prune_thresholds_;
+
+ Output output_;
};
} // namespace
@@ -92,8 +140,9 @@ Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions
// perform order-wise interpolation
void Interpolate::Run(const util::stream::ChainPositions &positions) {
assert(positions.size() == backoffs_.size() + 1);
- Callback callback(uniform_prob_, backoffs_, prune_thresholds_);
- JointOrder<Callback, SuffixOrder>(positions, callback);
+ typedef Callback<OutputProbBackoff> C;
+ C callback(uniform_prob_, backoffs_, prune_thresholds_);
+ JointOrder<C, SuffixOrder>(positions, callback);
}
}} // namespaces