diff options
-rw-r--r-- | phrase-extract/ExtractedRule.h | 12 | ||||
-rw-r--r-- | phrase-extract/PhraseExtractionOptions.h | 10 | ||||
-rw-r--r-- | phrase-extract/RuleExtractionOptions.h | 5 | ||||
-rw-r--r-- | phrase-extract/extract-main.cpp | 96 | ||||
-rw-r--r-- | phrase-extract/extract-rules-main.cpp | 112 | ||||
-rwxr-xr-x | scripts/generic/extract-parallel.perl | 15 | ||||
-rwxr-xr-x | scripts/generic/score-parallel.perl | 84 | ||||
-rwxr-xr-x | scripts/training/flexibility_score.py | 171 | ||||
-rwxr-xr-x | scripts/training/train-model.perl | 12 |
9 files changed, 501 insertions, 16 deletions
diff --git a/phrase-extract/ExtractedRule.h b/phrase-extract/ExtractedRule.h index 8e3513582..a6cd5074d 100644 --- a/phrase-extract/ExtractedRule.h +++ b/phrase-extract/ExtractedRule.h @@ -41,6 +41,12 @@ public: std::string alignmentInv; std::string orientation; std::string orientationForward; + std::string sourceContextLeft; + std::string sourceContextRight; + std::string targetContextLeft; + std::string targetContextRight; + std::string sourceHoleString; + std::string targetHoleString; int startT; int endT; int startS; @@ -57,6 +63,12 @@ public: , alignmentInv() , orientation() , orientationForward() + , sourceContextLeft() + , sourceContextRight() + , targetContextLeft() + , targetContextRight() + , sourceHoleString() + , targetHoleString() , startT(sT) , endT(eT) , startS(sS) diff --git a/phrase-extract/PhraseExtractionOptions.h b/phrase-extract/PhraseExtractionOptions.h index 38063108b..a410473f4 100644 --- a/phrase-extract/PhraseExtractionOptions.h +++ b/phrase-extract/PhraseExtractionOptions.h @@ -48,6 +48,7 @@ private: bool onlyOutputSpanInfo; bool gzOutput; std::string instanceWeightsFile; //weights for each sentence + bool flexScoreFlag; public: std::vector<std::string> placeholders; @@ -65,7 +66,8 @@ public: translationFlag(true), includeSentenceIdFlag(false), onlyOutputSpanInfo(false), - gzOutput(false) {} + gzOutput(false), + flexScoreFlag(false) {} //functions for initialization of options void initAllModelsOutputFlag(const bool initallModelsOutputFlag) { @@ -107,6 +109,9 @@ public: void initInstanceWeightsFile(const char* initInstanceWeightsFile) { instanceWeightsFile = std::string(initInstanceWeightsFile); } + void initFlexScoreFlag(const bool initflexScoreFlag){ + flexScoreFlag=initflexScoreFlag; + } // functions for getting values bool isAllModelsOutputFlag() const { @@ -148,6 +153,9 @@ public: std::string getInstanceWeightsFile() const { return instanceWeightsFile; } + bool isFlexScoreFlag() const { + return flexScoreFlag; + } }; } diff --git a/phrase-extract/RuleExtractionOptions.h b/phrase-extract/RuleExtractionOptions.h index d437c679c..a9b0ce9e6 100644 --- a/phrase-extract/RuleExtractionOptions.h +++ b/phrase-extract/RuleExtractionOptions.h @@ -54,6 +54,7 @@ public: bool unpairedExtractFormat; bool conditionOnTargetLhs; bool boundaryRules; + bool flexScoreFlag; RuleExtractionOptions() : maxSpan(10) @@ -86,8 +87,8 @@ public: , gzOutput(false) , unpairedExtractFormat(false) , conditionOnTargetLhs(false) - , boundaryRules(false) { - } + , boundaryRules(false) + , flexScoreFlag(false) {} }; } diff --git a/phrase-extract/extract-main.cpp b/phrase-extract/extract-main.cpp index 4804f83b8..5bc018173 100644 --- a/phrase-extract/extract-main.cpp +++ b/phrase-extract/extract-main.cpp @@ -80,6 +80,8 @@ int sentenceOffset = 0; std::vector<std::string> Tokenize(const std::string& str, const std::string& delimiters = " \t"); +bool flexScoreFlag = false; + } namespace MosesTraining @@ -88,18 +90,22 @@ namespace MosesTraining class ExtractTask { public: - ExtractTask(size_t id, SentenceAlignment &sentence,PhraseExtractionOptions &initoptions, Moses::OutputFileStream &extractFile, Moses::OutputFileStream &extractFileInv,Moses::OutputFileStream &extractFileOrientation): + ExtractTask(size_t id, SentenceAlignment &sentence,PhraseExtractionOptions &initoptions, Moses::OutputFileStream &extractFile, Moses::OutputFileStream &extractFileInv,Moses::OutputFileStream &extractFileOrientation, Moses::OutputFileStream &extractFileContext, Moses::OutputFileStream &extractFileContextInv): m_sentence(sentence), m_options(initoptions), m_extractFile(extractFile), m_extractFileInv(extractFileInv), - m_extractFileOrientation(extractFileOrientation) {} + m_extractFileOrientation(extractFileOrientation), + m_extractFileContext(extractFileContext), + m_extractFileContextInv(extractFileContextInv) {} void Run(); private: vector< string > m_extractedPhrases; vector< string > m_extractedPhrasesInv; vector< string > m_extractedPhrasesOri; vector< string > m_extractedPhrasesSid; + vector< string > m_extractedPhrasesContext; + vector< string > m_extractedPhrasesContextInv; void extractBase(SentenceAlignment &); void extract(SentenceAlignment &); void addPhrase(SentenceAlignment &, int, int, int, int, string &); @@ -112,6 +118,8 @@ private: Moses::OutputFileStream &m_extractFile; Moses::OutputFileStream &m_extractFileInv; Moses::OutputFileStream &m_extractFileOrientation; + Moses::OutputFileStream &m_extractFileContext; + Moses::OutputFileStream &m_extractFileContextInv; }; } @@ -129,6 +137,8 @@ int main(int argc, char* argv[]) Moses::OutputFileStream extractFile; Moses::OutputFileStream extractFileInv; Moses::OutputFileStream extractFileOrientation; + Moses::OutputFileStream extractFileContext; + Moses::OutputFileStream extractFileContextInv; const char* const &fileNameE = argv[1]; const char* const &fileNameF = argv[2]; const char* const &fileNameA = argv[3]; @@ -140,6 +150,8 @@ int main(int argc, char* argv[]) options.initOnlyOutputSpanInfo(true); } else if (strcmp(argv[i],"orientation") == 0 || strcmp(argv[i],"--Orientation") == 0) { options.initOrientationFlag(true); + } else if (strcmp(argv[i],"--FlexibilityScore") == 0) { + options.initFlexScoreFlag(true); } else if (strcmp(argv[i],"--NoTTable") == 0) { options.initTranslationFlag(false); } else if (strcmp(argv[i], "--IncludeSentenceId") == 0) { @@ -254,8 +266,15 @@ int main(int argc, char* argv[]) string fileNameExtractOrientation = fileNameExtract + ".o" + (options.isGzOutput()?".gz":""); extractFileOrientation.Open(fileNameExtractOrientation.c_str()); } + if (options.isFlexScoreFlag()) { + string fileNameExtractContext = fileNameExtract + ".context" + (options.isGzOutput()?".gz":""); + string fileNameExtractContextInv = fileNameExtract + ".context.inv" + (options.isGzOutput()?".gz":""); + extractFileContext.Open(fileNameExtractContext.c_str()); + extractFileContextInv.Open(fileNameExtractContextInv.c_str()); + } int i = sentenceOffset; + while(true) { i++; if (i%10000 == 0) cerr << "." << flush; @@ -280,7 +299,7 @@ int main(int argc, char* argv[]) cout << "LOG: PHRASES_BEGIN:" << endl; } if (sentence.create( englishString, foreignString, alignmentString, weightString, i, false)) { - ExtractTask *task = new ExtractTask(i-1, sentence, options, extractFile , extractFileInv, extractFileOrientation); + ExtractTask *task = new ExtractTask(i-1, sentence, options, extractFile , extractFileInv, extractFileOrientation, extractFileContext, extractFileContextInv); task->Run(); delete task; @@ -302,6 +321,11 @@ int main(int argc, char* argv[]) if (options.isOrientationFlag()) { extractFileOrientation.Close(); } + + if (options.isFlexScoreFlag()) { + extractFileContext.Close(); + extractFileContextInv.Close(); + } } } @@ -315,6 +339,8 @@ void ExtractTask::Run() m_extractedPhrasesInv.clear(); m_extractedPhrasesOri.clear(); m_extractedPhrasesSid.clear(); + m_extractedPhrasesContext.clear(); + m_extractedPhrasesContextInv.clear(); } @@ -680,6 +706,8 @@ void ExtractTask::addPhrase( SentenceAlignment &sentence, int startE, int endE, ostringstream outextractstr; ostringstream outextractstrInv; ostringstream outextractstrOrientation; + ostringstream outextractstrContext; + ostringstream outextractstrContextInv; if (m_options.isOnlyOutputSpanInfo()) { cout << startF << " " << endF << " " << startE << " " << endE << endl; @@ -693,19 +721,25 @@ void ExtractTask::addPhrase( SentenceAlignment &sentence, int startE, int endE, for(int fi=startF; fi<=endF; fi++) { if (m_options.isTranslationFlag()) outextractstr << sentence.source[fi] << " "; if (m_options.isOrientationFlag()) outextractstrOrientation << sentence.source[fi] << " "; + if (m_options.isFlexScoreFlag()) outextractstrContext << sentence.source[fi] << " "; } if (m_options.isTranslationFlag()) outextractstr << "||| "; if (m_options.isOrientationFlag()) outextractstrOrientation << "||| "; + if (m_options.isFlexScoreFlag()) outextractstrContext << "||| "; // target for(int ei=startE; ei<=endE; ei++) { if (m_options.isTranslationFlag()) outextractstr << sentence.target[ei] << " "; if (m_options.isTranslationFlag()) outextractstrInv << sentence.target[ei] << " "; if (m_options.isOrientationFlag()) outextractstrOrientation << sentence.target[ei] << " "; + if (m_options.isFlexScoreFlag()) outextractstrContext << sentence.target[ei] << " "; + if (m_options.isFlexScoreFlag()) outextractstrContextInv << sentence.target[ei] << " "; } if (m_options.isTranslationFlag()) outextractstr << "|||"; if (m_options.isTranslationFlag()) outextractstrInv << "||| "; if (m_options.isOrientationFlag()) outextractstrOrientation << "||| "; + if (m_options.isFlexScoreFlag()) outextractstrContext << "||| "; + if (m_options.isFlexScoreFlag()) outextractstrContextInv << "||| "; // source (for inverse) @@ -714,6 +748,12 @@ void ExtractTask::addPhrase( SentenceAlignment &sentence, int startE, int endE, outextractstrInv << sentence.source[fi] << " "; outextractstrInv << "|||"; } + if (m_options.isFlexScoreFlag()) { + for(int fi=startF; fi<=endF; fi++) + outextractstrContextInv << sentence.source[fi] << " "; + outextractstrContextInv << "|||"; + } + // alignment if (m_options.isTranslationFlag()) { for(int ei=startE; ei<=endE; ei++) { @@ -743,6 +783,46 @@ void ExtractTask::addPhrase( SentenceAlignment &sentence, int startE, int endE, } + + // generate two lines for every extracted phrase: + // once with left, once with right context + if (m_options.isFlexScoreFlag()) { + + string strContext = outextractstrContext.str(); + string strContextInv = outextractstrContextInv.str(); + + ostringstream outextractstrContextRight(strContext, ostringstream::app); + ostringstream outextractstrContextRightInv(strContextInv, ostringstream::app); + + // write context to left + outextractstrContext << "< "; + if (startF == 0) outextractstrContext << "<s>"; + else outextractstrContext << sentence.source[startF-1]; + + outextractstrContextInv << " < "; + if (startE == 0) outextractstrContextInv << "<s>"; + else outextractstrContextInv << sentence.target[startE-1]; + + // write context to right + outextractstrContextRight << "> "; + if (endF+1 == sentence.source.size()) outextractstrContextRight << "<s>"; + else outextractstrContextRight << sentence.source[endF+1]; + + outextractstrContextRightInv << " > "; + if (endE+1 == sentence.target.size()) outextractstrContextRightInv << "<s>"; + else outextractstrContextRightInv << sentence.target[endE+1]; + + outextractstrContext << "\n"; + outextractstrContextInv << "\n"; + outextractstrContextRight << "\n"; + outextractstrContextRightInv << "\n"; + + m_extractedPhrasesContext.push_back(outextractstrContext.str()); + m_extractedPhrasesContextInv.push_back(outextractstrContextInv.str()); + m_extractedPhrasesContext.push_back(outextractstrContextRight.str()); + m_extractedPhrasesContextInv.push_back(outextractstrContextRightInv.str()); + } + if (m_options.isTranslationFlag()) outextractstr << "\n"; if (m_options.isTranslationFlag()) outextractstrInv << "\n"; if (m_options.isOrientationFlag()) outextractstrOrientation << "\n"; @@ -760,6 +840,8 @@ void ExtractTask::writePhrasesToFile() ostringstream outextractFile; ostringstream outextractFileInv; ostringstream outextractFileOrientation; + ostringstream outextractFileContext; + ostringstream outextractFileContextInv; for(vector<string>::const_iterator phrase=m_extractedPhrases.begin(); phrase!=m_extractedPhrases.end(); phrase++) { outextractFile<<phrase->data(); @@ -770,10 +852,18 @@ void ExtractTask::writePhrasesToFile() for(vector<string>::const_iterator phrase=m_extractedPhrasesOri.begin(); phrase!=m_extractedPhrasesOri.end(); phrase++) { outextractFileOrientation<<phrase->data(); } + for(vector<string>::const_iterator phrase=m_extractedPhrasesContext.begin();phrase!=m_extractedPhrasesContext.end();phrase++){ + outextractFileContext<<phrase->data(); + } + for(vector<string>::const_iterator phrase=m_extractedPhrasesContextInv.begin();phrase!=m_extractedPhrasesContextInv.end();phrase++){ + outextractFileContextInv<<phrase->data(); + } m_extractFile << outextractFile.str(); m_extractFileInv << outextractFileInv.str(); m_extractFileOrientation << outextractFileOrientation.str(); + m_extractFileContext << outextractFileContext.str(); + m_extractFileContextInv << outextractFileContextInv.str(); } // if proper conditioning, we need the number of times a source phrase occured diff --git a/phrase-extract/extract-rules-main.cpp b/phrase-extract/extract-rules-main.cpp index f8e315e2c..97a593085 100644 --- a/phrase-extract/extract-rules-main.cpp +++ b/phrase-extract/extract-rules-main.cpp @@ -62,6 +62,8 @@ private: const RuleExtractionOptions &m_options; Moses::OutputFileStream& m_extractFile; Moses::OutputFileStream& m_extractFileInv; + Moses::OutputFileStream& m_extractFileContext; + Moses::OutputFileStream& m_extractFileContextInv; vector< ExtractedRule > m_extractedRules; @@ -94,11 +96,13 @@ private: } public: - ExtractTask(SentenceAlignmentWithSyntax &sentence, const RuleExtractionOptions &options, Moses::OutputFileStream &extractFile, Moses::OutputFileStream &extractFileInv): + ExtractTask(SentenceAlignmentWithSyntax &sentence, const RuleExtractionOptions &options, Moses::OutputFileStream &extractFile, Moses::OutputFileStream &extractFileInv, Moses::OutputFileStream &extractFileContext, Moses::OutputFileStream &extractFileContextInv): m_sentence(sentence), m_options(options), m_extractFile(extractFile), - m_extractFileInv(extractFileInv) {} + m_extractFileInv(extractFileInv), + m_extractFileContext(extractFileContext), + m_extractFileContextInv(extractFileContextInv) {} void Run(); }; @@ -138,7 +142,8 @@ int main(int argc, char* argv[]) << " | --AllowOnlyUnalignedWords | --DisallowNonTermConsecTarget |--NonTermConsecSource | --NoNonTermFirstWord | --NoFractionalCounting" << " | --UnpairedExtractFormat" << " | --ConditionOnTargetLHS ]" - << " | --BoundaryRules[" << options.boundaryRules << "]"; + << " | --BoundaryRules[" << options.boundaryRules << "]" + << " | --FlexibilityScore\n"; exit(1); } @@ -263,6 +268,8 @@ int main(int argc, char* argv[]) options.unpairedExtractFormat = true; } else if (strcmp(argv[i],"--ConditionOnTargetLHS") == 0) { options.conditionOnTargetLhs = true; + } else if (strcmp(argv[i],"--FlexibilityScore") == 0) { + options.flexScoreFlag = true; } else if (strcmp(argv[i],"-threads") == 0 || strcmp(argv[i],"--threads") == 0 || strcmp(argv[i],"--Threads") == 0) { @@ -301,10 +308,20 @@ int main(int argc, char* argv[]) string fileNameExtractInv = fileNameExtract + ".inv" + (options.gzOutput?".gz":""); Moses::OutputFileStream extractFile; Moses::OutputFileStream extractFileInv; + Moses::OutputFileStream extractFileContext; + Moses::OutputFileStream extractFileContextInv; extractFile.Open((fileNameExtract + (options.gzOutput?".gz":"")).c_str()); if (!options.onlyDirectFlag) extractFileInv.Open(fileNameExtractInv.c_str()); + if (options.flexScoreFlag) { + string fileNameExtractContext = fileNameExtract + ".context" + (options.gzOutput?".gz":""); + extractFileContext.Open(fileNameExtractContext.c_str()); + if (!options.onlyDirectFlag) { + string fileNameExtractContextInv = fileNameExtract + ".context.inv" + (options.gzOutput?".gz":""); + extractFileContextInv.Open(fileNameExtractContextInv.c_str()); + } + } // stats on labels for glue grammar and unknown word label probabilities set< string > targetLabelCollection, sourceLabelCollection; @@ -339,7 +356,7 @@ int main(int argc, char* argv[]) if (options.unknownWordLabelFlag) { collectWordLabelCounts(sentence); } - ExtractTask *task = new ExtractTask(sentence, options, extractFile, extractFileInv); + ExtractTask *task = new ExtractTask(sentence, options, extractFile, extractFileInv, extractFileContext, extractFileContextInv); task->Run(); delete task; } @@ -355,6 +372,11 @@ int main(int argc, char* argv[]) if (!options.onlyDirectFlag) extractFileInv.Close(); } + if (options.flexScoreFlag) { + extractFileContext.Close(); + if (!options.onlyDirectFlag) extractFileContextInv.Close(); + } + if (options.glueGrammarFlag) writeGlueGrammar(fileNameGlueGrammar, options, targetLabelCollection, targetTopLabelCollection); @@ -698,6 +720,46 @@ void ExtractTask::saveHieroPhrase( int startT, int endT, int startS, int endS // alignment saveHieroAlignment(startT, endT, startS, endS, indexS, indexT, holeColl, rule); + // context (words to left and right) + if (m_options.flexScoreFlag) { + rule.sourceContextLeft = startS == 0 ? "<s>" : m_sentence.source[startS-1]; + rule.sourceContextRight = endS+1 == m_sentence.source.size() ? "<s>" : m_sentence.source[endS+1]; + rule.targetContextLeft = startT == 0 ? "<s>" : m_sentence.target[startT-1]; + rule.targetContextRight = endT+1 == m_sentence.target.size() ? "<s>" : m_sentence.target[endT+1]; + rule.sourceHoleString = ""; + rule.targetHoleString = ""; + + HoleList::const_iterator iterHole; + for (iterHole = holeColl.GetHoles().begin(); iterHole != holeColl.GetHoles().end(); ++iterHole) { + const Hole &hole = *iterHole; + rule.sourceHoleString += hole.GetLabel(0) + ": "; + + // rule starts with nonterminal; end of NT is considered left context + if (hole.GetStart(0) == startS) { + rule.sourceContextLeft = m_sentence.source[hole.GetEnd(0)]; + } + // rule ends with nonterminal; start of NT is considered right context + else if (hole.GetEnd(0) == endS) { + rule.sourceContextRight = m_sentence.source[hole.GetStart(0)]; + } + + if (hole.GetStart(1) == startT) { + rule.targetContextLeft = m_sentence.target[hole.GetEnd(1)]; + } + else if (hole.GetEnd(1) == endT) { + rule.targetContextRight = m_sentence.target[hole.GetStart(1)]; + } + + for (int i = hole.GetStart(0); i <= hole.GetEnd(0); ++i) { + rule.sourceHoleString += m_sentence.source[i] + " "; + } + rule.targetHoleString += hole.GetLabel(1) + ": "; + for (int i = hole.GetStart(1); i <= hole.GetEnd(1); ++i) { + rule.targetHoleString += m_sentence.target[i] + " "; + } + } + } + addRuleToCollection( rule ); } @@ -938,6 +1000,14 @@ void ExtractTask::addRule( int startT, int endT, int startS, int endS, int count } } + // context (words to left and right) + if (m_options.flexScoreFlag) { + rule.sourceContextLeft = startS == 0 ? "<s>" : m_sentence.source[startS-1]; + rule.sourceContextRight = endS+1 == m_sentence.source.size() ? "<s>" : m_sentence.source[endS+1]; + rule.targetContextLeft = startT == 0 ? "<s>" : m_sentence.target[startT-1]; + rule.targetContextRight = endT+1 == m_sentence.target.size() ? "<s>" : m_sentence.target[endT+1]; + } + rule.alignment.erase(rule.alignment.size()-1); if (!m_options.onlyDirectFlag) rule.alignmentInv.erase(rule.alignmentInv.size()-1); @@ -997,6 +1067,8 @@ void ExtractTask::writeRulesToFile() vector<ExtractedRule>::const_iterator rule; ostringstream out; ostringstream outInv; + ostringstream outContext; + ostringstream outContextInv; for(rule = m_extractedRules.begin(); rule != m_extractedRules.end(); rule++ ) { if (rule->count == 0) continue; @@ -1019,9 +1091,41 @@ void ExtractTask::writeRulesToFile() << rule->alignmentInv << " ||| " << rule->count << "\n"; } + + if (m_options.flexScoreFlag) { + for(int iContext=0;iContext<2;iContext++){ + outContext << rule->source << " ||| " + << rule->target << " ||| " + << rule->alignment << " ||| "; + iContext ? outContext << "< " << rule->sourceContextLeft << "\n" : outContext << "> " << rule->sourceContextRight << "\n"; + + if (!m_options.onlyDirectFlag) { + outContextInv << rule->target << " ||| " + << rule->source << " ||| " + << rule->alignmentInv << " ||| "; + iContext ? outContextInv << "< " << rule->targetContextLeft << "\n" : outContextInv << "> " << rule->targetContextRight << "\n"; + } + } + + if (rule->sourceHoleString != "") { + outContext << rule->source << " ||| " + << rule->target << " ||| " + << rule->alignment << " ||| v " + << rule->sourceHoleString << "\n"; + } + + if (!m_options.onlyDirectFlag and rule->targetHoleString != "") { + outContextInv << rule->target << " ||| " + << rule->source << " ||| " + << rule->alignmentInv << " ||| v " + << rule->targetHoleString << "\n"; + } + } } m_extractFile << out.str(); m_extractFileInv << outInv.str(); + m_extractFileContext << outContext.str(); + m_extractFileContextInv << outContextInv.str(); } void writeGlueGrammar( const string & fileName, RuleExtractionOptions &options, set< string > &targetLabelCollection, map< string, int > &targetTopLabelCollection ) diff --git a/scripts/generic/extract-parallel.perl b/scripts/generic/extract-parallel.perl index 253bd97b2..b663dcfe8 100755 --- a/scripts/generic/extract-parallel.perl +++ b/scripts/generic/extract-parallel.perl @@ -152,12 +152,17 @@ foreach (@children) { my $catCmd = "gunzip -c "; my $catInvCmd = $catCmd; my $catOCmd = $catCmd; +my $catContextCmd = $catCmd; +my $catContextInvCmd = $catCmd; + for (my $i = 0; $i < $numParallel; ++$i) { my $numStr = NumStr($i); $catCmd .= "$TMPDIR/extract.$numStr.gz "; $catInvCmd .= "$TMPDIR/extract.$numStr.inv.gz "; $catOCmd .= "$TMPDIR/extract.$numStr.o.gz "; + $catContextCmd .= "$TMPDIR/extract.$numStr.context "; + $catContextInvCmd .= "$TMPDIR/extract.$numStr.context.inv "; } if (defined($baselineExtract)) { my $sorted = -e "$baselineExtract.sorted.gz" ? ".sorted" : ""; @@ -169,6 +174,8 @@ if (defined($baselineExtract)) { $catCmd .= " | LC_ALL=C $sortCmd -T $TMPDIR 2>> /dev/stderr | gzip -c > $extract.sorted.gz 2>> /dev/stderr \n"; $catInvCmd .= " | LC_ALL=C $sortCmd -T $TMPDIR 2>> /dev/stderr | gzip -c > $extract.inv.sorted.gz 2>> /dev/stderr \n"; $catOCmd .= " | LC_ALL=C $sortCmd -T $TMPDIR 2>> /dev/stderr | gzip -c > $extract.o.sorted.gz 2>> /dev/stderr \n"; +$catContextCmd .= " | LC_ALL=C $sortCmd -T $TMPDIR 2>> /dev/stderr | uniq | gzip -c > $extract.context.sorted.gz 2>> /dev/stderr \n"; +$catContextInvCmd .= " | LC_ALL=C $sortCmd -T $TMPDIR 2>> /dev/stderr | uniq | gzip -c > $extract.context.inv.sorted.gz 2>> /dev/stderr \n"; @children = (); @@ -185,6 +192,14 @@ else { print STDERR "skipping extract, doing only extract.o\n"; } +if ($otherExtractArgs =~ /--FlexibilityScore/) { + $pid = RunFork($catContextCmd); + push(@children, $pid); + + $pid = RunFork($catContextInvCmd); + push(@children, $pid); + } + my $numStr = NumStr(0); if (-e "$TMPDIR/extract.$numStr.o.gz") { diff --git a/scripts/generic/score-parallel.perl b/scripts/generic/score-parallel.perl index 3f763e5d9..da37b1353 100755 --- a/scripts/generic/score-parallel.perl +++ b/scripts/generic/score-parallel.perl @@ -11,6 +11,7 @@ sub RunFork($); sub systemCheck($); sub GetSourcePhrase($); sub NumStr($); +sub CutContextFile($$$); #my $EXTRACT_SPLIT_LINES = 5000000; my $EXTRACT_SPLIT_LINES = 50000000; @@ -34,6 +35,13 @@ for (my $i = 6; $i < $#ARGV; ++$i) } #$scoreCmd $extractFile $lexFile $ptHalf $otherExtractArgs +my $FlexibilityScore = $otherExtractArgs =~ /--FlexibilityScore/; +my $FlexibilityCmd = $otherExtractArgs; +$otherExtractArgs =~ s/--FlexibilityScore=\S+//; # don't pass flexibility_score command to score program +if ($FlexibilityCmd =~ /--FlexibilityScore=(\S+)/) { + $FlexibilityCmd = $1; +} + my $doSort = $ARGV[$#ARGV]; # last arg my $TMPDIR=dirname($ptHalf) ."/tmp.$$"; @@ -41,10 +49,19 @@ mkdir $TMPDIR; my $cmd; +my $extractFileContext; +if ($FlexibilityScore) { + $extractFileContext = $extractFile; + $extractFileContext =~ s/extract./extract.context./; +} + my $fileCount = 0; if ($numParallel <= 1) { # don't do parallel. Just link the extract file into place $cmd = "ln -s $extractFile $TMPDIR/extract.0.gz"; + if ($FlexibilityScore) { + $cmd .= " && ln -s $extractFileContext $TMPDIR/extract.context.0.gz"; + } print STDERR "$cmd \n"; systemCheck($cmd); @@ -59,6 +76,17 @@ else open(IN, $extractFile) || die "can't open $extractFile"; } + my $lastlineContext; + if ($FlexibilityScore) { + $lastlineContext = ""; + if ($extractFileContext =~ /\.gz$/) { + open(IN_CONTEXT, "gunzip -c $extractFileContext |") || die "can't open pipe to $extractFileContext"; + } + else { + open(IN_CONTEXT, $extractFileContext) || die "can't open $extractFileContext"; + } + } + my $filePath = "$TMPDIR/extract.$fileCount.gz"; open (OUT, "| gzip -c > $filePath") or die "error starting gzip $!"; @@ -84,7 +112,10 @@ else else { # cut off, open next min-extract file & write to that instead close OUT; - + + if ($FlexibilityScore) { + $lastlineContext = CutContextFile($prevSourcePhrase, $fileCount, $lastlineContext); + } $prevSourcePhrase = ""; $lineCount = 0; ++$fileCount; @@ -101,6 +132,9 @@ else } close OUT; + if ($FlexibilityScore) { + $lastlineContext = CutContextFile($prevSourcePhrase, $fileCount, $lastlineContext); + } ++$fileCount; } @@ -121,8 +155,18 @@ for (my $i = 0; $i < $fileCount; ++$i) my $fileInd = $i % $numParallel; my $fh = $runFiles[$fileInd]; + my $cmd = "$scoreCmd $TMPDIR/extract.$i.gz $lexFile $TMPDIR/phrase-table.half.$numStr.gz $otherExtractArgs 2>> /dev/stderr \n"; print STDERR $cmd; + + if ($FlexibilityScore) { + $cmd .= "zcat $TMPDIR/phrase-table.half.$numStr.gz | $FlexibilityCmd $TMPDIR/extract.context.$i.gz"; + $cmd .= " --Inverse" if ($otherExtractArgs =~ /--Inverse/); + $cmd .= " --Hierarchical" if ($otherExtractArgs =~ /--Hierarchical/); + $cmd .= " | gzip -c > $TMPDIR/phrase-table.half.$numStr.flex.gz\n"; + $cmd .= "mv $TMPDIR/phrase-table.half.$numStr.flex.gz $TMPDIR/phrase-table.half.$numStr.gz\n"; + } + print $fh $cmd; } @@ -150,7 +194,7 @@ foreach (@children) { # merge & sort $cmd = "\n\nOH SHIT. This should have been filled in \n\n"; -if ($fileCount == 1 && !$doSort) +if ($fileCount == 1 && !$doSort && !$FlexibilityScore) { my $numStr = NumStr(0); $cmd = "mv $TMPDIR/phrase-table.half.$numStr.gz $ptHalf"; @@ -279,3 +323,39 @@ sub NumStr($) } +sub CutContextFile($$$) +{ + my($lastsourcePhrase, $fileCount, $lastline) = @_; + my $line; + my $sourcePhrase; + + my $filePath = "$TMPDIR/extract.context.$fileCount.gz"; + open (OUT_CONTEXT, "| gzip -c > $filePath") or die "error starting gzip $!"; + + if ($lastline ne "") { + print OUT_CONTEXT "$lastline\n"; + } + + #write all lines in context file until we meet last source phrase in extract file + while ($line=<IN_CONTEXT>) + { + chomp($line); + $sourcePhrase = GetSourcePhrase($line); + print OUT_CONTEXT "$line\n"; + if ($sourcePhrase eq $lastsourcePhrase) {last;} + } + + #write all lines in context file that correspond to last source phrase in extract file + while ($line=<IN_CONTEXT>) + { + chomp($line); + $sourcePhrase = GetSourcePhrase($line); + if ($sourcePhrase ne $lastsourcePhrase) {last;} + print OUT_CONTEXT "$line\n"; + } + + close(OUT_CONTEXT); + + return $line; + +} diff --git a/scripts/training/flexibility_score.py b/scripts/training/flexibility_score.py new file mode 100755 index 000000000..66f104605 --- /dev/null +++ b/scripts/training/flexibility_score.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# add flexibility scores to a phrase table half +# you usually don't have to call this script directly, but you can run train_model.perl with the option "--flexibility-score" (will only affect steps 5 and 6) +# usage: python flexibility_score.py extract.context(.inv).sorted [--Inverse] [--Hierarchical] < phrasetable > output_file +# author: Rico Sennrich + +from __future__ import division +from __future__ import unicode_literals + +import sys +import gzip +from collections import defaultdict + +class FlexScore: + + def __init__(self, inverted, hierarchical): + self.inverted = inverted + self.hierarchical = hierarchical + + + def store_pt(self,obj): + """store line in dictionary; if we work with inverted phrase table, swap the two phrases""" + src,target = obj[0],obj[1] + + if self.inverted: + src, target = target, src + + self.phrase_pairs[src][target] = obj + + + def update_contextcounts(self, obj): + """count the number of contexts a phrase pair occurs in""" + src,target = obj[0],obj[1] + self.context_counts[src][target] += 1 + if obj[-1].startswith(b'<'): + self.context_counts_l[src][target] += 1 + elif obj[-1].startswith(b'>'): + self.context_counts_r[src][target] += 1 + elif obj[-1].startswith(b'v'): + self.context_counts_d[src][target] += 1 + else: + sys.stderr.write(b'\nERROR in line: {0}\n'.format(b' ||| '.join(obj))) + sys.stderr.write(b'ERROR: expecting one of \'<, >, v\' as context marker in context extract file\n') + raise ValueError + + + def traverse_incrementally(self,phrasetable,flexfile): + """traverse phrase table and phrase extract file (with context information) incrementally + without storing all in memory.""" + + increment = b'' + old_increment = 1 + stack = ['']*2 + + # which phrase to use for sorting + sort_pt = 0 + if self.inverted: + sort_pt = 1 + + while old_increment != increment: + + old_increment = increment + + self.phrase_pairs = defaultdict(dict) + self.context_counts = defaultdict(lambda:defaultdict(int)) + self.context_counts_l = defaultdict(lambda:defaultdict(int)) + self.context_counts_r = defaultdict(lambda:defaultdict(int)) + self.context_counts_d = defaultdict(lambda:defaultdict(int)) + + if stack[0]: + self.store_pt(stack[0]) + stack[0] = b'' + + if stack[1]: + self.update_contextcounts(stack[1]) + stack[1] = b'' + + for line in phrasetable: + line = line.rstrip().split(b' ||| ') + if line[sort_pt] != increment: + increment = line[sort_pt] + stack[0] = line + break + else: + self.store_pt(line) + + for line in flexfile: + line = line.rstrip().split(b' ||| ') + if line[0] + b' |' <= old_increment + b' |': + self.update_contextcounts(line) + + else: + stack[1] = line + break + + yield 1 + + + def main(self,phrasetable,flexfile,output_object): + + i = 0 + sys.stderr.write('Incrementally loading phrase table and adding flexibility score...') + for block in self.traverse_incrementally(phrasetable,flexfile): + + self.flexprob_l = normalize(self.context_counts_l) + self.flexprob_r = normalize(self.context_counts_r) + self.flexprob_d = normalize(self.context_counts_d) + + for src in sorted(self.phrase_pairs, key = lambda x: x + b' |'): + for target in sorted(self.phrase_pairs[src], key = lambda x: x + b' |'): + + if not i % 1000000: + sys.stderr.write('.') + i += 1 + + outline = self.write_phrase_table(src,target) + output_object.write(outline) + sys.stderr.write('done\n') + + + def write_phrase_table(self,src,target): + + line = self.phrase_pairs[src][target] + flexscore_l = b"{0:.6g}".format(self.flexprob_l[src][target]) + flexscore_r = b"{0:.6g}".format(self.flexprob_r[src][target]) + line[2] += b' ' + flexscore_l + b' ' + flexscore_r + + if self.hierarchical: + try: + flexscore_d = b"{0:.6g}".format(self.flexprob_d[src][target]) + except KeyError: + flexscore_d = b"1" + line[2] += b' ' + flexscore_d + + return b' ||| '.join(line) + b'\n' + + + +def normalize(d): + + out_dict = defaultdict(dict) + + for src in d: + total = sum(d[src].values()) + + for target in d[src]: + out_dict[src][target] = d[src][target]/total + + return out_dict + + +if __name__ == '__main__': + + if len(sys.argv) < 1: + sys.stderr.write('Usage: python flexibility_score.py extract.context(.inv).sorted [--Inverse] [--Hierarchical] < phrasetable > output_file\n') + exit() + + flexfile = sys.argv[1] + if '--Inverse' in sys.argv: + inverted = True + else: + inverted = False + + if '--Hierarchical' in sys.argv: + hierarchical = True + else: + hierarchical = False + + FS = FlexScore(inverted, hierarchical) + FS.main(sys.stdin,gzip.open(flexfile,'r'),sys.stdout) diff --git a/scripts/training/train-model.perl b/scripts/training/train-model.perl index 7ba3d106a..2f0c4e822 100755 --- a/scripts/training/train-model.perl +++ b/scripts/training/train-model.perl @@ -39,10 +39,9 @@ my($_EXTERNAL_BINDIR, $_ROOT_DIR, $_CORPUS_DIR, $_GIZA_E2F, $_GIZA_F2E, $_MODEL_ $_CONTINUE,$_MAX_LEXICAL_REORDERING,$_DO_STEPS, @_ADDITIONAL_INI,$_ADDITIONAL_INI_FILE, @_BASELINE_ALIGNMENT_MODEL, $_BASELINE_EXTRACT, $_BASELINE_ALIGNMENT, - $_DICTIONARY, $_SPARSE_PHRASE_FEATURES, $_EPPEX, $_INSTANCE_WEIGHTS_FILE, $_LMODEL_OOV_FEATURE, $_NUM_LATTICE_FEATURES, $IGNORE); + $_DICTIONARY, $_SPARSE_PHRASE_FEATURES, $_EPPEX, $_INSTANCE_WEIGHTS_FILE, $_LMODEL_OOV_FEATURE, $_NUM_LATTICE_FEATURES, $IGNORE, $_FLEXIBILITY_SCORE); my $_BASELINE_CORPUS = ""; my $_CORES = 1; - my $debug = 0; # debug this script, do not delete any files in debug mode $_HELP = 1 @@ -138,6 +137,7 @@ $_HELP = 1 'instance-weights-file=s' => \$_INSTANCE_WEIGHTS_FILE, 'lmodel-oov-feature' => \$_LMODEL_OOV_FEATURE, 'num-lattice-features=i' => \$_NUM_LATTICE_FEATURES, + 'flexibility-score' => \$_FLEXIBILITY_SCORE, ); if ($_HELP) { @@ -323,6 +323,7 @@ my $PHRASE_SCORE = "$SCRIPTS_ROOTDIR/../bin/score"; $PHRASE_SCORE = "$SCRIPTS_ROOTDIR/generic/score-parallel.perl $_CORES \"$SORT_EXEC $__SORT_BUFFER_SIZE $__SORT_BATCH_SIZE $__SORT_COMPRESS $__SORT_PARALLEL\" $PHRASE_SCORE"; my $PHRASE_CONSOLIDATE = "$SCRIPTS_ROOTDIR/../bin/consolidate"; +my $FLEX_SCORER = "$SCRIPTS_ROOTDIR/training/flexibility_score.py"; # utilities my $ZCAT = "gzip -cd"; @@ -1436,6 +1437,7 @@ sub extract_phrase { $cmd .= " --GZOutput "; $cmd .= " --InstanceWeights $_INSTANCE_WEIGHTS_FILE " if defined $_INSTANCE_WEIGHTS_FILE; $cmd .= " --BaselineExtract $_BASELINE_EXTRACT" if defined($_BASELINE_EXTRACT) && $PHRASE_EXTRACT =~ /extract-parallel.perl/; + $cmd .= " --FlexibilityScore" if $_FLEXIBILITY_SCORE; map { die "File not found: $_" if ! -e $_ } ($alignment_file_e, $alignment_file_f, $alignment_file_a); print STDERR "$cmd\n"; @@ -1456,7 +1458,6 @@ sub extract_phrase { foreach my $f (@tempfiles) { unlink $f; } - } ### (6) PHRASE SCORING @@ -1554,7 +1555,7 @@ sub score_phrase_phrase_extract { my $inverse = ""; my $extract_filename = $extract_file; if ($direction eq "e2f") { - $inverse = " --Inverse"; + $inverse = "--Inverse"; $extract_filename = $extract_file.".inv"; } @@ -1575,6 +1576,7 @@ sub score_phrase_phrase_extract { $cmd .= " --ConditionOnTargetLHS" if $_ALT_DIRECT_RULE_SCORE_1; $cmd .= " $DOMAIN" if $DOMAIN; $cmd .= " $CORE_SCORE_OPTIONS" if defined($_SCORE_OPTIONS); + $cmd .= " --FlexibilityScore=$FLEX_SCORER" if $_FLEXIBILITY_SCORE; # sorting if ($direction eq "e2f" || $_ALT_DIRECT_RULE_SCORE_1 || $_ALT_DIRECT_RULE_SCORE_2) { @@ -1895,6 +1897,8 @@ sub create_ini { $basic_weight_count += 2**$count-1 if $method eq "Subset"; } $basic_weight_count++ if $_PCFG; + $basic_weight_count+=4 if $_FLEXIBILITY_SCORE; + $basic_weight_count+=2 if $_FLEXIBILITY_SCORE && $_HIERARCHICAL; # go over each table foreach my $f (split(/\+/,$___TRANSLATION_FACTORS)) { |