Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lm/builder/interpolate.cc')
-rw-r--r--lm/builder/interpolate.cc54
1 files changed, 43 insertions, 11 deletions
diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc
index 500268069..db8537448 100644
--- a/lm/builder/interpolate.cc
+++ b/lm/builder/interpolate.cc
@@ -1,9 +1,12 @@
#include "lm/builder/interpolate.hh"
+#include "lm/builder/hash_gamma.hh"
#include "lm/builder/joint_order.hh"
-#include "lm/builder/multi_stream.hh"
+#include "lm/builder/ngram_stream.hh"
#include "lm/builder/sort.hh"
#include "lm/lm_exception.hh"
+#include "util/fixed_array.hh"
+#include "util/murmur_hash.hh"
#include <assert.h>
@@ -12,7 +15,8 @@ namespace {
class Callback {
public:
- Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) {
+ Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds)
+ : backoffs_(backoffs.size()), probs_(backoffs.size() + 2), prune_thresholds_(prune_thresholds) {
probs_[0] = uniform_prob;
for (std::size_t i = 0; i < backoffs.size(); ++i) {
backoffs_.push_back(backoffs[i]);
@@ -33,12 +37,37 @@ class Callback {
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
probs_[order_minus_1 + 1] = pay.complete.prob;
pay.complete.prob = log10(pay.complete.prob);
- // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling.
+
if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
- pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
- ++backoffs_[order_minus_1];
+ // This skips over ngrams if backoffs have been exhausted.
+ if(!backoffs_[order_minus_1]) {
+ pay.complete.backoff = 0.0;
+ return;
+ }
+
+ if(prune_thresholds_[order_minus_1 + 1] > 0) {
+ //Compute hash value for current context
+ uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));
+
+ const HashGamma *hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
+ while(backoffs_[order_minus_1] && current_hash != hashed_backoff->hash_value) {
+ hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
+ ++backoffs_[order_minus_1];
+ }
+
+ if(current_hash == hashed_backoff->hash_value) {
+ pay.complete.backoff = log10(hashed_backoff->gamma);
+ ++backoffs_[order_minus_1];
+ } else {
+ // Has been pruned away so it is not a context anymore
+ pay.complete.backoff = 0.0;
+ }
+ } else {
+ pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
+ ++backoffs_[order_minus_1];
+ }
} else {
- // Not a context.
+ // Not a context.
pay.complete.backoff = 0.0;
}
}
@@ -46,19 +75,22 @@ class Callback {
void Exit(unsigned, const NGram &) const {}
private:
- FixedArray<util::stream::Stream> backoffs_;
+ util::FixedArray<util::stream::Stream> backoffs_;
std::vector<float> probs_;
+ const std::vector<uint64_t>& prune_thresholds_;
};
} // namespace
-Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)
- : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {}
+Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds)
+ : uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
+ backoffs_(backoffs),
+ prune_thresholds_(prune_thresholds) {}
// perform order-wise interpolation
-void Interpolate::Run(const ChainPositions &positions) {
+void Interpolate::Run(const util::stream::ChainPositions &positions) {
assert(positions.size() == backoffs_.size() + 1);
- Callback callback(uniform_prob_, backoffs_);
+ Callback callback(uniform_prob_, backoffs_, prune_thresholds_);
JointOrder<Callback, SuffixOrder>(positions, callback);
}