diff options
Diffstat (limited to 'moses/src/ChartTrellisNode.cpp')
-rw-r--r-- | moses/src/ChartTrellisNode.cpp | 132 |
1 files changed, 62 insertions, 70 deletions
diff --git a/moses/src/ChartTrellisNode.cpp b/moses/src/ChartTrellisNode.cpp index eb65120a7..725886c68 100644 --- a/moses/src/ChartTrellisNode.cpp +++ b/moses/src/ChartTrellisNode.cpp @@ -20,81 +20,57 @@ ***********************************************************************/ #include "ChartTrellisNode.h" + #include "ChartHypothesis.h" -#include "DotChart.h" -#include "ScoreComponentCollection.h" +#include "ChartTrellisDetour.h" +#include "ChartTrellisPath.h" #include "StaticData.h" - -using namespace std; +#include "DotChart.h" namespace Moses { -ChartTrellisNode::ChartTrellisNode(const ChartHypothesis *hypo) - :m_hypo(hypo) +ChartTrellisNode::ChartTrellisNode(const ChartHypothesis &hypo) + : m_hypo(hypo) { - const std::vector<const ChartHypothesis*> &prevHypos = hypo->GetPrevHypos(); + CreateChildren(); +} - m_edge.reserve(prevHypos.size()); - for (size_t ind = 0; ind < prevHypos.size(); ++ind) { - const ChartHypothesis *prevHypo = prevHypos[ind]; - ChartTrellisNode *child = new ChartTrellisNode(prevHypo); - m_edge.push_back(child); +ChartTrellisNode::ChartTrellisNode(const ChartTrellisDetour &detour, + ChartTrellisNode *&deviationPoint) + : m_hypo((&detour.GetBasePath().GetFinalNode() == &detour.GetSubstitutedNode()) + ? detour.GetReplacementHypo() + : detour.GetBasePath().GetFinalNode().GetHypothesis()) +{ + if (&m_hypo == &detour.GetReplacementHypo()) { + deviationPoint = this; + CreateChildren(); + } else { + CreateChildren(detour.GetBasePath().GetFinalNode(), + detour.GetSubstitutedNode(), detour.GetReplacementHypo(), + deviationPoint); } - - assert(m_hypo); } -ChartTrellisNode::ChartTrellisNode(const ChartTrellisNode &origNode - , const ChartTrellisNode &soughtNode - , const ChartHypothesis &replacementHypo - , ScoreComponentCollection &scoreChange - , const ChartTrellisNode *&nodeChanged) +ChartTrellisNode::ChartTrellisNode(const ChartTrellisNode &root, + const ChartTrellisNode &substitutedNode, + const ChartHypothesis &replacementHypo, + ChartTrellisNode *&deviationPoint) + : m_hypo((&root == &substitutedNode) + ? replacementHypo + : root.GetHypothesis()) { - if (&origNode.GetHypothesis() == &soughtNode.GetHypothesis()) { - // this node should be replaced - m_hypo = &replacementHypo; - nodeChanged = this; - - // scores - assert(scoreChange.GetWeightedScore() == 0); // should only be changing 1 node - - scoreChange = replacementHypo.GetScoreBreakdown(); - scoreChange.MinusEquals(origNode.GetHypothesis().GetScoreBreakdown()); - - float deltaScore = scoreChange.GetWeightedScore(); - assert(deltaScore <= 0.005); - - // follow prev hypos back to beginning - const std::vector<const ChartHypothesis*> &prevHypos = replacementHypo.GetPrevHypos(); - vector<const ChartHypothesis*>::const_iterator iter; - assert(m_edge.empty()); - m_edge.reserve(prevHypos.size()); - for (iter = prevHypos.begin(); iter != prevHypos.end(); ++iter) { - const ChartHypothesis *prevHypo = *iter; - ChartTrellisNode *prevNode = new ChartTrellisNode(prevHypo); - m_edge.push_back(prevNode); - } - + if (&root == &substitutedNode) { + deviationPoint = this; + CreateChildren(); } else { - // not the node we're looking for. Copy as-is and continue finding node - m_hypo = &origNode.GetHypothesis(); - NodeChildren::const_iterator iter; - assert(m_edge.empty()); - m_edge.reserve(origNode.m_edge.size()); - for (iter = origNode.m_edge.begin(); iter != origNode.m_edge.end(); ++iter) { - const ChartTrellisNode &prevNode = **iter; - ChartTrellisNode *newPrevNode = new ChartTrellisNode(prevNode, soughtNode, replacementHypo, scoreChange, nodeChanged); - m_edge.push_back(newPrevNode); - } + CreateChildren(root, substitutedNode, replacementHypo, deviationPoint); } - - assert(m_hypo); } ChartTrellisNode::~ChartTrellisNode() { - RemoveAllInColl(m_edge); + RemoveAllInColl(m_children); } Phrase ChartTrellisNode::GetOutputPhrase() const @@ -102,13 +78,13 @@ Phrase ChartTrellisNode::GetOutputPhrase() const // exactly like same fn in hypothesis, but use trellis nodes instead of prevHypos pointer Phrase ret(Output, ARRAY_SIZE_INCR); - const ChartTranslationOption &transOpt = m_hypo->GetTranslationOption(); - - VERBOSE(3, "Trans Opt:" << transOpt.GetDottedRule() << ": " << m_hypo->GetCurrTargetPhrase().GetTargetLHS() << "->" << m_hypo->GetCurrTargetPhrase() << std::endl); + const ChartTranslationOption &transOpt = m_hypo.GetTranslationOption(); - const Phrase &currTargetPhrase = m_hypo->GetCurrTargetPhrase(); + VERBOSE(3, "Trans Opt:" << transOpt.GetDottedRule() << ": " << m_hypo.GetCurrTargetPhrase().GetTargetLHS() << "->" << m_hypo.GetCurrTargetPhrase() << std::endl); + + const Phrase &currTargetPhrase = m_hypo.GetCurrTargetPhrase(); const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = - m_hypo->GetCurrTargetPhrase().GetAlignmentInfo().GetNonTermIndexMap(); + m_hypo.GetCurrTargetPhrase().GetAlignmentInfo().GetNonTermIndexMap(); for (size_t pos = 0; pos < currTargetPhrase.GetSize(); ++pos) { const Word &word = currTargetPhrase.GetWord(pos); if (word.IsNonTerminal()) { @@ -125,17 +101,33 @@ Phrase ChartTrellisNode::GetOutputPhrase() const return ret; } -std::ostream& operator<<(std::ostream &out, const ChartTrellisNode &node) +void ChartTrellisNode::CreateChildren() { - out << "* " << node.GetHypothesis() << endl; - - ChartTrellisNode::NodeChildren::const_iterator iter; - for (iter = node.GetChildren().begin(); iter != node.GetChildren().end(); ++iter) { - out << **iter; + assert(m_children.empty()); + const std::vector<const ChartHypothesis*> &prevHypos = m_hypo.GetPrevHypos(); + m_children.reserve(prevHypos.size()); + for (size_t ind = 0; ind < prevHypos.size(); ++ind) { + const ChartHypothesis *prevHypo = prevHypos[ind]; + ChartTrellisNode *child = new ChartTrellisNode(*prevHypo); + m_children.push_back(child); } - - return out; } +void ChartTrellisNode::CreateChildren(const ChartTrellisNode &rootNode, + const ChartTrellisNode &substitutedNode, + const ChartHypothesis &replacementHypo, + ChartTrellisNode *&deviationPoint) +{ + assert(m_children.empty()); + const NodeChildren &children = rootNode.GetChildren(); + m_children.reserve(children.size()); + for (size_t ind = 0; ind < children.size(); ++ind) { + const ChartTrellisNode *origChild = children[ind]; + ChartTrellisNode *child = new ChartTrellisNode(*origChild, substitutedNode, + replacementHypo, + deviationPoint); + m_children.push_back(child); + } } +} |