diff options
author | Matthias Huck <huck@i6.informatik.rwth-aachen.de> | 2014-06-11 22:27:18 +0400 |
---|---|---|
committer | Matthias Huck <huck@i6.informatik.rwth-aachen.de> | 2014-06-11 22:27:18 +0400 |
commit | d0e92da7340ae1c46c4eaa41f52bf5eaaf47961c (patch) | |
tree | aff4ce24eca81443c7c11181d08f380966355c1e | |
parent | 02848112d8bd2bc16114ad7b0dff465f083e0d4b (diff) |
GHKM extraction can add a source labels phrase property
-rw-r--r-- | phrase-extract/ExtractionPhrasePair.cpp | 143 | ||||
-rw-r--r-- | phrase-extract/ExtractionPhrasePair.h | 7 | ||||
-rw-r--r-- | phrase-extract/ScoreFeature.h | 2 | ||||
-rw-r--r-- | phrase-extract/extract-ghkm/ExtractGHKM.cpp | 277 | ||||
-rw-r--r-- | phrase-extract/extract-ghkm/ExtractGHKM.h | 10 | ||||
-rw-r--r-- | phrase-extract/extract-ghkm/Options.h | 6 | ||||
-rw-r--r-- | phrase-extract/extract-ghkm/ParseTree.h | 4 | ||||
-rw-r--r-- | phrase-extract/extract-ghkm/ScfgRule.cpp | 82 | ||||
-rw-r--r-- | phrase-extract/extract-ghkm/ScfgRule.h | 25 | ||||
-rw-r--r-- | phrase-extract/extract-ghkm/ScfgRuleWriter.cpp | 6 | ||||
-rw-r--r-- | phrase-extract/extract-ghkm/XmlTreeParser.h | 6 | ||||
-rw-r--r-- | phrase-extract/score-main.cpp | 180 | ||||
-rw-r--r-- | phrase-extract/score.h | 28 | ||||
-rw-r--r-- | phrase-extract/tables-core.h | 2 |
14 files changed, 707 insertions, 71 deletions
diff --git a/phrase-extract/ExtractionPhrasePair.cpp b/phrase-extract/ExtractionPhrasePair.cpp index 2b26c2ad6..9564b1cfe 100644 --- a/phrase-extract/ExtractionPhrasePair.cpp +++ b/phrase-extract/ExtractionPhrasePair.cpp @@ -321,5 +321,148 @@ std::string ExtractionPhrasePair::CollectAllPropertyValues(const std::string &ke } +std::string ExtractionPhrasePair::CollectAllLabelsSeparateLHSAndRHS(const std::string& propertyKey, + std::set<std::string>& labelSet, + boost::unordered_map<std::string,float>& countsLabelsLHS, + boost::unordered_map<std::string, boost::unordered_map<std::string,float>* >& jointCountsRulesTargetLHSAndLabelsLHS, + Vocabulary &vcbT) const +{ + const PROPERTY_VALUES *allPropertyValues = GetProperty( propertyKey ); + + if ( allPropertyValues == NULL ) { + return ""; + } + + std::string lhs="", rhs="", currentRhs=""; + float currentRhsCount = 0.0; + std::list< std::pair<std::string,float> > lhsGivenCurrentRhsCounts; + + std::ostringstream oss; + for (PROPERTY_VALUES::const_iterator iter=allPropertyValues->begin(); + iter!=allPropertyValues->end(); ++iter) { + + size_t space = (iter->first).find_last_of(' '); + if ( space == string::npos ) { + lhs = iter->first; + rhs.clear(); + } else { + lhs = (iter->first).substr(space+1); + rhs = (iter->first).substr(0,space); + } + + labelSet.insert(lhs); + + if ( rhs.compare(currentRhs) ) { + + if ( iter!=allPropertyValues->begin() ) { + if ( !currentRhs.empty() ) { + istringstream tokenizer(currentRhs); + std::string rhsLabel; + while ( tokenizer.peek() != EOF ) { + tokenizer >> rhsLabel; + labelSet.insert(rhsLabel); + } + oss << " " << currentRhs << " " << currentRhsCount; + } + if ( lhsGivenCurrentRhsCounts.size() > 0 ) { + if ( !currentRhs.empty() ) { + oss << " " << lhsGivenCurrentRhsCounts.size(); + } + for ( std::list< std::pair<std::string,float> >::const_iterator iter2=lhsGivenCurrentRhsCounts.begin(); + iter2!=lhsGivenCurrentRhsCounts.end(); ++iter2 ) { + oss << " " << iter2->first << " " << iter2->second; + + // update countsLabelsLHS and jointCountsRulesTargetLHSAndLabelsLHS + std::string ruleTargetLhs = vcbT.getWord(m_phraseTarget->back()); + ruleTargetLhs.erase(ruleTargetLhs.begin()); // strip square brackets + ruleTargetLhs.erase(ruleTargetLhs.size()-1); + + std::pair< boost::unordered_map<std::string,float>::iterator, bool > insertedCountsLabelsLHS = + countsLabelsLHS.insert(std::pair<std::string,float>(iter2->first,iter2->second)); + if (!insertedCountsLabelsLHS.second) { + (insertedCountsLabelsLHS.first)->second += iter2->second; + } + + boost::unordered_map<std::string, boost::unordered_map<std::string,float>* >::iterator jointCountsRulesTargetLHSAndLabelsLHSIter = + jointCountsRulesTargetLHSAndLabelsLHS.find(ruleTargetLhs); + if ( jointCountsRulesTargetLHSAndLabelsLHSIter == jointCountsRulesTargetLHSAndLabelsLHS.end() ) { + boost::unordered_map<std::string,float>* jointCounts = new boost::unordered_map<std::string,float>; + jointCounts->insert(std::pair<std::string,float>(iter2->first,iter2->second)); + jointCountsRulesTargetLHSAndLabelsLHS.insert(std::pair<std::string,boost::unordered_map<std::string,float>* >(ruleTargetLhs,jointCounts)); + } else { + boost::unordered_map<std::string,float>* jointCounts = jointCountsRulesTargetLHSAndLabelsLHSIter->second; + std::pair< boost::unordered_map<std::string,float>::iterator, bool > insertedJointCounts = + jointCounts->insert(std::pair<std::string,float>(iter2->first,iter2->second)); + if (!insertedJointCounts.second) { + (insertedJointCounts.first)->second += iter2->second; + } + } + + } + } + + lhsGivenCurrentRhsCounts.clear(); + } + + currentRhsCount = 0.0; + currentRhs = rhs; + } + + currentRhsCount += iter->second; + lhsGivenCurrentRhsCounts.push_back( std::pair<std::string,float>(lhs,iter->second) ); + } + + if ( !currentRhs.empty() ) { + istringstream tokenizer(currentRhs); + std::string rhsLabel; + while ( tokenizer.peek() != EOF ) { + tokenizer >> rhsLabel; + labelSet.insert(rhsLabel); + } + oss << " " << currentRhs << " " << currentRhsCount; + } + if ( lhsGivenCurrentRhsCounts.size() > 0 ) { + if ( !currentRhs.empty() ) { + oss << " " << lhsGivenCurrentRhsCounts.size(); + } + for ( std::list< std::pair<std::string,float> >::const_iterator iter2=lhsGivenCurrentRhsCounts.begin(); + iter2!=lhsGivenCurrentRhsCounts.end(); ++iter2 ) { + oss << " " << iter2->first << " " << iter2->second; + + // update countsLabelsLHS and jointCountsRulesTargetLHSAndLabelsLHS + std::string ruleTargetLhs = vcbT.getWord(m_phraseTarget->back()); + ruleTargetLhs.erase(ruleTargetLhs.begin()); // strip square brackets + ruleTargetLhs.erase(ruleTargetLhs.size()-1); + + std::pair< boost::unordered_map<std::string,float>::iterator, bool > insertedCountsLabelsLHS = + countsLabelsLHS.insert(std::pair<std::string,float>(iter2->first,iter2->second)); + if (!insertedCountsLabelsLHS.second) { + (insertedCountsLabelsLHS.first)->second += iter2->second; + } + + boost::unordered_map<std::string, boost::unordered_map<std::string,float>* >::iterator jointCountsRulesTargetLHSAndLabelsLHSIter = + jointCountsRulesTargetLHSAndLabelsLHS.find(ruleTargetLhs); + if ( jointCountsRulesTargetLHSAndLabelsLHSIter == jointCountsRulesTargetLHSAndLabelsLHS.end() ) { + boost::unordered_map<std::string,float>* jointCounts = new boost::unordered_map<std::string,float>; + jointCounts->insert(std::pair<std::string,float>(iter2->first,iter2->second)); + jointCountsRulesTargetLHSAndLabelsLHS.insert(std::pair<std::string,boost::unordered_map<std::string,float>* >(ruleTargetLhs,jointCounts)); + } else { + boost::unordered_map<std::string,float>* jointCounts = jointCountsRulesTargetLHSAndLabelsLHSIter->second; + std::pair< boost::unordered_map<std::string,float>::iterator, bool > insertedJointCounts = + jointCounts->insert(std::pair<std::string,float>(iter2->first,iter2->second)); + if (!insertedJointCounts.second) { + (insertedJointCounts.first)->second += iter2->second; + } + } + + } + } + + std::string allPropertyValuesString(oss.str()); + return allPropertyValuesString; +} + + + } diff --git a/phrase-extract/ExtractionPhrasePair.h b/phrase-extract/ExtractionPhrasePair.h index f04984391..ba23ac1f2 100644 --- a/phrase-extract/ExtractionPhrasePair.h +++ b/phrase-extract/ExtractionPhrasePair.h @@ -23,6 +23,7 @@ #include <vector> #include <set> #include <map> +#include <boost/unordered_map.hpp> namespace MosesTraining { @@ -124,6 +125,12 @@ public: std::string CollectAllPropertyValues(const std::string &key) const; + std::string CollectAllLabelsSeparateLHSAndRHS(const std::string& propertyKey, + std::set<std::string>& sourceLabelSet, + boost::unordered_map<std::string,float>& sourceLHSCounts, + boost::unordered_map<std::string, boost::unordered_map<std::string,float>* >& sourceRHSAndLHSJointCounts, + Vocabulary &vcbT) const; + void AddProperties( const std::string &str, float count ); void AddProperty( const std::string &key, const std::string &value, float count ) diff --git a/phrase-extract/ScoreFeature.h b/phrase-extract/ScoreFeature.h index 926397e71..30e198e21 100644 --- a/phrase-extract/ScoreFeature.h +++ b/phrase-extract/ScoreFeature.h @@ -90,7 +90,7 @@ public: float count, int sentenceId) const {}; - /** Add the values for this feature function. */ + /** Add the values for this score feature. */ virtual void add(const ScoreFeatureContext& context, std::vector<float>& denseValues, std::map<std::string,float>& sparseValues) const = 0; diff --git a/phrase-extract/extract-ghkm/ExtractGHKM.cpp b/phrase-extract/extract-ghkm/ExtractGHKM.cpp index 5b12203a5..b86c28586 100644 --- a/phrase-extract/extract-ghkm/ExtractGHKM.cpp +++ b/phrase-extract/extract-ghkm/ExtractGHKM.cpp @@ -30,6 +30,10 @@ #include "ScfgRule.h" #include "ScfgRuleWriter.h" #include "Span.h" +#include "SyntaxTree.h" +#include "tables-core.h" +#include "XmlException.h" +#include "XmlTree.h" #include "XmlTreeParser.h" #include <boost/program_options.hpp> @@ -63,7 +67,9 @@ int ExtractGHKM::Main(int argc, char *argv[]) OutputFileStream fwdExtractStream; OutputFileStream invExtractStream; std::ofstream glueGrammarStream; - std::ofstream unknownWordStream; + std::ofstream targetUnknownWordStream; + std::ofstream sourceUnknownWordStream; + std::ofstream sourceLabelSetStream; std::ofstream unknownWordSoftMatchesStream; std::string fwdFileName = options.extractFile; std::string invFileName = options.extractFile + std::string(".inv"); @@ -76,26 +82,44 @@ int ExtractGHKM::Main(int argc, char *argv[]) if (!options.glueGrammarFile.empty()) { OpenOutputFileOrDie(options.glueGrammarFile, glueGrammarStream); } - if (!options.unknownWordFile.empty()) { - OpenOutputFileOrDie(options.unknownWordFile, unknownWordStream); + if (!options.targetUnknownWordFile.empty()) { + OpenOutputFileOrDie(options.targetUnknownWordFile, targetUnknownWordStream); + } + if (!options.sourceUnknownWordFile.empty()) { + OpenOutputFileOrDie(options.sourceUnknownWordFile, sourceUnknownWordStream); + } + if (!options.sourceLabelSetFile.empty()) { + if (!options.sourceLabels) { + Error("SourceLabels should be active if SourceLabelSet is supposed to be written to a file"); + } + OpenOutputFileOrDie(options.sourceLabelSetFile, sourceLabelSetStream); // TODO: global sourceLabelSet cannot be determined during parallelized extraction } if (!options.unknownWordSoftMatchesFile.empty()) { OpenOutputFileOrDie(options.unknownWordSoftMatchesFile, unknownWordSoftMatchesStream); } // Target label sets for producing glue grammar. - std::set<std::string> labelSet; - std::map<std::string, int> topLabelSet; + std::set<std::string> targetLabelSet; + std::map<std::string, int> targetTopLabelSet; + + // Source label sets for producing glue grammar. + std::set<std::string> sourceLabelSet; + std::map<std::string, int> sourceTopLabelSet; // Word count statistics for producing unknown word labels. - std::map<std::string, int> wordCount; - std::map<std::string, std::string> wordLabel; + std::map<std::string, int> targetWordCount; + std::map<std::string, std::string> targetWordLabel; + + // Word count statistics for producing unknown word labels: source side. + std::map<std::string, int> sourceWordCount; + std::map<std::string, std::string> sourceWordLabel; std::string targetLine; std::string sourceLine; std::string alignmentLine; Alignment alignment; - XmlTreeParser xmlTreeParser(labelSet, topLabelSet); + XmlTreeParser xmlTreeParser(targetLabelSet, targetTopLabelSet); +// XmlTreeParser sourceXmlTreeParser(sourceLabelSet, sourceTopLabelSet); ScfgRuleWriter writer(fwdExtractStream, invExtractStream, options); size_t lineNum = options.sentenceOffset; while (true) { @@ -118,30 +142,71 @@ int ExtractGHKM::Main(int argc, char *argv[]) std::cerr << "skipping line " << lineNum << " with empty target tree\n"; continue; } - std::auto_ptr<ParseTree> t; + std::auto_ptr<ParseTree> targetParseTree; try { - t = xmlTreeParser.Parse(targetLine); - assert(t.get()); + targetParseTree = xmlTreeParser.Parse(targetLine); + assert(targetParseTree.get()); } catch (const Exception &e) { - std::ostringstream s; - s << "Failed to parse XML tree at line " << lineNum; + std::ostringstream oss; + oss << "Failed to parse target XML tree at line " << lineNum; if (!e.GetMsg().empty()) { - s << ": " << e.GetMsg(); + oss << ": " << e.GetMsg(); + } + Error(oss.str()); + } + + + // Parse source tree and construct a SyntaxTree object. + MosesTraining::SyntaxTree sourceSyntaxTree; + MosesTraining::SyntaxNode *sourceSyntaxTreeRoot=NULL; + + if (options.sourceLabels) { + try { + if (!ProcessAndStripXMLTags(sourceLine, sourceSyntaxTree, sourceLabelSet, sourceTopLabelSet, false)) { + throw Exception(""); + } + sourceSyntaxTree.ConnectNodes(); + sourceSyntaxTreeRoot = sourceSyntaxTree.GetTop(); + assert(sourceSyntaxTreeRoot); + } catch (const Exception &e) { + std::ostringstream oss; + oss << "Failed to parse source XML tree at line " << lineNum; + if (!e.GetMsg().empty()) { + oss << ": " << e.GetMsg(); + } + Error(oss.str()); } - Error(s.str()); } // Read source tokens. std::vector<std::string> sourceTokens(ReadTokens(sourceLine)); + // Construct a source ParseTree object object from the SyntaxTree object. + std::auto_ptr<ParseTree> sourceParseTree; + + if (options.sourceLabels) { + try { + sourceParseTree = XmlTreeParser::ConvertTree(*sourceSyntaxTreeRoot, sourceTokens); + assert(sourceParseTree.get()); + } catch (const Exception &e) { + std::ostringstream oss; + oss << "Failed to parse source XML tree at line " << lineNum; + if (!e.GetMsg().empty()) { + oss << ": " << e.GetMsg(); + } + Error(oss.str()); + } + } + + // Read word alignments. try { ReadAlignment(alignmentLine, alignment); } catch (const Exception &e) { - std::ostringstream s; - s << "Failed to read alignment at line " << lineNum << ": "; - s << e.GetMsg(); - Error(s.str()); + std::ostringstream oss; + oss << "Failed to read alignment at line " << lineNum << ": "; + oss << e.GetMsg(); + Error(oss.str()); } if (alignment.size() == 0) { std::cerr << "skipping line " << lineNum << " without alignment points\n"; @@ -149,13 +214,18 @@ int ExtractGHKM::Main(int argc, char *argv[]) } // Record word counts. - if (!options.unknownWordFile.empty()) { - CollectWordLabelCounts(*t, options, wordCount, wordLabel); + if (!options.targetUnknownWordFile.empty()) { + CollectWordLabelCounts(*targetParseTree, options, targetWordCount, targetWordLabel); + } + + // Record word counts: source side. + if (options.sourceLabels && !options.sourceUnknownWordFile.empty()) { + CollectWordLabelCounts(*sourceParseTree, options, sourceWordCount, sourceWordLabel); } // Form an alignment graph from the target tree, source words, and // alignment. - AlignmentGraph graph(t.get(), sourceTokens, alignment); + AlignmentGraph graph(targetParseTree.get(), sourceTokens, alignment); // Extract minimal rules, adding each rule to its root node's rule set. graph.ExtractMinimalRules(options); @@ -172,29 +242,54 @@ int ExtractGHKM::Main(int argc, char *argv[]) const std::vector<const Subgraph *> &rules = (*p)->GetRules(); for (std::vector<const Subgraph *>::const_iterator q = rules.begin(); q != rules.end(); ++q) { - ScfgRule r(**q); + ScfgRule *r = 0; + if (options.sourceLabels) { + r = new ScfgRule(**q, &sourceSyntaxTree); + } else { + r = new ScfgRule(**q); + } // TODO Can scope pruning be done earlier? - if (r.Scope() <= options.maxScope) { + if (r->Scope() <= options.maxScope) { if (!options.treeFragments) { - writer.Write(r); + writer.Write(*r); } else { - writer.Write(r,**q); + writer.Write(*r,**q); } } + delete r; } } } + std::map<std::string,size_t> sourceLabels; + if (options.sourceLabels && !options.sourceLabelSetFile.empty()) { + + sourceLabelSet.insert("XLHS"); // non-matching label (left-hand side) + sourceLabelSet.insert("XRHS"); // non-matching label (right-hand side) + sourceLabelSet.insert("TOPLABEL"); // as used in the glue grammar + sourceLabelSet.insert("SOMELABEL"); // as used in the glue grammar + size_t index = 0; + for (std::set<std::string>::const_iterator iter=sourceLabelSet.begin(); + iter!=sourceLabelSet.end(); ++iter, ++index) { + sourceLabels.insert(std::pair<std::string,size_t>(*iter,index)); + } + WriteSourceLabelSet(sourceLabels, sourceLabelSetStream); + } + if (!options.glueGrammarFile.empty()) { - WriteGlueGrammar(labelSet, topLabelSet, glueGrammarStream); + WriteGlueGrammar(targetLabelSet, targetTopLabelSet, sourceLabels, options, glueGrammarStream); } - if (!options.unknownWordFile.empty()) { - WriteUnknownWordLabel(wordCount, wordLabel, options, unknownWordStream); + if (!options.targetUnknownWordFile.empty()) { + WriteUnknownWordLabel(targetWordCount, targetWordLabel, options, targetUnknownWordStream); + } + + if (options.sourceLabels && !options.sourceUnknownWordFile.empty()) { + WriteUnknownWordLabel(sourceWordCount, sourceWordLabel, options, sourceUnknownWordStream, true); } if (!options.unknownWordSoftMatchesFile.empty()) { - WriteUnknownWordSoftMatches(labelSet, unknownWordSoftMatchesStream); + WriteUnknownWordSoftMatches(targetLabelSet, unknownWordSoftMatchesStream); } return 0; @@ -305,12 +400,20 @@ void ExtractGHKM::ProcessOptions(int argc, char *argv[], "include score based on PCFG scores in target corpus") ("TreeFragments", "output parse tree information") + ("SourceLabels", + "output source syntax label information") + ("SourceLabelSet", + po::value(&options.sourceLabelSetFile), + "write source syntax label set to named file") ("SentenceOffset", po::value(&options.sentenceOffset)->default_value(options.sentenceOffset), "set sentence number offset if processing split corpus") ("UnknownWordLabel", - po::value(&options.unknownWordFile), + po::value(&options.targetUnknownWordFile), "write unknown word labels to named file") + ("SourceUnknownWordLabel", + po::value(&options.sourceUnknownWordFile), + "write source syntax unknown word labels to named file") ("UnknownWordMinRelFreq", po::value(&options.unknownWordMinRelFreq)->default_value( options.unknownWordMinRelFreq), @@ -402,6 +505,9 @@ void ExtractGHKM::ProcessOptions(int argc, char *argv[], if (vm.count("TreeFragments")) { options.treeFragments = true; } + if (vm.count("SourceLabels")) { + options.sourceLabels = true; + } if (vm.count("UnknownWordUniform")) { options.unknownWordUniform = true; } @@ -411,7 +517,10 @@ void ExtractGHKM::ProcessOptions(int argc, char *argv[], // Workaround for extract-parallel issue. if (options.sentenceOffset > 0) { - options.unknownWordFile.clear(); + options.targetUnknownWordFile.clear(); + } + if (options.sentenceOffset > 0) { + options.sourceUnknownWordFile.clear(); options.unknownWordSoftMatchesFile.clear(); } } @@ -422,7 +531,7 @@ void ExtractGHKM::Error(const std::string &msg) const std::exit(1); } -std::vector<std::string> ExtractGHKM::ReadTokens(const std::string &s) +std::vector<std::string> ExtractGHKM::ReadTokens(const std::string &s) const { std::vector<std::string> tokens; @@ -454,9 +563,11 @@ std::vector<std::string> ExtractGHKM::ReadTokens(const std::string &s) void ExtractGHKM::WriteGlueGrammar( const std::set<std::string> &labelSet, const std::map<std::string, int> &topLabelSet, + const std::map<std::string,size_t> &sourceLabels, + const Options &options, std::ostream &out) { - // chose a top label that is not already a label + // choose a top label that is not already a label std::string topLabel = "QQQQQQ"; for(size_t i = 1; i <= topLabel.length(); i++) { if (labelSet.find(topLabel.substr(0,i)) == labelSet.end() ) { @@ -465,23 +576,75 @@ void ExtractGHKM::WriteGlueGrammar( } } + std::string sourceTopLabel = "TOPLABEL"; + std::string sourceSLabel = "S"; + std::string sourceSomeLabel = "SOMELABEL"; + // basic rules - out << "<s> [X] ||| <s> [" << topLabel << "] ||| 1 ||| ||| ||| ||| {{Tree [" << topLabel << " <s>]}}" << std::endl; - out << "[X][" << topLabel << "] </s> [X] ||| [X][" << topLabel << "] </s> [" << topLabel << "] ||| 1 ||| 0-0 ||| ||| ||| {{Tree [" << topLabel << " [" << topLabel << "] </s>]}}" << std::endl; + out << "<s> [X] ||| <s> [" << topLabel << "] ||| 1 ||| ||| ||| |||"; + if (options.treeFragments) { + out << " {{Tree [" << topLabel << " <s>]}}"; + } + if (options.sourceLabels) { + out << " {{SourceLabels 1 1 " << sourceTopLabel << " 1}}"; + } + out << std::endl; + + out << "[X][" << topLabel << "] </s> [X] ||| [X][" << topLabel << "] </s> [" << topLabel << "] ||| 1 ||| 0-0 ||| ||| |||"; + if (options.treeFragments) { + out << " {{Tree [" << topLabel << " [" << topLabel << "] </s>]}}"; + } + if (options.sourceLabels) { + out << " {{SourceLabels 2 1 " << sourceTopLabel << " 1 1 " << sourceTopLabel << " 1}}"; + } + out << std::endl; // top rules for (std::map<std::string, int>::const_iterator i = topLabelSet.begin(); i != topLabelSet.end(); ++i) { - out << "<s> [X][" << i->first << "] </s> [X] ||| <s> [X][" << i->first << "] </s> [" << topLabel << "] ||| 1 ||| 1-1 ||| ||| ||| {{Tree [" << topLabel << " <s> [" << i->first << "] </s>]}}" << std::endl; + out << "<s> [X][" << i->first << "] </s> [X] ||| <s> [X][" << i->first << "] </s> [" << topLabel << "] ||| 1 ||| 1-1 ||| ||| |||"; + if (options.treeFragments) { + out << " {{Tree [" << topLabel << " <s> [" << i->first << "] </s>]}}"; + } + if (options.sourceLabels) { + out << " {{SourceLabels 2 1 " << sourceSLabel << " 1 1 " << sourceTopLabel << " 1}}"; + } + out << std::endl; } // glue rules for(std::set<std::string>::const_iterator i = labelSet.begin(); i != labelSet.end(); i++ ) { - out << "[X][" << topLabel << "] [X][" << *i << "] [X] ||| [X][" << topLabel << "] [X][" << *i << "] [" << topLabel << "] ||| 2.718 ||| 0-0 1-1 ||| ||| ||| {{Tree [" << topLabel << " ["<< topLabel << "] [" << *i << "]]}}" << std::endl; + out << "[X][" << topLabel << "] [X][" << *i << "] [X] ||| [X][" << topLabel << "] [X][" << *i << "] [" << topLabel << "] ||| 2.718 ||| 0-0 1-1 ||| ||| |||"; + if (options.treeFragments) { + out << " {{Tree [" << topLabel << " ["<< topLabel << "] [" << *i << "]]}}"; + } + if (options.sourceLabels) { + out << " {{SourceLabels 3 2.718 " << sourceTopLabel << " " << sourceSomeLabel << " 2.718 1 " << sourceTopLabel << " 2.718}}"; // TODO: there should be better options than using "SOMELABEL" + } + out << std::endl; } + // glue rule for unknown word... - out << "[X][" << topLabel << "] [X][X] [X] ||| [X][" << topLabel << "] [X][X] [" << topLabel << "] ||| 2.718 ||| 0-0 1-1 ||| ||| ||| {{Tree [" << topLabel << " [" << topLabel << "] [X]]}}" << std::endl; + out << "[X][" << topLabel << "] [X][X] [X] ||| [X][" << topLabel << "] [X][X] [" << topLabel << "] ||| 2.718 ||| 0-0 1-1 ||| ||| |||"; + if (options.treeFragments) { + out << " {{Tree [" << topLabel << " [" << topLabel << "] [X]]}}"; + } + if (options.sourceLabels) { + out << " {{SourceLabels 3 1 " << sourceTopLabel << " " << sourceSomeLabel << " 1 1 " << sourceTopLabel << " 1}}"; // TODO: there should be better options than using "SOMELABEL" + } + out << std::endl; +} + +void ExtractGHKM::WriteSourceLabelSet( + const std::map<std::string,size_t> &sourceLabels, + std::ostream &out) +{ + out << sourceLabels.size() << std::endl; + for (std::map<std::string,size_t>::const_iterator iter=sourceLabels.begin(); + iter!=sourceLabels.end(); ++iter) { + out << iter->first << " " << iter->second << std::endl; + } } void ExtractGHKM::CollectWordLabelCounts( @@ -513,11 +676,26 @@ void ExtractGHKM::CollectWordLabelCounts( } } +std::vector<std::string> ExtractGHKM::ReadTokens(const ParseTree &root) const +{ + std::vector<std::string> tokens; + std::vector<const ParseTree*> leaves; + root.GetLeaves(std::back_inserter(leaves)); + for (std::vector<const ParseTree *>::const_iterator p = leaves.begin(); + p != leaves.end(); ++p) { + const ParseTree &leaf = **p; + const std::string &word = leaf.GetLabel(); + tokens.push_back(word); + } + return tokens; +} + void ExtractGHKM::WriteUnknownWordLabel( const std::map<std::string, int> &wordCount, const std::map<std::string, std::string> &wordLabel, const Options &options, - std::ostream &out) + std::ostream &out, + bool writeCounts) { if (!options.unknownWordSoftMatchesFile.empty()) { out << "UNK 1" << std::endl; @@ -537,12 +715,19 @@ void ExtractGHKM::WriteUnknownWordLabel( ++total; } } - for (std::map<std::string, int>::const_iterator p = labelCount.begin(); - p != labelCount.end(); ++p) { - double ratio = static_cast<double>(p->second) / static_cast<double>(total); - if (ratio >= options.unknownWordMinRelFreq) { - float weight = options.unknownWordUniform ? 1.0f : ratio; - out << p->first << " " << weight << std::endl; + if ( writeCounts ) { + for (std::map<std::string, int>::const_iterator p = labelCount.begin(); + p != labelCount.end(); ++p) { + out << p->first << " " << p->second << std::endl; + } + } else { + for (std::map<std::string, int>::const_iterator p = labelCount.begin(); + p != labelCount.end(); ++p) { + double ratio = static_cast<double>(p->second) / static_cast<double>(total); + if (ratio >= options.unknownWordMinRelFreq) { + float weight = options.unknownWordUniform ? 1.0f : ratio; + out << p->first << " " << weight << std::endl; + } } } } diff --git a/phrase-extract/extract-ghkm/ExtractGHKM.h b/phrase-extract/extract-ghkm/ExtractGHKM.h index 4c78923d3..44ce9fdbd 100644 --- a/phrase-extract/extract-ghkm/ExtractGHKM.h +++ b/phrase-extract/extract-ghkm/ExtractGHKM.h @@ -59,13 +59,19 @@ private: void WriteUnknownWordLabel(const std::map<std::string, int> &, const std::map<std::string, std::string> &, const Options &, - std::ostream &); + std::ostream &, + bool writeCounts=false); void WriteUnknownWordSoftMatches(const std::set<std::string> &, std::ostream &); void WriteGlueGrammar(const std::set<std::string> &, const std::map<std::string, int> &, + const std::map<std::string,size_t> &, + const Options &, std::ostream &); - std::vector<std::string> ReadTokens(const std::string &); + void WriteSourceLabelSet(const std::map<std::string,size_t> &, + std::ostream &); + std::vector<std::string> ReadTokens(const std::string &) const; + std::vector<std::string> ReadTokens(const ParseTree &root) const; void ProcessOptions(int, char *[], Options &) const; diff --git a/phrase-extract/extract-ghkm/Options.h b/phrase-extract/extract-ghkm/Options.h index ffa9bfa35..28a581802 100644 --- a/phrase-extract/extract-ghkm/Options.h +++ b/phrase-extract/extract-ghkm/Options.h @@ -41,6 +41,7 @@ public: , minimal(false) , pcfg(false) , treeFragments(false) + , sourceLabels(false) , sentenceOffset(0) , unpairedExtractFormat(false) , unknownWordMinRelFreq(0.03f) @@ -64,9 +65,12 @@ public: bool minimal; bool pcfg; bool treeFragments; + bool sourceLabels; + std::string sourceLabelSetFile; int sentenceOffset; bool unpairedExtractFormat; - std::string unknownWordFile; + std::string targetUnknownWordFile; + std::string sourceUnknownWordFile; std::string unknownWordSoftMatchesFile; float unknownWordMinRelFreq; bool unknownWordUniform; diff --git a/phrase-extract/extract-ghkm/ParseTree.h b/phrase-extract/extract-ghkm/ParseTree.h index 03da17735..694286c9d 100644 --- a/phrase-extract/extract-ghkm/ParseTree.h +++ b/phrase-extract/extract-ghkm/ParseTree.h @@ -63,7 +63,7 @@ public: bool IsLeaf() const; template<typename OutputIterator> - void GetLeaves(OutputIterator); + void GetLeaves(OutputIterator) const; private: // Disallow copying @@ -77,7 +77,7 @@ private: }; template<typename OutputIterator> -void ParseTree::GetLeaves(OutputIterator result) +void ParseTree::GetLeaves(OutputIterator result) const { if (IsLeaf()) { *result++ = this; diff --git a/phrase-extract/extract-ghkm/ScfgRule.cpp b/phrase-extract/extract-ghkm/ScfgRule.cpp index 2c901413d..a4dd91e0e 100644 --- a/phrase-extract/extract-ghkm/ScfgRule.cpp +++ b/phrase-extract/extract-ghkm/ScfgRule.cpp @@ -21,6 +21,7 @@ #include "Node.h" #include "Subgraph.h" +#include "SyntaxTree.h" #include <algorithm> @@ -29,11 +30,14 @@ namespace Moses namespace GHKM { -ScfgRule::ScfgRule(const Subgraph &fragment) +ScfgRule::ScfgRule(const Subgraph &fragment, + const MosesTraining::SyntaxTree *sourceSyntaxTree) : m_sourceLHS("X", NonTerminal) , m_targetLHS(fragment.GetRoot()->GetLabel(), NonTerminal) , m_pcfgScore(fragment.GetPcfgScore()) + , m_hasSourceLabels(sourceSyntaxTree) { + // Source RHS const std::set<const Node *> &leaves = fragment.GetLeaves(); @@ -55,6 +59,7 @@ ScfgRule::ScfgRule(const Subgraph &fragment) std::map<const Node *, std::vector<int> > sourceOrder; m_sourceRHS.reserve(sourceRHSNodes.size()); + m_numberOfNonTerminals = 0; int srcIndex = 0; for (std::vector<const Node *>::const_iterator p(sourceRHSNodes.begin()); p != sourceRHSNodes.end(); ++p, ++srcIndex) { @@ -62,6 +67,11 @@ ScfgRule::ScfgRule(const Subgraph &fragment) if (sinkNode.GetType() == TREE) { m_sourceRHS.push_back(Symbol("X", NonTerminal)); sourceOrder[&sinkNode].push_back(srcIndex); + ++m_numberOfNonTerminals; + if (sourceSyntaxTree) { + // Source syntax label + PushSourceLabel(sourceSyntaxTree,&sinkNode,"XRHS"); + } } else { assert(sinkNode.GetType() == SOURCE); m_sourceRHS.push_back(Symbol(sinkNode.GetLabel(), Terminal)); @@ -112,6 +122,76 @@ ScfgRule::ScfgRule(const Subgraph &fragment) } } } + + if (sourceSyntaxTree) { + // Source syntax label for root node (if sourceSyntaxTree available) + PushSourceLabel(sourceSyntaxTree,fragment.GetRoot(),"XLHS"); + // All non-terminal spans (including the LHS) should have obtained a label + // (a source-side syntactic constituent label if the span matches, "XLHS" otherwise) + assert(m_sourceLabels.size() == m_numberOfNonTerminals+1); + } +} + +void ScfgRule::PushSourceLabel(const MosesTraining::SyntaxTree *sourceSyntaxTree, + const Node *node, + const std::string &nonMatchingLabel) +{ + ContiguousSpan span = Closure(node->GetSpan()); + if (sourceSyntaxTree->HasNode(span.first,span.second)) { // does a source constituent match the span? + std::vector<MosesTraining::SyntaxNode*> sourceLabels = + sourceSyntaxTree->GetNodes(span.first,span.second); + if (!sourceLabels.empty()) { + // store the topmost matching label from the source syntax tree + m_sourceLabels.push_back(sourceLabels.back()->GetLabel()); + } + } else { + // no matching source-side syntactic constituent: store nonMatchingLabel + m_sourceLabels.push_back(nonMatchingLabel); + } +} + +// TODO: rather implement the method external to ScfgRule +void ScfgRule::UpdateSourceLabelCoocCounts(std::map< std::string, std::map<std::string,float>* > &coocCounts, float count) const +{ + std::map<int, int> sourceToTargetNTMap; + std::map<int, int> targetToSourceNTMap; + + for (Alignment::const_iterator p(m_alignment.begin()); + p != m_alignment.end(); ++p) { + if ( m_sourceRHS[p->first].GetType() == NonTerminal ) { + assert(m_targetRHS[p->second].GetType() == NonTerminal); + sourceToTargetNTMap[p->first] = p->second; + } + } + + size_t sourceIndex = 0; + size_t sourceNonTerminalIndex = 0; + for (std::vector<Symbol>::const_iterator p=m_sourceRHS.begin(); + p != m_sourceRHS.end(); ++p, ++sourceIndex) { + if ( p->GetType() == NonTerminal ) { + const std::string &sourceLabel = m_sourceLabels[sourceNonTerminalIndex]; + int targetIndex = sourceToTargetNTMap[sourceIndex]; + const std::string &targetLabel = m_targetRHS[targetIndex].GetValue(); + ++sourceNonTerminalIndex; + + std::map<std::string,float>* countMap = NULL; + std::map< std::string, std::map<std::string,float>* >::iterator iter = coocCounts.find(sourceLabel); + if ( iter == coocCounts.end() ) { + std::map<std::string,float> *newCountMap = new std::map<std::string,float>(); + std::pair< std::map< std::string, std::map<std::string,float>* >::iterator, bool > inserted = + coocCounts.insert( std::pair< std::string, std::map<std::string,float>* >(sourceLabel, newCountMap) ); + assert(inserted.second); + countMap = (inserted.first)->second; + } else { + countMap = iter->second; + } + std::pair< std::map<std::string,float>::iterator, bool > inserted = + countMap->insert( std::pair< std::string,float>(targetLabel, count) ); + if ( !inserted.second ) { + (inserted.first)->second += count; + } + } + } } int ScfgRule::Scope() const diff --git a/phrase-extract/extract-ghkm/ScfgRule.h b/phrase-extract/extract-ghkm/ScfgRule.h index 21a9e9900..5f1f35a61 100644 --- a/phrase-extract/extract-ghkm/ScfgRule.h +++ b/phrase-extract/extract-ghkm/ScfgRule.h @@ -22,9 +22,13 @@ #define EXTRACT_GHKM_SCFG_RULE_H_ #include "Alignment.h" +#include "SyntaxTree.h" #include <string> #include <vector> +#include <list> +#include <memory> +#include <iostream> namespace Moses { @@ -55,7 +59,8 @@ private: class ScfgRule { public: - ScfgRule(const Subgraph &fragment); + ScfgRule(const Subgraph &fragment, + const MosesTraining::SyntaxTree *sourceSyntaxTree = 0); const Symbol &GetSourceLHS() const { return m_sourceLHS; @@ -75,18 +80,36 @@ public: float GetPcfgScore() const { return m_pcfgScore; } + bool HasSourceLabels() const { + return m_hasSourceLabels; + } + void PrintSourceLabels(std::ostream &out) const { + for (std::vector<std::string>::const_iterator it = m_sourceLabels.begin(); + it != m_sourceLabels.end(); ++it) { + out << " " << (*it); + } + } + void UpdateSourceLabelCoocCounts(std::map< std::string, std::map<std::string,float>* > &coocCounts, + float count) const; int Scope() const; private: static bool PartitionOrderComp(const Node *, const Node *); + void PushSourceLabel(const MosesTraining::SyntaxTree *sourceSyntaxTree, + const Node *node, + const std::string &nonMatchingLabel); + Symbol m_sourceLHS; Symbol m_targetLHS; std::vector<Symbol> m_sourceRHS; std::vector<Symbol> m_targetRHS; Alignment m_alignment; float m_pcfgScore; + bool m_hasSourceLabels; + std::vector<std::string> m_sourceLabels; + unsigned m_numberOfNonTerminals; }; } // namespace GHKM diff --git a/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp b/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp index bc8fd7233..be373b67b 100644 --- a/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp +++ b/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp @@ -66,6 +66,12 @@ void ScfgRuleWriter::Write(const ScfgRule &rule, bool printEndl) m_fwd << " ||| " << std::exp(rule.GetPcfgScore()); } + if (m_options.sourceLabels && rule.HasSourceLabels()) { + m_fwd << " {{SourceLabels"; + rule.PrintSourceLabels(m_fwd); + m_fwd << "}}"; + } + if (printEndl) { m_fwd << std::endl; m_inv << std::endl; diff --git a/phrase-extract/extract-ghkm/XmlTreeParser.h b/phrase-extract/extract-ghkm/XmlTreeParser.h index d00fd7d9f..e5bf5b463 100644 --- a/phrase-extract/extract-ghkm/XmlTreeParser.h +++ b/phrase-extract/extract-ghkm/XmlTreeParser.h @@ -45,9 +45,11 @@ class XmlTreeParser public: XmlTreeParser(std::set<std::string> &, std::map<std::string, int> &); std::auto_ptr<ParseTree> Parse(const std::string &); + + static std::auto_ptr<ParseTree> ConvertTree(const MosesTraining::SyntaxNode &, + const std::vector<std::string> &); + private: - std::auto_ptr<ParseTree> ConvertTree(const MosesTraining::SyntaxNode &, - const std::vector<std::string> &); std::set<std::string> &m_labelSet; std::map<std::string, int> &m_topLabelSet; diff --git a/phrase-extract/score-main.cpp b/phrase-extract/score-main.cpp index 3ab6e2fd3..33c854274 100644 --- a/phrase-extract/score-main.cpp +++ b/phrase-extract/score-main.cpp @@ -28,6 +28,7 @@ #include <set> #include <vector> #include <algorithm> +#include <boost/unordered_map.hpp> #include "ScoreFeature.h" #include "tables-core.h" @@ -46,6 +47,10 @@ bool inverseFlag = false; bool hierarchicalFlag = false; bool pcfgFlag = false; bool treeFragmentsFlag = false; +bool sourceSyntaxLabelsFlag = false; +bool sourceSyntaxLabelSetFlag = false; +bool sourceSyntaxLabelCountsLHSFlag = false; +bool targetPreferenceLabelsFlag = false; bool unpairedExtractFormatFlag = false; bool conditionOnTargetLhsFlag = false; bool wordAlignmentFlag = true; @@ -61,13 +66,19 @@ bool crossedNonTerm = false; int countOfCounts[COC_MAX+1]; int totalDistinct = 0; float minCountHierarchical = 0; -std::map<std::string,float> sourceLHSCounts; -std::map<std::string, std::map<std::string,float>* > targetLHSAndSourceLHSJointCounts; +boost::unordered_map<std::string,float> sourceLHSCounts; +boost::unordered_map<std::string, boost::unordered_map<std::string,float>* > targetLHSAndSourceLHSJointCounts; std::set<std::string> sourceLabelSet; std::map<std::string,size_t> sourceLabels; std::vector<std::string> sourceLabelsByIndex; +boost::unordered_map<std::string,float> targetPreferenceLHSCounts; +boost::unordered_map<std::string, boost::unordered_map<std::string,float>* > ruleTargetLHSAndTargetPreferenceLHSJointCounts; +std::set<std::string> targetPreferenceLabelSet; +std::map<std::string,size_t> targetPreferenceLabels; +std::vector<std::string> targetPreferenceLabelsByIndex; + Vocabulary vcbT; Vocabulary vcbS; @@ -81,6 +92,11 @@ void processLine( std::string line, std::string &additionalPropertiesString, float &count, float &pcfgSum ); void writeCountOfCounts( const std::string &fileNameCountOfCounts ); +void writeLeftHandSideLabelCounts( const boost::unordered_map<std::string,float> &countsLabelLHS, + const boost::unordered_map<std::string, boost::unordered_map<std::string,float>* > &jointCountsLabelLHS, + const std::string &fileNameLeftHandSideSourceLabelCounts, + const std::string &fileNameLeftHandSideTargetSourceLabelCounts ); +void writeLabelSet( const std::set<std::string> &labelSet, const std::string &fileName ); void processPhrasePairs( std::vector< ExtractionPhrasePair* > &phrasePairsWithSameSource, ostream &phraseTableFile, const ScoreFeatureManager& featureManager, const MaybeLog& maybeLogProb ); void outputPhrasePair(const ExtractionPhrasePair &phrasePair, float, int, ostream &phraseTableFile, const ScoreFeatureManager &featureManager, const MaybeLog &maybeLog ); @@ -102,15 +118,21 @@ int main(int argc, char* argv[]) ScoreFeatureManager featureManager; if (argc < 4) { - std::cerr << "syntax: score extract lex phrase-table [--Inverse] [--Hierarchical] [--LogProb] [--NegLogProb] [--NoLex] [--GoodTuring] [--KneserNey] [--NoWordAlignment] [--UnalignedPenalty] [--UnalignedFunctionWordPenalty function-word-file] [--MinCountHierarchical count] [--PCFG] [--TreeFragments] [--UnpairedExtractFormat] [--ConditionOnTargetLHS] [--CrossedNonTerm]" << std::endl; + std::cerr << "syntax: score extract lex phrase-table [--Inverse] [--Hierarchical] [--LogProb] [--NegLogProb] [--NoLex] [--GoodTuring] [--KneserNey] [--NoWordAlignment] [--UnalignedPenalty] [--UnalignedFunctionWordPenalty function-word-file] [--MinCountHierarchical count] [--PCFG] [--TreeFragments] [--SourceLabels] [--SourceLabelSet] [--SourceLabelCountsLHS] [--TargetPreferenceLabels] [--UnpairedExtractFormat] [--ConditionOnTargetLHS] [--CrossedNonTerm]" << std::endl; std::cerr << featureManager.usage() << std::endl; exit(1); } std::string fileNameExtract = argv[1]; std::string fileNameLex = argv[2]; std::string fileNamePhraseTable = argv[3]; + std::string fileNameSourceLabelSet; std::string fileNameCountOfCounts; std::string fileNameFunctionWords; + std::string fileNameLeftHandSideSourceLabelCounts; + std::string fileNameLeftHandSideTargetSourceLabelCounts; + std::string fileNameTargetPreferenceLabelSet; + std::string fileNameLeftHandSideTargetPreferenceLabelCounts; + std::string fileNameLeftHandSideRuleTargetTargetPreferenceLabelCounts; std::vector<std::string> featureArgs; // all unknown args passed to feature manager for(int i=4; i<argc; i++) { @@ -126,6 +148,26 @@ int main(int argc, char* argv[]) } else if (strcmp(argv[i],"--TreeFragments") == 0) { treeFragmentsFlag = true; std::cerr << "including tree fragment information from syntactic parse\n"; + } else if (strcmp(argv[i],"--SourceLabels") == 0) { + sourceSyntaxLabelsFlag = true; + std::cerr << "including source label information" << std::endl; + } else if (strcmp(argv[i],"--SourceLabelSet") == 0) { + sourceSyntaxLabelSetFlag = true; + fileNameSourceLabelSet = std::string(fileNamePhraseTable) + ".syntaxLabels.src"; + std::cerr << "writing source syntax label set to file " << fileNameSourceLabelSet << std::endl; + } else if (strcmp(argv[i],"--SourceLabelCountsLHS") == 0) { + sourceSyntaxLabelCountsLHSFlag = true; + fileNameLeftHandSideSourceLabelCounts = std::string(fileNamePhraseTable) + ".src.lhs"; + fileNameLeftHandSideTargetSourceLabelCounts = std::string(fileNamePhraseTable) + ".tgt-src.lhs"; + std::cerr << "counting left-hand side source labels and writing them to files " << fileNameLeftHandSideSourceLabelCounts << " and " << fileNameLeftHandSideTargetSourceLabelCounts << std::endl; + } else if (strcmp(argv[i],"--TargetPreferenceLabels") == 0) { + targetPreferenceLabelsFlag = true; + std::cerr << "including target preference label information" << std::endl; + fileNameTargetPreferenceLabelSet = std::string(fileNamePhraseTable) + ".syntaxLabels.tgtpref"; + std::cerr << "writing target preference label set to file " << fileNameTargetPreferenceLabelSet << std::endl; + fileNameLeftHandSideTargetPreferenceLabelCounts = std::string(fileNamePhraseTable) + ".tgtpref.lhs"; + fileNameLeftHandSideRuleTargetTargetPreferenceLabelCounts = std::string(fileNamePhraseTable) + ".tgt-tgtpref.lhs"; + std::cerr << "counting left-hand side target preference labels and writing them to files " << fileNameLeftHandSideTargetPreferenceLabelCounts << " and " << fileNameLeftHandSideRuleTargetTargetPreferenceLabelCounts << std::endl; } else if (strcmp(argv[i],"--UnpairedExtractFormat") == 0) { unpairedExtractFormatFlag = true; std::cerr << "processing unpaired extract format" << std::endl; @@ -243,7 +285,7 @@ int main(int argc, char* argv[]) int i=0; // TODO why read only the 1st line? - if ( getline(extractFileP, line)) { + if ( getline(extractFileP, line) ) { ++i; tmpPhraseSource = new PHRASE(); tmpPhraseTarget = new PHRASE(); @@ -373,6 +415,26 @@ int main(int argc, char* argv[]) if (goodTuringFlag || kneserNeyFlag) { writeCountOfCounts( fileNameCountOfCounts ); } + + // source syntax labels + if (sourceSyntaxLabelsFlag && sourceSyntaxLabelSetFlag && !inverseFlag) { + writeLabelSet( sourceLabelSet, fileNameSourceLabelSet ); + } + if (sourceSyntaxLabelsFlag && sourceSyntaxLabelCountsLHSFlag && !inverseFlag) { + writeLeftHandSideLabelCounts( sourceLHSCounts, + targetLHSAndSourceLHSJointCounts, + fileNameLeftHandSideSourceLabelCounts, + fileNameLeftHandSideTargetSourceLabelCounts ); + } + + // target preference labels + if (targetPreferenceLabelsFlag && !inverseFlag) { + writeLabelSet( targetPreferenceLabelSet, fileNameTargetPreferenceLabelSet ); + writeLeftHandSideLabelCounts( targetPreferenceLHSCounts, + ruleTargetLHSAndTargetPreferenceLHSJointCounts, + fileNameLeftHandSideTargetPreferenceLabelCounts, + fileNameLeftHandSideRuleTargetTargetPreferenceLabelCounts ); + } } @@ -467,6 +529,70 @@ void writeCountOfCounts( const string &fileNameCountOfCounts ) } +void writeLeftHandSideLabelCounts( const boost::unordered_map<std::string,float> &countsLabelLHS, + const boost::unordered_map<std::string, boost::unordered_map<std::string,float>* > &jointCountsLabelLHS, + const std::string &fileNameLeftHandSideSourceLabelCounts, + const std::string &fileNameLeftHandSideTargetSourceLabelCounts ) +{ + // open file + Moses::OutputFileStream leftHandSideSourceLabelCounts; + bool success = leftHandSideSourceLabelCounts.Open(fileNameLeftHandSideSourceLabelCounts.c_str()); + if (!success) { + std::cerr << "ERROR: could not open left-hand side label counts file " + << fileNameLeftHandSideSourceLabelCounts << std::endl; + return; + } + + // write source left-hand side counts + for (boost::unordered_map<std::string,float>::const_iterator iter=sourceLHSCounts.begin(); + iter!=sourceLHSCounts.end(); ++iter) { + leftHandSideSourceLabelCounts << iter->first << " " << iter->second << std::endl; + } + + leftHandSideSourceLabelCounts.Close(); + + // open file + Moses::OutputFileStream leftHandSideTargetSourceLabelCounts; + success = leftHandSideTargetSourceLabelCounts.Open(fileNameLeftHandSideTargetSourceLabelCounts.c_str()); + if (!success) { + std::cerr << "ERROR: could not open left-hand side label joint counts file " + << fileNameLeftHandSideTargetSourceLabelCounts << std::endl; + return; + } + + // write source left-hand side / target left-hand side joint counts + for (boost::unordered_map<std::string, boost::unordered_map<std::string,float>* >::const_iterator iter=targetLHSAndSourceLHSJointCounts.begin(); + iter!=targetLHSAndSourceLHSJointCounts.end(); ++iter) { + for (boost::unordered_map<std::string,float>::const_iterator iter2=(iter->second)->begin(); + iter2!=(iter->second)->end(); ++iter2) { + leftHandSideTargetSourceLabelCounts << iter->first << " "<< iter2->first << " " << iter2->second << std::endl; + } + } + + leftHandSideTargetSourceLabelCounts.Close(); +} + + +void writeLabelSet( const std::set<std::string> &labelSet, const std::string &fileName ) +{ + // open file + Moses::OutputFileStream out; + bool success = out.Open(fileName.c_str()); + if (!success) { + std::cerr << "ERROR: could not open label set file " + << fileName << std::endl; + return; + } + + for (std::set<std::string>::const_iterator iter=labelSet.begin(); + iter!=labelSet.end(); ++iter) { + out << *iter << std::endl; + } + + out.Close(); +} + + void processPhrasePairs( std::vector< ExtractionPhrasePair* > &phrasePairsWithSameSource, ostream &phraseTableFile, const ScoreFeatureManager& featureManager, const MaybeLog& maybeLogProb ) { @@ -639,7 +765,7 @@ void outputPhrasePair(const ExtractionPhrasePair &phrasePair, if (kneserNeyFlag) phraseTableFile << " " << distinctCount; - if ((treeFragmentsFlag) && + if ((treeFragmentsFlag || sourceSyntaxLabelsFlag || targetPreferenceLabelsFlag) && !inverseFlag) { phraseTableFile << " |||"; } @@ -654,6 +780,49 @@ void outputPhrasePair(const ExtractionPhrasePair &phrasePair, } } + // syntax labels + if ((sourceSyntaxLabelsFlag || targetPreferenceLabelsFlag) && !inverseFlag) { + unsigned nNTs = 1; + for(size_t j=0; j<phraseSource->size()-1; ++j) { + if (isNonTerminal(vcbS.getWord( phraseSource->at(j) ))) + ++nNTs; + } + // source syntax labels + if (sourceSyntaxLabelsFlag) { + std::string sourceLabelCounts; + sourceLabelCounts = phrasePair.CollectAllLabelsSeparateLHSAndRHS("SourceLabels", + sourceLabelSet, + sourceLHSCounts, + targetLHSAndSourceLHSJointCounts, + vcbT); + if ( !sourceLabelCounts.empty() ) { + phraseTableFile << " {{SourceLabels " + << nNTs // for convenience: number of non-terminal symbols in this rule (incl. left hand side NT) + << " " + << count // rule count + << sourceLabelCounts + << "}}"; + } + } + // target preference labels + if (targetPreferenceLabelsFlag) { + std::string targetPreferenceLabelCounts; + targetPreferenceLabelCounts = phrasePair.CollectAllLabelsSeparateLHSAndRHS("TargetPreferences", + targetPreferenceLabelSet, + targetPreferenceLHSCounts, + ruleTargetLHSAndTargetPreferenceLHSJointCounts, + vcbT); + if ( !targetPreferenceLabelCounts.empty() ) { + phraseTableFile << " {{TargetPreferences " + << nNTs // for convenience: number of non-terminal symbols in this rule (incl. left hand side NT) + << " " + << count // rule count + << targetPreferenceLabelCounts + << "}}"; + } + } + } + phraseTableFile << std::endl; } @@ -894,3 +1063,4 @@ void invertAlignment(const PHRASE *phraseSource, const PHRASE *phraseTarget, } } } + diff --git a/phrase-extract/score.h b/phrase-extract/score.h index 6a10536c1..470332a06 100644 --- a/phrase-extract/score.h +++ b/phrase-extract/score.h @@ -1,12 +1,22 @@ -#pragma once -/* - * score.h - * extract - * - * Created by Hieu Hoang on 28/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ +/*********************************************************************** + Moses - factored phrase-based language decoder + Copyright (C) 2009 University of Edinburgh + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + ***********************************************************************/ + #include <string> #include <vector> diff --git a/phrase-extract/tables-core.h b/phrase-extract/tables-core.h index e239e5900..9662ced2a 100644 --- a/phrase-extract/tables-core.h +++ b/phrase-extract/tables-core.h @@ -27,7 +27,7 @@ public: std::vector< WORD > vocab; WORD_ID storeIfNew( const WORD& ); WORD_ID getWordID( const WORD& ); - inline WORD &getWord( WORD_ID id ) { + inline WORD &getWord( const WORD_ID id ) { return vocab[ id ]; } }; |