diff options
Diffstat (limited to 'search/edge_generator.cc')
-rw-r--r-- | search/edge_generator.cc | 138 |
1 files changed, 69 insertions, 69 deletions
diff --git a/search/edge_generator.cc b/search/edge_generator.cc index 56239dfbb..654f9e36f 100644 --- a/search/edge_generator.cc +++ b/search/edge_generator.cc @@ -10,111 +10,111 @@ namespace search { -EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { -/* for (unsigned char i = 0; i < edge.Arity(); ++i) { - root.nt[i] = edge.GetVertex(i).RootPartial(); - } - for (unsigned char i = edge.Arity(); i < 2; ++i) { - root.nt[i] = kBlankPartialVertex; - }*/ - generate_.push(&root); - top_score_ = root.score; +EdgeGenerator::EdgeGenerator(PartialEdge root, Note note) : top_score_(root.GetScore()), arity_(root.GetArity()), note_(note) { + generate_.push(root); } namespace { -template <class Model> float FastScore(const Context<Model> &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) { - memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1)); +template <class Model> void FastScore(const Context<Model> &context, Arity victim, Arity before_idx, Arity incomplete, const PartialEdge previous, PartialEdge update) { + lm::ngram::ChartState *between = update.Between(); + const lm::ngram::ChartState *previous_between = previous.Between(); + const search::PartialVertex &previous_vertex = previous.NT()[victim]; - float ret = 0.0; - lm::ngram::ChartState *before, *after; - if (victim == 0) { - before = &update.between[0]; - after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1]; - } else { - assert(victim == 1); - assert(arity == 2); - before = &update.between[previous.nt[0].Complete() ? 0 : 1]; - after = &update.between[2]; - } - const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State(); - const PartialVertex &update_nt = update.nt[victim]; + lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1]; + // copy [0, after] + memcpy(between, previous_between, sizeof(lm::ngram::ChartState) * (before_idx + 2)); + + float adjustment = 0.0; + const lm::ngram::ChartState &previous_reveal = previous_vertex.State(); + const PartialVertex &update_nt = update.NT()[victim]; const lm::ngram::ChartState &update_reveal = update_nt.State(); - float just_after = 0.0; if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { - just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); + adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); } - if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) { - ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); + if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) { + adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); } if (update_nt.Complete()) { if (update_reveal.left.full) { before->left.full = true; } else { assert(update_reveal.left.length == update_reveal.right.length); - ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); - } - if (victim == 0) { - update.between[0].right = after->right; - } else { - update.between[2].left = before->left; + adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); } + before->right = after->right; + // Copy the others shifted one down, covering after. + memcpy(after, previous_between + before_idx + 2, sizeof(lm::ngram::ChartState) * (incomplete + 1 - before_idx - 2)); + } else { + // Copy [after + 1, incomplete] + memcpy(after + 1, previous_between + before_idx + 2, sizeof(lm::ngram::ChartState) * (incomplete + 1 - before_idx - 2)); } - return previous.score + (ret + just_after) * context.GetWeights().LM(); + update.SetScore(previous.GetScore() + adjustment * context.GetWeights().LM()); } } // namespace -template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) { +template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context, PartialEdgePool &partial_edge_pool) { assert(!generate_.empty()); - PartialEdge &top = *generate_.top(); + PartialEdge top = generate_.top(); generate_.pop(); - unsigned int victim = 0; - unsigned char lowest_length = 255; - for (unsigned char i = 0; i != arity_; ++i) { - if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { - lowest_length = top.nt[i].Length(); - victim = i; + PartialVertex *top_nt = top.NT(); + + Arity victim = 0; + Arity victim_completed; + Arity completed = 0; + // Select victim or return if complete. + { + unsigned char lowest_length = 255; + for (Arity i = 0; i != arity_; ++i) { + if (top_nt[i].Complete()) { + ++completed; + } else if (top_nt[i].Length() < lowest_length) { + lowest_length = top_nt[i].Length(); + victim = i; + victim_completed = completed; + } + } + if (lowest_length == 255) { + // Now top.between[0] is the full edge state. + top_score_ = generate_.empty() ? -kScoreInf : generate_.top().GetScore(); + return top; } - } - if (lowest_length == 255) { - // All states report complete. - top.between[0].right = top.between[arity_].right; - // Now top.between[0] is the full edge state. - top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return ⊤ } - unsigned int stay = !victim; - PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc()); - float old_bound = top.nt[victim].Bound(); - // The alternate's score will change because alternate.nt[victim] changes. - bool split = top.nt[victim].Split(continuation.nt[victim]); + float old_bound = top_nt[victim].Bound(); + PartialEdge continuation = partial_edge_pool.Allocate(arity_); + PartialVertex *continuation_nt = continuation.NT(); + // The alternate's score will change because the nt changes. + bool split = top_nt[victim].Split(continuation_nt[victim]); // top is now the alternate. - continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, arity_, top, continuation); + for (Arity i = 0; i < victim; ++i) continuation_nt[i] = top_nt[i]; + for (Arity i = victim + 1; i < arity_; ++i) continuation_nt[i] = top_nt[i]; + FastScore(context, victim, victim - victim_completed, arity_ - completed, top, continuation); // TODO: dedupe? - generate_.push(&continuation); + generate_.push(continuation); if (split) { // We have an alternate. - top.score += top.nt[victim].Bound() - old_bound; + top.SetScore(top_nt[victim].Bound() - old_bound); // TODO: dedupe? - generate_.push(&top); + generate_.push(top); } else { - partial_edge_pool.free(&top); + // TODO should free top here. + // Better would be changing Split. } - top_score_ = generate_.top()->score; - return NULL; + top_score_ = generate_.top().GetScore(); + // Invalid indicates no new hypothesis generated. + return PartialEdge(); } -template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, PartialEdgePool &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, PartialEdgePool &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, PartialEdgePool &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, PartialEdgePool &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, PartialEdgePool &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, PartialEdgePool &partial_edge_pool); } // namespace search |