diff options
Diffstat (limited to 'GIZA++-v2/model2.cpp')
-rw-r--r-- | GIZA++-v2/model2.cpp | 232 |
1 files changed, 232 insertions, 0 deletions
diff --git a/GIZA++-v2/model2.cpp b/GIZA++-v2/model2.cpp new file mode 100644 index 0000000..945b91e --- /dev/null +++ b/GIZA++-v2/model2.cpp @@ -0,0 +1,232 @@ +/* + +EGYPT Toolkit for Statistical Machine Translation +Written by Yaser Al-Onaizan, Jan Curin, Michael Jahr, Kevin Knight, John Lafferty, Dan Melamed, David Purdy, Franz Och, Noah Smith, and David Yarowsky. + +This program is free software; you can redistribute it and/or +modify it under the terms of the GNU General Public License +as published by the Free Software Foundation; either version 2 +of the License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program; if not, write to the Free Software +Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, +USA. + +*/ +#include "model2.h" +#include "Globals.h" +#include "utility.h" +#include "Parameter.h" +#include "defs.h" + +extern short NoEmptyWord; + + +GLOBAL_PARAMETER2(int,Model2_Dump_Freq,"MODEL 2 DUMP FREQUENCY","t2","dump frequency of Model 2",PARLEV_OUTPUT,0); + +model2::model2(model1& m,amodel<PROB>&_aTable,amodel<COUNT>&_aCountTable): + model1(m),aTable(_aTable),aCountTable(_aCountTable) +{ } + +void model2::initialize_table_uniformly(sentenceHandler& sHandler1){ + // initialize the aTable uniformly (run this before running em_with_tricks) + int n=0; + sentPair sent ; + sHandler1.rewind(); + while(sHandler1.getNextSentence(sent)){ + Vector<WordIndex>& es = sent.eSent; + Vector<WordIndex>& fs = sent.fSent; + WordIndex l = es.size() - 1; + WordIndex m = fs.size() - 1; + n++; + if(1<=m&&aTable.getValue(l,m,l,m)<=PROB_SMOOTH) + { + PROB uniform_val = 1.0 / (l+1) ; + for(WordIndex j=1; j <= m; j++) + for(WordIndex i=0; i <= l; i++) + aTable.setValue(i,j, l, m, uniform_val); + } + } +} + +int model2::em_with_tricks(int noIterations) +{ + double minErrors=1.0;int minIter=0; + string modelName="Model2",shortModelName="2"; + time_t it_st, st, it_fn, fn; + string tfile, afile, number, alignfile, test_alignfile; + int pair_no = 0; + bool dump_files = false ; + ofstream of2 ; + st = time(NULL) ; + sHandler1.rewind(); + cout << "\n==========================================================\n"; + cout << modelName << " Training Started at: " << ctime(&st) << " iter: " << noIterations << "\n"; + for(int it=1; it <= noIterations ; it++){ + pair_no = 0; + it_st = time(NULL) ; + cout << endl << "-----------\n" << modelName << ": Iteration " << it << '\n'; + dump_files = (Model2_Dump_Freq != 0) && ((it % Model2_Dump_Freq) == 0) && !NODUMPS; + number = ""; + int n = it; + do{ + number.insert((size_t)0, 1, (char)(n % 10 + '0')); + } while((n /= 10) > 0); + tfile = Prefix + ".t" + shortModelName + "." + number ; + afile = Prefix + ".a" + shortModelName + "." + number ; + alignfile = Prefix + ".A" + shortModelName + "." + number ; + test_alignfile = Prefix + ".tst.A" + shortModelName + "." + number ; + aCountTable.clear(); + initAL(); + em_loop(perp, sHandler1, dump_files, alignfile.c_str(), trainViterbiPerp, false); + if( errorsAL()<minErrors ) + { + minErrors=errorsAL(); + minIter=it; + } + if (testPerp && testHandler) + em_loop(*testPerp, *testHandler, dump_files, test_alignfile.c_str(), *testViterbiPerp, true); + if (dump_files&&OutputInAachenFormat==1) + tTable.printCountTable(tfile.c_str(),Elist.getVocabList(),Flist.getVocabList(),1); + tTable.normalizeTable(Elist, Flist); + aCountTable.normalize(aTable); + cout << modelName << ": ("<<it<<") TRAIN CROSS-ENTROPY " << perp.cross_entropy() + << " PERPLEXITY " << perp.perplexity() << '\n'; + if (testPerp && testHandler) + cout << modelName << ": ("<<it<<") TEST CROSS-ENTROPY " << (*testPerp).cross_entropy() + << " PERPLEXITY " << (*testPerp).perplexity() + << '\n'; + cout << modelName << ": ("<<it<<") VITERBI TRAIN CROSS-ENTROPY " << trainViterbiPerp.cross_entropy() + << " PERPLEXITY " << trainViterbiPerp.perplexity() << '\n'; + if (testPerp && testHandler) + cout << modelName << ": ("<<it<<") VITERBI TEST CROSS-ENTROPY " << testViterbiPerp->cross_entropy() + << " PERPLEXITY " << testViterbiPerp->perplexity() + << '\n'; + if (dump_files) + { + if(OutputInAachenFormat==0) + tTable.printProbTable(tfile.c_str(),Elist.getVocabList(),Flist.getVocabList(),OutputInAachenFormat); + aCountTable.printTable(afile.c_str()); + } + it_fn = time(NULL) ; + cout << modelName << " Iteration: " << it<< " took: " << difftime(it_fn, it_st) << " seconds\n"; + } // end of iterations + aCountTable.clear(); + fn = time(NULL) ; + cout << endl << "Entire " << modelName << " Training took: " << difftime(fn, st) << " seconds\n"; + // cout << "tTable contains " << tTable.getHash().bucket_count() + // << " buckets and " << tTable.getHash().size() << " entries." ; + cout << "==========================================================\n"; + return minIter; +} + +void model2::load_table(const char* aname){ + /* This function loads the a table from the given file; use it + when you want to load results from previous a training without + doing any new training. + NAS, 7/11/99 + */ + cout << "Model2: loading a table \n"; + aTable.readTable(aname); +} + + +void model2::em_loop(Perplexity& perp, sentenceHandler& sHandler1, + bool dump_alignment, const char* alignfile, Perplexity& viterbi_perp, + bool test) +{ + massert( aTable.is_distortion==0 ); + massert( aCountTable.is_distortion==0 ); + WordIndex i, j, l, m ; + double cross_entropy; + int pair_no=0 ; + perp.clear(); + viterbi_perp.clear(); + ofstream of2; + // for each sentence pair in the corpus + if (dump_alignment||FEWDUMPS ) + of2.open(alignfile); + sentPair sent ; + + vector<double> ferts(evlist.size()); + + sHandler1.rewind(); + while(sHandler1.getNextSentence(sent)){ + Vector<WordIndex>& es = sent.eSent; + Vector<WordIndex>& fs = sent.fSent; + const float so = sent.getCount(); + l = es.size() - 1; + m = fs.size() - 1; + cross_entropy = log(1.0); + Vector<WordIndex> viterbi_alignment(fs.size()); + double viterbi_score = 1; + for(j=1; j <= m; j++){ + Vector<LpPair<COUNT,PROB> *> sPtrCache(es.size(),0); // cache pointers to table + // entries that map fs to all possible ei in this sentence. + PROB denom = 0.0; + PROB e = 0.0, word_best_score = 0; + WordIndex best_i = 0 ; // i for which fj is best maped to ei + for(i=0; i <= l; i++){ + sPtrCache[i] = tTable.getPtr(es[i], fs[j]) ; + if (sPtrCache[i] != 0 &&(*(sPtrCache[i])).prob > PROB_SMOOTH ) + e = (*(sPtrCache[i])).prob * aTable.getValue(i,j, l, m) ; + else e = PROB_SMOOTH * aTable.getValue(i,j, l, m); + denom += e ; + if (e > word_best_score){ + word_best_score = e ; + best_i = i ; + } + } + viterbi_alignment[j] = best_i ; + viterbi_score *= word_best_score; ///denom ; + cross_entropy += log(denom) ; + if (denom == 0){ + if (test) + cerr << "WARNING: denom is zero (TEST)\n"; + else + cerr << "WARNING: denom is zero (TRAIN)\n"; + } + if (!test){ + if(denom > 0){ + COUNT val = COUNT(so) / (COUNT) double(denom) ; + for( i=0; i <= l; i++){ + PROB e(0.0); + if (sPtrCache[i] != 0 && (*(sPtrCache[i])).prob > PROB_SMOOTH) + e = (*(sPtrCache[i])).prob ; + else e = PROB_SMOOTH ; + e *= aTable.getValue(i,j, l, m); + COUNT temp = COUNT(e) * val ; + if( NoEmptyWord==0 || i!=0 ) + if (sPtrCache[i] != 0) + (*(sPtrCache[i])).count += temp ; + else + tTable.incCount(es[i], fs[j], temp); + aCountTable.getRef(i,j, l, m)+= temp ; + } /* end of for i */ + } // end of if (denom > 0) + }// if (!test) + } // end of for (j) ; + sHandler1.setProbOfSentence(sent,cross_entropy); + perp.addFactor(cross_entropy, so, l, m,1); + viterbi_perp.addFactor(log(viterbi_score), so, l, m,1); + if (dump_alignment||(FEWDUMPS&&sent.sentenceNo<1000) ) + printAlignToFile(es, fs, Elist.getVocabList(), Flist.getVocabList(), of2, viterbi_alignment, sent.sentenceNo, viterbi_score); + addAL(viterbi_alignment,sent.sentenceNo,l); + pair_no++; + } /* of while */ + sHandler1.rewind(); + perp.record("Model2"); + viterbi_perp.record("Model2"); + errorReportAL(cout,"IBM-2"); +} + + + + + |