From 05b31b53f22abfabb8141b6cd7b0246890d521a1 Mon Sep 17 00:00:00 2001 From: Phil Williams Date: Mon, 13 Apr 2015 16:31:58 +0100 Subject: Implement -output-unknowns for search algorithms 7 and 9 (T2S/F2S) --- moses/Syntax/F2S/HyperTreeLoader.cpp | 24 +++++++++++++++++++++--- moses/Syntax/F2S/HyperTreeLoader.h | 10 +++++++++- moses/Syntax/F2S/Manager-inl.h | 24 +++++++++++++++++++++++- moses/Syntax/F2S/Manager.h | 3 +++ moses/Syntax/RuleTableFF.cpp | 3 ++- moses/Syntax/RuleTableFF.h | 7 +++++++ 6 files changed, 65 insertions(+), 6 deletions(-) diff --git a/moses/Syntax/F2S/HyperTreeLoader.cpp b/moses/Syntax/F2S/HyperTreeLoader.cpp index f3caa2cec..bd19cbace 100644 --- a/moses/Syntax/F2S/HyperTreeLoader.cpp +++ b/moses/Syntax/F2S/HyperTreeLoader.cpp @@ -40,12 +40,12 @@ bool HyperTreeLoader::Load(const std::vector &input, const std::vector &output, const std::string &inFile, const RuleTableFF &ff, - HyperTree &trie) + HyperTree &trie, + boost::unordered_set &sourceTermSet) { PrintUserTime(std::string("Start loading HyperTree")); - // const StaticData &staticData = StaticData::Instance(); - // const std::string &factorDelimiter = staticData.GetFactorDelimiter(); + sourceTermSet.clear(); std::size_t count = 0; @@ -106,6 +106,7 @@ bool HyperTreeLoader::Load(const std::vector &input, // Source-side HyperPath sourceFragment; hyperPathLoader.Load(sourceString, sourceFragment); + ExtractSourceTerminalSetFromHyperPath(sourceFragment, sourceTermSet); // Target-side TargetPhrase *targetPhrase = new TargetPhrase(&ff); @@ -144,6 +145,23 @@ bool HyperTreeLoader::Load(const std::vector &input, return true; } +void HyperTreeLoader::ExtractSourceTerminalSetFromHyperPath( + const HyperPath &hp, boost::unordered_set &sourceTerminalSet) +{ + for (std::vector::const_iterator p = hp.nodeSeqs.begin(); + p != hp.nodeSeqs.end(); ++p) { + for (std::vector::const_iterator q = p->begin(); + q != p->end(); ++q) { + const std::size_t factorId = *q; + if (factorId >= moses_MaxNumNonterminals && + factorId != HyperPath::kComma && + factorId != HyperPath::kEpsilon) { + sourceTerminalSet.insert(factorId); + } + } + } +} + } // namespace F2S } // namespace Syntax } // namespace Moses diff --git a/moses/Syntax/F2S/HyperTreeLoader.h b/moses/Syntax/F2S/HyperTreeLoader.h index ea009022d..088c7eaf5 100644 --- a/moses/Syntax/F2S/HyperTreeLoader.h +++ b/moses/Syntax/F2S/HyperTreeLoader.h @@ -3,9 +3,12 @@ #include #include +#include + #include "moses/TypeDef.h" #include "moses/Syntax/RuleTableFF.h" +#include "HyperPath.h" #include "HyperTree.h" #include "HyperTreeCreator.h" @@ -23,7 +26,12 @@ public: const std::vector &output, const std::string &inFile, const RuleTableFF &, - HyperTree &); + HyperTree &, + boost::unordered_set &); + +private: + void ExtractSourceTerminalSetFromHyperPath( + const HyperPath &, boost::unordered_set &); }; } // namespace F2S diff --git a/moses/Syntax/F2S/Manager-inl.h b/moses/Syntax/F2S/Manager-inl.h index a422e8085..f7f8f0ae9 100644 --- a/moses/Syntax/F2S/Manager-inl.h +++ b/moses/Syntax/F2S/Manager-inl.h @@ -38,6 +38,7 @@ Manager::Manager(const InputType &source) if (const ForestInput *p = dynamic_cast(&source)) { m_forest = p->GetForest(); m_rootVertex = p->GetRootVertex(); + m_sentenceLength = p->GetSize(); } else if (const TreeInput *p = dynamic_cast(&source)) { T2S::InputTreeBuilder builder; T2S::InputTree tmpTree; @@ -45,6 +46,7 @@ Manager::Manager(const InputType &source) boost::shared_ptr forest = boost::make_shared(); m_rootVertex = T2S::InputTreeToForest(tmpTree, *forest); m_forest = forest; + m_sentenceLength = p->GetSize(); } else { UTIL_THROW2("ERROR: F2S::Manager requires input to be a tree or forest"); } @@ -82,8 +84,13 @@ void Manager::Decode() p = sortedVertices.begin(); p != sortedVertices.end(); ++p) { const Forest::Vertex &vertex = **p; - // Skip terminal vertices. + // Skip terminal vertices (after checking if they are OOVs). if (vertex.incoming.empty()) { + if (vertex.pvertex.span.GetStartPos() > 0 && + vertex.pvertex.span.GetEndPos() < m_sentenceLength-1 && + IsUnknownSourceWord(vertex.pvertex.symbol)) { + m_oovs.insert(vertex.pvertex.symbol); + } continue; } @@ -189,6 +196,21 @@ void Manager::InitializeStacks() } } +template +bool Manager::IsUnknownSourceWord(const Word &w) const +{ + const std::size_t factorId = w[0]->GetId(); + const std::vector &ffs = RuleTableFF::Instances(); + for (std::size_t i = 0; i < ffs.size(); ++i) { + RuleTableFF *ff = ffs[i]; + const boost::unordered_set &sourceTerms = + ff->GetSourceTerminalSet(); + if (sourceTerms.find(factorId) != sourceTerms.end()) { + return false; + } + } + return true; +} template const SHyperedge *Manager::GetBestSHyperedge() const diff --git a/moses/Syntax/F2S/Manager.h b/moses/Syntax/F2S/Manager.h index 3c7ff8da1..90f34c04b 100644 --- a/moses/Syntax/F2S/Manager.h +++ b/moses/Syntax/F2S/Manager.h @@ -51,10 +51,13 @@ private: void InitializeStacks(); + bool IsUnknownSourceWord(const Word &) const; + void RecombineAndSort(const std::vector &, SVertexStack &); boost::shared_ptr m_forest; const Forest::Vertex *m_rootVertex; + std::size_t m_sentenceLength; // Includes and PVertexToStackMap m_stackMap; boost::shared_ptr m_glueRuleTrie; std::vector > m_mainRuleMatchers; diff --git a/moses/Syntax/RuleTableFF.cpp b/moses/Syntax/RuleTableFF.cpp index f4e06f489..37063e048 100644 --- a/moses/Syntax/RuleTableFF.cpp +++ b/moses/Syntax/RuleTableFF.cpp @@ -35,7 +35,8 @@ void RuleTableFF::Load() staticData.GetSearchAlgorithm() == SyntaxT2S) { F2S::HyperTree *trie = new F2S::HyperTree(this); F2S::HyperTreeLoader loader; - loader.Load(m_input, m_output, m_filePath, *this, *trie); + loader.Load(m_input, m_output, m_filePath, *this, *trie, + m_sourceTerminalSet); m_table = trie; } else if (staticData.GetSearchAlgorithm() == SyntaxS2T) { S2TParsingAlgorithm algorithm = staticData.GetS2TParsingAlgorithm(); diff --git a/moses/Syntax/RuleTableFF.h b/moses/Syntax/RuleTableFF.h index 4d6132e86..25e7d8428 100644 --- a/moses/Syntax/RuleTableFF.h +++ b/moses/Syntax/RuleTableFF.h @@ -43,10 +43,17 @@ public: return 0; } + // Get the source terminal vocabulary for this table's grammar (as a set of + // factor IDs) + const boost::unordered_set &GetSourceTerminalSet() const { + return m_sourceTerminalSet; + } + private: static std::vector s_instances; const RuleTable *m_table; + boost::unordered_set m_sourceTerminalSet; }; } // Syntax -- cgit v1.2.3