diff options
author | bhaddow <bhaddow@1f5c12ca-751b-0410-a591-d2e778427230> | 2011-09-07 20:42:46 +0400 |
---|---|---|
committer | bhaddow <bhaddow@1f5c12ca-751b-0410-a591-d2e778427230> | 2011-09-07 20:42:46 +0400 |
commit | 2c585ce6e797cc3a510272038972c61a642d8bb4 (patch) | |
tree | b9ad558be560f9738922f4b7394b1f9810483052 /scripts | |
parent | de51b69d030a02d3e3117d97774c398e0cdd333b (diff) |
restore
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4186 1f5c12ca-751b-0410-a591-d2e778427230
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/training/phrase-extract/score.cpp | 549 |
1 files changed, 549 insertions, 0 deletions
diff --git a/scripts/training/phrase-extract/score.cpp b/scripts/training/phrase-extract/score.cpp new file mode 100644 index 000000000..fbb27b944 --- /dev/null +++ b/scripts/training/phrase-extract/score.cpp @@ -0,0 +1,549 @@ +/*********************************************************************** + Moses - factored phrase-based language decoder + Copyright (C) 2009 University of Edinburgh + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + ***********************************************************************/ + +#include <sstream> +#include <cstdio> +#include <iostream> +#include <fstream> +#include <vector> +#include <stdlib.h> +#include <assert.h> +#include <cstring> +#include <set> + +#include "SafeGetline.h" +#include "tables-core.h" +#include "PhraseAlignment.h" +#include "score.h" +#include "InputFileStream.h" + +using namespace std; + +#define LINE_MAX_LENGTH 100000 + +Vocabulary vcbT; +Vocabulary vcbS; + +class LexicalTable +{ +public: + map< WORD_ID, map< WORD_ID, double > > ltable; + void load( char[] ); + double permissiveLookup( WORD_ID wordS, WORD_ID wordT ) { + // cout << endl << vcbS.getWord( wordS ) << "-" << vcbT.getWord( wordT ) << ":"; + if (ltable.find( wordS ) == ltable.end()) return 1.0; + if (ltable[ wordS ].find( wordT ) == ltable[ wordS ].end()) return 1.0; + // cout << ltable[ wordS ][ wordT ]; + return ltable[ wordS ][ wordT ]; + } +}; + +vector<string> tokenize( const char [] ); + +void writeCountOfCounts( const char* fileNameCountOfCounts ); +void processPhrasePairs( vector< PhraseAlignment > & , ostream &phraseTableFile); +PhraseAlignment* findBestAlignment( vector< PhraseAlignment* > & ); +void outputPhrasePair( vector< PhraseAlignment * > &, float, int, ostream &phraseTableFile ); +double computeLexicalTranslation( const PHRASE &, const PHRASE &, PhraseAlignment * ); +double computeUnalignedPenalty( const PHRASE &, const PHRASE &, PhraseAlignment * ); +set<string> functionWordList; +void loadFunctionWords( const char* fileNameFunctionWords ); +double computeUnalignedFWPenalty( const PHRASE &, const PHRASE &, PhraseAlignment * ); + +LexicalTable lexTable; +bool inverseFlag = false; +bool hierarchicalFlag = false; +bool wordAlignmentFlag = false; +bool goodTuringFlag = false; +bool kneserNeyFlag = false; +#define COC_MAX 10 +bool logProbFlag = false; +int negLogProb = 1; +bool lexFlag = true; +bool unalignedFlag = false; +bool unalignedFWFlag = false; +int countOfCounts[COC_MAX+1]; +int totalDistinct = 0; +float minCountHierarchical = 0; + +int main(int argc, char* argv[]) +{ + cerr << "Score v2.0 written by Philipp Koehn\n" + << "scoring methods for extracted rules\n"; + + if (argc < 4) { + cerr << "syntax: score extract lex phrase-table [--Inverse] [--Hierarchical] [--LogProb] [--NegLogProb] [--NoLex] [--GoodTuring coc-file] [--KneserNey coc-file] [--WordAlignment] [--UnalignedPenalty] [--UnalignedFunctionWordPenalty function-word-file] [--MinCountHierarchical count]\n"; + exit(1); + } + char* fileNameExtract = argv[1]; + char* fileNameLex = argv[2]; + char* fileNamePhraseTable = argv[3]; + char* fileNameCountOfCounts; + char* fileNameFunctionWords; + + for(int i=4; i<argc; i++) { + if (strcmp(argv[i],"inverse") == 0 || strcmp(argv[i],"--Inverse") == 0) { + inverseFlag = true; + cerr << "using inverse mode\n"; + } else if (strcmp(argv[i],"--Hierarchical") == 0) { + hierarchicalFlag = true; + cerr << "processing hierarchical rules\n"; + } else if (strcmp(argv[i],"--WordAlignment") == 0) { + wordAlignmentFlag = true; + cerr << "outputing word alignment" << endl; + } else if (strcmp(argv[i],"--NoLex") == 0) { + lexFlag = false; + cerr << "not computing lexical translation score\n"; + } else if (strcmp(argv[i],"--GoodTuring") == 0) { + goodTuringFlag = true; + if (i+1==argc) { + cerr << "ERROR: specify count of count files for Good Turing discounting!\n"; + exit(1); + } + fileNameCountOfCounts = argv[++i]; + cerr << "adjusting phrase translation probabilities with Good Turing discounting\n"; + } else if (strcmp(argv[i],"--KneserNey") == 0) { + kneserNeyFlag = true; + if (i+1==argc) { + cerr << "ERROR: specify count of count files for Kneser Ney discounting!\n"; + exit(1); + } + fileNameCountOfCounts = argv[++i]; + cerr << "adjusting phrase translation probabilities with Kneser Ney discounting\n"; + } else if (strcmp(argv[i],"--UnalignedPenalty") == 0) { + unalignedFlag = true; + cerr << "using unaligned word penalty\n"; + } else if (strcmp(argv[i],"--UnalignedFunctionWordPenalty") == 0) { + unalignedFWFlag = true; + if (i+1==argc) { + cerr << "ERROR: specify count of count files for Kneser Ney discounting!\n"; + exit(1); + } + fileNameFunctionWords = argv[++i]; + cerr << "using unaligned function word penalty with function words from " << fileNameFunctionWords << endl; + } else if (strcmp(argv[i],"--LogProb") == 0) { + logProbFlag = true; + cerr << "using log-probabilities\n"; + } else if (strcmp(argv[i],"--NegLogProb") == 0) { + logProbFlag = true; + negLogProb = -1; + cerr << "using negative log-probabilities\n"; + } else if (strcmp(argv[i],"--MinCountHierarchical") == 0) { + minCountHierarchical = atof(argv[++i]); + cerr << "dropping all phrase pairs occurring less than " << minCountHierarchical << " times\n"; + minCountHierarchical -= 0.00001; // account for rounding + } else { + cerr << "ERROR: unknown option " << argv[i] << endl; + exit(1); + } + } + + // lexical translation table + if (lexFlag) + lexTable.load( fileNameLex ); + + // function word list + if (unalignedFWFlag) + loadFunctionWords( fileNameFunctionWords ); + + // compute count of counts for Good Turing discounting + if (goodTuringFlag || kneserNeyFlag) { + for(int i=1; i<=COC_MAX; i++) countOfCounts[i] = 0; + } + + // sorted phrase extraction file + Moses::InputFileStream extractFile(fileNameExtract); + + if (extractFile.fail()) { + cerr << "ERROR: could not open extract file " << fileNameExtract << endl; + exit(1); + } + istream &extractFileP = extractFile; + + // output file: phrase translation table + ostream *phraseTableFile; + + if (strcmp(fileNamePhraseTable, "-") == 0) { + phraseTableFile = &cout; + } + else { + ofstream *outputFile = new ofstream(); + outputFile->open(fileNamePhraseTable); + if (outputFile->fail()) { + cerr << "ERROR: could not open file phrase table file " + << fileNamePhraseTable << endl; + exit(1); + } + phraseTableFile = outputFile; + } + + // loop through all extracted phrase translations + float lastCount = 0.0f; + vector< PhraseAlignment > phrasePairsWithSameF; + int i=0; + char line[LINE_MAX_LENGTH],lastLine[LINE_MAX_LENGTH]; + lastLine[0] = '\0'; + PhraseAlignment *lastPhrasePair = NULL; + while(true) { + if (extractFileP.eof()) break; + if (++i % 100000 == 0) cerr << "." << flush; + SAFE_GETLINE((extractFileP), line, LINE_MAX_LENGTH, '\n', __FILE__); + if (extractFileP.eof()) break; + + // identical to last line? just add count + if (strcmp(line,lastLine) == 0) { + lastPhrasePair->count += lastCount; + continue; + } + strcpy( lastLine, line ); + + // create new phrase pair + PhraseAlignment phrasePair; + phrasePair.create( line, i ); + lastCount = phrasePair.count; + + // only differs in count? just add count + if (lastPhrasePair != NULL && lastPhrasePair->equals( phrasePair )) { + lastPhrasePair->count += phrasePair.count; + continue; + } + + // if new source phrase, process last batch + if (lastPhrasePair != NULL && + lastPhrasePair->GetSource() != phrasePair.GetSource()) { + processPhrasePairs( phrasePairsWithSameF, *phraseTableFile ); + phrasePairsWithSameF.clear(); + lastPhrasePair = NULL; + } + + // add phrase pairs to list, it's now the last one + phrasePairsWithSameF.push_back( phrasePair ); + lastPhrasePair = &phrasePairsWithSameF.back(); + } + processPhrasePairs( phrasePairsWithSameF, *phraseTableFile ); + + phraseTableFile->flush(); + if (phraseTableFile != &cout) { + (dynamic_cast<ofstream*>(phraseTableFile))->close(); + delete phraseTableFile; + } + + // output count of count statistics + if (goodTuringFlag || kneserNeyFlag) { + writeCountOfCounts( fileNameCountOfCounts ); + } +} + +void writeCountOfCounts( const char* fileNameCountOfCounts ) +{ + // open file + ofstream countOfCountsFile; + countOfCountsFile.open(fileNameCountOfCounts); + if (countOfCountsFile.fail()) { + cerr << "ERROR: could not open count-of-counts file " + << fileNameCountOfCounts << endl; + return; + } + + // Kneser-Ney needs the total number of phrase pairs + countOfCountsFile << totalDistinct; + + // write out counts + for(int i=1; i<=COC_MAX; i++) { + countOfCountsFile << countOfCounts[ i ] << endl; + } + countOfCountsFile.close(); +} + +void processPhrasePairs( vector< PhraseAlignment > &phrasePair, ostream &phraseTableFile ) +{ + if (phrasePair.size() == 0) return; + + // group phrase pairs based on alignments that matter + // (i.e. that re-arrange non-terminals) + vector< vector< PhraseAlignment * > > phrasePairGroup; + float totalSource = 0; + + // loop through phrase pairs + for(size_t i=0; i<phrasePair.size(); i++) { + // add to total count + totalSource += phrasePair[i].count; + + // check for matches + bool matched = false; + for(size_t g=0; g<phrasePairGroup.size(); g++) { + vector< PhraseAlignment* > &group = phrasePairGroup[g]; + // matched? place into same group + if ( group[0]->match( phrasePair[i] )) { + group.push_back( &phrasePair[i] ); + matched = true; + } + } + // not matched? create new group + if (! matched) { + vector< PhraseAlignment* > newGroup; + newGroup.push_back( &phrasePair[i] ); + phrasePairGroup.push_back( newGroup ); + } + } + + // output the distinct phrase pairs, one at a time + for(size_t g=0; g<phrasePairGroup.size(); g++) { + vector< PhraseAlignment* > &group = phrasePairGroup[g]; + outputPhrasePair( group, totalSource, phrasePairGroup.size(), phraseTableFile ); + } +} + +PhraseAlignment* findBestAlignment( vector< PhraseAlignment* > &phrasePair ) +{ + float bestAlignmentCount = -1; + PhraseAlignment* bestAlignment; + + for(int i=0; i<phrasePair.size(); i++) { + if (phrasePair[i]->count > bestAlignmentCount) { + bestAlignmentCount = phrasePair[i]->count; + bestAlignment = phrasePair[i]; + } + } + + return bestAlignment; +} + +void outputPhrasePair( vector< PhraseAlignment* > &phrasePair, float totalCount, int distinctCount, ostream &phraseTableFile ) +{ + if (phrasePair.size() == 0) return; + + PhraseAlignment *bestAlignment = findBestAlignment( phrasePair ); + + // compute count + float count = 0; + for(size_t i=0; i<phrasePair.size(); i++) { + count += phrasePair[i]->count; + } + + // collect count of count statistics + if (goodTuringFlag || kneserNeyFlag) { + totalDistinct++; + int countInt = count + 0.99999; + if(countInt <= COC_MAX) + countOfCounts[ countInt ]++; + } + + // output phrases + const PHRASE &phraseS = phrasePair[0]->GetSource(); + const PHRASE &phraseT = phrasePair[0]->GetTarget(); + + // do not output if hierarchical and count below threshold + if (hierarchicalFlag && count < minCountHierarchical) { + for(int j=0; j<phraseS.size()-1; j++) { + if (isNonTerminal(vcbS.getWord( phraseS[j] ))) + return; + } + } + + // source phrase (unless inverse) + if (! inverseFlag) { + for(int j=0; j<phraseS.size(); j++) { + phraseTableFile << vcbS.getWord( phraseS[j] ); + phraseTableFile << " "; + } + phraseTableFile << "||| "; + } + + // target phrase + for(int j=0; j<phraseT.size(); j++) { + phraseTableFile << vcbT.getWord( phraseT[j] ); + phraseTableFile << " "; + } + phraseTableFile << "||| "; + + // source phrase (if inverse) + if (inverseFlag) { + for(int j=0; j<phraseS.size(); j++) { + phraseTableFile << vcbS.getWord( phraseS[j] ); + phraseTableFile << " "; + } + phraseTableFile << "||| "; + } + + // lexical translation probability + if (lexFlag) { + double lexScore = computeLexicalTranslation( phraseS, phraseT, bestAlignment); + phraseTableFile << ( logProbFlag ? negLogProb*log(lexScore) : lexScore ); + } + + // unaligned word penalty + if (unalignedFlag) { + double penalty = computeUnalignedPenalty( phraseS, phraseT, bestAlignment); + phraseTableFile << " " << ( logProbFlag ? negLogProb*log(penalty) : penalty ); + } + + // unaligned function word penalty + if (unalignedFWFlag) { + double penalty = computeUnalignedFWPenalty( phraseS, phraseT, bestAlignment); + phraseTableFile << " " << ( logProbFlag ? negLogProb*log(penalty) : penalty ); + } + + phraseTableFile << " ||| "; + + // alignment info for non-terminals + if (! inverseFlag) { + if (hierarchicalFlag) { + // always output alignment if hiero style, but only for non-terms + assert(phraseT.size() == bestAlignment->alignedToT.size() + 1); + for(int j = 0; j < phraseT.size() - 1; j++) { + if (isNonTerminal(vcbT.getWord( phraseT[j] ))) { + if (bestAlignment->alignedToT[ j ].size() != 1) { + cerr << "Error: unequal numbers of non-terminals. Make sure the text does not contain words in square brackets (like [xxx])." << endl; + phraseTableFile.flush(); + assert(bestAlignment->alignedToT[ j ].size() == 1); + } + int sourcePos = *(bestAlignment->alignedToT[ j ].begin()); + phraseTableFile << sourcePos << "-" << j << " "; + } + } + } else if (wordAlignmentFlag) { + // alignment info in pb model + for(int j=0; j<bestAlignment->alignedToT.size(); j++) { + const set< size_t > &aligned = bestAlignment->alignedToT[j]; + for (set< size_t >::const_iterator p(aligned.begin()); p != aligned.end(); ++p) { + phraseTableFile << *p << "-" << j << " "; + } + } + } + } + + // counts + phraseTableFile << " ||| " << totalCount << " " << count; + if (kneserNeyFlag) + phraseTableFile << " " << distinctCount; + phraseTableFile << endl; +} + +double computeUnalignedPenalty( const PHRASE &phraseS, const PHRASE &phraseT, PhraseAlignment *alignment ) +{ + // unaligned word counter + double unaligned = 1.0; + // only checking target words - source words are caught when computing inverse + for(int ti=0; ti<alignment->alignedToT.size(); ti++) { + const set< size_t > & srcIndices = alignment->alignedToT[ ti ]; + if (srcIndices.empty()) { + unaligned *= 2.718; + } + } + return unaligned; +} + +double computeUnalignedFWPenalty( const PHRASE &phraseS, const PHRASE &phraseT, PhraseAlignment *alignment ) +{ + // unaligned word counter + double unaligned = 1.0; + // only checking target words - source words are caught when computing inverse + for(int ti=0; ti<alignment->alignedToT.size(); ti++) { + const set< size_t > & srcIndices = alignment->alignedToT[ ti ]; + if (srcIndices.empty() && functionWordList.find( vcbT.getWord( phraseT[ ti ] ) ) != functionWordList.end()) { + unaligned *= 2.718; + } + } + return unaligned; +} + +void loadFunctionWords( const char *fileName ) +{ + cerr << "Loading function word list from " << fileName; + ifstream inFile; + inFile.open(fileName); + if (inFile.fail()) { + cerr << " - ERROR: could not open file\n"; + exit(1); + } + istream *inFileP = &inFile; + + char line[LINE_MAX_LENGTH]; + while(true) { + SAFE_GETLINE((*inFileP), line, LINE_MAX_LENGTH, '\n', __FILE__); + if (inFileP->eof()) break; + vector<string> token = tokenize( line ); + if (token.size() > 0) + functionWordList.insert( token[0] ); + } + inFile.close(); + + cerr << " - read " << functionWordList.size() << " function words\n"; + inFile.close(); +} + +double computeLexicalTranslation( const PHRASE &phraseS, const PHRASE &phraseT, PhraseAlignment *alignment ) +{ + // lexical translation probability + double lexScore = 1.0; + int null = vcbS.getWordID("NULL"); + // all target words have to be explained + for(int ti=0; ti<alignment->alignedToT.size(); ti++) { + const set< size_t > & srcIndices = alignment->alignedToT[ ti ]; + if (srcIndices.empty()) { + // explain unaligned word by NULL + lexScore *= lexTable.permissiveLookup( null, phraseT[ ti ] ); + } else { + // go through all the aligned words to compute average + double thisWordScore = 0; + for (set< size_t >::const_iterator p(srcIndices.begin()); p != srcIndices.end(); ++p) { + thisWordScore += lexTable.permissiveLookup( phraseS[ *p ], phraseT[ ti ] ); + } + lexScore *= thisWordScore / (double)srcIndices.size(); + } + } + return lexScore; +} + +void LexicalTable::load( char *fileName ) +{ + cerr << "Loading lexical translation table from " << fileName; + ifstream inFile; + inFile.open(fileName); + if (inFile.fail()) { + cerr << " - ERROR: could not open file\n"; + exit(1); + } + istream *inFileP = &inFile; + + char line[LINE_MAX_LENGTH]; + + int i=0; + while(true) { + i++; + if (i%100000 == 0) cerr << "." << flush; + SAFE_GETLINE((*inFileP), line, LINE_MAX_LENGTH, '\n', __FILE__); + if (inFileP->eof()) break; + + vector<string> token = tokenize( line ); + if (token.size() != 3) { + cerr << "line " << i << " in " << fileName + << " has wrong number of tokens, skipping:\n" + << token.size() << " " << token[0] << " " << line << endl; + continue; + } + + double prob = atof( token[2].c_str() ); + WORD_ID wordT = vcbT.storeIfNew( token[0] ); + WORD_ID wordS = vcbS.storeIfNew( token[1] ); + ltable[ wordS ][ wordT ] = prob; + } + cerr << endl; +} |