Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPhil Williams <philip.williams@mac.com>2012-05-27 15:43:16 +0400
committerPhil Williams <philip.williams@mac.com>2012-05-27 15:43:16 +0400
commite3e62846bfe84d9a7edd78affd23f020d8ae2468 (patch)
treedb38abb1a8ee914158068423062b604f8a961d55 /scripts/training
parent82580280bc0b30607b00a55ffe0f22d5665269a3 (diff)
train-model.perl: add -alt-direct-rule-score-1 and
-alt-direct-rule-score-2 options, which use either p(RHS_t|RHS_s,LHS) or p(LHS,RHS_t|RHS_s), respectively, as a grammar rule's direct translation score.
Diffstat (limited to 'scripts/training')
-rw-r--r--scripts/training/phrase-extract/RuleExtractionOptions.h4
-rw-r--r--scripts/training/phrase-extract/extract-ghkm/ExtractGHKM.cpp3
-rw-r--r--scripts/training/phrase-extract/extract-ghkm/Options.h2
-rw-r--r--scripts/training/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp12
-rw-r--r--scripts/training/phrase-extract/extract-rules.cpp43
-rw-r--r--scripts/training/phrase-extract/score.cpp94
-rw-r--r--scripts/training/phrase-extract/score.h8
-rwxr-xr-xscripts/training/train-model.perl.missing_bin_dir9
8 files changed, 139 insertions, 36 deletions
diff --git a/scripts/training/phrase-extract/RuleExtractionOptions.h b/scripts/training/phrase-extract/RuleExtractionOptions.h
index f9123de86..272af2c76 100644
--- a/scripts/training/phrase-extract/RuleExtractionOptions.h
+++ b/scripts/training/phrase-extract/RuleExtractionOptions.h
@@ -48,6 +48,8 @@ public:
bool pcfgScore;
bool outputNTLengths;
bool gzOutput;
+ bool unpairedExtractFormat;
+ bool conditionOnTargetLhs;
RuleExtractionOptions()
: maxSpan(10)
@@ -78,6 +80,8 @@ public:
, pcfgScore(false)
, outputNTLengths(false)
, gzOutput(false)
+ , unpairedExtractFormat(false)
+ , conditionOnTargetLhs(false)
{}
};
diff --git a/scripts/training/phrase-extract/extract-ghkm/ExtractGHKM.cpp b/scripts/training/phrase-extract/extract-ghkm/ExtractGHKM.cpp
index 397ce1e3c..6b6fbb7eb 100644
--- a/scripts/training/phrase-extract/extract-ghkm/ExtractGHKM.cpp
+++ b/scripts/training/phrase-extract/extract-ghkm/ExtractGHKM.cpp
@@ -357,6 +357,9 @@ void ExtractGHKM::ProcessOptions(int argc, char *argv[],
if (vm.count("AllowUnary")) {
options.allowUnary = true;
}
+ if (vm.count("ConditionOnTargetLHS")) {
+ options.conditionOnTargetLhs = true;
+ }
if (vm.count("GZOutput")) {
options.gzOutput = true;
}
diff --git a/scripts/training/phrase-extract/extract-ghkm/Options.h b/scripts/training/phrase-extract/extract-ghkm/Options.h
index c4b57f311..362fc95d2 100644
--- a/scripts/training/phrase-extract/extract-ghkm/Options.h
+++ b/scripts/training/phrase-extract/extract-ghkm/Options.h
@@ -30,6 +30,7 @@ struct Options {
public:
Options()
: allowUnary(false)
+ , conditionOnTargetLhs(false)
, gzOutput(false)
, maxNodes(15)
, maxRuleDepth(3)
@@ -47,6 +48,7 @@ struct Options {
// All other options
bool allowUnary;
+ bool conditionOnTargetLhs;
std::string glueGrammarFile;
bool gzOutput;
int maxNodes;
diff --git a/scripts/training/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp b/scripts/training/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp
index d5d16b790..cd993d6e8 100644
--- a/scripts/training/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp
+++ b/scripts/training/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp
@@ -101,7 +101,11 @@ void ScfgRuleWriter::WriteStandardFormat(const ScfgRule &rule,
}
sourceSS << " ";
}
- WriteSymbol(rule.GetSourceLHS(), sourceSS);
+ if (m_options.conditionOnTargetLhs) {
+ WriteSymbol(rule.GetTargetLHS(), sourceSS);
+ } else {
+ WriteSymbol(rule.GetSourceLHS(), sourceSS);
+ }
// Write the target side of the rule to targetSS.
i = 0;
@@ -131,7 +135,11 @@ void ScfgRuleWriter::WriteUnpairedFormat(const ScfgRule &rule,
WriteSymbol(*p, sourceSS);
sourceSS << " ";
}
- WriteSymbol(rule.GetSourceLHS(), sourceSS);
+ if (m_options.conditionOnTargetLhs) {
+ WriteSymbol(rule.GetTargetLHS(), sourceSS);
+ } else {
+ WriteSymbol(rule.GetSourceLHS(), sourceSS);
+ }
// Write the target side of the rule to targetSS.
i = 0;
diff --git a/scripts/training/phrase-extract/extract-rules.cpp b/scripts/training/phrase-extract/extract-rules.cpp
index a00667b82..997038224 100644
--- a/scripts/training/phrase-extract/extract-rules.cpp
+++ b/scripts/training/phrase-extract/extract-rules.cpp
@@ -140,7 +140,9 @@ int main(int argc, char* argv[])
<< " | --MaxNonTerm[" << options.maxNonTerm << "]"
<< " | --MaxScope[" << options.maxScope << "]"
<< " | --SourceSyntax | --TargetSyntax"
- << " | --AllowOnlyUnalignedWords | --DisallowNonTermConsecTarget |--NonTermConsecSource | --NoNonTermFirstWord | --NoFractionalCounting ]\n";
+ << " | --AllowOnlyUnalignedWords | --DisallowNonTermConsecTarget |--NonTermConsecSource | --NoNonTermFirstWord | --NoFractionalCounting"
+ << " | --UnpairedExtractFormat"
+ << " | --ConditionOnTargetLHS ]\n";
exit(1);
}
char* &fileNameT = argv[1];
@@ -261,6 +263,10 @@ int main(int argc, char* argv[])
options.pcfgScore = true;
} else if (strcmp(argv[i],"--OutputNTLengths") == 0) {
options.outputNTLengths = true;
+ } else if (strcmp(argv[i],"--UnpairedExtractFormat") == 0) {
+ options.unpairedExtractFormat = true;
+ } else if (strcmp(argv[i],"--ConditionOnTargetLHS") == 0) {
+ options.conditionOnTargetLhs = true;
#ifdef WITH_THREADS
} else if (strcmp(argv[i],"-threads") == 0 ||
strcmp(argv[i],"--threads") == 0 ||
@@ -545,7 +551,11 @@ string ExtractTask::printTargetHieroPhrase( int startT, int endT, int startS, in
m_sentence->targetTree.GetNodes(currPos,hole.GetEnd(1))[ labelI ]->GetLabel() : "X";
hole.SetLabel(targetLabel, 1);
- out += "[" + sourceLabel + "][" + targetLabel + "] ";
+ if (m_options.unpairedExtractFormat) {
+ out += "[" + targetLabel + "] ";
+ } else {
+ out += "[" + sourceLabel + "][" + targetLabel + "] ";
+ }
if (m_options.pcfgScore) {
double score = m_sentence->targetTree.GetNodes(currPos,hole.GetEnd(1))[labelI]->GetPcfgScore();
@@ -591,7 +601,11 @@ string ExtractTask::printSourceHieroPhrase( int startT, int endT, int startS, in
assert(targetLabel != "");
const string &sourceLabel = hole.GetLabel(0);
- out += "[" + sourceLabel + "][" + targetLabel + "] ";
+ if (m_options.unpairedExtractFormat) {
+ out += "[" + sourceLabel + "] ";
+ } else {
+ out += "[" + sourceLabel + "][" + targetLabel + "] ";
+ }
currPos = hole.GetEnd(0);
hole.SetPos(outPos, 0);
@@ -659,7 +673,6 @@ void ExtractTask::printHieroPhrase( int startT, int endT, int startS, int endS
m_sentence->targetTree.GetNodes(startT,endT)[ labelIndex[0] ]->GetLabel() : "X";
string sourceLabel = m_options.sourceSyntax ?
m_sentence->sourceTree.GetNodes(startS,endS)[ labelIndex[1] ]->GetLabel() : "X";
- //string sourceLabel = "X";
// create non-terms on the source side
preprocessSourceHieroPhrase(startT, endT, startS, endS, indexS, holeColl, labelIndex);
@@ -677,9 +690,12 @@ void ExtractTask::printHieroPhrase( int startT, int endT, int startS, int endS
}
// source
- // holeColl.SortSourceHoles();
- rule.source = printSourceHieroPhrase(startT, endT, startS, endS, holeColl, labelIndex)
- + " [" + sourceLabel + "]";
+ rule.source = printSourceHieroPhrase(startT, endT, startS, endS, holeColl, labelIndex);
+ if (m_options.conditionOnTargetLhs) {
+ rule.source += " [" + targetLabel + "]";
+ } else {
+ rule.source += " [" + sourceLabel + "]";
+ }
// alignment
printHieroAlignment(startT, endT, startS, endS, indexS, indexT, holeColl, rule);
@@ -875,10 +891,15 @@ void ExtractTask::addRule( int startT, int endT, int startS, int endS, RuleExist
// phrase labels
string targetLabel,sourceLabel;
- sourceLabel = m_options.sourceSyntax ?
- m_sentence->sourceTree.GetNodes(startS,endS)[0]->GetLabel() : "X";
- targetLabel = m_options.targetSyntax ?
- m_sentence->targetTree.GetNodes(startT,endT)[0]->GetLabel() : "X";
+ if (m_options.targetSyntax && m_options.conditionOnTargetLhs) {
+ sourceLabel = targetLabel = m_sentence->targetTree.GetNodes(startT,endT)[0]->GetLabel();
+ }
+ else {
+ sourceLabel = m_options.sourceSyntax ?
+ m_sentence->sourceTree.GetNodes(startS,endS)[0]->GetLabel() : "X";
+ targetLabel = m_options.targetSyntax ?
+ m_sentence->targetTree.GetNodes(startT,endT)[0]->GetLabel() : "X";
+ }
// source
rule.source = "";
diff --git a/scripts/training/phrase-extract/score.cpp b/scripts/training/phrase-extract/score.cpp
index c5fb0b99f..5e0ade627 100644
--- a/scripts/training/phrase-extract/score.cpp
+++ b/scripts/training/phrase-extract/score.cpp
@@ -69,10 +69,15 @@ double computeUnalignedFWPenalty( const PHRASE &, const PHRASE &, PhraseAlignmen
void calcNTLengthProb(const vector< PhraseAlignment* > &phrasePairs
, map<size_t, map<size_t, float> > &sourceProb
, map<size_t, map<size_t, float> > &targetProb);
+void printSourcePhrase(const PHRASE &, const PHRASE &, const PhraseAlignment &, ostream &);
+void printTargetPhrase(const PHRASE &, const PHRASE &, const PhraseAlignment &, ostream &);
+
LexicalTable lexTable;
bool inverseFlag = false;
bool hierarchicalFlag = false;
bool pcfgFlag = false;
+bool unpairedExtractFormatFlag = false;
+bool conditionOnTargetLhsFlag = false;
bool wordAlignmentFlag = false;
bool goodTuringFlag = false;
bool kneserNeyFlag = false;
@@ -93,7 +98,7 @@ int main(int argc, char* argv[])
<< "scoring methods for extracted rules\n";
if (argc < 4) {
- cerr << "syntax: score extract lex phrase-table [--Inverse] [--Hierarchical] [--LogProb] [--NegLogProb] [--NoLex] [--GoodTuring] [--KneserNey] [--WordAlignment] [--UnalignedPenalty] [--UnalignedFunctionWordPenalty function-word-file] [--MinCountHierarchical count] [--OutputNTLengths] \n";
+ cerr << "syntax: score extract lex phrase-table [--Inverse] [--Hierarchical] [--LogProb] [--NegLogProb] [--NoLex] [--GoodTuring] [--KneserNey] [--WordAlignment] [--UnalignedPenalty] [--UnalignedFunctionWordPenalty function-word-file] [--MinCountHierarchical count] [--OutputNTLengths] [--PCFG] [--UnpairedExtractFormat] [--ConditionOnTargetLHS]\n";
exit(1);
}
char* fileNameExtract = argv[1];
@@ -112,6 +117,12 @@ int main(int argc, char* argv[])
} else if (strcmp(argv[i],"--PCFG") == 0) {
pcfgFlag = true;
cerr << "including PCFG scores\n";
+ } else if (strcmp(argv[i],"--UnpairedExtractFormat") == 0) {
+ unpairedExtractFormatFlag = true;
+ cerr << "processing unpaired extract format\n";
+ } else if (strcmp(argv[i],"--ConditionOnTargetLHS") == 0) {
+ conditionOnTargetLhsFlag = true;
+ cerr << "processing unpaired extract format\n";
} else if (strcmp(argv[i],"--WordAlignment") == 0) {
wordAlignmentFlag = true;
cerr << "outputing word alignment" << endl;
@@ -470,27 +481,18 @@ void outputPhrasePair(const PhraseAlignmentCollection &phrasePair, float totalCo
// source phrase (unless inverse)
if (! inverseFlag) {
- for(size_t j=0; j<phraseS.size(); j++) {
- phraseTableFile << vcbS.getWord( phraseS[j] );
- phraseTableFile << " ";
- }
- phraseTableFile << "||| ";
+ printSourcePhrase(phraseS, phraseT, *bestAlignment, phraseTableFile);
+ phraseTableFile << " ||| ";
}
// target phrase
- for(size_t j=0; j<phraseT.size(); j++) {
- phraseTableFile << vcbT.getWord( phraseT[j] );
- phraseTableFile << " ";
- }
- phraseTableFile << "||| ";
+ printTargetPhrase(phraseS, phraseT, *bestAlignment, phraseTableFile);
+ phraseTableFile << " ||| ";
// source phrase (if inverse)
if (inverseFlag) {
- for(size_t j=0; j<phraseS.size(); j++) {
- phraseTableFile << vcbS.getWord( phraseS[j] );
- phraseTableFile << " ";
- }
- phraseTableFile << "||| ";
+ printSourcePhrase(phraseS, phraseT, *bestAlignment, phraseTableFile);
+ phraseTableFile << " ||| ";
}
// lexical translation probability
@@ -683,6 +685,66 @@ void LexicalTable::load( char *fileName )
cerr << endl;
}
+void printSourcePhrase(const PHRASE &phraseS, const PHRASE &phraseT,
+ const PhraseAlignment &bestAlignment, ostream &out)
+{
+ // output source symbols, except root, in rule table format
+ for (std::size_t i = 0; i < phraseS.size()-1; ++i) {
+ const std::string &word = vcbS.getWord(phraseS[i]);
+ if (!unpairedExtractFormatFlag || !isNonTerminal(word)) {
+ out << word << " ";
+ continue;
+ }
+ // get corresponding target non-terminal and output pair
+ std::set<std::size_t> alignmentPoints = bestAlignment.alignedToS[i];
+ assert(alignmentPoints.size() == 1);
+ int j = *(alignmentPoints.begin());
+ if (inverseFlag) {
+ out << vcbT.getWord(phraseT[j]) << word << " ";
+ } else {
+ out << word << vcbT.getWord(phraseT[j]) << " ";
+ }
+ }
+ // output source root symbol
+ if (conditionOnTargetLhsFlag && !inverseFlag) {
+ out << "[X]";
+ } else {
+ out << vcbS.getWord(phraseS.back());
+ }
+}
+
+void printTargetPhrase(const PHRASE &phraseS, const PHRASE &phraseT,
+ const PhraseAlignment &bestAlignment, ostream &out)
+{
+ // output target symbols, except root, in rule table format
+ for (std::size_t i = 0; i < phraseT.size()-1; ++i) {
+ const std::string &word = vcbT.getWord(phraseT[i]);
+ if (!unpairedExtractFormatFlag || !isNonTerminal(word)) {
+ out << word << " ";
+ continue;
+ }
+ // get corresponding source non-terminal and output pair
+ std::set<std::size_t> alignmentPoints = bestAlignment.alignedToT[i];
+ assert(alignmentPoints.size() == 1);
+ int j = *(alignmentPoints.begin());
+ if (inverseFlag) {
+ out << word << vcbS.getWord(phraseS[j]) << " ";
+ } else {
+ out << vcbS.getWord(phraseS[j]) << word << " ";
+ }
+ }
+ // output target root symbol
+ if (conditionOnTargetLhsFlag) {
+ if (inverseFlag) {
+ out << "[X]";
+ } else {
+ out << vcbS.getWord(phraseS.back());
+ }
+ } else {
+ out << vcbT.getWord(phraseT.back());
+ }
+}
+
std::pair<PhrasePairGroup::Coll::iterator,bool> PhrasePairGroup::insert ( const PhraseAlignmentCollection& obj )
{
std::pair<iterator,bool> ret = m_coll.insert(obj);
diff --git a/scripts/training/phrase-extract/score.h b/scripts/training/phrase-extract/score.h
index dc94ecfde..9faa144c5 100644
--- a/scripts/training/phrase-extract/score.h
+++ b/scripts/training/phrase-extract/score.h
@@ -59,11 +59,7 @@ private:
};
// other functions *********************************************
-inline bool isNonTerminal( std::string &word )
+inline bool isNonTerminal( const std::string &word )
{
- return (word.length()>=3 &&
- word.substr(0,1).compare("[") == 0 &&
- word.substr(word.length()-1,1).compare("]") == 0);
+ return (word.length()>=3 && word[0] == '[' && word[word.length()-1] == ']');
}
-
-
diff --git a/scripts/training/train-model.perl.missing_bin_dir b/scripts/training/train-model.perl.missing_bin_dir
index aac6cef96..0db2ee437 100755
--- a/scripts/training/train-model.perl.missing_bin_dir
+++ b/scripts/training/train-model.perl.missing_bin_dir
@@ -31,6 +31,7 @@ my($_ROOT_DIR, $_CORPUS_DIR, $_GIZA_E2F, $_GIZA_F2E, $_MODEL_DIR, $_TEMP_DIR, $_
@_REORDERING_TABLE, @_GENERATION_TABLE, @_GENERATION_TYPE, $_GENERATION_CORPUS,
$_DONT_ZIP, $_MGIZA, $_MGIZA_CPUS, $_SNT2COOC, $_HMM_ALIGN, $_CONFIG,
$_HIERARCHICAL,$_XML,$_SOURCE_SYNTAX,$_TARGET_SYNTAX,$_GLUE_GRAMMAR,$_GLUE_GRAMMAR_FILE,$_UNKNOWN_WORD_LABEL_FILE,$_GHKM,$_PCFG,$_EXTRACT_OPTIONS,$_SCORE_OPTIONS,
+ $_ALT_DIRECT_RULE_SCORE_1, $_ALT_DIRECT_RULE_SCORE_2,
$_PHRASE_WORD_ALIGNMENT,$_FORCE_FACTORED_FILENAMES,
$_MEMSCORE, $_FINAL_ALIGNMENT_MODEL,
$_CONTINUE,$_MAX_LEXICAL_REORDERING,$_DO_STEPS,
@@ -106,6 +107,8 @@ $_HELP = 1
'unknown-word-label-file=s' => \$_UNKNOWN_WORD_LABEL_FILE,
'ghkm' => \$_GHKM,
'pcfg' => \$_PCFG,
+ 'alt-direct-rule-score-1' => \$_ALT_DIRECT_RULE_SCORE_1,
+ 'alt-direct-rule-score-2' => \$_ALT_DIRECT_RULE_SCORE_2,
'extract-options=s' => \$_EXTRACT_OPTIONS,
'score-options=s' => \$_SCORE_OPTIONS,
'source-syntax' => \$_SOURCE_SYNTAX,
@@ -1375,6 +1378,8 @@ sub extract_phrase {
$cmd .= " --GlueGrammar $___GLUE_GRAMMAR_FILE" if $_GLUE_GRAMMAR;
$cmd .= " --UnknownWordLabel $_UNKNOWN_WORD_LABEL_FILE" if $_TARGET_SYNTAX && defined($_UNKNOWN_WORD_LABEL_FILE);
$cmd .= " --PCFG" if $_PCFG;
+ $cmd .= " --UnpairedExtractFormat" if $_ALT_DIRECT_RULE_SCORE_1 || $_ALT_DIRECT_RULE_SCORE_2;
+ $cmd .= " --ConditionOnTargetLHS" if $_ALT_DIRECT_RULE_SCORE_1;
if (!defined($_GHKM)) {
$cmd .= " --SourceSyntax" if $_SOURCE_SYNTAX;
$cmd .= " --TargetSyntax" if $_TARGET_SYNTAX;
@@ -1506,10 +1511,12 @@ sub score_phrase_phrase_extract {
$cmd .= " --UnalignedFunctionWordPenalty ".($inverse ? $UNALIGNED_FW_F : $UNALIGNED_FW_E) if $UNALIGNED_FW_COUNT;
$cmd .= " --MinCountHierarchical $MIN_COUNT_HIERARCHICAL" if $MIN_COUNT_HIERARCHICAL;
$cmd .= " --PCFG" if $_PCFG;
+ $cmd .= " --UnpairedExtractFormat" if $_ALT_DIRECT_RULE_SCORE_1 || $_ALT_DIRECT_RULE_SCORE_2;
+ $cmd .= " --ConditionOnTargetLHS" if $_ALT_DIRECT_RULE_SCORE_1;
$cmd .= " $CORE_SCORE_OPTIONS" if defined($_SCORE_OPTIONS);
# sorting
- if ($direction eq "e2f") {
+ if ($direction eq "e2f" || $_ALT_DIRECT_RULE_SCORE_1 || $_ALT_DIRECT_RULE_SCORE_2) {
$cmd .= " 1 ";
}
else {