diff options
author | Lane Schwartz <dowobeha@gmail.com> | 2013-02-22 21:24:35 +0400 |
---|---|---|
committer | Lane Schwartz <dowobeha@gmail.com> | 2013-02-23 01:28:48 +0400 |
commit | 04f107fbb02442638928c190dd3fa2f13225d570 (patch) | |
tree | 577dec3cc61465632d5a7ee081ae000209dda239 /moses | |
parent | e7563111de02c5e39ff297e58641b612ff02fb4b (diff) |
Add flag to output search graph in Kenneth's hypergraph format.
Diffstat (limited to 'moses')
-rw-r--r-- | moses/Manager.cpp | 204 | ||||
-rw-r--r-- | moses/Manager.h | 11 | ||||
-rw-r--r-- | moses/Parameter.cpp | 1 | ||||
-rw-r--r-- | moses/StaticData.cpp | 7 | ||||
-rw-r--r-- | moses/StaticData.h | 4 |
5 files changed, 227 insertions, 0 deletions
diff --git a/moses/Manager.cpp b/moses/Manager.cpp index ce214c414..21f116f42 100644 --- a/moses/Manager.cpp +++ b/moses/Manager.cpp @@ -663,6 +663,39 @@ void Manager::OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) } +void Manager::OutputFeatureWeightsForHypergraph(std::ostream &outputSearchGraphStream) const +{ + outputSearchGraphStream.setf(std::ios::fixed); + outputSearchGraphStream.precision(6); + + const StaticData& staticData = StaticData::Instance(); + const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT); + const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions(); + const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions(); + size_t featureIndex = 1; + for (size_t i = 0; i < sff.size(); ++i) { + featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, sff[i], outputSearchGraphStream); + } + for (size_t i = 0; i < slf.size(); ++i) { + if (slf[i]->GetScoreProducerWeightShortName() != "u" && + slf[i]->GetScoreProducerWeightShortName() != "tm" && + slf[i]->GetScoreProducerWeightShortName() != "I" && + slf[i]->GetScoreProducerWeightShortName() != "g") + { + featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, slf[i], outputSearchGraphStream); + } + } + const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries(); + for( size_t i=0; i<pds.size(); i++ ) { + featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, pds[i], outputSearchGraphStream); + } + const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries(); + for( size_t i=0; i<gds.size(); i++ ) { + featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, gds[i], outputSearchGraphStream); + } + +} + void Manager::OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const { @@ -702,6 +735,39 @@ void Manager::OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std: } +void Manager::OutputFeatureValuesForHypergraph(const Hypothesis* hypo, std::ostream &outputSearchGraphStream) const +{ + outputSearchGraphStream.setf(std::ios::fixed); + outputSearchGraphStream.precision(6); + + const StaticData& staticData = StaticData::Instance(); + const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT); + const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions(); + const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions(); + size_t featureIndex = 1; + for (size_t i = 0; i < sff.size(); ++i) { + featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, sff[i], outputSearchGraphStream); + } + for (size_t i = 0; i < slf.size(); ++i) { + if (slf[i]->GetScoreProducerWeightShortName() != "u" && + slf[i]->GetScoreProducerWeightShortName() != "tm" && + slf[i]->GetScoreProducerWeightShortName() != "I" && + slf[i]->GetScoreProducerWeightShortName() != "g") + { + featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, slf[i], outputSearchGraphStream); + } + } + const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries(); + for( size_t i=0; i<pds.size(); i++ ) { + featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, pds[i], outputSearchGraphStream); + } + const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries(); + for( size_t i=0; i<gds.size(); i++ ) { + featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, gds[i], outputSearchGraphStream); + } + +} + size_t Manager::OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const { @@ -722,6 +788,30 @@ size_t Manager::OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* } } +size_t Manager::OutputFeatureWeightsForHypergraph(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const +{ + size_t numScoreComps = ff->GetNumScoreComponents(); + if (numScoreComps != ScoreProducer::unlimited) { + vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff); + if (numScoreComps > 1) { + for (size_t i = 0; i < numScoreComps; ++i) { + outputSearchGraphStream << ff->GetScoreProducerWeightShortName() + << i + << "=" << values[i] << endl; + } + } else { + outputSearchGraphStream << ff->GetScoreProducerWeightShortName() + << "=" << values[0] << endl; + } + return index+numScoreComps; + } else { + cerr << "Sparse features are not yet supported when outputting hypergraph format" << endl; + assert(false); + return 0; + } +} + + size_t Manager::OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const { @@ -764,6 +854,120 @@ size_t Manager::OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypoth // } } +size_t Manager::OutputFeatureValuesForHypergraph(size_t index, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const +{ + + const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown(); + + vector<float> featureValues = scoreCollection.GetScoresForProducer(ff); + size_t numScoreComps = featureValues.size(); + + if (numScoreComps > 1) { + for (size_t i = 0; i < numScoreComps; ++i) { + outputSearchGraphStream << ff->GetScoreProducerWeightShortName() << i << "=" << featureValues[i] << " "; + } + } else { + outputSearchGraphStream << ff->GetScoreProducerWeightShortName() << "=" << featureValues[0] << " "; + } + + return index+numScoreComps; +} + +void OutputSearchNode(long translationId, std::ostream &outputSearchGraphStream, + const SearchGraphNode& searchNode); +/**! Output search graph in hypergraph format of Kenneth Heafield's lazy hypergraph decoder */ +void Manager::OutputSearchGraphAsHypergraph(long translationId, std::ostream &outputSearchGraphStream) const +{ + vector<SearchGraphNode> searchGraph; + GetSearchGraph(searchGraph); +outputSearchGraphStream << "searchGraph.size() == " << searchGraph.size() << endl; + // long numArcs = 0; + long numNodes = 0; + + map<int,int> nodes; + set<int> terminalNodes; + multimap<int,int> nodeToLines; + + // Unique start node + // nodes[0] = 0; + //numNodes += 1; + for (size_t arcNumber = 0, size=searchGraph.size(); arcNumber < size; ++arcNumber) { +OutputSearchNode(translationId,outputSearchGraphStream,searchGraph[arcNumber]); + // Record that this arc ends at this node + // numArcs += 1; + nodeToLines.insert(pair<int,int>(numNodes,arcNumber)); + + int hypothesisID = searchGraph[arcNumber].hypo->GetId(); + if (nodes.count(hypothesisID) == 0) { + + nodes[hypothesisID] = numNodes; + numNodes += 1; + + bool terminalNode = (searchGraph[arcNumber].forward == -1); + if (terminalNode) { + terminalNodes.insert(numNodes); + // numArcs += 1; // Final arc to end node, representing the end of the sentence </s> + } + } + + } + + // Unique end node + nodes[numNodes] = numNodes; + numNodes += 1; + + long numArcs = searchGraph.size() + terminalNodes.size(); + // Unique start node + // numNodes += 1; + + // Print number of nodes and arcs + outputSearchGraphStream << numNodes << " " << numArcs << "(" << searchGraph.size() << ", " << terminalNodes.size() << ")" << endl; + + // Print node and arc for beginning of sentence <s> + // outputSearchGraphStream << 1 << endl; + // outputSearchGraphStream << "<s> ||| " << endl; + + for (int nodeNumber=0; nodeNumber <= numNodes; nodeNumber+=1) { + + size_t count = nodeToLines.count(nodeNumber); + if (count > 0) { + outputSearchGraphStream << count << endl; + + pair<multimap<int,int>::iterator, multimap<int,int>::iterator> range = nodeToLines.equal_range(nodeNumber); + for (multimap<int,int>::iterator it=range.first; it!=range.second; ++it) { + int lineNumber = (*it).second; + const Hypothesis *thisHypo = searchGraph[lineNumber].hypo; + const Hypothesis *prevHypo = thisHypo->GetPrevHypo(); + if (prevHypo==NULL) { + outputSearchGraphStream << "<s> ||| " << endl; + } else { + int startNode = nodes[prevHypo->GetId()]; + + const TargetPhrase &targetPhrase = thisHypo->GetCurrTargetPhrase(); + int targetWordCount = targetPhrase.GetSize(); + + outputSearchGraphStream << "[" << startNode << "]"; + for (int targetWordIndex=0; targetWordIndex<targetWordCount; targetWordIndex+=1) { + outputSearchGraphStream << " " << targetPhrase.GetWord(targetWordIndex); + } + outputSearchGraphStream << " ||| "; + OutputFeatureValuesForHypergraph(thisHypo, outputSearchGraphStream); + outputSearchGraphStream << endl; + } + + } + } + } + + // Print node and arc(s) for end of sentence </s> + outputSearchGraphStream << terminalNodes.size() << endl; + for (set<int>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); ++it) { + outputSearchGraphStream << "[" << (*it) << "] </s> ||| " << endl; + } + +} + + /**! Output search graph in HTK standard lattice format (SLF) */ void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSearchGraphStream) const { diff --git a/moses/Manager.h b/moses/Manager.h index c5f54847b..d580674b4 100644 --- a/moses/Manager.h +++ b/moses/Manager.h @@ -94,10 +94,20 @@ class Manager void operator=(Manager const&); const TranslationSystem* m_system; private: + + // Helper functions to output search graph in HTK standard lattice format void OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) const; size_t OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const; void OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const; size_t OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const; + + // Helper functions to output search graph in the hypergraph format of Kenneth Heafield's lazy hypergraph decoder + void OutputFeatureWeightsForHypergraph(std::ostream &outputSearchGraphStream) const; + size_t OutputFeatureWeightsForHypergraph(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const; + void OutputFeatureValuesForHypergraph(const Hypothesis* hypo, std::ostream &outputSearchGraphStream) const; + size_t OutputFeatureValuesForHypergraph(size_t index, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const; + + protected: // data // InputType const& m_source; /**< source sentence to be translated */ @@ -143,6 +153,7 @@ public: void OutputSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const; void OutputSearchGraphAsSLF(long translationId, std::ostream &outputSearchGraphStream) const; + void OutputSearchGraphAsHypergraph(long translationId, std::ostream &outputSearchGraphStream) const; void GetSearchGraph(std::vector<SearchGraphNode>& searchGraph) const; const InputType& GetSource() const { return m_source; diff --git a/moses/Parameter.cpp b/moses/Parameter.cpp index 876cbd224..359174280 100644 --- a/moses/Parameter.cpp +++ b/moses/Parameter.cpp @@ -131,6 +131,7 @@ Parameter::Parameter() AddParam("output-search-graph-extended", "osgx", "Output connected hypotheses of search into specified filename, in extended format"); AddParam("unpruned-search-graph", "usg", "When outputting chart search graph, do not exclude dead ends. Note: stack pruning may have eliminated some hypotheses"); AddParam("output-search-graph-slf", "slf", "Output connected hypotheses of search into specified directory, one file per sentence, in HTK standard lattice format (SLF)"); + AddParam("output-search-graph-hypergraph", "Output connected hypotheses of search into specified directory, one file per sentence, in a hypergraph format (see Kenneth Heafield's lazy hypergraph decoder)"); AddParam("include-lhs-in-search-graph", "lhssg", "When outputting chart search graph, include the label of the LHS of the rule (useful when using syntax)"); #ifdef HAVE_PROTOBUF AddParam("output-search-graph-pb", "pb", "Write phrase lattice to protocol buffer objects in the specified path."); diff --git a/moses/StaticData.cpp b/moses/StaticData.cpp index 1d9d4907c..cf797582b 100644 --- a/moses/StaticData.cpp +++ b/moses/StaticData.cpp @@ -240,6 +240,13 @@ bool StaticData::LoadData(Parameter *parameter) } if (m_parameter->GetParam("output-search-graph-slf").size() > 0) { m_outputSearchGraphSLF = true; + } else { + m_outputSearchGraphSLF = false; + } + if (m_parameter->GetParam("output-search-graph-hypergraph").size() > 0) { + m_outputSearchGraphHypergraph = true; + } else { + m_outputSearchGraphHypergraph = false; } #ifdef HAVE_PROTOBUF if (m_parameter->GetParam("output-search-graph-pb").size() > 0) { diff --git a/moses/StaticData.h b/moses/StaticData.h index d644e59f7..8a9e65162 100644 --- a/moses/StaticData.h +++ b/moses/StaticData.h @@ -217,6 +217,7 @@ protected: bool m_outputSearchGraph; //! whether to output search graph bool m_outputSearchGraphExtended; //! ... in extended format bool m_outputSearchGraphSLF; //! whether to output search graph in HTK standard lattice format (SLF) + bool m_outputSearchGraphHypergraph; //! whether to output search graph in hypergraph #ifdef HAVE_PROTOBUF bool m_outputSearchGraphPB; //! whether to output search graph as a protobuf #endif @@ -635,6 +636,9 @@ public: bool GetOutputSearchGraphSLF() const { return m_outputSearchGraphSLF; } + bool GetOutputSearchGraphHypergraph() const { + return m_outputSearchGraphHypergraph; + } #ifdef HAVE_PROTOBUF bool GetOutputSearchGraphPB() const { return m_outputSearchGraphPB; |