Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'moses/src/PDTAimp.h')
-rw-r--r--[-rwxr-xr-x]moses/src/PDTAimp.h41
1 files changed, 21 insertions, 20 deletions
diff --git a/moses/src/PDTAimp.h b/moses/src/PDTAimp.h
index 7bff39c07..0c2a984b0 100755..100644
--- a/moses/src/PDTAimp.h
+++ b/moses/src/PDTAimp.h
@@ -33,13 +33,11 @@ class PDTAimp
protected:
PDTAimp(PhraseDictionaryTreeAdaptor *p,unsigned nis)
- : m_languageModels(0),m_weightWP(0.0),m_dict(0),
+ : m_languageModels(0),m_dict(0),
m_obj(p),useCache(1),m_numInputScores(nis),totalE(0),distinctE(0) {}
-public:
- std::vector<float> m_weights;
+ public:
LMList const* m_languageModels;
- float m_weightWP;
std::vector<FactorType> m_input,m_output;
PhraseDictionaryTree *m_dict;
typedef std::vector<TargetPhraseCollection const*> vTPC;
@@ -139,6 +137,10 @@ public:
return 0;
}
+ const TranslationSystem& system = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT);
+ std::vector<float> weights = system.GetTranslationWeights(m_obj->GetDictIndex());
+ float weightWP = system.GetWeightWordPenalty();
+
std::vector<TargetPhrase> tCands;
tCands.reserve(cands.size());
std::vector<std::pair<float,size_t> > costs;
@@ -165,7 +167,7 @@ public:
*(cands[i].fnames[j]), cands[i].fvalues[j]);
}
}
- CreateTargetPhrase(targetPhrase,factorStrings,scoreVector, sparseFeatures, wacands[i],&src);
+ CreateTargetPhrase(targetPhrase,factorStrings,scoreVector, sparseFeatures, wacands[i], weights, weightWP, &src);
costs.push_back(std::make_pair(-targetPhrase.GetFutureScore(),tCands.size()));
tCands.push_back(targetPhrase);
}
@@ -188,9 +190,8 @@ public:
void Create(const std::vector<FactorType> &input
, const std::vector<FactorType> &output
, const std::string &filePath
- , const std::vector<float> &weight
+ , const std::vector<float> &weight
, const LMList &languageModels
- , float weightWP
) {
// set my members
@@ -198,8 +199,6 @@ public:
m_input=input;
m_output=output;
m_languageModels=&languageModels;
- m_weightWP=weightWP;
- m_weights=weight;
const StaticData &staticData = StaticData::Instance();
m_dict->UseWordAlignment(staticData.UseAlignmentInfo());
@@ -262,8 +261,10 @@ public:
Scores const& scoreVector,
const ScoreComponentCollection& sparseFeatures,
const std::string& alignmentString,
+ std::vector<float> &weights,
+ float weightWP,
Phrase const* srcPtr=0) const {
- CreateTargetPhrase(targetPhrase, factorStrings, scoreVector, sparseFeatures, srcPtr);
+ CreateTargetPhrase(targetPhrase, factorStrings, scoreVector, sparseFeatures, weights, weightWP, srcPtr);
targetPhrase.SetAlignmentInfo(alignmentString);
}
@@ -272,6 +273,8 @@ public:
StringTgtCand::Tokens const& factorStrings,
Scores const& scoreVector,
const ScoreComponentCollection& sparseFeatures,
+ std::vector<float> &weights,
+ float weightWP,
Phrase const* srcPtr=0) const {
FactorCollection &factorCollection = FactorCollection::Instance();
@@ -284,7 +287,7 @@ public:
}
}
- targetPhrase.SetScore(m_obj->GetFeature(), scoreVector, sparseFeatures, m_weights, m_weightWP, *m_languageModels);
+ targetPhrase.SetScore(m_obj->GetFeature(), scoreVector, sparseFeatures, weights, weightWP, *m_languageModels);
targetPhrase.SetSourcePhrase(*srcPtr);
}
@@ -366,6 +369,10 @@ public:
for(Position i=0 ; i < srcSize ; ++i)
stack.push_back(State(i, i, m_dict->GetRoot(), std::vector<float>(m_numInputScores,0.0)));
+ const TranslationSystem& system = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT);
+ std::vector<float> weightT = system.GetTranslationWeights(m_obj->GetDictIndex());
+ float weightWP = system.GetWeightWordPenalty();
+
while(!stack.empty()) {
State curr(stack.back());
stack.pop_back();
@@ -440,19 +447,13 @@ public:
//put in phrase table scores, logging as we insert
std::transform(tcands[i].scores.begin(),tcands[i].scores.end(),nscores.begin() + m_numInputScores,TransformScore);
-
- CHECK(nscores.size()==m_weights.size());
-
- const TranslationSystem& system = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT);
+ CHECK(nscores.size()==weightT.size());
+
//tally up
- std::vector<float> weightT = system.GetTranslationWeights();
- //float score=std::inner_product(nscores.begin(), nscores.end(), m_weights.begin(), 0.0f);
float score=std::inner_product(nscores.begin(), nscores.end(), weightT.begin(), 0.0f);
//count word penalty
- float weightWP = system.GetWeightWordPenalty();
- //score-=tcands[i].tokens.size() * m_weightWP;
score-=tcands[i].tokens.size() * weightWP;
std::pair<E2Costs::iterator,bool> p=e2costs.insert(std::make_pair(tcands[i].tokens,TScores()));
@@ -501,7 +502,7 @@ public:
for(E2Costs::const_iterator j=i->second.begin(); j!=i->second.end(); ++j) {
TScores const & scores=j->second;
TargetPhrase targetPhrase(Output);
- CreateTargetPhrase(targetPhrase,j->first,scores.trans,ScoreComponentCollection(),scores.src);
+ CreateTargetPhrase(targetPhrase,j->first,scores.trans,ScoreComponentCollection(),weightT,weightWP,scores.src);
costs.push_back(std::make_pair(-targetPhrase.GetFutureScore(),tCands.size()));
tCands.push_back(targetPhrase);
//std::cerr << i->first.first << "-" << i->first.second << ": " << targetPhrase << std::endl;