diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-10-18 19:49:54 +0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-10-18 19:49:54 +0400 |
commit | 05e91def21d90d848b5f804f2fc9bd974bc496df (patch) | |
tree | 8b2e7449c5856f7838f7224692a165afd9bf0b0e /search | |
parent | bedc7136b830d11f1ef32664038658d44688e032 (diff) |
Change VertexGenerator to batch, Remove kScoreInf
Diffstat (limited to 'search')
-rw-r--r-- | search/source.hh | 48 | ||||
-rw-r--r-- | search/types.hh | 3 | ||||
-rw-r--r-- | search/vertex.hh | 1 | ||||
-rw-r--r-- | search/vertex_generator.cc | 91 | ||||
-rw-r--r-- | search/vertex_generator.hh | 33 |
5 files changed, 59 insertions, 117 deletions
diff --git a/search/source.hh b/search/source.hh deleted file mode 100644 index 11839f7bc..000000000 --- a/search/source.hh +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef SEARCH_SOURCE__ -#define SEARCH_SOURCE__ - -#include "search/types.hh" - -#include <assert.h> -#include <vector> - -namespace search { - -template <class Final> class Source { - public: - Source() : bound_(kScoreInf) {} - - Index Size() const { - return final_.size(); - } - - Score Bound() const { - return bound_; - } - - const Final &operator[](Index index) const { - return *final_[index]; - } - - Score ScoreOrBound(Index index) const { - return Size() > index ? final_[index]->Total() : Bound(); - } - - protected: - void AddFinal(const Final &store) { - final_.push_back(&store); - } - - void SetBound(Score to) { - assert(to <= bound_ + 0.001); - bound_ = to; - } - - private: - std::vector<const Final *> final_; - - Score bound_; -}; - -} // namespace search -#endif // SEARCH_SOURCE__ diff --git a/search/types.hh b/search/types.hh index 46bc95288..06eb5bfa2 100644 --- a/search/types.hh +++ b/search/types.hh @@ -1,14 +1,11 @@ #ifndef SEARCH_TYPES__ #define SEARCH_TYPES__ -#include <cmath> - #include <stdint.h> namespace search { typedef float Score; -const Score kScoreInf = INFINITY; typedef uint32_t Arity; diff --git a/search/vertex.hh b/search/vertex.hh index 1cdcae30d..2c2e46d3c 100644 --- a/search/vertex.hh +++ b/search/vertex.hh @@ -26,7 +26,6 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - bound_ = -kScoreInf; end_ = NULL; } diff --git a/search/vertex_generator.cc b/search/vertex_generator.cc index 53220bc55..4113ae1d9 100644 --- a/search/vertex_generator.cc +++ b/search/vertex_generator.cc @@ -10,7 +10,6 @@ namespace search { VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); - root_.under = &gen.root_; } namespace { @@ -26,68 +25,76 @@ void FillFinal(PartialEdge partial, Final &out) { } } -} // namespace +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} -void VertexGenerator::NewHypothesis(PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL))); - if (!got.second) { - // Found it already. - Final &exists = *got.first->second; - if (exists.Bound() < partial.GetScore()) - FillFinal(partial, exists); - return; + VertexNode *under; + boost::unordered_map<uint64_t, Trie> extend; +}; + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + Trie &next = node.extend[added]; + if (!next.under) { + next.under = context.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); } + return next; +} + +void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { + Final *final = context.NewFinal(); + FillFinal(partial, *final); + VertexNode &node = *starter.under; + assert(!node.End()); + node.SetEnd(final); +} + +void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + unsigned char left = 0, right = 0; - Trie *node = &root_; + Trie *node = &root; while (true) { if (left == state.left.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); for (; right < state.right.length; ++right) { - node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); } break; } - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); left++; if (right == state.right.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); for (; left < state.left.length; ++left) { - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); } break; } - node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); right++; } - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, partial); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + CompleteTransition(context, *node, partial); } -VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { - VertexGenerator::Trie &next = node.extend[added]; - if (!next.under) { - next.under = context_.NewVertexNode(); - lm::ngram::ChartState &writing = next.under->MutableState(); - writing = state; - writing.left.full &= left_full && state.left.full; - next.under->MutableRightFull() = right_full && state.left.full; - writing.left.length = left; - writing.right.length = right; - node.under->AddExtend(next.under); - } - return next; -} +} // namespace -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, PartialEdge partial) { - VertexNode &node = *starter.under; - assert(node.State().left.full == state.left.full); - assert(!node.End()); - Final *final = context_.NewFinal(); - FillFinal(partial, *final); - node.SetEnd(final); - return final; +void VertexGenerator::FinishedSearch() { + Trie root; + root.under = &gen_.root_; + for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, i->second); + } + root.under->SortAndSet(context_, NULL); } } // namespace search diff --git a/search/vertex_generator.hh b/search/vertex_generator.hh index 96df3e0a8..8122aaa5f 100644 --- a/search/vertex_generator.hh +++ b/search/vertex_generator.hh @@ -1,13 +1,11 @@ #ifndef SEARCH_VERTEX_GENERATOR__ #define SEARCH_VERTEX_GENERATOR__ -#include "search/note.hh" +#include "search/edge.hh" #include "search/vertex.hh" #include <boost/unordered_map.hpp> -#include <queue> - namespace lm { namespace ngram { class ChartState; @@ -18,40 +16,29 @@ namespace search { class ContextBase; class Final; -struct PartialEdge; class VertexGenerator { public: VertexGenerator(ContextBase &context, Vertex &gen); - void NewHypothesis(PartialEdge partial); - - void FinishedSearch() { - root_.under->SortAndSet(context_, NULL); + void NewHypothesis(PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial))); + if (ret.second && ret.first->second < partial) { + ret.first->second = partial; + } } + void FinishedSearch(); + const Vertex &Generating() const { return gen_; } private: - // Parallel structure to VertexNode. - struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map<uint64_t, Trie> extend; - }; - - Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); - - Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, PartialEdge partial); - ContextBase &context_; Vertex &gen_; - Trie root_; - - typedef boost::unordered_map<uint64_t, Final*> Existing; + typedef boost::unordered_map<uint64_t, PartialEdge> Existing; Existing existing_; }; |