/* Efficient left and right language model state for sentence fragments. * Intended usage: * Store ChartState with every chart entry. * To do a rule application: * 1. Make a ChartState object for your new entry. * 2. Construct RuleScore. * 3. Going from left to right, call Terminal or NonTerminal. * For terminals, just pass the vocab id. * For non-terminals, pass that non-terminal's ChartState. * If your decoder expects scores inclusive of subtree scores (i.e. you * label entries with the highest-scoring path), pass the non-terminal's * score as prob. * If your decoder expects relative scores and will walk the chart later, * pass prob = 0.0. * In other words, the only effect of prob is that it gets added to the * returned log probability. * 4. Call Finish. It returns the log probability. * * There's a couple more details: * Do not pass to Terminal as it is formally not a word in the sentence, * only context. Instead, call BeginSentence. If called, it should be the * first call after RuleScore is constructed (since is always the * leftmost). * * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal. * * Hashing and sorting comparison operators are provided. All state objects * are POD. If you intend to use memcmp on raw state objects, you must call * ZeroRemaining first, as the value of array entries beyond length is * otherwise undefined. * * Usage is of course not limited to chart decoding. Anything that generates * sentence fragments missing left context could benefit. For example, a * phrase-based decoder could pre-score phrases, storing ChartState with each * phrase, even if hypotheses are generated left-to-right. */ #ifndef LM_LEFT__ #define LM_LEFT__ #include "lm/max_order.hh" #include "lm/model.hh" #include "lm/return.hh" #include "util/murmur_hash.hh" #include namespace lm { namespace ngram { struct Left { bool operator==(const Left &other) const { return (length == other.length) && pointers[length - 1] == other.pointers[length - 1]; } int Compare(const Left &other) const { if (length != other.length) return length < other.length ? -1 : 1; if (pointers[length - 1] > other.pointers[length - 1]) return 1; if (pointers[length - 1] < other.pointers[length - 1]) return -1; return 0; } bool operator<(const Left &other) const { if (length != other.length) return length < other.length; return pointers[length - 1] < other.pointers[length - 1]; } void ZeroRemaining() { for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i) *i = 0; } unsigned char length; uint64_t pointers[kMaxOrder - 1]; }; inline size_t hash_value(const Left &left) { return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]); } struct ChartState { bool operator==(const ChartState &other) { return (left == other.left) && (right == other.right) && (full == other.full); } int Compare(const ChartState &other) const { int lres = left.Compare(other.left); if (lres) return lres; int rres = right.Compare(other.right); if (rres) return rres; return (int)full - (int)other.full; } bool operator<(const ChartState &other) const { return Compare(other) == -1; } void ZeroRemaining() { left.ZeroRemaining(); right.ZeroRemaining(); } Left left; bool full; State right; }; inline size_t hash_value(const ChartState &state) { size_t hashes[2]; hashes[0] = hash_value(state.left); hashes[1] = hash_value(state.right); return util::MurmurHashNative(hashes, sizeof(size_t), state.full); } template class RuleScore { public: explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } void BeginSentence() { out_.right = model_.BeginSentenceState(); // out_.left is empty. left_done_ = true; } void Terminal(WordIndex word) { State copy(out_.right); FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); prob_ += ret.prob; if (left_done_) return; if (ret.independent_left) { left_done_ = true; return; } out_.left.pointers[out_.left.length++] = ret.extend_left; if (out_.right.length != copy.length + 1) left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob) { prob_ = prob; out_ = in; left_done_ = in.full; } void NonTerminal(const ChartState &in, float prob) { prob_ += prob; if (!in.left.length) { if (in.full) { for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; left_done_ = true; out_.right = in.right; } return; } if (!out_.right.length) { out_.right = in.right; if (left_done_) return; if (out_.left.length) { left_done_ = true; } else { out_.left = in.left; left_done_ = in.full; } return; } float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; float *back = backoffs, *back2 = backoffs2; unsigned char next_use = out_.right.length; // First word if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; // Words after the first, so extending a bigram to begin with for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { if (ExtendLeft(in, next_use, extend_length, back, back2)) return; std::swap(back, back2); } if (in.full) { for (const float *i = back; i != back + next_use; ++i) prob_ += *i; left_done_ = true; out_.right = in.right; return; } // Right state was minimized, so it's already independent of the new words to the left. if (in.right.length < in.left.length) { out_.right = in.right; return; } // Shift exisiting words down. for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { *(i + in.right.length) = *i; } // Add words from in.right. std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); std::copy(back, back + next_use, out_.right.backoff + in.right.length); out_.right.length = in.right.length + next_use; } float Finish() { // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. out_.full = left_done_ || (out_.left.length == model_.Order() - 1); return prob_; } private: bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { ProcessRet(model_.ExtendLeft( out_.right.words, out_.right.words + next_use, // Words to extend into back_in, // Backoffs to use in.left.pointers[extend_length - 1], extend_length, // Words to be extended back_out, // Backoffs for the next score next_use)); // Length of n-gram to use in next scoring. if (next_use != out_.right.length) { left_done_ = true; if (!next_use) { out_.right = in.right; // Early exit. return true; } } // Continue scoring. return false; } void ProcessRet(const FullScoreReturn &ret) { prob_ += ret.prob; if (left_done_) return; if (ret.independent_left) { left_done_ = true; return; } out_.left.pointers[out_.left.length++] = ret.extend_left; } const M &model_; ChartState &out_; bool left_done_; float prob_; }; } // namespace ngram } // namespace lm #endif // LM_LEFT__