diff options
author | Phil Williams <philip.williams@mac.com> | 2015-02-02 20:00:42 +0300 |
---|---|---|
committer | Phil Williams <philip.williams@mac.com> | 2015-02-02 20:07:54 +0300 |
commit | ac8f01bc3d2eae44c945d2da8d5fcd7a951fe009 (patch) | |
tree | 0454f784b8e787ce0f1bb46cee50be7b4c47ba01 /moses/Syntax | |
parent | c8ad84fa55f1f136d99fc4ff712e2aa421867c5d (diff) |
Partial merge of t2s branch (implements t2s and f2s algorithms)
Implements search algorithms 7, 8, and 9:
-search-algorithm 7
tree-to-string (STSG-based, currently a special-case of forest-to-string)
-search-algorithm 8
tree-to-string (SCFG-based)
-search-algorithm 9
forest-to-string (STSG-based)
Diffstat (limited to 'moses/Syntax')
45 files changed, 3502 insertions, 3 deletions
diff --git a/moses/Syntax/F2S/DerivationWriter.cpp b/moses/Syntax/F2S/DerivationWriter.cpp new file mode 100644 index 000000000..efa3c3d47 --- /dev/null +++ b/moses/Syntax/F2S/DerivationWriter.cpp @@ -0,0 +1,101 @@ +#include "DerivationWriter.h" + +#include "moses/Factor.h" +#include "moses/Syntax/PVertex.h" +#include "moses/Syntax/SHyperedge.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +// 1-best version. +void DerivationWriter::Write(const SHyperedge ­peredge, + std::size_t sentNum, std::ostream &out) +{ + WriteLine(shyperedge, sentNum, out); + for (std::size_t i = 0; i < shyperedge.tail.size(); ++i) { + const SVertex &pred = *(shyperedge.tail[i]); + if (pred.best) { + Write(*pred.best, sentNum, out); + } + } +} + +// k-best derivation. +void DerivationWriter::Write(const KBestExtractor::Derivation &derivation, + std::size_t sentNum, std::ostream &out) +{ + WriteLine(derivation.edge->shyperedge, sentNum, out); + for (std::size_t i = 0; i < derivation.subderivations.size(); ++i) { + Write(*(derivation.subderivations[i]), sentNum, out); + } +} + +void DerivationWriter::WriteLine(const SHyperedge ­peredge, + std::size_t sentNum, std::ostream &out) +{ + // Sentence number. + out << sentNum << " |||"; + + // Source LHS. + out << " "; + WriteSymbol(shyperedge.head->pvertex->symbol, out); + out << " ->"; + + // Source RHS symbols. + for (std::size_t i = 0; i < shyperedge.tail.size(); ++i) { + const Word &symbol = shyperedge.tail[i]->pvertex->symbol; + out << " "; + WriteSymbol(symbol, out); + } + out << " |||"; + + // Target RHS. + out << " [X] ->"; + + // Target RHS symbols. + const TargetPhrase &phrase = *(shyperedge.label.translation); + for (std::size_t i = 0; i < phrase.GetSize(); ++i) { + const Word &symbol = phrase.GetWord(i); + out << " "; + if (symbol.IsNonTerminal()) { + out << "[X]"; + } else { + WriteSymbol(symbol, out); + } + } + out << " |||"; + + // Non-terminal alignments + const AlignmentInfo &a = phrase.GetAlignNonTerm(); + for (AlignmentInfo::const_iterator p = a.begin(); p != a.end(); ++p) { + out << " " << p->first << "-" << p->second; + } + out << " |||"; + + // Spans covered by source RHS symbols. + for (std::size_t i = 0; i < shyperedge.tail.size(); ++i) { + const SVertex *child = shyperedge.tail[i]; + const WordsRange &span = child->pvertex->span; + out << " " << span.GetStartPos() << ".." << span.GetEndPos(); + } + + out << "\n"; +} + +void DerivationWriter::WriteSymbol(const Word &symbol, std::ostream &out) +{ + const Factor *f = symbol[0]; + if (symbol.IsNonTerminal()) { + out << "[" << f->GetString() << "]"; + } else { + out << f->GetString(); + } +} + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/DerivationWriter.h b/moses/Syntax/F2S/DerivationWriter.h new file mode 100644 index 000000000..76ca14313 --- /dev/null +++ b/moses/Syntax/F2S/DerivationWriter.h @@ -0,0 +1,36 @@ +#pragma once + +#include <ostream> + +#include "moses/Syntax/KBestExtractor.h" +#include "moses/Word.h" + +namespace Moses +{ +namespace Syntax +{ +struct SHyperedge; + +namespace F2S +{ + +// Writes a string representation of a derivation to a std::ostream. This is +// used by the -translation-details / -T option. +// TODO Merge this with S2T::DerivationWriter. +class DerivationWriter +{ + public: + // 1-best version. + static void Write(const SHyperedge&, std::size_t, std::ostream &); + + // k-best version. + static void Write(const KBestExtractor::Derivation &, std::size_t, + std::ostream &); + private: + static void WriteLine(const SHyperedge &, std::size_t, std::ostream &); + static void WriteSymbol(const Word &, std::ostream &); +}; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/Forest.cpp b/moses/Syntax/F2S/Forest.cpp new file mode 100644 index 000000000..e130d5ec2 --- /dev/null +++ b/moses/Syntax/F2S/Forest.cpp @@ -0,0 +1,34 @@ +#include "Forest.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +Forest::~Forest() +{ + Clear(); +} + +void Forest::Clear() +{ + for (std::vector<Vertex *>::iterator p = vertices.begin(); + p != vertices.end(); ++p) { + delete *p; + } + vertices.clear(); +} + +Forest::Vertex::~Vertex() +{ + for (std::vector<Hyperedge *>::iterator p = incoming.begin(); + p != incoming.end(); ++p) { + delete *p; + } +} + +} // F2S +} // Syntax +} // Moses diff --git a/moses/Syntax/F2S/GlueRuleSynthesizer.cpp b/moses/Syntax/F2S/GlueRuleSynthesizer.cpp new file mode 100644 index 000000000..7c7d35beb --- /dev/null +++ b/moses/Syntax/F2S/GlueRuleSynthesizer.cpp @@ -0,0 +1,85 @@ +#include "GlueRuleSynthesizer.h" + +#include <sstream> + +#include "moses/FF/UnknownWordPenaltyProducer.h" +#include "moses/StaticData.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +GlueRuleSynthesizer::GlueRuleSynthesizer(HyperTree &trie) + : m_hyperTree(trie) +{ + const std::vector<FactorType> &inputFactorOrder = + StaticData::Instance().GetInputFactorOrder(); + Word *lhs = NULL; + m_dummySourcePhrase.CreateFromString(Input, inputFactorOrder, "hello", &lhs); + delete lhs; +} + +void GlueRuleSynthesizer::SynthesizeRule(const Forest::Hyperedge &e) +{ + HyperPath source; + SynthesizeHyperPath(e, source); + TargetPhrase *tp = SynthesizeTargetPhrase(e); + TargetPhraseCollection &tpc = GetOrCreateTargetPhraseCollection(m_hyperTree, + source); + tpc.Add(tp); +} + +void GlueRuleSynthesizer::SynthesizeHyperPath(const Forest::Hyperedge &e, + HyperPath &path) +{ + path.nodeSeqs.clear(); + path.nodeSeqs.resize(2); + path.nodeSeqs[0].push_back(e.head->pvertex.symbol[0]->GetId()); + for (std::vector<Forest::Vertex*>::const_iterator p = e.tail.begin(); + p != e.tail.end(); ++p) { + const Forest::Vertex &child = **p; + path.nodeSeqs[1].push_back(child.pvertex.symbol[0]->GetId()); + } +} + +TargetPhrase *GlueRuleSynthesizer::SynthesizeTargetPhrase( + const Forest::Hyperedge &e) +{ + const StaticData &staticData = StaticData::Instance(); + + const UnknownWordPenaltyProducer &unknownWordPenaltyProducer = + UnknownWordPenaltyProducer::Instance(); + + TargetPhrase *targetPhrase = new TargetPhrase(); + + std::ostringstream alignmentSS; + for (std::size_t i = 0; i < e.tail.size(); ++i) { + const Word &symbol = e.tail[i]->pvertex.symbol; + if (symbol.IsNonTerminal()) { + targetPhrase->AddWord(staticData.GetOutputDefaultNonTerminal()); + } else { + // TODO Check this + Word &targetWord = targetPhrase->AddWord(); + targetWord.CreateUnknownWord(symbol); + } + alignmentSS << i << "-" << i << " "; + } + + // Assign the lowest possible score so that glue rules are only used when + // absolutely required. + float score = LOWEST_SCORE; + targetPhrase->GetScoreBreakdown().Assign(&unknownWordPenaltyProducer, score); + targetPhrase->EvaluateInIsolation(m_dummySourcePhrase); + Word *targetLhs = new Word(staticData.GetOutputDefaultNonTerminal()); + targetPhrase->SetTargetLHS(targetLhs); + targetPhrase->SetAlignmentInfo(alignmentSS.str()); + + return targetPhrase; +} + +} // F2S +} // Syntax +} // Moses diff --git a/moses/Syntax/F2S/GlueRuleSynthesizer.h b/moses/Syntax/F2S/GlueRuleSynthesizer.h new file mode 100644 index 000000000..77b454f87 --- /dev/null +++ b/moses/Syntax/F2S/GlueRuleSynthesizer.h @@ -0,0 +1,37 @@ +#pragma once + +#include "moses/Phrase.h" +#include "moses/TargetPhrase.h" + +#include "HyperTree.h" +#include "HyperTreeCreator.h" +#include "Forest.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +class GlueRuleSynthesizer : public HyperTreeCreator +{ + public: + GlueRuleSynthesizer(HyperTree &); + + // Synthesize the minimal, monotone rule that can be applied to the given + // hyperedge and add it to the rule trie. + void SynthesizeRule(const Forest::Hyperedge &); + + private: + void SynthesizeHyperPath(const Forest::Hyperedge &, HyperPath &); + + TargetPhrase *SynthesizeTargetPhrase(const Forest::Hyperedge &); + + HyperTree &m_hyperTree; + Phrase m_dummySourcePhrase; +}; + +} // F2S +} // Syntax +} // Moses diff --git a/moses/Syntax/F2S/HyperPath.cpp b/moses/Syntax/F2S/HyperPath.cpp new file mode 100644 index 000000000..e60b4f411 --- /dev/null +++ b/moses/Syntax/F2S/HyperPath.cpp @@ -0,0 +1,20 @@ +#include "HyperPath.h" + +#include <limits> + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +const std::size_t HyperPath::kEpsilon = + std::numeric_limits<std::size_t>::max()-1; + +const std::size_t HyperPath::kComma = + std::numeric_limits<std::size_t>::max()-2; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/HyperPath.h b/moses/Syntax/F2S/HyperPath.h new file mode 100644 index 000000000..4a11990e8 --- /dev/null +++ b/moses/Syntax/F2S/HyperPath.h @@ -0,0 +1,35 @@ +#pragma once + +#include <vector> + +#include "moses/Factor.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +// A HyperPath for representing the source-side tree fragment of a +// tree-to-string rule. See this paper: +// +// Hui Zhang, Min Zhang, Haizhou Li, and Chew Lim Tan +// "Fast Translation Rule Matching for Syntax-based Statistical Machine +// Translation" +// In proceedings of EMNLP 2009 +// +struct HyperPath +{ + public: + typedef std::vector<std::size_t> NodeSeq; + + static const std::size_t kEpsilon; + static const std::size_t kComma; + + std::vector<NodeSeq> nodeSeqs; +}; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/HyperPathLoader.cpp b/moses/Syntax/F2S/HyperPathLoader.cpp new file mode 100644 index 000000000..e4f22ae07 --- /dev/null +++ b/moses/Syntax/F2S/HyperPathLoader.cpp @@ -0,0 +1,172 @@ +#include "HyperPathLoader.h" + +#include "TreeFragmentTokenizer.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +HyperPathLoader::HyperPathLoader(FactorDirection direction, + const std::vector<FactorType> &factorOrder) + : m_direction(direction) + , m_factorOrder(factorOrder) +{ +} + +void HyperPathLoader::Load(const StringPiece &s, HyperPath &path) +{ + path.nodeSeqs.clear(); + // Tokenize the string and store the tokens in m_tokenSeq. + m_tokenSeq.clear(); + for (TreeFragmentTokenizer p(s); p != TreeFragmentTokenizer(); ++p) { + m_tokenSeq.push_back(*p); + } + // Determine the height of the tree fragment. + int height = DetermineHeight(); + // Ensure path contains the correct number of elements. + path.nodeSeqs.resize(height+1); + // Generate the fragment's NodeTuple sequence and store it in m_nodeTupleSeq. + GenerateNodeTupleSeq(height); + // Fill the HyperPath. + for (int depth = 0; depth <= height; ++depth) { + int prevParent = -1; +// TODO Generate one node tuple sequence for each depth instead of one +// TODO sequence that contains node tuples at every depth + for (std::vector<NodeTuple>::const_iterator p = m_nodeTupleSeq.begin(); + p != m_nodeTupleSeq.end(); ++p) { + const NodeTuple &tuple = *p; + if (tuple.depth != depth) { + continue; + } + if (prevParent != -1 && tuple.parent != prevParent) { + path.nodeSeqs[depth].push_back(HyperPath::kComma); + } + path.nodeSeqs[depth].push_back(tuple.symbol); + prevParent = tuple.parent; + } + } +} + +int HyperPathLoader::DetermineHeight() const +{ + int height = 0; + int maxHeight = 0; + std::size_t numTokens = m_tokenSeq.size(); + for (std::size_t i = 0; i < numTokens; ++i) { + if (m_tokenSeq[i].type == TreeFragmentToken_LSB) { + assert(i+2 < numTokens); + // Does this bracket indicate the start of a subtree or the start of + // a non-terminal leaf? + if (m_tokenSeq[i+2].type != TreeFragmentToken_RSB) { // It's a subtree. + maxHeight = std::max(++height, maxHeight); + } else { // It's a non-terminal leaf: jump to its end. + i += 2; + } + } else if (m_tokenSeq[i].type == TreeFragmentToken_RSB) { + --height; + } + } + return maxHeight; +} + +void HyperPathLoader::GenerateNodeTupleSeq(int height) +{ + m_nodeTupleSeq.clear(); + + // Initialize the stack of parent indices. + assert(m_parentStack.empty()); + m_parentStack.push(-1); + + // Initialize a temporary tuple that tracks the state as we iterate over + // the tree fragment tokens. + NodeTuple tuple; + tuple.index = -1; + tuple.parent = -1; + tuple.depth = -1; + tuple.symbol = HyperPath::kEpsilon; + + // Iterate over the tree fragment tokens. + std::size_t numTokens = m_tokenSeq.size(); + for (std::size_t i = 0; i < numTokens; ++i) { + if (m_tokenSeq[i].type == TreeFragmentToken_LSB) { + assert(i+2 < numTokens); + // Does this bracket indicate the start of a subtree or the start of + // a non-terminal leaf? + if (m_tokenSeq[i+2].type != TreeFragmentToken_RSB) { // It's a subtree. + ++tuple.index; + tuple.parent = m_parentStack.top(); + m_parentStack.push(tuple.index); + ++tuple.depth; + tuple.symbol = AddNonTerminalFactor(m_tokenSeq[++i].value)->GetId(); + m_nodeTupleSeq.push_back(tuple); + } else { // It's a non-terminal leaf. + ++tuple.index; + tuple.parent = m_parentStack.top(); + ++tuple.depth; + tuple.symbol = AddNonTerminalFactor(m_tokenSeq[++i].value)->GetId(); + m_nodeTupleSeq.push_back(tuple); + // Add virtual nodes if required. + if (tuple.depth < height) { + int origDepth = tuple.depth; + m_parentStack.push(tuple.index); + for (int depth = origDepth+1; depth <= height; ++depth) { + ++tuple.index; + tuple.parent = m_parentStack.top(); + m_parentStack.push(tuple.index); + tuple.depth = depth; + tuple.symbol = HyperPath::kEpsilon; + m_nodeTupleSeq.push_back(tuple); + } + for (int depth = origDepth; depth <= height; ++depth) { + m_parentStack.pop(); + } + tuple.depth = origDepth; + } + --tuple.depth; + // Skip over the closing bracket. + ++i; + } + } else if (m_tokenSeq[i].type == TreeFragmentToken_WORD) { + // Token i is a word that doesn't follow a bracket. This must be a + // terminal since all non-terminals are either non-leaves (which follow + // an opening bracket) or are enclosed in brackets. + ++tuple.index; + tuple.parent = m_parentStack.top(); + ++tuple.depth; + tuple.symbol = AddTerminalFactor(m_tokenSeq[i].value)->GetId(); + m_nodeTupleSeq.push_back(tuple); + // Add virtual nodes if required. + if (m_tokenSeq[i+1].type == TreeFragmentToken_RSB && + tuple.depth < height) { + int origDepth = tuple.depth; + m_parentStack.push(tuple.index); + for (int depth = origDepth+1; depth <= height; ++depth) { + ++tuple.index; + tuple.parent = m_parentStack.top(); + m_parentStack.push(tuple.index); + tuple.depth = depth; + tuple.symbol = HyperPath::kEpsilon; + m_nodeTupleSeq.push_back(tuple); + } + for (int depth = origDepth; depth <= height; ++depth) { + m_parentStack.pop(); + } + tuple.depth = origDepth; + } + --tuple.depth; + } else if (m_tokenSeq[i].type == TreeFragmentToken_RSB) { + m_parentStack.pop(); + --tuple.depth; + } + } + + // Remove the -1 parent index. + m_parentStack.pop(); +} + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/HyperPathLoader.h b/moses/Syntax/F2S/HyperPathLoader.h new file mode 100644 index 000000000..27cd7c306 --- /dev/null +++ b/moses/Syntax/F2S/HyperPathLoader.h @@ -0,0 +1,70 @@ +#pragma once + +#include <stack> +#include <vector> + +#include "util/string_piece.hh" + +#include "moses/FactorCollection.h" +#include "moses/TypeDef.h" + +#include "HyperPath.h" +#include "TreeFragmentTokenizer.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +// Parses a string representation of a tree fragment, adding the terminals +// and non-terminals to FactorCollection::Instance() and building a +// HyperPath object. +// +// This class is designed to be used during rule table loading. Since every +// rule has a tree fragment on the source-side, Load() may be called millions +// of times. The algorithm therefore sacrifices readability for speed and +// shoehorns everything into two passes over the input token sequence. +// +class HyperPathLoader +{ + public: + HyperPathLoader(FactorDirection, const std::vector<FactorType> &); + + void Load(const StringPiece &, HyperPath &); + + private: + struct NodeTuple { + int index; // Preorder index of the node. + int parent; // Preorder index of the node's parent. + int depth; // Depth of the node. + std::size_t symbol; // Either the factor ID of a tree terminal/non-terminal + // or for virtual nodes, HyperPath::kEpsilon. + }; + + // Determine the height of the current tree fragment (stored in m_tokenSeq). + int DetermineHeight() const; + + // Generate the preorder sequence of NodeTuples for the current tree fragment, + // including virtual nodes. + void GenerateNodeTupleSeq(int height); + + const Factor *AddTerminalFactor(const StringPiece &s) { + return FactorCollection::Instance().AddFactor(s, false); + } + + const Factor *AddNonTerminalFactor(const StringPiece &s) { + return FactorCollection::Instance().AddFactor(s, true); + } + + FactorDirection m_direction; + const std::vector<FactorType> &m_factorOrder; + std::vector<TreeFragmentToken> m_tokenSeq; + std::vector<NodeTuple> m_nodeTupleSeq; + std::stack<int> m_parentStack; +}; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/HyperTree.cpp b/moses/Syntax/F2S/HyperTree.cpp new file mode 100644 index 000000000..cf28f275e --- /dev/null +++ b/moses/Syntax/F2S/HyperTree.cpp @@ -0,0 +1,70 @@ +#include "HyperTree.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +void HyperTree::Node::Prune(std::size_t tableLimit) +{ + // Recusively prune child nodes. + for (Map::iterator p = m_map.begin(); p != m_map.end(); ++p) { + p->second.Prune(tableLimit); + } + // Prune TargetPhraseCollection at this node. + m_targetPhraseCollection.Prune(true, tableLimit); +} + +void HyperTree::Node::Sort(std::size_t tableLimit) +{ + // Recusively sort child nodes. + for (Map::iterator p = m_map.begin(); p != m_map.end(); ++p) { + p->second.Sort(tableLimit); + } + // Sort TargetPhraseCollection at this node. + m_targetPhraseCollection.Sort(true, tableLimit); +} + +HyperTree::Node *HyperTree::Node::GetOrCreateChild( + const HyperPath::NodeSeq &nodeSeq) +{ + return &m_map[nodeSeq]; +} + +const HyperTree::Node *HyperTree::Node::GetChild( + const HyperPath::NodeSeq &nodeSeq) const +{ + Map::const_iterator p = m_map.find(nodeSeq); + return (p == m_map.end()) ? NULL : &p->second; +} + +TargetPhraseCollection &HyperTree::GetOrCreateTargetPhraseCollection( + const HyperPath &hyperPath) +{ + Node &node = GetOrCreateNode(hyperPath); + return node.GetTargetPhraseCollection(); +} + +HyperTree::Node &HyperTree::GetOrCreateNode(const HyperPath &hyperPath) +{ + const std::size_t height = hyperPath.nodeSeqs.size(); + Node *node = &m_root; + for (std::size_t i = 0; i < height; ++i) { + const HyperPath::NodeSeq &nodeSeq = hyperPath.nodeSeqs[i]; + node = node->GetOrCreateChild(nodeSeq); + } + return *node; +} + +void HyperTree::SortAndPrune(std::size_t tableLimit) +{ + if (tableLimit) { + m_root.Sort(tableLimit); + } +} + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/HyperTree.h b/moses/Syntax/F2S/HyperTree.h new file mode 100644 index 000000000..75706712f --- /dev/null +++ b/moses/Syntax/F2S/HyperTree.h @@ -0,0 +1,79 @@ +#pragma once + +#include <map> +#include <vector> + +#include <boost/unordered_map.hpp> + +#include "moses/Syntax/RuleTable.h" +#include "moses/TargetPhraseCollection.h" + +#include "HyperPath.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +// A HyperTree for representing a tree-to-string rule table. See this paper: +// +// Hui Zhang, Min Zhang, Haizhou Li, and Chew Lim Tan +// "Fast Translation Rule Matching for Syntax-based Statistical Machine +// Translation" +// In proceedings of EMNLP 2009 +// +class HyperTree : public RuleTable +{ + public: + class Node + { + public: + typedef boost::unordered_map<HyperPath::NodeSeq, Node> Map; + + bool IsLeaf() const { return m_map.empty(); } + + bool HasRules() const { return !m_targetPhraseCollection.IsEmpty(); } + + void Prune(std::size_t tableLimit); + void Sort(std::size_t tableLimit); + + Node *GetOrCreateChild(const HyperPath::NodeSeq &); + + const Node *GetChild(const HyperPath::NodeSeq &) const; + + const TargetPhraseCollection &GetTargetPhraseCollection() const { + return m_targetPhraseCollection; + } + + TargetPhraseCollection &GetTargetPhraseCollection() { + return m_targetPhraseCollection; + } + + const Map &GetMap() const { return m_map; } + + private: + Map m_map; + TargetPhraseCollection m_targetPhraseCollection; + }; + + HyperTree(const RuleTableFF *ff) : RuleTable(ff) {} + + const Node &GetRootNode() const { return m_root; } + + private: + friend class HyperTreeCreator; + + TargetPhraseCollection &GetOrCreateTargetPhraseCollection(const HyperPath &); + + Node &GetOrCreateNode(const HyperPath &); + + void SortAndPrune(std::size_t); + + Node m_root; +}; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/HyperTreeCreator.h b/moses/Syntax/F2S/HyperTreeCreator.h new file mode 100644 index 000000000..bbae6e5c7 --- /dev/null +++ b/moses/Syntax/F2S/HyperTreeCreator.h @@ -0,0 +1,32 @@ +#pragma once + +#include "HyperTree.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +// Base for classes that create a HyperTree (currently HyperTreeLoader and +// GlueRuleSynthesizer). HyperTreeCreator is a friend of HyperTree. +class HyperTreeCreator +{ + protected: + // Provide access to HyperTree's private SortAndPrune function. + void SortAndPrune(HyperTree &trie, std::size_t limit) { + trie.SortAndPrune(limit); + } + + // Provide access to HyperTree's private GetOrCreateTargetPhraseCollection + // function. + TargetPhraseCollection &GetOrCreateTargetPhraseCollection( + HyperTree &trie, const HyperPath &fragment) { + return trie.GetOrCreateTargetPhraseCollection(fragment); + } +}; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/HyperTreeLoader.cpp b/moses/Syntax/F2S/HyperTreeLoader.cpp new file mode 100644 index 000000000..8dcadef55 --- /dev/null +++ b/moses/Syntax/F2S/HyperTreeLoader.cpp @@ -0,0 +1,148 @@ +#include "HyperTreeLoader.h" + +#include <sys/stat.h> +#include <stdlib.h> + +#include <fstream> +#include <string> +#include <iterator> +#include <algorithm> +#include <iostream> + +#include "moses/FactorCollection.h" +#include "moses/Word.h" +#include "moses/Util.h" +#include "moses/InputFileStream.h" +#include "moses/StaticData.h" +#include "moses/WordsRange.h" +#include "moses/ChartTranslationOptionList.h" +#include "moses/FactorCollection.h" +#include "moses/Syntax/RuleTableFF.h" +#include "util/file_piece.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" +#include "util/double-conversion/double-conversion.h" +#include "util/exception.hh" + +#include "HyperPath.h" +#include "HyperPathLoader.h" +#include "HyperTree.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +bool HyperTreeLoader::Load(const std::vector<FactorType> &input, + const std::vector<FactorType> &output, + const std::string &inFile, + const RuleTableFF &ff, + HyperTree &trie) +{ + PrintUserTime(std::string("Start loading HyperTree")); + + const StaticData &staticData = StaticData::Instance(); + const std::string &factorDelimiter = staticData.GetFactorDelimiter(); + + std::size_t count = 0; + + std::ostream *progress = NULL; + IFVERBOSE(1) progress = &std::cerr; + util::FilePiece in(inFile.c_str(), progress); + + // reused variables + std::vector<float> scoreVector; + StringPiece line; + + double_conversion::StringToDoubleConverter converter(double_conversion::StringToDoubleConverter::NO_FLAGS, NAN, NAN, "inf", "nan"); + + HyperPathLoader hyperPathLoader(Input, input); + + Phrase dummySourcePhrase; + { + Word *lhs = NULL; + dummySourcePhrase.CreateFromString(Input, input, "hello", &lhs); + delete lhs; + } + + while(true) { + try { + line = in.ReadLine(); + } catch (const util::EndOfFileException &e) { + break; + } + + util::TokenIter<util::MultiCharacter> pipes(line, "|||"); + StringPiece sourceString(*pipes); + StringPiece targetString(*++pipes); + StringPiece scoreString(*++pipes); + + StringPiece alignString; + if (++pipes) { + StringPiece temp(*pipes); + alignString = temp; + } + + if (++pipes) { + StringPiece str(*pipes); //counts + } + + scoreVector.clear(); + for (util::TokenIter<util::AnyCharacter, true> s(scoreString, " \t"); s; ++s) { + int processed; + float score = converter.StringToFloat(s->data(), s->length(), &processed); + UTIL_THROW_IF2(isnan(score), "Bad score " << *s << " on line " << count); + scoreVector.push_back(FloorScore(TransformScore(score))); + } + const std::size_t numScoreComponents = ff.GetNumScoreComponents(); + if (scoreVector.size() != numScoreComponents) { + UTIL_THROW2("Size of scoreVector != number (" << scoreVector.size() << "!=" + << numScoreComponents << ") of score components on line " << count); + } + + // Source-side + HyperPath sourceFragment; + hyperPathLoader.Load(sourceString, sourceFragment); + + // Target-side + TargetPhrase *targetPhrase = new TargetPhrase(&ff); + Word *targetLHS = NULL; + targetPhrase->CreateFromString(Output, output, targetString, &targetLHS); + targetPhrase->SetTargetLHS(targetLHS); + targetPhrase->SetAlignmentInfo(alignString); + + if (++pipes) { + StringPiece sparseString(*pipes); + targetPhrase->SetSparseScore(&ff, sparseString); + } + + if (++pipes) { + StringPiece propertiesString(*pipes); + targetPhrase->SetProperties(propertiesString); + } + + targetPhrase->GetScoreBreakdown().Assign(&ff, scoreVector); + targetPhrase->EvaluateInIsolation(dummySourcePhrase, + ff.GetFeaturesToApply()); + + // Add rule to trie. + TargetPhraseCollection &phraseColl = GetOrCreateTargetPhraseCollection( + trie, sourceFragment); + phraseColl.Add(targetPhrase); + + count++; + } + + // sort and prune each target phrase collection + if (ff.GetTableLimit()) { + SortAndPrune(trie, ff.GetTableLimit()); + } + + return true; +} + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/HyperTreeLoader.h b/moses/Syntax/F2S/HyperTreeLoader.h new file mode 100644 index 000000000..b760834d3 --- /dev/null +++ b/moses/Syntax/F2S/HyperTreeLoader.h @@ -0,0 +1,31 @@ +#pragma once + +#include <istream> +#include <vector> + +#include "moses/TypeDef.h" +#include "moses/Syntax/RuleTableFF.h" + +#include "HyperTree.h" +#include "HyperTreeCreator.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +class HyperTreeLoader : public HyperTreeCreator +{ + public: + bool Load(const std::vector<FactorType> &input, + const std::vector<FactorType> &output, + const std::string &inFile, + const RuleTableFF &, + HyperTree &); +}; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/Manager-inl.h b/moses/Syntax/F2S/Manager-inl.h new file mode 100644 index 000000000..5eb722cf7 --- /dev/null +++ b/moses/Syntax/F2S/Manager-inl.h @@ -0,0 +1,318 @@ +#pragma once + +#include "moses/DecodeGraph.h" +#include "moses/ForestInput.h" +#include "moses/StaticData.h" +#include "moses/Syntax/BoundedPriorityContainer.h" +#include "moses/Syntax/CubeQueue.h" +#include "moses/Syntax/PHyperedge.h" +#include "moses/Syntax/RuleTable.h" +#include "moses/Syntax/RuleTableFF.h" +#include "moses/Syntax/SHyperedgeBundle.h" +#include "moses/Syntax/SVertex.h" +#include "moses/Syntax/SVertexRecombinationOrderer.h" +#include "moses/Syntax/SymbolEqualityPred.h" +#include "moses/Syntax/SymbolHasher.h" +#include "moses/Syntax/T2S/InputTree.h" +#include "moses/Syntax/T2S/InputTreeBuilder.h" +#include "moses/Syntax/T2S/InputTreeToForest.h" +#include "moses/TreeInput.h" + +#include "DerivationWriter.h" +#include "GlueRuleSynthesizer.h" +#include "HyperTree.h" +#include "RuleMatcherCallback.h" +#include "TopologicalSorter.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +template<typename RuleMatcher> +Manager<RuleMatcher>::Manager(const InputType &source) + : Syntax::Manager(source) +{ + if (const ForestInput *p = dynamic_cast<const ForestInput*>(&source)) { + m_forest = p->GetForest(); + m_rootVertex = p->GetRootVertex(); + } else if (const TreeInput *p = dynamic_cast<const TreeInput*>(&source)) { + T2S::InputTreeBuilder builder; + T2S::InputTree tmpTree; + builder.Build(*p, "Q", tmpTree); + boost::shared_ptr<Forest> forest = boost::make_shared<Forest>(); + m_rootVertex = T2S::InputTreeToForest(tmpTree, *forest); + m_forest = forest; + } +} + +template<typename RuleMatcher> +void Manager<RuleMatcher>::Decode() +{ + const StaticData &staticData = StaticData::Instance(); + + // Get various pruning-related constants. + const std::size_t popLimit = staticData.GetCubePruningPopLimit(); + const std::size_t ruleLimit = staticData.GetRuleLimit(); + const std::size_t stackLimit = staticData.GetMaxHypoStackSize(); + + // Initialize the stacks. + InitializeStacks(); + + // Initialize the rule matchers. + InitializeRuleMatchers(); + + // Create a callback to process the PHyperedges produced by the rule matchers. + RuleMatcherCallback callback(m_stackMap, ruleLimit); + + // Create a glue rule synthesizer. + GlueRuleSynthesizer glueRuleSynthesizer(*m_glueRuleTrie); + + // Sort the input forest's vertices into bottom-up topological order. + std::vector<const Forest::Vertex *> sortedVertices; + TopologicalSorter sorter; + sorter.Sort(*m_forest, sortedVertices); + + // Visit each vertex of the input forest in topological order. + for (std::vector<const Forest::Vertex *>::const_iterator + p = sortedVertices.begin(); p != sortedVertices.end(); ++p) { + const Forest::Vertex &vertex = **p; + + // Skip terminal vertices. + if (vertex.incoming.empty()) { + continue; + } + + // Call the rule matchers to generate PHyperedges for this vertex and + // convert each one to a SHyperedgeBundle (via the callback). The + // callback prunes the SHyperedgeBundles and keeps the best ones (up + // to ruleLimit). + callback.ClearContainer(); + for (typename std::vector<boost::shared_ptr<RuleMatcher> >::iterator + q = m_mainRuleMatchers.begin(); q != m_mainRuleMatchers.end(); ++q) { + (*q)->EnumerateHyperedges(vertex, callback); + } + + // Retrieve the (pruned) set of SHyperedgeBundles from the callback. + const BoundedPriorityContainer<SHyperedgeBundle> &bundles = + callback.GetContainer(); + + // Check if any rules were matched. If not then for each incoming + // hyperedge, synthesize a glue rule that is guaranteed to match. + if (bundles.Size() == 0) { + for (std::vector<Forest::Hyperedge *>::const_iterator p = + vertex.incoming.begin(); p != vertex.incoming.end(); ++p) { + glueRuleSynthesizer.SynthesizeRule(**p); + } + m_glueRuleMatcher->EnumerateHyperedges(vertex, callback); + // FIXME This assertion occasionally fails -- why? + // assert(bundles.Size() == vertex.incoming.size()); + } + + // Use cube pruning to extract SHyperedges from SHyperedgeBundles and + // collect the SHyperedges in a buffer. + CubeQueue cubeQueue(bundles.Begin(), bundles.End()); + std::size_t count = 0; + std::vector<SHyperedge*> buffer; + while (count < popLimit && !cubeQueue.IsEmpty()) { + SHyperedge *hyperedge = cubeQueue.Pop(); + // FIXME See corresponding code in S2T::Manager + // BEGIN{HACK} + hyperedge->head->pvertex = &(vertex.pvertex); + // END{HACK} + buffer.push_back(hyperedge); + ++count; + } + + // Recombine SVertices and sort into a stack. + SVertexStack &stack = m_stackMap[&(vertex.pvertex)]; + RecombineAndSort(buffer, stack); + + // Prune stack. + if (stackLimit > 0 && stack.size() > stackLimit) { + stack.resize(stackLimit); + } + } +} + +template<typename RuleMatcher> +void Manager<RuleMatcher>::InitializeRuleMatchers() +{ + const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); + for (std::size_t i = 0; i < ffs.size(); ++i) { + RuleTableFF *ff = ffs[i]; + // This may change in the future, but currently we assume that every + // RuleTableFF is associated with a static, file-based rule table of + // some sort and that the table should have been loaded into a RuleTable + // by this point. + const RuleTable *table = ff->GetTable(); + assert(table); + RuleTable *nonConstTable = const_cast<RuleTable*>(table); + HyperTree *trie = dynamic_cast<HyperTree*>(nonConstTable); + assert(trie); + boost::shared_ptr<RuleMatcher> p(new RuleMatcher(*trie)); + m_mainRuleMatchers.push_back(p); + } + + // Create an additional rule trie + matcher for glue rules (which are + // synthesized on demand). + // FIXME Add a hidden RuleTableFF for the glue rule trie(?) + m_glueRuleTrie.reset(new HyperTree(ffs[0])); + m_glueRuleMatcher = boost::shared_ptr<RuleMatcher>( + new RuleMatcher(*m_glueRuleTrie)); +} + +template<typename RuleMatcher> +void Manager<RuleMatcher>::InitializeStacks() +{ + // Check that m_forest has been initialized. + assert(!m_forest->vertices.empty()); + + for (std::vector<Forest::Vertex *>::const_iterator + p = m_forest->vertices.begin(); p != m_forest->vertices.end(); ++p) { + const Forest::Vertex &vertex = **p; + + // Create an empty stack. + SVertexStack &stack = m_stackMap[&(vertex.pvertex)]; + + // For terminals only, add a single SVertex. + if (vertex.incoming.empty()) { + boost::shared_ptr<SVertex> v(new SVertex()); + v->best = 0; + v->pvertex = &(vertex.pvertex); + stack.push_back(v); + } + } +} + + +template<typename RuleMatcher> +const SHyperedge *Manager<RuleMatcher>::GetBestSHyperedge() const +{ + PVertexToStackMap::const_iterator p = m_stackMap.find(&m_rootVertex->pvertex); + assert(p != m_stackMap.end()); + const SVertexStack &stack = p->second; + assert(!stack.empty()); + return stack[0]->best; +} + +template<typename RuleMatcher> +void Manager<RuleMatcher>::ExtractKBest( + std::size_t k, + std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList, + bool onlyDistinct) const +{ + kBestList.clear(); + if (k == 0 || m_source.GetSize() == 0) { + return; + } + + // Get the top-level SVertex stack. + PVertexToStackMap::const_iterator p = m_stackMap.find(&m_rootVertex->pvertex); + assert(p != m_stackMap.end()); + const SVertexStack &stack = p->second; + assert(!stack.empty()); + + KBestExtractor extractor; + + if (!onlyDistinct) { + // Return the k-best list as is, including duplicate translations. + extractor.Extract(stack, k, kBestList); + return; + } + + // Determine how many derivations to extract. If the k-best list is + // restricted to distinct translations then this limit should be bigger + // than k. The k-best factor determines how much bigger the limit should be, + // with 0 being 'unlimited.' This actually sets a large-ish limit in case + // too many translations are identical. + const StaticData &staticData = StaticData::Instance(); + const std::size_t nBestFactor = staticData.GetNBestFactor(); + std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor; + + // Extract the derivations. + KBestExtractor::KBestVec bigList; + bigList.reserve(numDerivations); + extractor.Extract(stack, numDerivations, bigList); + + // Copy derivations into kBestList, skipping ones with repeated translations. + std::set<Phrase> distinct; + for (KBestExtractor::KBestVec::const_iterator p = bigList.begin(); + kBestList.size() < k && p != bigList.end(); ++p) { + boost::shared_ptr<KBestExtractor::Derivation> derivation = *p; + Phrase translation = KBestExtractor::GetOutputPhrase(*derivation); + if (distinct.insert(translation).second) { + kBestList.push_back(derivation); + } + } +} + +// TODO Move this function into parent directory (Recombiner class?) and +// TODO share with S2T +template<typename RuleMatcher> +void Manager<RuleMatcher>::RecombineAndSort( + const std::vector<SHyperedge*> &buffer, SVertexStack &stack) +{ + // Step 1: Create a map containing a single instance of each distinct vertex + // (where distinctness is defined by the state value). The hyperedges' + // head pointers are updated to point to the vertex instances in the map and + // any 'duplicate' vertices are deleted. +// TODO Set? + typedef std::map<SVertex *, SVertex *, SVertexRecombinationOrderer> Map; + Map map; + for (std::vector<SHyperedge*>::const_iterator p = buffer.begin(); + p != buffer.end(); ++p) { + SHyperedge *h = *p; + SVertex *v = h->head; + assert(v->best == h); + assert(v->recombined.empty()); + std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v)); + if (result.second) { + continue; // v's recombination value hasn't been seen before. + } + // v is a duplicate (according to the recombination rules). + // Compare the score of h against the score of the best incoming hyperedge + // for the stored vertex. + SVertex *storedVertex = result.first->second; + if (h->label.score > storedVertex->best->label.score) { + // h's score is better. + storedVertex->recombined.push_back(storedVertex->best); + storedVertex->best = h; + } else { + storedVertex->recombined.push_back(h); + } + h->head->best = 0; + delete h->head; + h->head = storedVertex; + } + + // Step 2: Copy the vertices from the map to the stack. + stack.clear(); + stack.reserve(map.size()); + for (Map::const_iterator p = map.begin(); p != map.end(); ++p) { + stack.push_back(boost::shared_ptr<SVertex>(p->first)); + } + + // Step 3: Sort the vertices in the stack. + std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer()); +} + +template<typename RuleMatcher> +void Manager<RuleMatcher>::OutputDetailedTranslationReport( + OutputCollector *collector) const +{ + const SHyperedge *best = GetBestSHyperedge(); + if (best == NULL || collector == NULL) { + return; + } + long translationId = m_source.GetTranslationId(); + std::ostringstream out; + DerivationWriter::Write(*best, translationId, out); + collector->Write(translationId, out.str()); +} + +} // F2S +} // Syntax +} // Moses diff --git a/moses/Syntax/F2S/Manager.h b/moses/Syntax/F2S/Manager.h new file mode 100644 index 000000000..1705d4f64 --- /dev/null +++ b/moses/Syntax/F2S/Manager.h @@ -0,0 +1,69 @@ +#pragma once + +#include <set> +#include <vector> + +#include <boost/shared_ptr.hpp> +#include <boost/unordered_map.hpp> + +#include "moses/InputType.h" +#include "moses/Syntax/KBestExtractor.h" +#include "moses/Syntax/Manager.h" +#include "moses/Syntax/SVertexStack.h" +#include "moses/Word.h" + +#include "Forest.h" +#include "HyperTree.h" +#include "PVertexToStackMap.h" + +namespace Moses +{ +namespace Syntax +{ + +struct SHyperedge; + +namespace F2S +{ + +template<typename RuleMatcher> +class Manager : public Syntax::Manager +{ + public: + Manager(const InputType &); + + void Decode(); + + // Get the SHyperedge for the 1-best derivation. + const SHyperedge *GetBestSHyperedge() const; + + void ExtractKBest( + std::size_t k, + std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList, + bool onlyDistinct=false) const; + + void OutputDetailedTranslationReport(OutputCollector *collector) const; + + private: + const Forest::Vertex &FindRootNode(const Forest &); + + void InitializeRuleMatchers(); + + void InitializeStacks(); + + void RecombineAndSort(const std::vector<SHyperedge*> &, SVertexStack &); + + boost::shared_ptr<const Forest> m_forest; + const Forest::Vertex *m_rootVertex; + PVertexToStackMap m_stackMap; + boost::shared_ptr<HyperTree> m_glueRuleTrie; + std::vector<boost::shared_ptr<RuleMatcher> > m_mainRuleMatchers; + boost::shared_ptr<RuleMatcher> m_glueRuleMatcher; +}; + +} // F2S +} // Syntax +} // Moses + +// Implementation +#include "Manager-inl.h" diff --git a/moses/Syntax/F2S/PHyperedgeToSHyperedgeBundle.h b/moses/Syntax/F2S/PHyperedgeToSHyperedgeBundle.h new file mode 100644 index 000000000..81c6f3da7 --- /dev/null +++ b/moses/Syntax/F2S/PHyperedgeToSHyperedgeBundle.h @@ -0,0 +1,32 @@ +#pragma once + +#include "moses/Syntax/PHyperedge.h" +#include "moses/Syntax/PVertex.h" +#include "moses/Syntax/SHyperedgeBundle.h" + +#include "PVertexToStackMap.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +// Given a PHyperedge object and SStackSet produces a SHyperedgeBundle object. +inline void PHyperedgeToSHyperedgeBundle(const PHyperedge &hyperedge, + const PVertexToStackMap &stackMap, + SHyperedgeBundle &bundle) { + bundle.translations = hyperedge.label.translations; + bundle.stacks.clear(); + for (std::vector<PVertex*>::const_iterator p = hyperedge.tail.begin(); + p != hyperedge.tail.end(); ++p) { + PVertexToStackMap::const_iterator q = stackMap.find(*p); + const SVertexStack &stack = q->second; + bundle.stacks.push_back(&stack); + } +} + +} // F2S +} // Syntax +} // Moses diff --git a/moses/Syntax/F2S/PVertexToStackMap.h b/moses/Syntax/F2S/PVertexToStackMap.h new file mode 100644 index 000000000..9e3142492 --- /dev/null +++ b/moses/Syntax/F2S/PVertexToStackMap.h @@ -0,0 +1,20 @@ +#pragma once + +#include <boost/unordered_map.hpp> + +#include "moses/Syntax/PVertex.h" +#include "moses/Syntax/SVertexStack.h" + + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +typedef boost::unordered_map<const PVertex *, SVertexStack> PVertexToStackMap; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/RuleMatcher.h b/moses/Syntax/F2S/RuleMatcher.h new file mode 100644 index 000000000..ac3a4c065 --- /dev/null +++ b/moses/Syntax/F2S/RuleMatcher.h @@ -0,0 +1,24 @@ +#pragma once + +#include "Forest.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +// Base class for rule matchers. +template<typename Callback> +class RuleMatcher +{ + public: + virtual ~RuleMatcher() {} + + virtual void EnumerateHyperedges(const Forest::Vertex &, Callback &) = 0; +}; + +} // F2S +} // Syntax +} // Moses diff --git a/moses/Syntax/F2S/RuleMatcherCallback.h b/moses/Syntax/F2S/RuleMatcherCallback.h new file mode 100644 index 000000000..c240b87db --- /dev/null +++ b/moses/Syntax/F2S/RuleMatcherCallback.h @@ -0,0 +1,46 @@ +#pragma once + +#include "moses/Syntax/BoundedPriorityContainer.h" +#include "moses/Syntax/PHyperedge.h" +#include "moses/Syntax/PVertex.h" +#include "moses/Syntax/SHyperedgeBundle.h" +#include "moses/Syntax/SHyperedgeBundleScorer.h" + +#include "PHyperedgeToSHyperedgeBundle.h" +#include "PVertexToStackMap.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +class RuleMatcherCallback { + private: + typedef BoundedPriorityContainer<SHyperedgeBundle> Container; + + public: + RuleMatcherCallback(const PVertexToStackMap &stackMap, std::size_t ruleLimit) + : m_stackMap(stackMap) + , m_container(ruleLimit) {} + + void operator()(const PHyperedge &hyperedge) { + PHyperedgeToSHyperedgeBundle(hyperedge, m_stackMap, m_tmpBundle); + float score = SHyperedgeBundleScorer::Score(m_tmpBundle); + m_container.SwapIn(m_tmpBundle, score); + } + + void ClearContainer() { m_container.LazyClear(); } + + const Container &GetContainer() { return m_container; } + + private: + const PVertexToStackMap &m_stackMap; + SHyperedgeBundle m_tmpBundle; + BoundedPriorityContainer<SHyperedgeBundle> m_container; +}; + +} // F2S +} // Syntax +} // Moses diff --git a/moses/Syntax/F2S/RuleMatcherHyperTree-inl.h b/moses/Syntax/F2S/RuleMatcherHyperTree-inl.h new file mode 100644 index 000000000..456594873 --- /dev/null +++ b/moses/Syntax/F2S/RuleMatcherHyperTree-inl.h @@ -0,0 +1,192 @@ +#pragma once + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +template<typename Callback> +RuleMatcherHyperTree<Callback>::RuleMatcherHyperTree(const HyperTree &ruleTrie) + : m_ruleTrie(ruleTrie) +{ +} + +template<typename Callback> +void RuleMatcherHyperTree<Callback>::EnumerateHyperedges( + const Forest::Vertex &v, Callback &callback) +{ + const HyperTree::Node &root = m_ruleTrie.GetRootNode(); + HyperPath::NodeSeq nodeSeq(1, v.pvertex.symbol[0]->GetId()); + const HyperTree::Node *child = root.GetChild(nodeSeq); + if (!child) { + return; + } + + m_hyperedge.head = const_cast<PVertex*>(&v.pvertex); + + // Initialize the queue. + MatchItem item; + item.annotatedFNS.fns = FNS(1, &v); + item.trieNode = child; + m_queue.push(item); + + while (!m_queue.empty()) { + MatchItem item = m_queue.front(); + m_queue.pop(); + if (item.trieNode->HasRules()) { + const FNS &fns = item.annotatedFNS.fns; + m_hyperedge.tail.clear(); + for (FNS::const_iterator p = fns.begin(); p != fns.end(); ++p) { + const Forest::Vertex *v = *p; + m_hyperedge.tail.push_back(const_cast<PVertex *>(&(v->pvertex))); + } + m_hyperedge.label.translations = + &(item.trieNode->GetTargetPhraseCollection()); + callback(m_hyperedge); + } + PropagateNextLexel(item); + } +} + +template<typename Callback> +void RuleMatcherHyperTree<Callback>::PropagateNextLexel(const MatchItem &item) +{ + std::vector<AnnotatedFNS> tfns; + std::vector<AnnotatedFNS> rfns; + std::vector<AnnotatedFNS> rfns2; + + const HyperTree::Node &trieNode = *(item.trieNode); + const HyperTree::Node::Map &map = trieNode.GetMap(); + + for (HyperTree::Node::Map::const_iterator p = map.begin(); + p != map.end(); ++p) { + const HyperPath::NodeSeq &edgeLabel = p->first; + const HyperTree::Node &child = p->second; + + const int numSubSeqs = CountCommas(edgeLabel) + 1; + + std::size_t pos = 0; + for (int i = 0; i < numSubSeqs; ++i) { + const FNS &fns = item.annotatedFNS.fns; + tfns.clear(); + if (edgeLabel[pos] == HyperPath::kEpsilon) { + AnnotatedFNS x; + x.fns = FNS(1, fns[i]); + tfns.push_back(x); + pos += 2; + } else { + const int subSeqLength = SubSeqLength(edgeLabel, pos); + const std::vector<Forest::Hyperedge*> &incoming = fns[i]->incoming; + for (std::vector<Forest::Hyperedge *>::const_iterator q = + incoming.begin(); q != incoming.end(); ++q) { + const Forest::Hyperedge &edge = **q; + if (MatchChildren(edge.tail, edgeLabel, pos, subSeqLength)) { + tfns.resize(tfns.size()+1); + tfns.back().fns.assign(edge.tail.begin(), edge.tail.end()); + tfns.back().fragment.push_back(&edge); + } + } + pos += subSeqLength + 1; + } + if (tfns.empty()) { + rfns.clear(); + break; + } else if (i == 0) { + rfns.swap(tfns); + } else { + CartesianProduct(rfns, tfns, rfns2); + rfns.swap(rfns2); + } + } + + for (typename std::vector<AnnotatedFNS>::const_iterator q = rfns.begin(); + q != rfns.end(); ++q) { + MatchItem newItem; + newItem.annotatedFNS.fns = q->fns; + newItem.annotatedFNS.fragment = item.annotatedFNS.fragment; + newItem.annotatedFNS.fragment.insert(newItem.annotatedFNS.fragment.end(), + q->fragment.begin(), + q->fragment.end()); + newItem.trieNode = &child; + m_queue.push(newItem); + } + } +} + +template<typename Callback> +void RuleMatcherHyperTree<Callback>::CartesianProduct( + const std::vector<AnnotatedFNS> &x, + const std::vector<AnnotatedFNS> &y, + std::vector<AnnotatedFNS> &z) +{ + z.clear(); + z.reserve(x.size() * y.size()); + for (typename std::vector<AnnotatedFNS>::const_iterator p = x.begin(); + p != x.end(); ++p) { + const AnnotatedFNS &a = *p; + for (typename std::vector<AnnotatedFNS>::const_iterator q = y.begin(); + q != y.end(); ++q) { + const AnnotatedFNS &b = *q; + // Create a new AnnotatedFNS. + z.resize(z.size()+1); + AnnotatedFNS &c = z.back(); + // Combine frontier node sequences from a and b. + c.fns.reserve(a.fns.size() + b.fns.size()); + c.fns.assign(a.fns.begin(), a.fns.end()); + c.fns.insert(c.fns.end(), b.fns.begin(), b.fns.end()); + // Combine tree fragments from a and b. + c.fragment.reserve(a.fragment.size() + b.fragment.size()); + c.fragment.assign(a.fragment.begin(), a.fragment.end()); + c.fragment.insert(c.fragment.end(), b.fragment.begin(), b.fragment.end()); + } + } +} + +template<typename Callback> +bool RuleMatcherHyperTree<Callback>::MatchChildren( + const std::vector<Forest::Vertex *> &children, + const HyperPath::NodeSeq &edgeLabel, + std::size_t pos, + std::size_t subSeqSize) +{ + if (children.size() != subSeqSize) { + return false; + } + for (int i = 0; i < subSeqSize; ++i) { + if (edgeLabel[pos+i] != children[i]->pvertex.symbol[0]->GetId()) { + return false; + } + } + return true; +} + +template<typename Callback> +int RuleMatcherHyperTree<Callback>::CountCommas(const HyperPath::NodeSeq &seq) +{ + int count = 0; + for (std::vector<std::size_t>::const_iterator p = seq.begin(); + p != seq.end(); ++p) { + if (*p == HyperPath::kComma) { + ++count; + } + } + return count; +} + +template<typename Callback> +int RuleMatcherHyperTree<Callback>::SubSeqLength(const HyperPath::NodeSeq &seq, + int pos) +{ + int length = 0; + while (pos != seq.size() && seq[pos] != HyperPath::kComma) { + ++pos; + ++length; + } + return length; +} + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/RuleMatcherHyperTree.h b/moses/Syntax/F2S/RuleMatcherHyperTree.h new file mode 100644 index 000000000..406d794ed --- /dev/null +++ b/moses/Syntax/F2S/RuleMatcherHyperTree.h @@ -0,0 +1,78 @@ +#pragma once + +#include "moses/Syntax/PHyperedge.h" + +#include "Forest.h" +#include "HyperTree.h" +#include "RuleMatcher.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +// Rule matcher based on the algorithm from this paper: +// +// Hui Zhang, Min Zhang, Haizhou Li, and Chew Lim Tan +// "Fast Translation Rule Matching for Syntax-based Statistical Machine +// Translation" +// In proceedings of EMNLP 2009 +// +template<typename Callback> +class RuleMatcherHyperTree : public RuleMatcher<Callback> +{ + public: + RuleMatcherHyperTree(const HyperTree &); + + ~RuleMatcherHyperTree() {} + + void EnumerateHyperedges(const Forest::Vertex &, Callback &); + + private: + // Frontier node sequence. + typedef std::vector<const Forest::Vertex *> FNS; + + // An AnnotatedFNS is a FNS annotated with the set of forest hyperedges that + // constitute the tree fragment from which it was derived. + struct AnnotatedFNS { + FNS fns; + std::vector<const Forest::Hyperedge *> fragment; + }; + + // A MatchItem is like the FP structure in Zhang et al. (2009), but it also + // records the set of forest hyperedges that constitute the matched tree + // fragment. + struct MatchItem { + AnnotatedFNS annotatedFNS; + const HyperTree::Node *trieNode; + }; + + // Implements the Cartsian product operation from line 16 of Algorithm 4 + // (Zhang et al., 2009), which in this implementation also involves + // combining the fragment information associated with the FNS objects. + void CartesianProduct(const std::vector<AnnotatedFNS> &, + const std::vector<AnnotatedFNS> &, + std::vector<AnnotatedFNS> &); + + int CountCommas(const HyperPath::NodeSeq &); + + bool MatchChildren(const std::vector<Forest::Vertex *> &, + const HyperPath::NodeSeq &, std::size_t, std::size_t); + + void PropagateNextLexel(const MatchItem &); + + int SubSeqLength(const HyperPath::NodeSeq &, int); + + const HyperTree &m_ruleTrie; + PHyperedge m_hyperedge; + std::queue<MatchItem> m_queue; // Called "SFP" in Zhang et al. (2009) +}; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses + +// Implementation +#include "RuleMatcherHyperTree-inl.h" diff --git a/moses/Syntax/F2S/TopologicalSorter.cpp b/moses/Syntax/F2S/TopologicalSorter.cpp new file mode 100644 index 000000000..4821177b3 --- /dev/null +++ b/moses/Syntax/F2S/TopologicalSorter.cpp @@ -0,0 +1,55 @@ +#include "TopologicalSorter.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +void TopologicalSorter::Sort(const Forest &forest, + std::vector<const Forest::Vertex *> &permutation) +{ + permutation.clear(); + BuildPredSets(forest); + m_visited.clear(); + for (std::vector<Forest::Vertex *>::const_iterator + p = forest.vertices.begin(); p != forest.vertices.end(); ++p) { + if (m_visited.find(*p) == m_visited.end()) { + Visit(**p, permutation); + } + } +} + +void TopologicalSorter::BuildPredSets(const Forest &forest) +{ + m_predSets.clear(); + for (std::vector<Forest::Vertex *>::const_iterator + p = forest.vertices.begin(); p != forest.vertices.end(); ++p) { + const Forest::Vertex *head = *p; + for (std::vector<Forest::Hyperedge *>::const_iterator + q = head->incoming.begin(); q != head->incoming.end(); ++q) { + for (std::vector<Forest::Vertex *>::const_iterator + r = (*q)->tail.begin(); r != (*q)->tail.end(); ++r) { + m_predSets[head].insert(*r); + } + } + } +} + +void TopologicalSorter::Visit(const Forest::Vertex &v, + std::vector<const Forest::Vertex *> &permutation) +{ + m_visited.insert(&v); + const VertexSet &predSet = m_predSets[&v]; + for (VertexSet::const_iterator p = predSet.begin(); p != predSet.end(); ++p) { + if (m_visited.find(*p) == m_visited.end()) { + Visit(**p, permutation); + } + } + permutation.push_back(&v); +} + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/TopologicalSorter.h b/moses/Syntax/F2S/TopologicalSorter.h new file mode 100644 index 000000000..9dbb874ec --- /dev/null +++ b/moses/Syntax/F2S/TopologicalSorter.h @@ -0,0 +1,34 @@ +#pragma once + +#include <vector> + +#include <boost/unordered_map.hpp> +#include <boost/unordered_set.hpp> + +#include "Forest.h" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +class TopologicalSorter +{ + public: + void Sort(const Forest &, std::vector<const Forest::Vertex *> &); + + private: + typedef boost::unordered_set<const Forest::Vertex *> VertexSet; + + void BuildPredSets(const Forest &); + void Visit(const Forest::Vertex &, std::vector<const Forest::Vertex *> &); + + boost::unordered_set<const Forest::Vertex *> m_visited; + boost::unordered_map<const Forest::Vertex *, VertexSet> m_predSets; +}; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/TreeFragmentTokenizer.cpp b/moses/Syntax/F2S/TreeFragmentTokenizer.cpp new file mode 100644 index 000000000..1d10a47ad --- /dev/null +++ b/moses/Syntax/F2S/TreeFragmentTokenizer.cpp @@ -0,0 +1,93 @@ +#include "TreeFragmentTokenizer.h" + +#include <cctype> + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +TreeFragmentToken::TreeFragmentToken(TreeFragmentTokenType t, + StringPiece v, std::size_t p) + : type(t) + , value(v) + , pos(p) { +} + +TreeFragmentTokenizer::TreeFragmentTokenizer() + : value_(TreeFragmentToken_EOS, "", -1) { +} + +TreeFragmentTokenizer::TreeFragmentTokenizer(const StringPiece &s) + : str_(s) + , value_(TreeFragmentToken_EOS, "", -1) + , iter_(s.begin()) + , end_(s.end()) + , pos_(0) { + ++(*this); +} + +TreeFragmentTokenizer &TreeFragmentTokenizer::operator++() { + while (iter_ != end_ && (*iter_ == ' ' || *iter_ == '\t')) { + ++iter_; + ++pos_; + } + + if (iter_ == end_) { + value_ = TreeFragmentToken(TreeFragmentToken_EOS, "", pos_); + return *this; + } + + if (*iter_ == '[') { + value_ = TreeFragmentToken(TreeFragmentToken_LSB, "[", pos_); + ++iter_; + ++pos_; + } else if (*iter_ == ']') { + value_ = TreeFragmentToken(TreeFragmentToken_RSB, "]", pos_); + ++iter_; + ++pos_; + } else { + std::size_t start = pos_; + while (true) { + ++iter_; + ++pos_; + if (iter_ == end_ || *iter_ == ' ' || *iter_ == '\t') { + break; + } + if (*iter_ == '[' || *iter_ == ']') { + break; + } + } + StringPiece word = str_.substr(start, pos_-start); + value_ = TreeFragmentToken(TreeFragmentToken_WORD, word, start); + } + + return *this; +} + +TreeFragmentTokenizer TreeFragmentTokenizer::operator++(int) { + TreeFragmentTokenizer tmp(*this); + ++*this; + return tmp; +} + +bool operator==(const TreeFragmentTokenizer &lhs, + const TreeFragmentTokenizer &rhs) { + if (lhs.value_.type == TreeFragmentToken_EOS || + rhs.value_.type == TreeFragmentToken_EOS) { + return lhs.value_.type == TreeFragmentToken_EOS && + rhs.value_.type == TreeFragmentToken_EOS; + } + return lhs.iter_ == rhs.iter_; +} + +bool operator!=(const TreeFragmentTokenizer &lhs, + const TreeFragmentTokenizer &rhs) { + return !(lhs == rhs); +} + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/F2S/TreeFragmentTokenizer.h b/moses/Syntax/F2S/TreeFragmentTokenizer.h new file mode 100644 index 000000000..3924c9bed --- /dev/null +++ b/moses/Syntax/F2S/TreeFragmentTokenizer.h @@ -0,0 +1,73 @@ +#pragma once + +#include "util/string_piece.hh" + +namespace Moses +{ +namespace Syntax +{ +namespace F2S +{ + +enum TreeFragmentTokenType { + TreeFragmentToken_EOS, + TreeFragmentToken_LSB, + TreeFragmentToken_RSB, + TreeFragmentToken_WORD +}; + +struct TreeFragmentToken { + public: + TreeFragmentToken(TreeFragmentTokenType, StringPiece, std::size_t); + TreeFragmentTokenType type; + StringPiece value; + std::size_t pos; +}; + +// Tokenizes tree fragment strings in Moses format. +// +// For example, the string "[NP [NP [NN a]] [NP]]" is tokenized to the sequence: +// +// 1 LSB "[" +// 2 WORD "NP" +// 3 LSB "[" +// 4 WORD "NP" +// 5 LSB "[" +// 6 WORD "NN" +// 7 WORD "a" +// 8 RSB "]" +// 9 RSB "]" +// 10 LSB "[" +// 11 WORD "NP" +// 12 RSB "]" +// 13 RSB "]" +// 14 EOS undefined +// +class TreeFragmentTokenizer { + public: + TreeFragmentTokenizer(); + TreeFragmentTokenizer(const StringPiece &); + + const TreeFragmentToken &operator*() const { return value_; } + const TreeFragmentToken *operator->() const { return &value_; } + + TreeFragmentTokenizer &operator++(); + TreeFragmentTokenizer operator++(int); + + friend bool operator==(const TreeFragmentTokenizer &, + const TreeFragmentTokenizer &); + + friend bool operator!=(const TreeFragmentTokenizer &, + const TreeFragmentTokenizer &); + + private: + StringPiece str_; + TreeFragmentToken value_; + StringPiece::const_iterator iter_; + StringPiece::const_iterator end_; + std::size_t pos_; +}; + +} // namespace F2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/RuleTableFF.cpp b/moses/Syntax/RuleTableFF.cpp index 192863926..f4e06f489 100644 --- a/moses/Syntax/RuleTableFF.cpp +++ b/moses/Syntax/RuleTableFF.cpp @@ -1,9 +1,13 @@ #include "RuleTableFF.h" #include "moses/StaticData.h" +#include "moses/Syntax/F2S/HyperTree.h" +#include "moses/Syntax/F2S/HyperTreeLoader.h" #include "moses/Syntax/S2T/RuleTrieCYKPlus.h" #include "moses/Syntax/S2T/RuleTrieLoader.h" #include "moses/Syntax/S2T/RuleTrieScope3.h" +#include "moses/Syntax/T2S/RuleTrie.h" +#include "moses/Syntax/T2S/RuleTrieLoader.h" namespace Moses { @@ -27,9 +31,13 @@ void RuleTableFF::Load() SetFeaturesToApply(); const StaticData &staticData = StaticData::Instance(); - if (!staticData.GetSearchAlgorithm() == SyntaxS2T) { - UTIL_THROW2("ERROR: RuleTableFF currently only supports the S2T search algorithm"); - } else { + if (staticData.GetSearchAlgorithm() == SyntaxF2S || + staticData.GetSearchAlgorithm() == SyntaxT2S) { + F2S::HyperTree *trie = new F2S::HyperTree(this); + F2S::HyperTreeLoader loader; + loader.Load(m_input, m_output, m_filePath, *this, *trie); + m_table = trie; + } else if (staticData.GetSearchAlgorithm() == SyntaxS2T) { S2TParsingAlgorithm algorithm = staticData.GetS2TParsingAlgorithm(); if (algorithm == RecursiveCYKPlus) { S2T::RuleTrieCYKPlus *trie = new S2T::RuleTrieCYKPlus(this); @@ -44,6 +52,14 @@ void RuleTableFF::Load() } else { UTIL_THROW2("ERROR: unhandled S2T parsing algorithm"); } + } else if (staticData.GetSearchAlgorithm() == SyntaxT2S_SCFG) { + T2S::RuleTrie *trie = new T2S::RuleTrie(this); + T2S::RuleTrieLoader loader; + loader.Load(m_input, m_output, m_filePath, *this, *trie); + m_table = trie; + } else { + UTIL_THROW2( + "ERROR: RuleTableFF currently only supports the S2T, T2S, T2S_SCFG, and F2S search algorithms"); } } diff --git a/moses/Syntax/T2S/GlueRuleSynthesizer.cpp b/moses/Syntax/T2S/GlueRuleSynthesizer.cpp new file mode 100644 index 000000000..ec60af5f0 --- /dev/null +++ b/moses/Syntax/T2S/GlueRuleSynthesizer.cpp @@ -0,0 +1,77 @@ +#include "GlueRuleSynthesizer.h" + +#include <sstream> + +#include "moses/FF/UnknownWordPenaltyProducer.h" +#include "moses/StaticData.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +void GlueRuleSynthesizer::SynthesizeRule(const InputTree::Node &node) +{ + const Word &sourceLhs = node.pvertex.symbol; + boost::scoped_ptr<Phrase> sourceRhs(SynthesizeSourcePhrase(node)); + TargetPhrase *tp = SynthesizeTargetPhrase(node, *sourceRhs); + TargetPhraseCollection &tpc = GetOrCreateTargetPhraseCollection( + m_ruleTrie, sourceLhs, *sourceRhs); + tpc.Add(tp); +} + +Phrase *GlueRuleSynthesizer::SynthesizeSourcePhrase(const InputTree::Node &node) +{ + Phrase *phrase = new Phrase(node.children.size()); + for (std::vector<InputTree::Node*>::const_iterator p = node.children.begin(); + p != node.children.end(); ++p) { + phrase->AddWord((*p)->pvertex.symbol); + } +/* +TODO What counts as an OOV? + phrase->AddWord() = sourceWord; + phrase->GetWord(0).SetIsOOV(true); +*/ + return phrase; +} + +TargetPhrase *GlueRuleSynthesizer::SynthesizeTargetPhrase( + const InputTree::Node &node, const Phrase &sourceRhs) +{ + const StaticData &staticData = StaticData::Instance(); + + const UnknownWordPenaltyProducer &unknownWordPenaltyProducer = + UnknownWordPenaltyProducer::Instance(); + + TargetPhrase *targetPhrase = new TargetPhrase(); + + std::ostringstream alignmentSS; + for (std::size_t i = 0; i < node.children.size(); ++i) { + const Word &symbol = node.children[i]->pvertex.symbol; + if (symbol.IsNonTerminal()) { + targetPhrase->AddWord(staticData.GetOutputDefaultNonTerminal()); + } else { + // TODO Check this + Word &targetWord = targetPhrase->AddWord(); + targetWord.CreateUnknownWord(symbol); + } + alignmentSS << i << "-" << i << " "; + } + + // Assign the lowest possible score so that glue rules are only used when + // absolutely required. + float score = LOWEST_SCORE; + targetPhrase->GetScoreBreakdown().Assign(&unknownWordPenaltyProducer, score); + targetPhrase->EvaluateInIsolation(sourceRhs); + Word *targetLhs = new Word(staticData.GetOutputDefaultNonTerminal()); + targetPhrase->SetTargetLHS(targetLhs); + targetPhrase->SetAlignmentInfo(alignmentSS.str()); + + return targetPhrase; +} + +} // T2S +} // Syntax +} // Moses diff --git a/moses/Syntax/T2S/GlueRuleSynthesizer.h b/moses/Syntax/T2S/GlueRuleSynthesizer.h new file mode 100644 index 000000000..95942004c --- /dev/null +++ b/moses/Syntax/T2S/GlueRuleSynthesizer.h @@ -0,0 +1,35 @@ +#pragma once + +#include "moses/Phrase.h" +#include "moses/TargetPhrase.h" + +#include "InputTree.h" +#include "RuleTrie.h" +#include "RuleTrieCreator.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +class GlueRuleSynthesizer : public RuleTrieCreator +{ + public: + GlueRuleSynthesizer(RuleTrie &trie) : m_ruleTrie(trie) {} + + // Synthesize the minimal, montone rule that can be applied to the given node + // and add it to the rule trie. + void SynthesizeRule(const InputTree::Node &); + + private: + Phrase *SynthesizeSourcePhrase(const InputTree::Node &); + TargetPhrase *SynthesizeTargetPhrase(const InputTree::Node &, const Phrase &); + + RuleTrie &m_ruleTrie; +}; + +} // T2S +} // Syntax +} // Moses diff --git a/moses/Syntax/T2S/HyperTree.h b/moses/Syntax/T2S/HyperTree.h new file mode 100644 index 000000000..745b2d26e --- /dev/null +++ b/moses/Syntax/T2S/HyperTree.h @@ -0,0 +1,81 @@ +#pragma once + +#include <map> +#include <vector> + +#include <boost/functional/hash.hpp> +#include <boost/unordered_map.hpp> +#include <boost/version.hpp> + +#include "moses/Syntax/RuleTable.h" +#include "moses/Syntax/SymbolEqualityPred.h" +#include "moses/Syntax/SymbolHasher.h" +#include "moses/TargetPhrase.h" +#include "moses/TargetPhraseCollection.h" +#include "moses/Terminal.h" +#include "moses/Util.h" +#include "moses/Word.h" + +#include "RuleTrie.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +class HyperTree: public RuleTable +{ + public: + class Node + { + public: + typedef boost::unordered_map<std::vector<Factor*>, Node> Map; + + bool IsLeaf() const { return m_map.empty(); } + + bool HasRules() const { return !m_targetPhraseCollection.IsEmpty(); } + + void Prune(std::size_t tableLimit); + void Sort(std::size_t tableLimit); + + Node *GetOrCreateChild(const HyperPath::NodeSeq &); + + const Node *GetChild(const HyperPath::NodeSeq &) const; + + const TargetPhraseCollection &GetTargetPhraseCollection() const + return m_targetPhraseCollection; + } + + TargetPhraseCollection &GetTargetPhraseCollection() + return m_targetPhraseCollection; + } + + const Map &GetMap() const { return m_map; } + + private: + Map m_map; + TargetPhraseCollection m_targetPhraseCollection; + }; + + HyperTree(const RuleTableFF *ff) : RuleTable(ff) {} + + const Node &GetRootNode() const { return m_root; } + + private: + friend class RuleTrieCreator; + + TargetPhraseCollection &GetOrCreateTargetPhraseCollection( + const Word &sourceLHS, const Phrase &sourceRHS); + + Node &GetOrCreateNode(const Phrase &sourceRHS); + + void SortAndPrune(std::size_t); + + Node m_root; +}; + +} // namespace T2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/T2S/InputTree.h b/moses/Syntax/T2S/InputTree.h new file mode 100644 index 000000000..93b7516e6 --- /dev/null +++ b/moses/Syntax/T2S/InputTree.h @@ -0,0 +1,38 @@ +#pragma once + +#include <vector> + +#include "moses/Syntax/PVertex.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +struct InputTree +{ + public: + struct Node { + Node(const PVertex &v, const std::vector<Node*> &c) + : pvertex(v) + , children(c) {} + + Node(const PVertex &v) : pvertex(v) {} + + PVertex pvertex; + std::vector<Node*> children; + }; + + // All tree nodes in post-order. + std::vector<Node> nodes; + + // Tree nodes arranged by starting position (i.e. the vector nodes[i] + // contains the subset of tree nodes with span [i,j] (for any j).) + std::vector<std::vector<Node*> > nodesAtPos; +}; + +} // T2S +} // Syntax +} // Moses diff --git a/moses/Syntax/T2S/InputTreeBuilder.cpp b/moses/Syntax/T2S/InputTreeBuilder.cpp new file mode 100644 index 000000000..ecded8e91 --- /dev/null +++ b/moses/Syntax/T2S/InputTreeBuilder.cpp @@ -0,0 +1,171 @@ +#include "InputTreeBuilder.h" + +#include "moses/StaticData.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +InputTreeBuilder::InputTreeBuilder() + : m_outputFactorOrder(StaticData::Instance().GetOutputFactorOrder()) +{ +} + +void InputTreeBuilder::Build(const TreeInput &in, + const std::string &topLevelLabel, + InputTree &out) +{ + CreateNodes(in, topLevelLabel, out); + ConnectNodes(out); +} + +// Create the InputTree::Node objects but do not connect them. +void InputTreeBuilder::CreateNodes(const TreeInput &in, + const std::string &topLevelLabel, + InputTree &out) +{ + // Get the input sentence word count. This includes the <s> and </s> symbols. + const std::size_t numWords = in.GetSize(); + + // Get the parse tree non-terminal nodes. The parse tree covers the original + // sentence only, not the <s> and </s> symbols, so at this point there is + // no top-level node. + std::vector<XMLParseOutput> xmlNodes = in.GetLabelledSpans(); + + // Sort the XML nodes into post-order. Prior to sorting they will be in the + // order that TreeInput created them. Usually that will be post-order, but + // if, for example, the tree was binarized by relax-parse then it won't be. + // In all cases, we assume that if two nodes cover the same span then the + // first one is the lowest. + SortXmlNodesIntoPostOrder(xmlNodes); + + // Copy the parse tree non-terminal nodes, but offset the ranges by 1 (to + // allow for the <s> symbol at position 0). + std::vector<XMLParseOutput> nonTerms; + nonTerms.reserve(xmlNodes.size()+1); + for (std::vector<XMLParseOutput>::const_iterator p = xmlNodes.begin(); + p != xmlNodes.end(); ++p) { + std::size_t start = p->m_range.GetStartPos(); + std::size_t end = p->m_range.GetEndPos(); + nonTerms.push_back(XMLParseOutput(p->m_label, WordsRange(start+1, end+1))); + } + // Add a top-level node that also covers <s> and </s>. + nonTerms.push_back(XMLParseOutput(topLevelLabel, WordsRange(0, numWords-1))); + + // Allocate space for the InputTree nodes. In the case of out.nodes, this + // step is essential because once created the PVertex objects must not be + // moved around (through vector resizing) because InputTree keeps pointers + // to them. + out.nodes.reserve(numWords + nonTerms.size()); + out.nodesAtPos.resize(numWords); + + // Create the InputTree::Node objects. + int prevStart = -1; + int prevEnd = -1; + for (std::vector<XMLParseOutput>::const_iterator p = nonTerms.begin(); + p != nonTerms.end(); ++p) { + int start = static_cast<int>(p->m_range.GetStartPos()); + int end = static_cast<int>(p->m_range.GetEndPos()); + + // Check if we've started ascending a new subtree. + if (start != prevStart && end != prevEnd) { + // Add a node for each terminal to the left of or below the first + // nonTerm child of the subtree. + for (int i = prevEnd+1; i <= end; ++i) { + PVertex v(WordsRange(i, i), in.GetWord(i)); + out.nodes.push_back(InputTree::Node(v)); + out.nodesAtPos[i].push_back(&out.nodes.back()); + } + } + // Add a node for the non-terminal. + Word w(true); + w.CreateFromString(Moses::Output, m_outputFactorOrder, p->m_label, true); + PVertex v(WordsRange(start, end), w); + out.nodes.push_back(InputTree::Node(v)); + out.nodesAtPos[start].push_back(&out.nodes.back()); + + prevStart = start; + prevEnd = end; + } +} + +// Connect the nodes by filling in the node.children vectors. +void InputTreeBuilder::ConnectNodes(InputTree &out) +{ + // Create a vector that records the parent of each node (except the root). + std::vector<InputTree::Node*> parents(out.nodes.size(), NULL); + for (std::size_t i = 0; i < out.nodes.size()-1; ++i) { + const InputTree::Node &node = out.nodes[i]; + std::size_t start = node.pvertex.span.GetStartPos(); + std::size_t end = node.pvertex.span.GetEndPos(); + // Find the next node (in post-order) that completely covers node's span. + std::size_t j = i+1; + while (true) { + const InputTree::Node &succ = out.nodes[j]; + std::size_t succStart = succ.pvertex.span.GetStartPos(); + std::size_t succEnd = succ.pvertex.span.GetEndPos(); + if (succStart <= start && succEnd >= end) { + break; + } + ++j; + } + parents[i] = &(out.nodes[j]); + } + + // Add each node to its parent's list of children (except the root). + for (std::size_t i = 0; i < out.nodes.size()-1; ++i) { + InputTree::Node &child = out.nodes[i]; + InputTree::Node &parent = *(parents[i]); + parent.children.push_back(&child); + } +} + +void InputTreeBuilder::SortXmlNodesIntoPostOrder( + std::vector<XMLParseOutput> &nodes) +{ + // Sorting is based on both the value of a node and its original position, + // so for each node construct a pair containing both pieces of information. + std::vector<std::pair<XMLParseOutput *, int> > pairs; + pairs.reserve(nodes.size()); + for (std::size_t i = 0; i < nodes.size(); ++i) { + pairs.push_back(std::make_pair(&(nodes[i]), i)); + } + + // Sort the pairs. + std::sort(pairs.begin(), pairs.end(), PostOrderComp); + + // Replace the original node sequence with the correctly sorted sequence. + std::vector<XMLParseOutput> tmp; + tmp.reserve(nodes.size()); + for (std::size_t i = 0; i < pairs.size(); ++i) { + tmp.push_back(nodes[pairs[i].second]); + } + nodes.swap(tmp); +} + +// Comparison function used by SortXmlNodesIntoPostOrder. +bool InputTreeBuilder::PostOrderComp(const std::pair<XMLParseOutput *, int> &x, + const std::pair<XMLParseOutput *, int> &y) +{ + std::size_t xStart = x.first->m_range.GetStartPos(); + std::size_t xEnd = x.first->m_range.GetEndPos(); + std::size_t yStart = y.first->m_range.GetStartPos(); + std::size_t yEnd = y.first->m_range.GetEndPos(); + + if (xEnd == yEnd) { + if (xStart == yStart) { + return x.second < y.second; + } else { + return xStart > yStart; + } + } else { + return xEnd < yEnd; + } +} + +} // T2S +} // Syntax +} // Moses diff --git a/moses/Syntax/T2S/InputTreeBuilder.h b/moses/Syntax/T2S/InputTreeBuilder.h new file mode 100644 index 000000000..24b107f81 --- /dev/null +++ b/moses/Syntax/T2S/InputTreeBuilder.h @@ -0,0 +1,39 @@ +#pragma once + +#include <vector> + +#include "moses/TreeInput.h" +#include "moses/TypeDef.h" + +#include "InputTree.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +class InputTreeBuilder +{ + public: + InputTreeBuilder(); + + // Constructs a Moses::T2S::InputTree given a Moses::TreeInput and a label + // for the top-level node (which covers <s> and </s>). + void Build(const TreeInput &, const std::string &, InputTree &); + + private: + static bool PostOrderComp(const std::pair<XMLParseOutput *, int> &, + const std::pair<XMLParseOutput *, int> &); + + void CreateNodes(const TreeInput &, const std::string &, InputTree &); + void ConnectNodes(InputTree &); + void SortXmlNodesIntoPostOrder(std::vector<XMLParseOutput> &); + + const std::vector<FactorType> &m_outputFactorOrder; +}; + +} // T2S +} // Syntax +} // Moses diff --git a/moses/Syntax/T2S/InputTreeToForest.cpp b/moses/Syntax/T2S/InputTreeToForest.cpp new file mode 100644 index 000000000..fda988d57 --- /dev/null +++ b/moses/Syntax/T2S/InputTreeToForest.cpp @@ -0,0 +1,52 @@ +#include "InputTreeToForest.h" + +#include <boost/unordered_map.hpp> + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +const F2S::Forest::Vertex *InputTreeToForest(const InputTree &tree, + F2S::Forest &forest) +{ + forest.Clear(); + + // Map from tree vertices to forest vertices. + boost::unordered_map<const InputTree::Node*, F2S::Forest::Vertex*> vertexMap; + + // Create forest vertices (but not hyperedges) and fill in vertexMap. + for (std::vector<InputTree::Node>::const_iterator p = tree.nodes.begin(); + p != tree.nodes.end(); ++p) { + F2S::Forest::Vertex *v = new F2S::Forest::Vertex(p->pvertex); + forest.vertices.push_back(v); + vertexMap[&*p] = v; + } + + // Create the forest hyperedges. + for (std::vector<InputTree::Node>::const_iterator p = tree.nodes.begin(); + p != tree.nodes.end(); ++p) { + const InputTree::Node &treeVertex = *p; + const std::vector<InputTree::Node*> &treeChildren = treeVertex.children; + if (treeChildren.empty()) { + continue; + } + F2S::Forest::Hyperedge *e = new F2S::Forest::Hyperedge(); + e->head = vertexMap[&treeVertex]; + e->tail.reserve(treeChildren.size()); + for (std::vector<InputTree::Node*>::const_iterator q = treeChildren.begin(); + q != treeChildren.end(); ++q) { + e->tail.push_back(vertexMap[*q]); + } + e->head->incoming.push_back(e); + } + + // Return a pointer to the forest's root vertex. + return forest.vertices.back(); +} + +} // T2S +} // Syntax +} // Moses diff --git a/moses/Syntax/T2S/InputTreeToForest.h b/moses/Syntax/T2S/InputTreeToForest.h new file mode 100644 index 000000000..e8532c6f2 --- /dev/null +++ b/moses/Syntax/T2S/InputTreeToForest.h @@ -0,0 +1,19 @@ +#pragma once + +#include "moses/Syntax/F2S/Forest.h" + +#include "InputTree.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +// Constructs a F2S::Forest given a T2S::InputTree. +const F2S::Forest::Vertex *InputTreeToForest(const InputTree &, F2S::Forest &); + +} // T2S +} // Syntax +} // Moses diff --git a/moses/Syntax/T2S/Manager-inl.h b/moses/Syntax/T2S/Manager-inl.h new file mode 100644 index 000000000..778d1048f --- /dev/null +++ b/moses/Syntax/T2S/Manager-inl.h @@ -0,0 +1,301 @@ +#pragma once + +#include "moses/DecodeGraph.h" +#include "moses/StaticData.h" +#include "moses/Syntax/BoundedPriorityContainer.h" +#include "moses/Syntax/CubeQueue.h" +#include "moses/Syntax/F2S/DerivationWriter.h" +#include "moses/Syntax/F2S/RuleMatcherCallback.h" +#include "moses/Syntax/PHyperedge.h" +#include "moses/Syntax/RuleTable.h" +#include "moses/Syntax/RuleTableFF.h" +#include "moses/Syntax/SHyperedgeBundle.h" +#include "moses/Syntax/SVertex.h" +#include "moses/Syntax/SVertexRecombinationOrderer.h" +#include "moses/Syntax/SymbolEqualityPred.h" +#include "moses/Syntax/SymbolHasher.h" + +#include "GlueRuleSynthesizer.h" +#include "InputTreeBuilder.h" +#include "RuleTrie.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +template<typename RuleMatcher> +Manager<RuleMatcher>::Manager(const TreeInput &source) + : Syntax::Manager(source) + , m_treeSource(source) +{ +} + +template<typename RuleMatcher> +void Manager<RuleMatcher>::InitializeRuleMatchers() +{ + const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); + for (std::size_t i = 0; i < ffs.size(); ++i) { + RuleTableFF *ff = ffs[i]; + // This may change in the future, but currently we assume that every + // RuleTableFF is associated with a static, file-based rule table of + // some sort and that the table should have been loaded into a RuleTable + // by this point. + const RuleTable *table = ff->GetTable(); + assert(table); + RuleTable *nonConstTable = const_cast<RuleTable*>(table); + RuleTrie *trie = dynamic_cast<RuleTrie*>(nonConstTable); + assert(trie); + boost::shared_ptr<RuleMatcher> p(new RuleMatcher(m_inputTree, *trie)); + m_ruleMatchers.push_back(p); + } + + // Create an additional rule trie + matcher for glue rules (which are + // synthesized on demand). + // FIXME Add a hidden RuleTableFF for the glue rule trie(?) + m_glueRuleTrie.reset(new RuleTrie(ffs[0])); + boost::shared_ptr<RuleMatcher> p(new RuleMatcher(m_inputTree, *m_glueRuleTrie)); + m_ruleMatchers.push_back(p); + m_glueRuleMatcher = p.get(); +} + +template<typename RuleMatcher> +void Manager<RuleMatcher>::InitializeStacks() +{ + // Check that m_inputTree has been initialized. + assert(!m_inputTree.nodes.empty()); + + for (std::vector<InputTree::Node>::const_iterator p = + m_inputTree.nodes.begin(); p != m_inputTree.nodes.end(); ++p) { + const InputTree::Node &node = *p; + + // Create an empty stack. + SVertexStack &stack = m_stackMap[&(node.pvertex)]; + + // For terminals only, add a single SVertex. + if (node.children.empty()) { + boost::shared_ptr<SVertex> v(new SVertex()); + v->best = 0; + v->pvertex = &(node.pvertex); + stack.push_back(v); + } + } +} + +template<typename RuleMatcher> +void Manager<RuleMatcher>::Decode() +{ + const StaticData &staticData = StaticData::Instance(); + + // Get various pruning-related constants. + const std::size_t popLimit = staticData.GetCubePruningPopLimit(); + const std::size_t ruleLimit = staticData.GetRuleLimit(); + const std::size_t stackLimit = staticData.GetMaxHypoStackSize(); + + // Construct the InputTree. + InputTreeBuilder builder; + builder.Build(m_treeSource, "Q", m_inputTree); + + // Initialize the stacks. + InitializeStacks(); + + // Initialize the rule matchers. + InitializeRuleMatchers(); + + // Create a callback to process the PHyperedges produced by the rule matchers. + F2S::RuleMatcherCallback callback(m_stackMap, ruleLimit); + + // Create a glue rule synthesizer. + GlueRuleSynthesizer glueRuleSynthesizer(*m_glueRuleTrie); + + // Visit each node of the input tree in post-order. + for (std::vector<InputTree::Node>::const_iterator p = + m_inputTree.nodes.begin(); p != m_inputTree.nodes.end(); ++p) { + + const InputTree::Node &node = *p; + + // Skip terminal nodes. + if (node.children.empty()) { + continue; + } + + // Call the rule matchers to generate PHyperedges for this node and + // convert each one to a SHyperedgeBundle (via the callback). The + // callback prunes the SHyperedgeBundles and keeps the best ones (up + // to ruleLimit). + callback.ClearContainer(); + for (typename std::vector<boost::shared_ptr<RuleMatcher> >::iterator + q = m_ruleMatchers.begin(); q != m_ruleMatchers.end(); ++q) { + (*q)->EnumerateHyperedges(node, callback); + } + + // Retrieve the (pruned) set of SHyperedgeBundles from the callback. + const BoundedPriorityContainer<SHyperedgeBundle> &bundles = + callback.GetContainer(); + + // Check if any rules were matched. If not then synthesize a glue rule + // that is guaranteed to match. + if (bundles.Size() == 0) { + glueRuleSynthesizer.SynthesizeRule(node); + m_glueRuleMatcher->EnumerateHyperedges(node, callback); + assert(bundles.Size() == 1); + } + + // Use cube pruning to extract SHyperedges from SHyperedgeBundles and + // collect the SHyperedges in a buffer. + CubeQueue cubeQueue(bundles.Begin(), bundles.End()); + std::size_t count = 0; + std::vector<SHyperedge*> buffer; + while (count < popLimit && !cubeQueue.IsEmpty()) { + SHyperedge *hyperedge = cubeQueue.Pop(); + // FIXME See corresponding code in S2T::Manager + // BEGIN{HACK} + hyperedge->head->pvertex = &(node.pvertex); + // END{HACK} + buffer.push_back(hyperedge); + ++count; + } + + // Recombine SVertices and sort into a stack. + SVertexStack &stack = m_stackMap[&(node.pvertex)]; + RecombineAndSort(buffer, stack); + + // Prune stack. + if (stackLimit > 0 && stack.size() > stackLimit) { + stack.resize(stackLimit); + } + } +} + +template<typename RuleMatcher> +const SHyperedge *Manager<RuleMatcher>::GetBestSHyperedge() const +{ + const InputTree::Node &rootNode = m_inputTree.nodes.back(); + F2S::PVertexToStackMap::const_iterator p = m_stackMap.find(&rootNode.pvertex); + assert(p != m_stackMap.end()); + const SVertexStack &stack = p->second; + assert(!stack.empty()); + return stack[0]->best; +} + +template<typename RuleMatcher> +void Manager<RuleMatcher>::ExtractKBest( + std::size_t k, + std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList, + bool onlyDistinct) const +{ + kBestList.clear(); + if (k == 0 || m_source.GetSize() == 0) { + return; + } + + // Get the top-level SVertex stack. + const InputTree::Node &rootNode = m_inputTree.nodes.back(); + F2S::PVertexToStackMap::const_iterator p = m_stackMap.find(&rootNode.pvertex); + assert(p != m_stackMap.end()); + const SVertexStack &stack = p->second; + assert(!stack.empty()); + + KBestExtractor extractor; + + if (!onlyDistinct) { + // Return the k-best list as is, including duplicate translations. + extractor.Extract(stack, k, kBestList); + return; + } + + // Determine how many derivations to extract. If the k-best list is + // restricted to distinct translations then this limit should be bigger + // than k. The k-best factor determines how much bigger the limit should be, + // with 0 being 'unlimited.' This actually sets a large-ish limit in case + // too many translations are identical. + const StaticData &staticData = StaticData::Instance(); + const std::size_t nBestFactor = staticData.GetNBestFactor(); + std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor; + + // Extract the derivations. + KBestExtractor::KBestVec bigList; + bigList.reserve(numDerivations); + extractor.Extract(stack, numDerivations, bigList); + + // Copy derivations into kBestList, skipping ones with repeated translations. + std::set<Phrase> distinct; + for (KBestExtractor::KBestVec::const_iterator p = bigList.begin(); + kBestList.size() < k && p != bigList.end(); ++p) { + boost::shared_ptr<KBestExtractor::Derivation> derivation = *p; + Phrase translation = KBestExtractor::GetOutputPhrase(*derivation); + if (distinct.insert(translation).second) { + kBestList.push_back(derivation); + } + } +} + +// TODO Move this function into parent directory (Recombiner class?) and +// TODO share with S2T +template<typename RuleMatcher> +void Manager<RuleMatcher>::RecombineAndSort( + const std::vector<SHyperedge*> &buffer, SVertexStack &stack) +{ + // Step 1: Create a map containing a single instance of each distinct vertex + // (where distinctness is defined by the state value). The hyperedges' + // head pointers are updated to point to the vertex instances in the map and + // any 'duplicate' vertices are deleted. +// TODO Set? + typedef std::map<SVertex *, SVertex *, SVertexRecombinationOrderer> Map; + Map map; + for (std::vector<SHyperedge*>::const_iterator p = buffer.begin(); + p != buffer.end(); ++p) { + SHyperedge *h = *p; + SVertex *v = h->head; + assert(v->best == h); + assert(v->recombined.empty()); + std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v)); + if (result.second) { + continue; // v's recombination value hasn't been seen before. + } + // v is a duplicate (according to the recombination rules). + // Compare the score of h against the score of the best incoming hyperedge + // for the stored vertex. + SVertex *storedVertex = result.first->second; + if (h->label.score > storedVertex->best->label.score) { + // h's score is better. + storedVertex->recombined.push_back(storedVertex->best); + storedVertex->best = h; + } else { + storedVertex->recombined.push_back(h); + } + h->head->best = 0; + delete h->head; + h->head = storedVertex; + } + + // Step 2: Copy the vertices from the map to the stack. + stack.clear(); + stack.reserve(map.size()); + for (Map::const_iterator p = map.begin(); p != map.end(); ++p) { + stack.push_back(boost::shared_ptr<SVertex>(p->first)); + } + + // Step 3: Sort the vertices in the stack. + std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer()); +} + +template<typename RuleMatcher> +void Manager<RuleMatcher>::OutputDetailedTranslationReport( + OutputCollector *collector) const +{ + const SHyperedge *best = GetBestSHyperedge(); + if (best == NULL || collector == NULL) { + return; + } + long translationId = m_source.GetTranslationId(); + std::ostringstream out; + F2S::DerivationWriter::Write(*best, translationId, out); + collector->Write(translationId, out.str()); +} + +} // T2S +} // Syntax +} // Moses diff --git a/moses/Syntax/T2S/Manager.h b/moses/Syntax/T2S/Manager.h new file mode 100644 index 000000000..0082e1038 --- /dev/null +++ b/moses/Syntax/T2S/Manager.h @@ -0,0 +1,67 @@ +#pragma once + +#include <set> +#include <vector> + +#include <boost/shared_ptr.hpp> +#include <boost/unordered_map.hpp> + +#include "moses/Syntax/F2S/PVertexToStackMap.h" +#include "moses/Syntax/KBestExtractor.h" +#include "moses/Syntax/Manager.h" +#include "moses/Syntax/SVertexStack.h" +#include "moses/TreeInput.h" +#include "moses/Word.h" + +#include "InputTree.h" +#include "RuleTrie.h" + +namespace Moses +{ +namespace Syntax +{ + +struct SHyperedge; + +namespace T2S +{ + +template<typename RuleMatcher> +class Manager : public Syntax::Manager +{ + public: + Manager(const TreeInput &); + + void Decode(); + + // Get the SHyperedge for the 1-best derivation. + const SHyperedge *GetBestSHyperedge() const; + + void ExtractKBest( + std::size_t k, + std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList, + bool onlyDistinct=false) const; + + void OutputDetailedTranslationReport(OutputCollector *collector) const; + + private: + void InitializeRuleMatchers(); + + void InitializeStacks(); + + void RecombineAndSort(const std::vector<SHyperedge*> &, SVertexStack &); + + const TreeInput &m_treeSource; + InputTree m_inputTree; + F2S::PVertexToStackMap m_stackMap; + boost::shared_ptr<RuleTrie> m_glueRuleTrie; + std::vector<boost::shared_ptr<RuleMatcher> > m_ruleMatchers; + RuleMatcher *m_glueRuleMatcher; +}; + +} // T2S +} // Syntax +} // Moses + +// Implementation +#include "Manager-inl.h" diff --git a/moses/Syntax/T2S/RuleMatcher.h b/moses/Syntax/T2S/RuleMatcher.h new file mode 100644 index 000000000..2f7d4c99a --- /dev/null +++ b/moses/Syntax/T2S/RuleMatcher.h @@ -0,0 +1,24 @@ +#pragma once + +#include "InputTree.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +// Base class for rule matchers. +template<typename Callback> +class RuleMatcher +{ + public: + virtual ~RuleMatcher() {} + + virtual void EnumerateHyperedges(const InputTree::Node &, Callback &) = 0; +}; + +} // T2S +} // Syntax +} // Moses diff --git a/moses/Syntax/T2S/RuleMatcherSCFG-inl.h b/moses/Syntax/T2S/RuleMatcherSCFG-inl.h new file mode 100644 index 000000000..c1d8db63b --- /dev/null +++ b/moses/Syntax/T2S/RuleMatcherSCFG-inl.h @@ -0,0 +1,107 @@ +#pragma once + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +template<typename Callback> +RuleMatcherSCFG<Callback>::RuleMatcherSCFG(const InputTree &inputTree, + const RuleTrie &ruleTrie) + : m_inputTree(inputTree) + , m_ruleTrie(ruleTrie) +{ +} + +template<typename Callback> +void RuleMatcherSCFG<Callback>::EnumerateHyperedges(const InputTree::Node &node, + Callback &callback) +{ + const int start = static_cast<int>(node.pvertex.span.GetStartPos()); + m_hyperedge.head = const_cast<PVertex*>(&node.pvertex); + m_hyperedge.tail.clear(); + Match(node, m_ruleTrie.GetRootNode(), start, callback); +} + +template<typename Callback> +void RuleMatcherSCFG<Callback>::Match(const InputTree::Node &inNode, + const RuleTrie::Node &trieNode, + int start, Callback &callback) +{ + // Try to extend the current hyperedge tail by adding a tree node that is a + // descendent of inNode and has a span starting at start. + const std::vector<InputTree::Node*> &nodes = m_inputTree.nodesAtPos[start]; + for (std::vector<InputTree::Node*>::const_iterator p = nodes.begin(); + p != nodes.end(); ++p) { + InputTree::Node &candidate = **p; + // Is candidate a descendent of inNode? + if (!IsDescendent(candidate, inNode)) { + continue; + } + // Get the appropriate SymbolMap (according to whether candidate is a + // terminal or non-terminal map) from the current rule trie node. + const RuleTrie::Node::SymbolMap *map = NULL; + if (candidate.children.empty()) { + map = &(trieNode.GetTerminalMap()); + } else { + map = &(trieNode.GetNonTerminalMap()); + } + // Test if the current rule prefix can be extended by candidate's symbol. + RuleTrie::Node::SymbolMap::const_iterator q = + map->find(candidate.pvertex.symbol); + if (q == map->end()) { + continue; + } + const RuleTrie::Node &newTrieNode = q->second; + // Add the candidate node to the tail. + m_hyperedge.tail.push_back(&candidate.pvertex); + // Have we now covered the full span of inNode? + if (candidate.pvertex.span.GetEndPos() == inNode.pvertex.span.GetEndPos()) { + // Check if the trie node has any rules with a LHS that match inNode. + const Word &lhs = inNode.pvertex.symbol; + const TargetPhraseCollection *tpc = + newTrieNode.GetTargetPhraseCollection(lhs); + if (tpc) { + m_hyperedge.label.translations = tpc; + callback(m_hyperedge); + } + } else { + // Recursive step. + int newStart = candidate.pvertex.span.GetEndPos()+1; + Match(inNode, newTrieNode, newStart, callback); + } + m_hyperedge.tail.pop_back(); + } +} + +// Return true iff x is a descendent of y; false otherwise. +template<typename Callback> +bool RuleMatcherSCFG<Callback>::IsDescendent(const InputTree::Node &x, + const InputTree::Node &y) +{ + const std::size_t xStart = x.pvertex.span.GetStartPos(); + const std::size_t yStart = y.pvertex.span.GetStartPos(); + const std::size_t xEnd = x.pvertex.span.GetEndPos(); + const std::size_t yEnd = y.pvertex.span.GetEndPos(); + if (xStart < yStart || xEnd > yEnd) { + return false; + } + if (xStart > yStart || xEnd < yEnd) { + return true; + } + // x and y both cover the same span. + const InputTree::Node *z = &y; + while (z->children.size() == 1) { + z = z->children[0]; + if (z == &x) { + return true; + } + } + return false; +} + +} // namespace T2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/T2S/RuleMatcherSCFG.h b/moses/Syntax/T2S/RuleMatcherSCFG.h new file mode 100644 index 000000000..078388f5f --- /dev/null +++ b/moses/Syntax/T2S/RuleMatcherSCFG.h @@ -0,0 +1,42 @@ +#pragma once + +#include "moses/Syntax/PHyperedge.h" + +#include "RuleMatcher.h" +#include "RuleTrie.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +// TODO +// +template<typename Callback> +class RuleMatcherSCFG : public RuleMatcher<Callback> +{ + public: + RuleMatcherSCFG(const InputTree &, const RuleTrie &); + + ~RuleMatcherSCFG() {} + + void EnumerateHyperedges(const InputTree::Node &, Callback &); + + private: + bool IsDescendent(const InputTree::Node &, const InputTree::Node &); + + void Match(const InputTree::Node &, const RuleTrie::Node &, int, Callback &); + + const InputTree &m_inputTree; + const RuleTrie &m_ruleTrie; + PHyperedge m_hyperedge; +}; + +} // namespace T2S +} // namespace Syntax +} // namespace Moses + +// Implementation +#include "RuleMatcherSCFG-inl.h" diff --git a/moses/Syntax/T2S/RuleTrie.cpp b/moses/Syntax/T2S/RuleTrie.cpp new file mode 100644 index 000000000..981543290 --- /dev/null +++ b/moses/Syntax/T2S/RuleTrie.cpp @@ -0,0 +1,139 @@ +#include "RuleTrie.h" + +#include <map> +#include <vector> + +#include <boost/functional/hash.hpp> +#include <boost/unordered_map.hpp> +#include <boost/version.hpp> + +#include "moses/NonTerminal.h" +#include "moses/TargetPhrase.h" +#include "moses/TargetPhraseCollection.h" +#include "moses/Util.h" +#include "moses/Word.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +void RuleTrie::Node::Prune(std::size_t tableLimit) +{ + // Recusively prune child nodes. + for (SymbolMap::iterator p = m_sourceTermMap.begin(); + p != m_sourceTermMap.end(); ++p) { + p->second.Prune(tableLimit); + } + for (SymbolMap::iterator p = m_nonTermMap.begin(); + p != m_nonTermMap.end(); ++p) { + p->second.Prune(tableLimit); + } + + // Prune TargetPhraseCollections at this node. + for (TPCMap::iterator p = m_targetPhraseCollections.begin(); + p != m_targetPhraseCollections.end(); ++p) { + p->second.Prune(true, tableLimit); + } +} + +void RuleTrie::Node::Sort(std::size_t tableLimit) +{ + // Recusively sort child nodes. + for (SymbolMap::iterator p = m_sourceTermMap.begin(); + p != m_sourceTermMap.end(); ++p) { + p->second.Sort(tableLimit); + } + for (SymbolMap::iterator p = m_nonTermMap.begin(); + p != m_nonTermMap.end(); ++p) { + p->second.Sort(tableLimit); + } + + // Sort TargetPhraseCollections at this node. + for (TPCMap::iterator p = m_targetPhraseCollections.begin(); + p != m_targetPhraseCollections.end(); ++p) { + p->second.Sort(true, tableLimit); + } +} + +RuleTrie::Node *RuleTrie::Node::GetOrCreateChild( + const Word &sourceTerm) +{ + return &m_sourceTermMap[sourceTerm]; +} + +RuleTrie::Node *RuleTrie::Node::GetOrCreateNonTerminalChild(const Word &targetNonTerm) +{ + UTIL_THROW_IF2(!targetNonTerm.IsNonTerminal(), + "Not a non-terminal: " << targetNonTerm); + + return &m_nonTermMap[targetNonTerm]; +} + +TargetPhraseCollection &RuleTrie::Node::GetOrCreateTargetPhraseCollection( + const Word &sourceLHS) +{ + UTIL_THROW_IF2(!sourceLHS.IsNonTerminal(), + "Not a non-terminal: " << sourceLHS); + return m_targetPhraseCollections[sourceLHS]; +} + +const RuleTrie::Node *RuleTrie::Node::GetChild( + const Word &sourceTerm) const +{ + UTIL_THROW_IF2(sourceTerm.IsNonTerminal(), + "Not a terminal: " << sourceTerm); + + SymbolMap::const_iterator p = m_sourceTermMap.find(sourceTerm); + return (p == m_sourceTermMap.end()) ? NULL : &p->second; +} + +const RuleTrie::Node *RuleTrie::Node::GetNonTerminalChild( + const Word &targetNonTerm) const +{ + UTIL_THROW_IF2(!targetNonTerm.IsNonTerminal(), + "Not a non-terminal: " << targetNonTerm); + + SymbolMap::const_iterator p = m_nonTermMap.find(targetNonTerm); + return (p == m_nonTermMap.end()) ? NULL : &p->second; +} + +TargetPhraseCollection &RuleTrie::GetOrCreateTargetPhraseCollection( + const Word &sourceLHS, const Phrase &sourceRHS) +{ + Node &currNode = GetOrCreateNode(sourceRHS); + return currNode.GetOrCreateTargetPhraseCollection(sourceLHS); +} + +RuleTrie::Node &RuleTrie::GetOrCreateNode(const Phrase &sourceRHS) +{ + const std::size_t size = sourceRHS.GetSize(); + + Node *currNode = &m_root; + for (std::size_t pos = 0 ; pos < size ; ++pos) { + const Word& word = sourceRHS.GetWord(pos); + + if (word.IsNonTerminal()) { + currNode = currNode->GetOrCreateNonTerminalChild(word); + } else { + currNode = currNode->GetOrCreateChild(word); + } + + UTIL_THROW_IF2(currNode == NULL, "Node not found at position " << pos); + } + + return *currNode; +} + +void RuleTrie::SortAndPrune(std::size_t tableLimit) +{ + if (tableLimit) { + m_root.Sort(tableLimit); + } +} + +} // namespace T2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/T2S/RuleTrie.h b/moses/Syntax/T2S/RuleTrie.h new file mode 100644 index 000000000..564c0cc80 --- /dev/null +++ b/moses/Syntax/T2S/RuleTrie.h @@ -0,0 +1,90 @@ +#pragma once + +#include <map> +#include <vector> + +#include <boost/functional/hash.hpp> +#include <boost/unordered_map.hpp> +#include <boost/version.hpp> + +#include "moses/Syntax/RuleTable.h" +#include "moses/Syntax/SymbolEqualityPred.h" +#include "moses/Syntax/SymbolHasher.h" +#include "moses/TargetPhrase.h" +#include "moses/TargetPhraseCollection.h" +#include "moses/Terminal.h" +#include "moses/Util.h" +#include "moses/Word.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +class RuleTrie: public RuleTable +{ + public: + class Node + { + public: + typedef boost::unordered_map<Word, Node, SymbolHasher, + SymbolEqualityPred> SymbolMap; + + typedef boost::unordered_map<Word, TargetPhraseCollection, + SymbolHasher, SymbolEqualityPred> TPCMap; + + bool IsLeaf() const { + return m_sourceTermMap.empty() && m_nonTermMap.empty(); + } + + bool HasRules() const { return !m_targetPhraseCollections.empty(); } + + void Prune(std::size_t tableLimit); + void Sort(std::size_t tableLimit); + + Node *GetOrCreateChild(const Word &sourceTerm); + Node *GetOrCreateNonTerminalChild(const Word &targetNonTerm); + TargetPhraseCollection &GetOrCreateTargetPhraseCollection(const Word &); + + const Node *GetChild(const Word &sourceTerm) const; + const Node *GetNonTerminalChild(const Word &targetNonTerm) const; + + const TargetPhraseCollection *GetTargetPhraseCollection( + const Word &sourceLHS) const { + TPCMap::const_iterator p = m_targetPhraseCollections.find(sourceLHS); + return p == m_targetPhraseCollections.end() ? 0 : &(p->second); + } + + // FIXME IS there any reason to distinguish these two for T2S? + const SymbolMap &GetTerminalMap() const { return m_sourceTermMap; } + + const SymbolMap &GetNonTerminalMap() const { return m_nonTermMap; } + + private: + SymbolMap m_sourceTermMap; + SymbolMap m_nonTermMap; + TPCMap m_targetPhraseCollections; + }; + + RuleTrie(const RuleTableFF *ff) : RuleTable(ff) {} + + const Node &GetRootNode() const { return m_root; } + + private: + friend class RuleTrieCreator; + + TargetPhraseCollection &GetOrCreateTargetPhraseCollection( + const Word &sourceLHS, const Phrase &sourceRHS); + + Node &GetOrCreateNode(const Phrase &sourceRHS); + + void SortAndPrune(std::size_t); + + Node m_root; +}; + +} // namespace T2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/T2S/RuleTrieCreator.h b/moses/Syntax/T2S/RuleTrieCreator.h new file mode 100644 index 000000000..b474a88e0 --- /dev/null +++ b/moses/Syntax/T2S/RuleTrieCreator.h @@ -0,0 +1,32 @@ +#pragma once + +#include "RuleTrie.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +// Base for classes that create a RuleTrie (currently RuleTrieLoader and +// OovHandler). RuleTrieCreator is a friend of RuleTrie. +class RuleTrieCreator +{ + protected: + // Provide access to RuleTrie's private SortAndPrune function. + void SortAndPrune(RuleTrie &trie, std::size_t limit) { + trie.SortAndPrune(limit); + } + + // Provide access to RuleTrie's private + // GetOrCreateTargetPhraseCollection function. + TargetPhraseCollection &GetOrCreateTargetPhraseCollection( + RuleTrie &trie, const Word &sourceLHS, const Phrase &sourceRHS) { + return trie.GetOrCreateTargetPhraseCollection(sourceLHS, sourceRHS); + } +}; + +} // namespace T2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/T2S/RuleTrieLoader.cpp b/moses/Syntax/T2S/RuleTrieLoader.cpp new file mode 100644 index 000000000..9feaefc94 --- /dev/null +++ b/moses/Syntax/T2S/RuleTrieLoader.cpp @@ -0,0 +1,154 @@ +#include "RuleTrieLoader.h" + +#include <sys/stat.h> +#include <stdlib.h> + +#include <fstream> +#include <string> +#include <iterator> +#include <algorithm> +#include <iostream> + +#include "moses/FactorCollection.h" +#include "moses/Word.h" +#include "moses/Util.h" +#include "moses/InputFileStream.h" +#include "moses/StaticData.h" +#include "moses/WordsRange.h" +#include "moses/ChartTranslationOptionList.h" +#include "moses/FactorCollection.h" +#include "moses/Syntax/RuleTableFF.h" +#include "util/file_piece.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" +#include "util/double-conversion/double-conversion.h" +#include "util/exception.hh" + +#include "RuleTrie.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +bool RuleTrieLoader::Load(const std::vector<FactorType> &input, + const std::vector<FactorType> &output, + const std::string &inFile, + const RuleTableFF &ff, + RuleTrie &trie) +{ + PrintUserTime(std::string("Start loading text phrase table. Moses format")); + + const StaticData &staticData = StaticData::Instance(); + const std::string &factorDelimiter = staticData.GetFactorDelimiter(); + + std::size_t count = 0; + + std::ostream *progress = NULL; + IFVERBOSE(1) progress = &std::cerr; + util::FilePiece in(inFile.c_str(), progress); + + // reused variables + std::vector<float> scoreVector; + StringPiece line; + + double_conversion::StringToDoubleConverter converter(double_conversion::StringToDoubleConverter::NO_FLAGS, NAN, NAN, "inf", "nan"); + + while(true) { + try { + line = in.ReadLine(); + } catch (const util::EndOfFileException &e) { + break; + } + + util::TokenIter<util::MultiCharacter> pipes(line, "|||"); + StringPiece sourcePhraseString(*pipes); + StringPiece targetPhraseString(*++pipes); + StringPiece scoreString(*++pipes); + + StringPiece alignString; + if (++pipes) { + StringPiece temp(*pipes); + alignString = temp; + } + + if (++pipes) { + StringPiece str(*pipes); //counts + } + + bool isLHSEmpty = (sourcePhraseString.find_first_not_of(" \t", 0) == std::string::npos); + if (isLHSEmpty && !staticData.IsWordDeletionEnabled()) { + TRACE_ERR( ff.GetFilePath() << ":" << count << ": pt entry contains empty target, skipping\n"); + continue; + } + + scoreVector.clear(); + for (util::TokenIter<util::AnyCharacter, true> s(scoreString, " \t"); s; ++s) { + int processed; + float score = converter.StringToFloat(s->data(), s->length(), &processed); + UTIL_THROW_IF2(isnan(score), "Bad score " << *s << " on line " << count); + scoreVector.push_back(FloorScore(TransformScore(score))); + } + const std::size_t numScoreComponents = ff.GetNumScoreComponents(); + if (scoreVector.size() != numScoreComponents) { + UTIL_THROW2("Size of scoreVector != number (" << scoreVector.size() << "!=" + << numScoreComponents << ") of score components on line " << count); + } + + // parse source & find pt node + + // constituent labels + Word *sourceLHS = NULL; + Word *targetLHS; + + // create target phrase obj + TargetPhrase *targetPhrase = new TargetPhrase(&ff); + // targetPhrase->CreateFromString(Output, output, targetPhraseString, factorDelimiter, &targetLHS); + targetPhrase->CreateFromString(Output, output, targetPhraseString, &targetLHS); + // source + Phrase sourcePhrase; + // sourcePhrase.CreateFromString(Input, input, sourcePhraseString, factorDelimiter, &sourceLHS); + sourcePhrase.CreateFromString(Input, input, sourcePhraseString, &sourceLHS); + + // rest of target phrase + targetPhrase->SetAlignmentInfo(alignString); + targetPhrase->SetTargetLHS(targetLHS); + + //targetPhrase->SetDebugOutput(string("New Format pt ") + line); + + if (++pipes) { + StringPiece sparseString(*pipes); + targetPhrase->SetSparseScore(&ff, sparseString); + } + + if (++pipes) { + StringPiece propertiesString(*pipes); + targetPhrase->SetProperties(propertiesString); + } + + targetPhrase->GetScoreBreakdown().Assign(&ff, scoreVector); + targetPhrase->EvaluateInIsolation(sourcePhrase, ff.GetFeaturesToApply()); + + TargetPhraseCollection &phraseColl = GetOrCreateTargetPhraseCollection( + trie, *sourceLHS, sourcePhrase); + phraseColl.Add(targetPhrase); + + // not implemented correctly in memory pt. just delete it for now + delete sourceLHS; + + count++; + } + + // sort and prune each target phrase collection + if (ff.GetTableLimit()) { + SortAndPrune(trie, ff.GetTableLimit()); + } + + return true; +} + +} // namespace T2S +} // namespace Syntax +} // namespace Moses diff --git a/moses/Syntax/T2S/RuleTrieLoader.h b/moses/Syntax/T2S/RuleTrieLoader.h new file mode 100644 index 000000000..d3fa4ec60 --- /dev/null +++ b/moses/Syntax/T2S/RuleTrieLoader.h @@ -0,0 +1,31 @@ +#pragma once + +#include <istream> +#include <vector> + +#include "moses/TypeDef.h" +#include "moses/Syntax/RuleTableFF.h" + +#include "RuleTrie.h" +#include "RuleTrieCreator.h" + +namespace Moses +{ +namespace Syntax +{ +namespace T2S +{ + +class RuleTrieLoader : public RuleTrieCreator +{ + public: + bool Load(const std::vector<FactorType> &input, + const std::vector<FactorType> &output, + const std::string &inFile, + const RuleTableFF &, + RuleTrie &); +}; + +} // namespace T2S +} // namespace Syntax +} // namespace Moses |