diff options
author | Lane Schwartz <dowobeha@gmail.com> | 2013-02-20 20:03:23 +0400 |
---|---|---|
committer | Lane Schwartz <dowobeha@gmail.com> | 2013-02-23 01:28:47 +0400 |
commit | e7563111de02c5e39ff297e58641b612ff02fb4b (patch) | |
tree | 7814a5b856f423535ff1b0609e2240df6b0ca07c /moses | |
parent | e106e04dc3c3fe609f82780cbd8286d042a5e47d (diff) |
More work on outputting HTK lattice format
Diffstat (limited to 'moses')
-rw-r--r-- | moses/Manager.cpp | 149 | ||||
-rw-r--r-- | moses/Manager.h | 7 |
2 files changed, 150 insertions, 6 deletions
diff --git a/moses/Manager.cpp b/moses/Manager.cpp index 39eb7f917..ce214c414 100644 --- a/moses/Manager.cpp +++ b/moses/Manager.cpp @@ -53,12 +53,12 @@ using namespace std; namespace Moses { Manager::Manager(size_t lineNumber, InputType const& source, SearchAlgorithm searchAlgorithm, const TranslationSystem* system) - :m_lineNumber(lineNumber) - ,m_system(system) + :m_system(system) ,m_transOptColl(source.CreateTranslationOptionCollection(system)) ,m_search(Search::CreateSearch(*this, source, searchAlgorithm, *m_transOptColl)) ,interrupted_flag(0) ,m_hypoId(0) + ,m_lineNumber(lineNumber) ,m_source(source) { m_system->InitializeBeforeSentenceProcessing(source); @@ -630,6 +630,140 @@ void Manager::GetSearchGraph(vector<SearchGraphNode>& searchGraph) const } +void Manager::OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) const +{ + outputSearchGraphStream.setf(std::ios::fixed); + outputSearchGraphStream.precision(6); + + const StaticData& staticData = StaticData::Instance(); + const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT); + const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions(); + const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions(); + size_t featureIndex = 1; + for (size_t i = 0; i < sff.size(); ++i) { + featureIndex = OutputFeatureWeightsForSLF(featureIndex, sff[i], outputSearchGraphStream); + } + for (size_t i = 0; i < slf.size(); ++i) { + if (slf[i]->GetScoreProducerWeightShortName() != "u" && + slf[i]->GetScoreProducerWeightShortName() != "tm" && + slf[i]->GetScoreProducerWeightShortName() != "I" && + slf[i]->GetScoreProducerWeightShortName() != "g") + { + featureIndex = OutputFeatureWeightsForSLF(featureIndex, slf[i], outputSearchGraphStream); + } + } + const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries(); + for( size_t i=0; i<pds.size(); i++ ) { + featureIndex = OutputFeatureWeightsForSLF(featureIndex, pds[i], outputSearchGraphStream); + } + const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries(); + for( size_t i=0; i<gds.size(); i++ ) { + featureIndex = OutputFeatureWeightsForSLF(featureIndex, gds[i], outputSearchGraphStream); + } + +} + + +void Manager::OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const +{ + outputSearchGraphStream.setf(std::ios::fixed); + outputSearchGraphStream.precision(6); + + // outputSearchGraphStream << endl; + // outputSearchGraphStream << (*hypo) << endl; + // const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown(); + // outputSearchGraphStream << scoreCollection << endl; + + const StaticData& staticData = StaticData::Instance(); + const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT); + const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions(); + const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions(); + size_t featureIndex = 1; + for (size_t i = 0; i < sff.size(); ++i) { + featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, sff[i], outputSearchGraphStream); + } + for (size_t i = 0; i < slf.size(); ++i) { + if (slf[i]->GetScoreProducerWeightShortName() != "u" && + slf[i]->GetScoreProducerWeightShortName() != "tm" && + slf[i]->GetScoreProducerWeightShortName() != "I" && + slf[i]->GetScoreProducerWeightShortName() != "g") + { + featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, slf[i], outputSearchGraphStream); + } + } + const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries(); + for( size_t i=0; i<pds.size(); i++ ) { + featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, pds[i], outputSearchGraphStream); + } + const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries(); + for( size_t i=0; i<gds.size(); i++ ) { + featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, gds[i], outputSearchGraphStream); + } + +} + + +size_t Manager::OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const +{ + size_t numScoreComps = ff->GetNumScoreComponents(); + if (numScoreComps != ScoreProducer::unlimited) { + vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff); + for (size_t i = 0; i < numScoreComps; ++i) { + outputSearchGraphStream << "# " << ff->GetScoreProducerDescription() + << " " << ff->GetScoreProducerWeightShortName() + << " " << (i+1) << " of " << numScoreComps << endl + << "x" << (index+i) << "scale=" << values[i] << endl; + } + return index+numScoreComps; + } else { + cerr << "Sparse features are not supported when outputting HTK standard lattice format" << endl; + assert(false); + return 0; + } +} + +size_t Manager::OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const +{ + + // { const FeatureFunction* sp = ff; + // const FVector& m_scores = scoreCollection.GetScoresVector(); + // FVector& scores = const_cast<FVector&>(m_scores); + // std::string prefix = sp->GetScoreProducerDescription() + FName::SEP; + // // std::cout << "prefix==" << prefix << endl; + // // cout << "m_scores==" << m_scores << endl; + // // cout << "m_scores.size()==" << m_scores.size() << endl; + // // cout << "m_scores.coreSize()==" << m_scores.coreSize() << endl; + // // cout << "m_scores.cbegin() ?= m_scores.cend()\t" << (m_scores.cbegin() == m_scores.cend()) << endl; + + + // // for(FVector::FNVmap::const_iterator i = m_scores.cbegin(); i != m_scores.cend(); i++) { + // // std::cout<<prefix << "\t" << (i->first) << "\t" << (i->second) << std::endl; + // // } + // for(int i=0, n=v.size(); i<n; i+=1) { + // // outputSearchGraphStream << prefix << i << "==" << v[i] << std::endl; + + // } + // } + + // FVector featureValues = scoreCollection.GetVectorForProducer(ff); + // outputSearchGraphStream << featureValues << endl; + const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown(); + + vector<float> featureValues = scoreCollection.GetScoresForProducer(ff); + size_t numScoreComps = featureValues.size();//featureValues.coreSize(); + // if (numScoreComps != ScoreProducer::unlimited) { + // vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff); + for (size_t i = 0; i < numScoreComps; ++i) { + outputSearchGraphStream << "x" << (index+i) << "=" << ((zeros) ? 0.0 : featureValues[i]) << " "; + } + return index+numScoreComps; + // } else { + // cerr << "Sparse features are not supported when outputting HTK standard lattice format" << endl; + // assert(false); + // return 0; + // } +} + /**! Output search graph in HTK standard lattice format (SLF) */ void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSearchGraphStream) const { @@ -673,10 +807,12 @@ void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSea outputSearchGraphStream << "UTTERANCE=Sentence_" << translationId << endl; outputSearchGraphStream << "VERSION=1.1" << endl; - outputSearchGraphStream << "base=e" << endl; + outputSearchGraphStream << "base=2.71828182845905" << endl; outputSearchGraphStream << "NODES=" << (numNodes+1) << endl; outputSearchGraphStream << "LINKS=" << numArcs << endl; + OutputFeatureWeightsForSLF(outputSearchGraphStream); + // const vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder(); for (size_t arcNumber = 0, lineNumber = 0; lineNumber < searchGraph.size(); ++lineNumber) { @@ -709,8 +845,11 @@ void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSea } outputSearchGraphStream << " E=" << endNode - (x-1) //(startNode + targetWordIndex + 1) - << " W=" << targetPhrase.GetWord(targetWordIndex) - << endl; + << " W=" << targetPhrase.GetWord(targetWordIndex); + + OutputFeatureValuesForSLF(thisHypo, (targetWordIndex>0), outputSearchGraphStream); + + outputSearchGraphStream << endl; arcNumber += 1; } diff --git a/moses/Manager.h b/moses/Manager.h index 0ae7cd6f1..c5f54847b 100644 --- a/moses/Manager.h +++ b/moses/Manager.h @@ -93,6 +93,11 @@ class Manager Manager(Manager const&); void operator=(Manager const&); const TranslationSystem* m_system; +private: + void OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) const; + size_t OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const; + void OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const; + size_t OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const; protected: // data // InputType const& m_source; /**< source sentence to be translated */ @@ -103,6 +108,7 @@ protected: size_t interrupted_flag; std::auto_ptr<SentenceStats> m_sentenceStats; int m_hypoId; //used to number the hypos as they are created. + size_t m_lineNumber; void GetConnectedGraph( std::map< int, bool >* pConnected, @@ -113,7 +119,6 @@ protected: public: - size_t m_lineNumber; InputType const& m_source; /**< source sentence to be translated */ Manager(size_t lineNumber, InputType const& source, SearchAlgorithm searchAlgorithm, const TranslationSystem* system); ~Manager(); |