Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/training/phrase-extract/extract-rules.cpp')
-rw-r--r--scripts/training/phrase-extract/extract-rules.cpp32
1 files changed, 27 insertions, 5 deletions
diff --git a/scripts/training/phrase-extract/extract-rules.cpp b/scripts/training/phrase-extract/extract-rules.cpp
index 2cc9dc54d..a00667b82 100644
--- a/scripts/training/phrase-extract/extract-rules.cpp
+++ b/scripts/training/phrase-extract/extract-rules.cpp
@@ -90,7 +90,7 @@ void addHieroRule( int startT, int endT, int startS, int endS
void printHieroPhrase( int startT, int endT, int startS, int endS
, HoleCollection &holeColl, LabelIndex &labelIndex);
string printTargetHieroPhrase( int startT, int endT, int startS, int endS
- , WordIndex &indexT, HoleCollection &holeColl, const LabelIndex &labelIndex);
+ , WordIndex &indexT, HoleCollection &holeColl, const LabelIndex &labelIndex, double &logPCFGScore);
string printSourceHieroPhrase( int startT, int endT, int startS, int endS
, HoleCollection &holeColl, const LabelIndex &labelIndex);
void preprocessSourceHieroPhrase( int startT, int endT, int startS, int endS
@@ -257,6 +257,8 @@ int main(int argc, char* argv[])
// if an source phrase is paired with two target phrases, then count(t|s) = 0.5
else if (strcmp(argv[i],"--NoFractionalCounting") == 0) {
options.fractionalCounting = false;
+ } else if (strcmp(argv[i],"--PCFG") == 0) {
+ options.pcfgScore = true;
} else if (strcmp(argv[i],"--OutputNTLengths") == 0) {
options.outputNTLengths = true;
#ifdef WITH_THREADS
@@ -517,7 +519,7 @@ void ExtractTask::preprocessSourceHieroPhrase( int startT, int endT, int startS,
}
string ExtractTask::printTargetHieroPhrase( int startT, int endT, int startS, int endS
- , WordIndex &indexT, HoleCollection &holeColl, const LabelIndex &labelIndex)
+ , WordIndex &indexT, HoleCollection &holeColl, const LabelIndex &labelIndex, double &logPCFGScore)
{
HoleList::iterator iterHoleList = holeColl.GetHoles().begin();
assert(iterHoleList != holeColl.GetHoles().end());
@@ -545,6 +547,11 @@ string ExtractTask::printTargetHieroPhrase( int startT, int endT, int startS, in
out += "[" + sourceLabel + "][" + targetLabel + "] ";
+ if (m_options.pcfgScore) {
+ double score = m_sentence->targetTree.GetNodes(currPos,hole.GetEnd(1))[labelI]->GetPcfgScore();
+ logPCFGScore -= score;
+ }
+
currPos = hole.GetEnd(1);
hole.SetPos(outPos, 1);
++iterHoleList;
@@ -658,8 +665,16 @@ void ExtractTask::printHieroPhrase( int startT, int endT, int startS, int endS
preprocessSourceHieroPhrase(startT, endT, startS, endS, indexS, holeColl, labelIndex);
// target
- rule.target = printTargetHieroPhrase(startT, endT, startS, endS, indexT, holeColl, labelIndex)
+ if (m_options.pcfgScore) {
+ double logPCFGScore = m_sentence->targetTree.GetNodes(startT,endT)[labelIndex[0]]->GetPcfgScore();
+ rule.target = printTargetHieroPhrase(startT, endT, startS, endS, indexT, holeColl, labelIndex, logPCFGScore)
+ + " [" + targetLabel + "]";
+ rule.pcfgScore = std::exp(logPCFGScore);
+ } else {
+ double logPCFGScore = 0.0f;
+ rule.target = printTargetHieroPhrase(startT, endT, startS, endS, indexT, holeColl, labelIndex, logPCFGScore)
+ " [" + targetLabel + "]";
+ }
// source
// holeColl.SortSourceHoles();
@@ -877,6 +892,11 @@ void ExtractTask::addRule( int startT, int endT, int startS, int endS, RuleExist
rule.target += m_sentence->target[ti] + " ";
rule.target += "[" + targetLabel + "]";
+ if (m_options.pcfgScore) {
+ double logPCFGScore = m_sentence->targetTree.GetNodes(startT,endT)[0]->GetPcfgScore();
+ rule.pcfgScore = std::exp(logPCFGScore);
+ }
+
// alignment
for(int ti=startT; ti<=endT; ti++) {
for(unsigned int i=0; i<m_sentence->alignedToT[ti].size(); i++) {
@@ -957,11 +977,13 @@ void ExtractTask::writeRulesToFile()
out << rule->source << " ||| "
<< rule->target << " ||| "
<< rule->alignment << " ||| "
- << rule->count;
+ << rule->count << " ||| ";
if (m_options.outputNTLengths) {
- out << " ||| ";
rule->OutputNTLengths(out);
}
+ if (m_options.pcfgScore) {
+ out << " ||| " << rule->pcfgScore;
+ }
out << "\n";
if (!m_options.onlyDirectFlag) {