diff options
Diffstat (limited to 'moses/src/PDTAimp.h')
-rw-r--r--[-rwxr-xr-x] | moses/src/PDTAimp.h | 41 |
1 files changed, 21 insertions, 20 deletions
diff --git a/moses/src/PDTAimp.h b/moses/src/PDTAimp.h index 7bff39c07..0c2a984b0 100755..100644 --- a/moses/src/PDTAimp.h +++ b/moses/src/PDTAimp.h @@ -33,13 +33,11 @@ class PDTAimp protected: PDTAimp(PhraseDictionaryTreeAdaptor *p,unsigned nis) - : m_languageModels(0),m_weightWP(0.0),m_dict(0), + : m_languageModels(0),m_dict(0), m_obj(p),useCache(1),m_numInputScores(nis),totalE(0),distinctE(0) {} -public: - std::vector<float> m_weights; + public: LMList const* m_languageModels; - float m_weightWP; std::vector<FactorType> m_input,m_output; PhraseDictionaryTree *m_dict; typedef std::vector<TargetPhraseCollection const*> vTPC; @@ -139,6 +137,10 @@ public: return 0; } + const TranslationSystem& system = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT); + std::vector<float> weights = system.GetTranslationWeights(m_obj->GetDictIndex()); + float weightWP = system.GetWeightWordPenalty(); + std::vector<TargetPhrase> tCands; tCands.reserve(cands.size()); std::vector<std::pair<float,size_t> > costs; @@ -165,7 +167,7 @@ public: *(cands[i].fnames[j]), cands[i].fvalues[j]); } } - CreateTargetPhrase(targetPhrase,factorStrings,scoreVector, sparseFeatures, wacands[i],&src); + CreateTargetPhrase(targetPhrase,factorStrings,scoreVector, sparseFeatures, wacands[i], weights, weightWP, &src); costs.push_back(std::make_pair(-targetPhrase.GetFutureScore(),tCands.size())); tCands.push_back(targetPhrase); } @@ -188,9 +190,8 @@ public: void Create(const std::vector<FactorType> &input , const std::vector<FactorType> &output , const std::string &filePath - , const std::vector<float> &weight + , const std::vector<float> &weight , const LMList &languageModels - , float weightWP ) { // set my members @@ -198,8 +199,6 @@ public: m_input=input; m_output=output; m_languageModels=&languageModels; - m_weightWP=weightWP; - m_weights=weight; const StaticData &staticData = StaticData::Instance(); m_dict->UseWordAlignment(staticData.UseAlignmentInfo()); @@ -262,8 +261,10 @@ public: Scores const& scoreVector, const ScoreComponentCollection& sparseFeatures, const std::string& alignmentString, + std::vector<float> &weights, + float weightWP, Phrase const* srcPtr=0) const { - CreateTargetPhrase(targetPhrase, factorStrings, scoreVector, sparseFeatures, srcPtr); + CreateTargetPhrase(targetPhrase, factorStrings, scoreVector, sparseFeatures, weights, weightWP, srcPtr); targetPhrase.SetAlignmentInfo(alignmentString); } @@ -272,6 +273,8 @@ public: StringTgtCand::Tokens const& factorStrings, Scores const& scoreVector, const ScoreComponentCollection& sparseFeatures, + std::vector<float> &weights, + float weightWP, Phrase const* srcPtr=0) const { FactorCollection &factorCollection = FactorCollection::Instance(); @@ -284,7 +287,7 @@ public: } } - targetPhrase.SetScore(m_obj->GetFeature(), scoreVector, sparseFeatures, m_weights, m_weightWP, *m_languageModels); + targetPhrase.SetScore(m_obj->GetFeature(), scoreVector, sparseFeatures, weights, weightWP, *m_languageModels); targetPhrase.SetSourcePhrase(*srcPtr); } @@ -366,6 +369,10 @@ public: for(Position i=0 ; i < srcSize ; ++i) stack.push_back(State(i, i, m_dict->GetRoot(), std::vector<float>(m_numInputScores,0.0))); + const TranslationSystem& system = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT); + std::vector<float> weightT = system.GetTranslationWeights(m_obj->GetDictIndex()); + float weightWP = system.GetWeightWordPenalty(); + while(!stack.empty()) { State curr(stack.back()); stack.pop_back(); @@ -440,19 +447,13 @@ public: //put in phrase table scores, logging as we insert std::transform(tcands[i].scores.begin(),tcands[i].scores.end(),nscores.begin() + m_numInputScores,TransformScore); - - CHECK(nscores.size()==m_weights.size()); - - const TranslationSystem& system = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT); + CHECK(nscores.size()==weightT.size()); + //tally up - std::vector<float> weightT = system.GetTranslationWeights(); - //float score=std::inner_product(nscores.begin(), nscores.end(), m_weights.begin(), 0.0f); float score=std::inner_product(nscores.begin(), nscores.end(), weightT.begin(), 0.0f); //count word penalty - float weightWP = system.GetWeightWordPenalty(); - //score-=tcands[i].tokens.size() * m_weightWP; score-=tcands[i].tokens.size() * weightWP; std::pair<E2Costs::iterator,bool> p=e2costs.insert(std::make_pair(tcands[i].tokens,TScores())); @@ -501,7 +502,7 @@ public: for(E2Costs::const_iterator j=i->second.begin(); j!=i->second.end(); ++j) { TScores const & scores=j->second; TargetPhrase targetPhrase(Output); - CreateTargetPhrase(targetPhrase,j->first,scores.trans,ScoreComponentCollection(),scores.src); + CreateTargetPhrase(targetPhrase,j->first,scores.trans,ScoreComponentCollection(),weightT,weightWP,scores.src); costs.push_back(std::make_pair(-targetPhrase.GetFutureScore(),tCands.size())); tCands.push_back(targetPhrase); //std::cerr << i->first.first << "-" << i->first.second << ": " << targetPhrase << std::endl; |