Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/search
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-10-18 21:54:38 +0400
committerKenneth Heafield <github@kheafield.com>2012-10-18 21:54:38 +0400
commit39664ac8f48c44fa4a81ff63de7ff2d31a956b5e (patch)
treef0928f6ff3f63bb1f68ca0b08009173a9968d85d /search
parent05e91def21d90d848b5f804f2fc9bd974bc496df (diff)
Support arbitrary arity
Also, some windows fixes in util.
Diffstat (limited to 'search')
-rw-r--r--search/arity.hh8
-rw-r--r--search/context.hh10
-rw-r--r--search/edge.hh64
-rw-r--r--search/edge_generator.cc2
-rw-r--r--search/edge_generator.hh4
-rw-r--r--search/final.hh40
-rw-r--r--search/header.hh57
-rw-r--r--search/vertex.cc4
-rw-r--r--search/vertex.hh33
-rw-r--r--search/vertex_generator.cc22
10 files changed, 126 insertions, 118 deletions
diff --git a/search/arity.hh b/search/arity.hh
deleted file mode 100644
index 09c2c671d..000000000
--- a/search/arity.hh
+++ /dev/null
@@ -1,8 +0,0 @@
-#ifndef SEARCH_ARITY__
-#define SEARCH_ARITY__
-namespace search {
-
-const unsigned int kMaxArity = 2;
-
-} // namespace search
-#endif // SEARCH_ARITY__
diff --git a/search/context.hh b/search/context.hh
index 27940053b..62163144f 100644
--- a/search/context.hh
+++ b/search/context.hh
@@ -7,6 +7,7 @@
#include "search/types.hh"
#include "search/vertex.hh"
#include "util/exception.hh"
+#include "util/pool.hh"
#include <boost/pool/object_pool.hpp>
#include <boost/ptr_container/ptr_vector.hpp>
@@ -21,10 +22,8 @@ class ContextBase {
public:
explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {}
- Final *NewFinal() {
- Final *ret = final_pool_.construct();
- assert(ret);
- return ret;
+ util::Pool &FinalPool() {
+ return final_pool_;
}
VertexNode *NewVertexNode() {
@@ -42,7 +41,8 @@ class ContextBase {
const Weights &GetWeights() const { return weights_; }
private:
- boost::object_pool<Final> final_pool_;
+ util::Pool final_pool_;
+
boost::object_pool<VertexNode> vertex_node_pool_;
unsigned int pop_limit_;
diff --git a/search/edge.hh b/search/edge.hh
index d92578761..187904bf9 100644
--- a/search/edge.hh
+++ b/search/edge.hh
@@ -2,6 +2,7 @@
#define SEARCH_EDGE__
#include "lm/state.hh"
+#include "search/header.hh"
#include "search/types.hh"
#include "search/vertex.hh"
#include "util/pool.hh"
@@ -13,76 +14,39 @@
namespace search {
// Copyable, but the copy will be shallow.
-class PartialEdge {
+class PartialEdge : public Header {
public:
// Allow default construction for STL.
- PartialEdge() : base_(NULL) {}
- bool Valid() const { return base_; }
+ PartialEdge() {}
- Score GetScore() const {
- return *reinterpret_cast<const float*>(base_);
- }
- void SetScore(Score to) {
- *reinterpret_cast<float*>(base_) = to;
- }
- bool operator<(const PartialEdge &other) const {
- return GetScore() < other.GetScore();
- }
-
- Arity GetArity() const {
- return *reinterpret_cast<const Arity*>(base_ + sizeof(Score));
- }
-
- Note GetNote() const {
- return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity));
- }
- void SetNote(Note to) {
- *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;
- }
+ PartialEdge(util::Pool &pool, Arity arity)
+ : Header(pool.Allocate(Size(arity, arity + 1)), arity) {}
+
+ PartialEdge(util::Pool &pool, Arity arity, Arity chart_states)
+ : Header(pool.Allocate(Size(arity, chart_states)), arity) {}
// Non-terminals
const PartialVertex *NT() const {
- return reinterpret_cast<const PartialVertex*>(base_ + kHeaderSize);
+ return reinterpret_cast<const PartialVertex*>(After());
}
PartialVertex *NT() {
- return reinterpret_cast<PartialVertex*>(base_ + kHeaderSize);
+ return reinterpret_cast<PartialVertex*>(After());
}
const lm::ngram::ChartState &CompletedState() const {
return *Between();
}
const lm::ngram::ChartState *Between() const {
- return reinterpret_cast<const lm::ngram::ChartState*>(base_ + kHeaderSize + GetArity() * sizeof(PartialVertex));
+ return reinterpret_cast<const lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex));
}
lm::ngram::ChartState *Between() {
- return reinterpret_cast<lm::ngram::ChartState*>(base_ + kHeaderSize + GetArity() * sizeof(PartialVertex));
+ return reinterpret_cast<lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex));
}
private:
- static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note);
-
- friend class PartialEdgePool;
- PartialEdge(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {
- *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;
+ static std::size_t Size(Arity arity, Arity chart_states) {
+ return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState);
}
-
- uint8_t *base_;
-};
-
-class PartialEdgePool {
- public:
- PartialEdge Allocate(Arity arity, Arity chart_states) {
- return PartialEdge(
- pool_.Allocate(PartialEdge::kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState)),
- arity);
- }
-
- PartialEdge Allocate(Arity arity) {
- return Allocate(arity, arity + 1);
- }
-
- private:
- util::Pool pool_;
};
diff --git a/search/edge_generator.cc b/search/edge_generator.cc
index baa91ed1b..260159b1f 100644
--- a/search/edge_generator.cc
+++ b/search/edge_generator.cc
@@ -75,7 +75,7 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
PartialVertex old_value(top_nt[victim]);
PartialVertex alternate_changed;
if (top_nt[victim].Split(alternate_changed)) {
- PartialEdge alternate = partial_edge_pool_.Allocate(arity, incomplete + 1);
+ PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1);
alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound());
alternate.SetNote(top.GetNote());
diff --git a/search/edge_generator.hh b/search/edge_generator.hh
index 43f257929..582c78b7b 100644
--- a/search/edge_generator.hh
+++ b/search/edge_generator.hh
@@ -22,7 +22,7 @@ class EdgeGenerator {
EdgeGenerator() {}
PartialEdge AllocateEdge(Arity arity) {
- return partial_edge_pool_.Allocate(arity);
+ return PartialEdge(partial_edge_pool_, arity);
}
void AddEdge(PartialEdge edge) {
@@ -47,7 +47,7 @@ class EdgeGenerator {
}
private:
- PartialEdgePool partial_edge_pool_;
+ util::Pool partial_edge_pool_;
typedef std::priority_queue<PartialEdge> Generate;
Generate generate_;
diff --git a/search/final.hh b/search/final.hh
index fc86e0f98..50e62cf2e 100644
--- a/search/final.hh
+++ b/search/final.hh
@@ -1,34 +1,34 @@
#ifndef SEARCH_FINAL__
#define SEARCH_FINAL__
-#include "search/arity.hh"
-#include "search/note.hh"
-#include "search/types.hh"
-
-#include <boost/array.hpp>
+#include "search/header.hh"
+#include "util/pool.hh"
namespace search {
-class Final {
+// A full hypothesis with pointers to children.
+class Final : public Header {
public:
- const Final **Reset(Score bound, Note note) {
- bound_ = bound;
- note_ = note;
- return children_;
- }
-
- const Final *const *Children() const { return children_; }
+ Final() {}
- Note GetNote() const { return note_; }
+ Final(util::Pool &pool, Score score, Arity arity, Note note)
+ : Header(pool.Allocate(Size(arity)), arity) {
+ SetScore(score);
+ SetNote(note);
+ }
- Score Bound() const { return bound_; }
+ // These are arrays of length GetArity().
+ Final *Children() {
+ return reinterpret_cast<Final*>(After());
+ }
+ const Final *Children() const {
+ return reinterpret_cast<const Final*>(After());
+ }
private:
- Score bound_;
-
- Note note_;
-
- const Final *children_[2];
+ static std::size_t Size(Arity arity) {
+ return kHeaderSize + arity * sizeof(const Final);
+ }
};
} // namespace search
diff --git a/search/header.hh b/search/header.hh
new file mode 100644
index 000000000..25550dbed
--- /dev/null
+++ b/search/header.hh
@@ -0,0 +1,57 @@
+#ifndef SEARCH_HEADER__
+#define SEARCH_HEADER__
+
+// Header consisting of Score, Arity, and Note
+
+#include "search/note.hh"
+#include "search/types.hh"
+
+#include <stdint.h>
+
+namespace search {
+
+// Copying is shallow.
+class Header {
+ public:
+ bool Valid() const { return base_; }
+
+ Score GetScore() const {
+ return *reinterpret_cast<const float*>(base_);
+ }
+ void SetScore(Score to) {
+ *reinterpret_cast<float*>(base_) = to;
+ }
+ bool operator<(const Header &other) const {
+ return GetScore() < other.GetScore();
+ }
+
+ Arity GetArity() const {
+ return *reinterpret_cast<const Arity*>(base_ + sizeof(Score));
+ }
+
+ Note GetNote() const {
+ return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity));
+ }
+ void SetNote(Note to) {
+ *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;
+ }
+
+ protected:
+ Header() : base_(NULL) {}
+
+ Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {
+ *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;
+ }
+
+ static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note);
+
+ uint8_t *After() { return base_ + kHeaderSize; }
+ const uint8_t *After() const { return base_ + kHeaderSize; }
+
+ private:
+ uint8_t *base_;
+};
+
+} // namespace search
+
+#endif // SEARCH_HEADER__
diff --git a/search/vertex.cc b/search/vertex.cc
index ed3631352..11f4631fa 100644
--- a/search/vertex.cc
+++ b/search/vertex.cc
@@ -21,9 +21,9 @@ struct GreaterByBound : public std::binary_function<const VertexNode *, const Ve
void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
if (Complete()) {
- assert(end_);
+ assert(end_.Valid());
assert(extend_.empty());
- bound_ = end_->Bound();
+ bound_ = end_.GetScore();
return;
}
if (extend_.size() == 1 && parent_ptr) {
diff --git a/search/vertex.hh b/search/vertex.hh
index 2c2e46d3c..52bc1dfe7 100644
--- a/search/vertex.hh
+++ b/search/vertex.hh
@@ -18,7 +18,7 @@ class ContextBase;
class VertexNode {
public:
- VertexNode() : end_(NULL) {}
+ VertexNode() {}
void InitRoot() {
extend_.clear();
@@ -26,7 +26,7 @@ class VertexNode {
state_.left.length = 0;
state_.right.length = 0;
right_full_ = false;
- end_ = NULL;
+ end_ = Final();
}
lm::ngram::ChartState &MutableState() { return state_; }
@@ -36,19 +36,20 @@ class VertexNode {
extend_.push_back(next);
}
- void SetEnd(Final *end) { end_ = end; }
+ void SetEnd(Final end) {
+ assert(!end_.Valid());
+ end_ = end;
+ }
- Final &MutableEnd() { return *end_; }
-
void SortAndSet(ContextBase &context, VertexNode **parent_pointer);
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
- return !end_ && extend_.empty();
+ return !end_.Valid() && extend_.empty();
}
bool Complete() const {
- return end_;
+ return end_.Valid();
}
const lm::ngram::ChartState &State() const { return state_; }
@@ -62,8 +63,8 @@ class VertexNode {
return state_.left.length + state_.right.length;
}
- // May be NULL.
- const Final *End() const { return end_; }
+ // Will be invalid unless this is a leaf.
+ const Final End() const { return end_; }
const VertexNode &operator[](size_t index) const {
return *extend_[index];
@@ -80,7 +81,7 @@ class VertexNode {
bool right_full_;
Score bound_;
- Final *end_;
+ Final end_;
};
class PartialVertex {
@@ -96,7 +97,7 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
- Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); }
+ Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); }
unsigned char Length() const { return back_->Length(); }
@@ -120,8 +121,8 @@ class PartialVertex {
return ret;
}
- const Final &End() const {
- return *back_->End();
+ const Final End() const {
+ return back_->End();
}
private:
@@ -135,16 +136,16 @@ class Vertex {
PartialVertex RootPartial() const { return PartialVertex(root_); }
- const Final *BestChild() const {
+ const Final BestChild() const {
PartialVertex top(RootPartial());
if (top.Empty()) {
- return NULL;
+ return Final();
} else {
PartialVertex continuation;
while (!top.Complete()) {
top.Split(continuation);
}
- return &top.End();
+ return top.End();
}
}
diff --git a/search/vertex_generator.cc b/search/vertex_generator.cc
index 4113ae1d9..0945fe55d 100644
--- a/search/vertex_generator.cc
+++ b/search/vertex_generator.cc
@@ -16,15 +16,6 @@ namespace {
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
-void FillFinal(PartialEdge partial, Final &out) {
- const Final **final_out = out.Reset(partial.GetScore(), partial.GetNote());
- const PartialVertex *part = partial.NT();
- const PartialVertex *const part_end_loop = part + partial.GetArity();
- for (; part != part_end_loop; ++part, ++final_out) {
- *final_out = &part->End();
- }
-}
-
// Parallel structure to VertexNode.
struct Trie {
Trie() : under(NULL) {}
@@ -49,11 +40,14 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n
}
void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) {
- Final *final = context.NewFinal();
- FillFinal(partial, *final);
- VertexNode &node = *starter.under;
- assert(!node.End());
- node.SetEnd(final);
+ Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote());
+ Final *child_out = final.Children();
+ const PartialVertex *part = partial.NT();
+ const PartialVertex *const part_end_loop = part + partial.GetArity();
+ for (; part != part_end_loop; ++part, ++child_out)
+ *child_out = part->End();
+
+ starter.under->SetEnd(final);
}
void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {