Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/search
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2013-02-14 17:11:53 +0400
committerKenneth Heafield <github@kheafield.com>2013-02-14 17:11:53 +0400
commit8ef095e8faa81e0c218ecb5229adc1bd5b70fe29 (patch)
treead753593da49eac18bc4ae5c6c96155e040e6a76 /search
parent10012fac15327d564e276a563daf5b12c5d43534 (diff)
Update incremental search, cuts runtime by a third
Diffstat (limited to 'search')
-rw-r--r--search/Jamfile2
-rw-r--r--search/edge_generator.cc12
-rw-r--r--search/vertex.cc204
-rw-r--r--search/vertex.hh121
-rw-r--r--search/vertex_generator.cc68
-rw-r--r--search/vertex_generator.hh41
6 files changed, 270 insertions, 178 deletions
diff --git a/search/Jamfile b/search/Jamfile
index f6433e0e3..1dee51cec 100644
--- a/search/Jamfile
+++ b/search/Jamfile
@@ -1 +1 @@
-fakelib search : edge_generator.cc nbest.cc rule.cc vertex.cc vertex_generator.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
+fakelib search : edge_generator.cc nbest.cc rule.cc vertex.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
diff --git a/search/edge_generator.cc b/search/edge_generator.cc
index eacf5de5c..dd9d61e41 100644
--- a/search/edge_generator.cc
+++ b/search/edge_generator.cc
@@ -54,20 +54,20 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
Arity victim = 0;
Arity victim_completed;
Arity incomplete;
+ unsigned char lowest_niceness = 255;
// Select victim or return if complete.
{
Arity completed = 0;
- unsigned char lowest_length = 255;
for (Arity i = 0; i != arity; ++i) {
if (top_nt[i].Complete()) {
++completed;
- } else if (top_nt[i].Length() < lowest_length) {
- lowest_length = top_nt[i].Length();
+ } else if (top_nt[i].Niceness() < lowest_niceness) {
+ lowest_niceness = top_nt[i].Niceness();
victim = i;
victim_completed = completed;
}
}
- if (lowest_length == 255) {
+ if (lowest_niceness == 255) {
return top;
}
incomplete = arity - completed;
@@ -92,10 +92,14 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
generate_.push(alternate);
}
+#ifndef NDEBUG
+ Score before = top.GetScore();
+#endif
// top is now the continuation.
FastScore(context, victim, victim - victim_completed, incomplete, old_value, top);
// TODO: dedupe?
generate_.push(top);
+ assert(lowest_niceness != 254 || top.GetScore() == before);
// Invalid indicates no new hypothesis generated.
return PartialEdge();
diff --git a/search/vertex.cc b/search/vertex.cc
index 45842982c..bf40810e9 100644
--- a/search/vertex.cc
+++ b/search/vertex.cc
@@ -2,6 +2,8 @@
#include "search/context.hh"
+#include <boost/unordered_map.hpp>
+
#include <algorithm>
#include <functional>
@@ -11,45 +13,193 @@ namespace search {
namespace {
-struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> {
- bool operator()(const VertexNode *first, const VertexNode *second) const {
- return first->Bound() > second->Bound();
+const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
+
+class DivideLeft {
+ public:
+ explicit DivideLeft(unsigned char index)
+ : index_(index) {}
+
+ uint64_t operator()(const lm::ngram::ChartState &state) const {
+ return (index_ < state.left.length) ?
+ state.left.pointers[index_] :
+ (kCompleteAdd - state.left.full);
+ }
+
+ private:
+ unsigned char index_;
+};
+
+class DivideRight {
+ public:
+ explicit DivideRight(unsigned char index)
+ : index_(index) {}
+
+ uint64_t operator()(const lm::ngram::ChartState &state) const {
+ return (index_ < state.right.length) ?
+ static_cast<uint64_t>(state.right.words[index_]) :
+ (kCompleteAdd - state.left.full);
+ }
+
+ private:
+ unsigned char index_;
+};
+
+template <class Divider> void Split(const Divider &divider, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) {
+ // Map from divider to index in extend.
+ typedef boost::unordered_map<uint64_t, std::size_t> Lookup;
+ Lookup lookup;
+ for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) {
+ uint64_t key = divider(i->state);
+ std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size())));
+ if (res.second) {
+ extend.resize(extend.size() + 1);
+ extend.back().AppendHypothesis(*i);
+ } else {
+ extend[res.first->second].AppendHypothesis(*i);
+ }
}
+ //assert((extend.size() != 1) || (hypos.size() == 1));
+}
+
+lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) {
+ return right.words[index];
+}
+
+uint64_t Identify(const lm::ngram::Left &left, unsigned char index) {
+ return left.pointers[index];
+}
+
+template <class Side> class DetermineSame {
+ public:
+ DetermineSame(const Side &side, unsigned char guaranteed)
+ : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {}
+
+ void Consider(const Side &other) {
+ if (shared_ != other.length) {
+ complete_ = false;
+ if (shared_ > other.length)
+ shared_ = other.length;
+ }
+ for (unsigned char i = guaranteed_; i < shared_; ++i) {
+ if (Identify(side_, i) != Identify(other, i)) {
+ shared_ = i;
+ complete_ = false;
+ return;
+ }
+ }
+ }
+
+ unsigned char Shared() const { return shared_; }
+
+ bool Complete() const { return complete_; }
+
+ private:
+ const Side &side_;
+ unsigned char guaranteed_, shared_;
+ bool complete_;
};
+// Custom enum to save memory: valid values of policy_.
+// Alternate and there is still alternation to do.
+const unsigned char kPolicyAlternate = 0;
+// Branch based on left state only, because right ran out or this is a left tree.
+const unsigned char kPolicyOneLeft = 1;
+// Branch based on right state only.
+const unsigned char kPolicyOneRight = 2;
+// Reveal everything in the next branch. Used to terminate the left/right policies.
+// static const unsigned char kPolicyEverything = 3;
+
+} // namespace
+
+namespace {
+struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> {
+ bool operator()(const HypoState &first, const HypoState &second) const {
+ return first.score > second.score;
+ }
+};
} // namespace
-void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) {
- if (Complete()) {
- assert(end_);
- assert(extend_.empty());
- return;
+void VertexNode::FinishRoot() {
+ std::sort(hypos_.begin(), hypos_.end(), GreaterByScore());
+ extend_.clear();
+ // HACK: extend to one hypo so that root can be blank.
+ state_.left.full = false;
+ state_.left.length = 0;
+ state_.right.length = 0;
+ right_full_ = false;
+ niceness_ = 0;
+ policy_ = kPolicyAlternate;
+ if (hypos_.size() == 1) {
+ extend_.resize(1);
+ extend_.front().AppendHypothesis(hypos_.front());
+ extend_.front().FinishedAppending(0, 0);
+ }
+ if (hypos_.empty()) {
+ bound_ = -INFINITY;
+ } else {
+ bound_ = hypos_.front().score;
}
- if (extend_.size() == 1) {
- parent_ptr = extend_[0];
- extend_[0]->RecursiveSortAndSet(context, parent_ptr);
- context.DeleteVertexNode(this);
- return;
+}
+
+void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) {
+ assert(!hypos_.empty());
+ assert(extend_.empty());
+ bound_ = hypos_.front().score;
+ state_ = hypos_.front().state;
+ bool all_full = state_.left.full;
+ bool all_non_full = !state_.left.full;
+ DetermineSame<lm::ngram::Left> left(state_.left, common_left);
+ DetermineSame<lm::ngram::Right> right(state_.right, common_right);
+ for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) {
+ all_full &= i->state.left.full;
+ all_non_full &= !i->state.left.full;
+ left.Consider(i->state.left);
+ right.Consider(i->state.right);
}
- for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
- (*i)->RecursiveSortAndSet(context, *i);
+ state_.left.full = all_full && left.Complete();
+ right_full_ = all_full && right.Complete();
+ state_.left.length = left.Shared();
+ state_.right.length = right.Shared();
+
+ if (!all_full && !all_non_full) {
+ policy_ = kPolicyAlternate;
+ } else if (left.Complete()) {
+ policy_ = kPolicyOneRight;
+ } else if (right.Complete()) {
+ policy_ = kPolicyOneLeft;
+ } else {
+ policy_ = kPolicyAlternate;
}
- std::sort(extend_.begin(), extend_.end(), GreaterByBound());
- bound_ = extend_.front()->Bound();
+ niceness_ = state_.left.length + state_.right.length;
}
-void VertexNode::SortAndSet(ContextBase &context) {
- // This is the root. The root might be empty.
- if (extend_.empty()) {
- bound_ = -INFINITY;
- return;
+void VertexNode::BuildExtend() {
+ // Already built.
+ if (!extend_.empty()) return;
+ // Nothing to build since this is a leaf.
+ if (hypos_.size() <= 1) return;
+ bool left_branch = true;
+ switch (policy_) {
+ case kPolicyAlternate:
+ left_branch = (state_.left.length <= state_.right.length);
+ break;
+ case kPolicyOneLeft:
+ left_branch = true;
+ break;
+ case kPolicyOneRight:
+ left_branch = false;
+ break;
+ }
+ if (left_branch) {
+ Split(DivideLeft(state_.left.length), hypos_, extend_);
+ } else {
+ Split(DivideRight(state_.right.length), hypos_, extend_);
}
- // The root cannot be replaced. There's always one transition.
- for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
- (*i)->RecursiveSortAndSet(context, *i);
+ for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
+ // TODO: provide more here for branching?
+ i->FinishedAppending(state_.left.length, state_.right.length);
}
- std::sort(extend_.begin(), extend_.end(), GreaterByBound());
- bound_ = extend_.front()->Bound();
}
} // namespace search
diff --git a/search/vertex.hh b/search/vertex.hh
index 10b3339b9..baeae3469 100644
--- a/search/vertex.hh
+++ b/search/vertex.hh
@@ -16,59 +16,74 @@ namespace search {
class ContextBase;
+struct HypoState {
+ History history;
+ lm::ngram::ChartState state;
+ Score score;
+};
+
class VertexNode {
public:
- VertexNode() : end_() {}
-
- void InitRoot() {
- extend_.clear();
- state_.left.full = false;
- state_.left.length = 0;
- state_.right.length = 0;
- right_full_ = false;
- end_ = History();
+ VertexNode() {}
+
+ void InitRoot() { hypos_.clear(); }
+
+ /* The steps of building a VertexNode:
+ * 1. Default construct.
+ * 2. AppendHypothesis at least once, possibly multiple times.
+ * 3. FinishAppending with the number of words on left and right guaranteed
+ * to be common.
+ * 4. If !Complete(), call BuildExtend to construct the extensions
+ */
+ // Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending.
+ void AppendHypothesis(const NBestComplete &best) {
+ assert(hypos_.empty() || !(hypos_.front().state == *best.state));
+ HypoState hypo;
+ hypo.history = best.history;
+ hypo.state = *best.state;
+ hypo.score = best.score;
+ hypos_.push_back(hypo);
+ }
+ void AppendHypothesis(const HypoState &hypo) {
+ hypos_.push_back(hypo);
}
- lm::ngram::ChartState &MutableState() { return state_; }
- bool &MutableRightFull() { return right_full_; }
+ // Sort hypotheses for the root.
+ void FinishRoot();
- void AddExtend(VertexNode *next) {
- extend_.push_back(next);
- }
+ void FinishedAppending(const unsigned char common_left, const unsigned char common_right);
- void SetEnd(History end, Score score) {
- assert(!end_);
- end_ = end;
- bound_ = score;
- }
-
- void SortAndSet(ContextBase &context);
+ void BuildExtend();
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
- return !end_ && extend_.empty();
+ return hypos_.empty() && extend_.empty();
}
bool Complete() const {
- return end_;
+ // HACK: prevent root from being complete. TODO: allow root to be complete.
+ return hypos_.size() == 1 && extend_.empty();
}
const lm::ngram::ChartState &State() const { return state_; }
bool RightFull() const { return right_full_; }
+ // Priority relative to other non-terminals. 0 is highest.
+ unsigned char Niceness() const { return niceness_; }
+
Score Bound() const {
return bound_;
}
- unsigned char Length() const {
- return state_.left.length + state_.right.length;
- }
-
// Will be invalid unless this is a leaf.
- const History End() const { return end_; }
+ const History End() const {
+ assert(hypos_.size() == 1);
+ return hypos_.front().history;
+ }
- const VertexNode &operator[](size_t index) const {
- return *extend_[index];
+ VertexNode &operator[](size_t index) {
+ assert(!extend_.empty());
+ return extend_[index];
}
size_t Size() const {
@@ -76,22 +91,26 @@ class VertexNode {
}
private:
- void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);
+ // Hypotheses to be split.
+ std::vector<HypoState> hypos_;
- std::vector<VertexNode*> extend_;
+ std::vector<VertexNode> extend_;
lm::ngram::ChartState state_;
bool right_full_;
+ unsigned char niceness_;
+
+ unsigned char policy_;
+
Score bound_;
- History end_;
};
class PartialVertex {
public:
PartialVertex() {}
- explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}
+ explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {}
bool Empty() const { return back_->Empty(); }
@@ -100,17 +119,14 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
- Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }
-
- unsigned char Length() const { return back_->Length(); }
+ Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); }
- bool HasAlternative() const {
- return index_ + 1 < back_->Size();
- }
+ unsigned char Niceness() const { return back_->Niceness(); }
// Split into continuation and alternative, rendering this the continuation.
bool Split(PartialVertex &alternative) {
assert(!Complete());
+ back_->BuildExtend();
bool ret;
if (index_ + 1 < back_->Size()) {
alternative.index_ = index_ + 1;
@@ -129,7 +145,7 @@ class PartialVertex {
}
private:
- const VertexNode *back_;
+ VertexNode *back_;
unsigned int index_;
};
@@ -139,10 +155,21 @@ class Vertex {
public:
Vertex() {}
- PartialVertex RootPartial() const { return PartialVertex(root_); }
+ //PartialVertex RootFirst() const { return PartialVertex(right_); }
+ PartialVertex RootAlternate() { return PartialVertex(root_); }
+ //PartialVertex RootLast() const { return PartialVertex(left_); }
+
+ bool Empty() const {
+ return root_.Empty();
+ }
+
+ Score Bound() const {
+ return root_.Bound();
+ }
- const History BestChild() const {
- PartialVertex top(RootPartial());
+ const History BestChild() {
+ // left_ and right_ are not set at the root.
+ PartialVertex top(RootAlternate());
if (top.Empty()) {
return History();
} else {
@@ -158,6 +185,12 @@ class Vertex {
template <class Output> friend class VertexGenerator;
template <class Output> friend class RootVertexGenerator;
VertexNode root_;
+
+ // These will not be set for the root vertex.
+ // Branches only on left state.
+ //VertexNode left_;
+ // Branches only on right state.
+ //VertexNode right_;
};
} // namespace search
diff --git a/search/vertex_generator.cc b/search/vertex_generator.cc
deleted file mode 100644
index 73139ffc5..000000000
--- a/search/vertex_generator.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-#include "search/vertex_generator.hh"
-
-#include "lm/left.hh"
-#include "search/context.hh"
-#include "search/edge.hh"
-
-#include <boost/unordered_map.hpp>
-#include <boost/version.hpp>
-
-#include <stdint.h>
-
-namespace search {
-
-#if BOOST_VERSION > 104200
-namespace {
-
-const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
-
-Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
- Trie &next = node.extend[added];
- if (!next.under) {
- next.under = context.NewVertexNode();
- lm::ngram::ChartState &writing = next.under->MutableState();
- writing = state;
- writing.left.full &= left_full && state.left.full;
- next.under->MutableRightFull() = right_full && state.left.full;
- writing.left.length = left;
- writing.right.length = right;
- node.under->AddExtend(next.under);
- }
- return next;
-}
-
-} // namespace
-
-void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) {
- const lm::ngram::ChartState &state = *end.state;
-
- unsigned char left = 0, right = 0;
- Trie *node = &root;
- while (true) {
- if (left == state.left.length) {
- node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false);
- for (; right < state.right.length; ++right) {
- node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false);
- }
- break;
- }
- node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false);
- left++;
- if (right == state.right.length) {
- node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true);
- for (; left < state.left.length; ++left) {
- node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true);
- }
- break;
- }
- node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false);
- right++;
- }
-
- node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
- node->under->SetEnd(end.history, end.score);
-}
-
-#endif // BOOST_VERSION
-
-} // namespace search
diff --git a/search/vertex_generator.hh b/search/vertex_generator.hh
index d0e0dacc9..6fce508d6 100644
--- a/search/vertex_generator.hh
+++ b/search/vertex_generator.hh
@@ -5,13 +5,6 @@
#include "search/types.hh"
#include "search/vertex.hh"
-#include <boost/unordered_map.hpp>
-#include <boost/version.hpp>
-
-#if BOOST_VERSION <= 104200
-#include "util/exception.hh"
-#endif
-
namespace lm {
namespace ngram {
class ChartState;
@@ -22,45 +15,25 @@ namespace search {
class ContextBase;
-#if BOOST_VERSION > 104200
-// Parallel structure to VertexNode.
-struct Trie {
- Trie() : under(NULL) {}
-
- VertexNode *under;
- boost::unordered_map<uint64_t, Trie> extend;
-};
-
-void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end);
-
-#endif // BOOST_VERSION
-
// Output makes the single-best or n-best list.
template <class Output> class VertexGenerator {
public:
- VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {
- gen.root_.InitRoot();
- }
+ VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {}
void NewHypothesis(PartialEdge partial) {
nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);
}
void FinishedSearch() {
-#if BOOST_VERSION > 104200
- Trie root;
- root.under = &gen_.root_;
+ gen_.root_.InitRoot();
for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) {
- AddHypothesis(context_, root, nbest_.Complete(i->second));
+ gen_.root_.AppendHypothesis(nbest_.Complete(i->second));
}
existing_.clear();
- root.under->SortAndSet(context_);
-#else
- UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search.");
-#endif
+ gen_.root_.FinishRoot();
}
- const Vertex &Generating() const { return gen_; }
+ Vertex &Generating() { return gen_; }
private:
ContextBase &context_;
@@ -87,8 +60,8 @@ template <class Output> class RootVertexGenerator {
void FinishedSearch() {
gen_.root_.InitRoot();
- NBestComplete completed(out_.Complete(combine_));
- gen_.root_.SetEnd(completed.history, completed.score);
+ gen_.root_.AppendHypothesis(out_.Complete(combine_));
+ gen_.root_.FinishRoot();
}
private: