diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-10-18 21:54:38 +0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-10-18 21:54:38 +0400 |
commit | 39664ac8f48c44fa4a81ff63de7ff2d31a956b5e (patch) | |
tree | f0928f6ff3f63bb1f68ca0b08009173a9968d85d /search | |
parent | 05e91def21d90d848b5f804f2fc9bd974bc496df (diff) |
Support arbitrary arity
Also, some windows fixes in util.
Diffstat (limited to 'search')
-rw-r--r-- | search/arity.hh | 8 | ||||
-rw-r--r-- | search/context.hh | 10 | ||||
-rw-r--r-- | search/edge.hh | 64 | ||||
-rw-r--r-- | search/edge_generator.cc | 2 | ||||
-rw-r--r-- | search/edge_generator.hh | 4 | ||||
-rw-r--r-- | search/final.hh | 40 | ||||
-rw-r--r-- | search/header.hh | 57 | ||||
-rw-r--r-- | search/vertex.cc | 4 | ||||
-rw-r--r-- | search/vertex.hh | 33 | ||||
-rw-r--r-- | search/vertex_generator.cc | 22 |
10 files changed, 126 insertions, 118 deletions
diff --git a/search/arity.hh b/search/arity.hh deleted file mode 100644 index 09c2c671d..000000000 --- a/search/arity.hh +++ /dev/null @@ -1,8 +0,0 @@ -#ifndef SEARCH_ARITY__ -#define SEARCH_ARITY__ -namespace search { - -const unsigned int kMaxArity = 2; - -} // namespace search -#endif // SEARCH_ARITY__ diff --git a/search/context.hh b/search/context.hh index 27940053b..62163144f 100644 --- a/search/context.hh +++ b/search/context.hh @@ -7,6 +7,7 @@ #include "search/types.hh" #include "search/vertex.hh" #include "util/exception.hh" +#include "util/pool.hh" #include <boost/pool/object_pool.hpp> #include <boost/ptr_container/ptr_vector.hpp> @@ -21,10 +22,8 @@ class ContextBase { public: explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - Final *NewFinal() { - Final *ret = final_pool_.construct(); - assert(ret); - return ret; + util::Pool &FinalPool() { + return final_pool_; } VertexNode *NewVertexNode() { @@ -42,7 +41,8 @@ class ContextBase { const Weights &GetWeights() const { return weights_; } private: - boost::object_pool<Final> final_pool_; + util::Pool final_pool_; + boost::object_pool<VertexNode> vertex_node_pool_; unsigned int pop_limit_; diff --git a/search/edge.hh b/search/edge.hh index d92578761..187904bf9 100644 --- a/search/edge.hh +++ b/search/edge.hh @@ -2,6 +2,7 @@ #define SEARCH_EDGE__ #include "lm/state.hh" +#include "search/header.hh" #include "search/types.hh" #include "search/vertex.hh" #include "util/pool.hh" @@ -13,76 +14,39 @@ namespace search { // Copyable, but the copy will be shallow. -class PartialEdge { +class PartialEdge : public Header { public: // Allow default construction for STL. - PartialEdge() : base_(NULL) {} - bool Valid() const { return base_; } + PartialEdge() {} - Score GetScore() const { - return *reinterpret_cast<const float*>(base_); - } - void SetScore(Score to) { - *reinterpret_cast<float*>(base_) = to; - } - bool operator<(const PartialEdge &other) const { - return GetScore() < other.GetScore(); - } - - Arity GetArity() const { - return *reinterpret_cast<const Arity*>(base_ + sizeof(Score)); - } - - Note GetNote() const { - return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity)); - } - void SetNote(Note to) { - *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to; - } + PartialEdge(util::Pool &pool, Arity arity) + : Header(pool.Allocate(Size(arity, arity + 1)), arity) {} + + PartialEdge(util::Pool &pool, Arity arity, Arity chart_states) + : Header(pool.Allocate(Size(arity, chart_states)), arity) {} // Non-terminals const PartialVertex *NT() const { - return reinterpret_cast<const PartialVertex*>(base_ + kHeaderSize); + return reinterpret_cast<const PartialVertex*>(After()); } PartialVertex *NT() { - return reinterpret_cast<PartialVertex*>(base_ + kHeaderSize); + return reinterpret_cast<PartialVertex*>(After()); } const lm::ngram::ChartState &CompletedState() const { return *Between(); } const lm::ngram::ChartState *Between() const { - return reinterpret_cast<const lm::ngram::ChartState*>(base_ + kHeaderSize + GetArity() * sizeof(PartialVertex)); + return reinterpret_cast<const lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex)); } lm::ngram::ChartState *Between() { - return reinterpret_cast<lm::ngram::ChartState*>(base_ + kHeaderSize + GetArity() * sizeof(PartialVertex)); + return reinterpret_cast<lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex)); } private: - static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); - - friend class PartialEdgePool; - PartialEdge(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) { - *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity; + static std::size_t Size(Arity arity, Arity chart_states) { + return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState); } - - uint8_t *base_; -}; - -class PartialEdgePool { - public: - PartialEdge Allocate(Arity arity, Arity chart_states) { - return PartialEdge( - pool_.Allocate(PartialEdge::kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState)), - arity); - } - - PartialEdge Allocate(Arity arity) { - return Allocate(arity, arity + 1); - } - - private: - util::Pool pool_; }; diff --git a/search/edge_generator.cc b/search/edge_generator.cc index baa91ed1b..260159b1f 100644 --- a/search/edge_generator.cc +++ b/search/edge_generator.cc @@ -75,7 +75,7 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) { PartialVertex old_value(top_nt[victim]); PartialVertex alternate_changed; if (top_nt[victim].Split(alternate_changed)) { - PartialEdge alternate = partial_edge_pool_.Allocate(arity, incomplete + 1); + PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1); alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound()); alternate.SetNote(top.GetNote()); diff --git a/search/edge_generator.hh b/search/edge_generator.hh index 43f257929..582c78b7b 100644 --- a/search/edge_generator.hh +++ b/search/edge_generator.hh @@ -22,7 +22,7 @@ class EdgeGenerator { EdgeGenerator() {} PartialEdge AllocateEdge(Arity arity) { - return partial_edge_pool_.Allocate(arity); + return PartialEdge(partial_edge_pool_, arity); } void AddEdge(PartialEdge edge) { @@ -47,7 +47,7 @@ class EdgeGenerator { } private: - PartialEdgePool partial_edge_pool_; + util::Pool partial_edge_pool_; typedef std::priority_queue<PartialEdge> Generate; Generate generate_; diff --git a/search/final.hh b/search/final.hh index fc86e0f98..50e62cf2e 100644 --- a/search/final.hh +++ b/search/final.hh @@ -1,34 +1,34 @@ #ifndef SEARCH_FINAL__ #define SEARCH_FINAL__ -#include "search/arity.hh" -#include "search/note.hh" -#include "search/types.hh" - -#include <boost/array.hpp> +#include "search/header.hh" +#include "util/pool.hh" namespace search { -class Final { +// A full hypothesis with pointers to children. +class Final : public Header { public: - const Final **Reset(Score bound, Note note) { - bound_ = bound; - note_ = note; - return children_; - } - - const Final *const *Children() const { return children_; } + Final() {} - Note GetNote() const { return note_; } + Final(util::Pool &pool, Score score, Arity arity, Note note) + : Header(pool.Allocate(Size(arity)), arity) { + SetScore(score); + SetNote(note); + } - Score Bound() const { return bound_; } + // These are arrays of length GetArity(). + Final *Children() { + return reinterpret_cast<Final*>(After()); + } + const Final *Children() const { + return reinterpret_cast<const Final*>(After()); + } private: - Score bound_; - - Note note_; - - const Final *children_[2]; + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Final); + } }; } // namespace search diff --git a/search/header.hh b/search/header.hh new file mode 100644 index 000000000..25550dbed --- /dev/null +++ b/search/header.hh @@ -0,0 +1,57 @@ +#ifndef SEARCH_HEADER__ +#define SEARCH_HEADER__ + +// Header consisting of Score, Arity, and Note + +#include "search/note.hh" +#include "search/types.hh" + +#include <stdint.h> + +namespace search { + +// Copying is shallow. +class Header { + public: + bool Valid() const { return base_; } + + Score GetScore() const { + return *reinterpret_cast<const float*>(base_); + } + void SetScore(Score to) { + *reinterpret_cast<float*>(base_) = to; + } + bool operator<(const Header &other) const { + return GetScore() < other.GetScore(); + } + + Arity GetArity() const { + return *reinterpret_cast<const Arity*>(base_ + sizeof(Score)); + } + + Note GetNote() const { + return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity)); + } + void SetNote(Note to) { + *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to; + } + + protected: + Header() : base_(NULL) {} + + Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) { + *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity; + } + + static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); + + uint8_t *After() { return base_ + kHeaderSize; } + const uint8_t *After() const { return base_ + kHeaderSize; } + + private: + uint8_t *base_; +}; + +} // namespace search + +#endif // SEARCH_HEADER__ diff --git a/search/vertex.cc b/search/vertex.cc index ed3631352..11f4631fa 100644 --- a/search/vertex.cc +++ b/search/vertex.cc @@ -21,9 +21,9 @@ struct GreaterByBound : public std::binary_function<const VertexNode *, const Ve void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { if (Complete()) { - assert(end_); + assert(end_.Valid()); assert(extend_.empty()); - bound_ = end_->Bound(); + bound_ = end_.GetScore(); return; } if (extend_.size() == 1 && parent_ptr) { diff --git a/search/vertex.hh b/search/vertex.hh index 2c2e46d3c..52bc1dfe7 100644 --- a/search/vertex.hh +++ b/search/vertex.hh @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() : end_(NULL) {} + VertexNode() {} void InitRoot() { extend_.clear(); @@ -26,7 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - end_ = NULL; + end_ = Final(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -36,19 +36,20 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final *end) { end_ = end; } + void SetEnd(Final end) { + assert(!end_.Valid()); + end_ = end; + } - Final &MutableEnd() { return *end_; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_ && extend_.empty(); + return !end_.Valid() && extend_.empty(); } bool Complete() const { - return end_; + return end_.Valid(); } const lm::ngram::ChartState &State() const { return state_; } @@ -62,8 +63,8 @@ class VertexNode { return state_.left.length + state_.right.length; } - // May be NULL. - const Final *End() const { return end_; } + // Will be invalid unless this is a leaf. + const Final End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -80,7 +81,7 @@ class VertexNode { bool right_full_; Score bound_; - Final *end_; + Final end_; }; class PartialVertex { @@ -96,7 +97,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -120,8 +121,8 @@ class PartialVertex { return ret; } - const Final &End() const { - return *back_->End(); + const Final End() const { + return back_->End(); } private: @@ -135,16 +136,16 @@ class Vertex { PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final *BestChild() const { + const Final BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return NULL; + return Final(); } else { PartialVertex continuation; while (!top.Complete()) { top.Split(continuation); } - return &top.End(); + return top.End(); } } diff --git a/search/vertex_generator.cc b/search/vertex_generator.cc index 4113ae1d9..0945fe55d 100644 --- a/search/vertex_generator.cc +++ b/search/vertex_generator.cc @@ -16,15 +16,6 @@ namespace { const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); -void FillFinal(PartialEdge partial, Final &out) { - const Final **final_out = out.Reset(partial.GetScore(), partial.GetNote()); - const PartialVertex *part = partial.NT(); - const PartialVertex *const part_end_loop = part + partial.GetArity(); - for (; part != part_end_loop; ++part, ++final_out) { - *final_out = &part->End(); - } -} - // Parallel structure to VertexNode. struct Trie { Trie() : under(NULL) {} @@ -49,11 +40,14 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n } void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { - Final *final = context.NewFinal(); - FillFinal(partial, *final); - VertexNode &node = *starter.under; - assert(!node.End()); - node.SetEnd(final); + Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); + Final *child_out = final.Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = part->End(); + + starter.under->SetEnd(final); } void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { |