Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mgiza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'mgizapp/src/model2to3.cpp')
-rw-r--r--mgizapp/src/model2to3.cpp398
1 files changed, 398 insertions, 0 deletions
diff --git a/mgizapp/src/model2to3.cpp b/mgizapp/src/model2to3.cpp
new file mode 100644
index 0000000..d4db81a
--- /dev/null
+++ b/mgizapp/src/model2to3.cpp
@@ -0,0 +1,398 @@
+/*
+
+EGYPT Toolkit for Statistical Machine Translation
+Written by Yaser Al-Onaizan, Jan Curin, Michael Jahr, Kevin Knight, John Lafferty, Dan Melamed, David Purdy, Franz Och, Noah Smith, and David Yarowsky.
+
+This program is free software; you can redistribute it and/or
+modify it under the terms of the GNU General Public License
+as published by the Free Software Foundation; either version 2
+of the License, or (at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program; if not, write to the Free Software
+Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
+USA.
+
+*/
+#include "model3.h"
+#include "utility.h"
+#include "Globals.h"
+
+#define _MAX_FERTILITY 10
+
+double get_sum_of_partitions(int n, int source_pos, double alpha[_MAX_FERTILITY][MAX_SENTENCE_LENGTH_ALLOWED])
+{
+ int done, init ;
+ double sum = 0, prod ;
+ int s, w, u, v;
+ WordIndex k, k1, i ;
+ WordIndex num_parts = 0 ;
+ int total_partitions_considered = 0;
+
+ int part[_MAX_FERTILITY], mult[_MAX_FERTILITY];
+
+ done = false ;
+ init = true ;
+ for (i = 0 ; i < _MAX_FERTILITY ; i++){
+ part[i] = mult[i] = 0 ;
+ }
+
+ //printf("Entering get sum of partitions\n");
+ while(! done){
+ total_partitions_considered++;
+ if (init){
+ part[1] = n ;
+ mult[1] = 1 ;
+ num_parts = 1 ;
+ init = false ;
+ }
+ else {
+ if ((part[num_parts] > 1) || (num_parts > 1)){
+ if (part[num_parts] == 1){
+ s = part[num_parts-1] + mult[num_parts];
+ k = num_parts - 1;
+ }
+ else {
+ s = part[num_parts];
+ k = num_parts ;
+ }
+ w = part[k] - 1 ;
+ u = s / w ;
+ v = s % w ;
+ mult[k] -= 1 ;
+ if (mult[k] == 0)
+ k1 = k ;
+ else k1 = k + 1 ;
+ mult[k1] = u ;
+ part[k1] = w ;
+ if (v == 0){
+ num_parts = k1 ;
+ }
+ else {
+ mult[k1+1] = 1 ;
+ part[k1+1] = v ;
+ num_parts = k1 + 1;
+ }
+ } /* of if num_parts > 1 || part[num_parts] > 1 */
+ else {
+ done = true ;
+ }
+ }
+ /* of else of if(init) */
+ if (!done){
+ prod = 1.0 ;
+ if (n != 0)
+ for (i = 1 ; i <= num_parts ; i++){
+ prod *= pow(alpha[part[i]][source_pos], mult[i]) / factorial(mult[i]) ;
+ }
+ sum += prod ;
+ }
+ } /* of while */
+ if (sum < 0) sum = 0 ;
+ return(sum) ;
+}
+
+void model3::estimate_t_a_d(sentenceHandler& sHandler1, Perplexity& perp, Perplexity& trainVPerp,
+ bool simple, bool dump_files,bool updateT)
+{
+ string tfile, nfile, dfile, p0file, afile, alignfile;
+ WordIndex i, j, l, m, max_fertility_here, k ;
+ PROB val, temp_mult[MAX_SENTENCE_LENGTH_ALLOWED][MAX_SENTENCE_LENGTH_ALLOWED];
+ double cross_entropy;
+ double beta, sum,
+ alpha[_MAX_FERTILITY][MAX_SENTENCE_LENGTH_ALLOWED];
+ double total, temp, r ;
+
+ dCountTable.clear();
+ aCountTable.clear();
+ initAL();
+ nCountTable.clear() ;
+ if (simple)
+ nTable.clear();
+ perp.clear() ;
+ trainVPerp.clear() ;
+ ofstream of2;
+ if (dump_files){
+ alignfile = Prefix +".A2to3";
+ of2.open(alignfile.c_str());
+ }
+ if (simple) cerr <<"Using simple estimation for fertilties\n";
+ sHandler1.rewind() ;
+ sentPair sent ;
+ while(sHandler1.getNextSentence(sent)){
+ Vector<WordIndex>& es = sent.eSent;
+ Vector<WordIndex>& fs = sent.fSent;
+ const float count = sent.getCount();
+ Vector<WordIndex> viterbi_alignment(fs.size());
+ l = es.size() - 1;
+ m = fs.size() - 1;
+ cross_entropy = log(1.0);
+ double viterbi_score = 1 ;
+ PROB word_best_score ; // score for the best mapping of fj
+ for(j = 1 ; j <= m ; j++){
+ word_best_score = 0 ; // score for the best mapping of fj
+ Vector<LpPair<COUNT,PROB> *> sPtrCache(es.size(),0);
+ total = 0 ;
+ WordIndex best_i = 0 ;
+ for(i = 0; i <= l ; i++){
+ sPtrCache[i] = tTable.getPtr(es[i], fs[j]) ;
+ if (sPtrCache[i] != 0 && (*(sPtrCache[i])).prob > PROB_SMOOTH) // if valid pointer
+ temp_mult[i][j]= (*(sPtrCache[i])).prob * aTable.getValue(i, j, l, m) ;
+ else
+ temp_mult[i][j] = PROB_SMOOTH * aTable.getValue(i, j, l, m) ;
+ total += temp_mult[i][j] ;
+ if (temp_mult[i][j] > word_best_score){
+ word_best_score = temp_mult[i][j] ;
+ best_i = i ;
+ }
+ } // end of for (i)
+ viterbi_alignment[j] = best_i ;
+ viterbi_score *= word_best_score ; /// total ;
+ cross_entropy += log(total) ;
+ if (total == 0){
+ cerr << "WARNING: total is zero (TRAIN)\n";
+ viterbi_score = 0 ;
+ }
+ if (total > 0){
+ for(i = 0; i <= l ; i++){
+ temp_mult[i][j] /= total ;
+ if (temp_mult[i][j] == 1) // smooth to prevent underflow
+ temp_mult[i][j] = 0.99 ;
+ else if (temp_mult[i][j] == 0)
+ temp_mult[i][j] = PROB_SMOOTH ;
+ val = temp_mult[i][j] * PROB(count) ;
+ if ( val > PROB_SMOOTH) {
+ if( updateT )
+ {
+ if (sPtrCache[i] != 0)
+ (*(sPtrCache[i])).count += val ;
+ else
+ tTable.incCount(es[i], fs[j], val);
+ }
+ aCountTable.addValue(i, j, l, m,val);
+ if (0 != i)
+ dCountTable.addValue(j, i, l, m,val);
+ }
+ } // for (i = ..)
+ } // for (if total ...)
+ } // end of for (j ...)
+ if (dump_files)
+ printAlignToFile(es, fs, Elist.getVocabList(), Flist.getVocabList(), of2, viterbi_alignment, sent.sentenceNo, viterbi_score);
+ addAL(viterbi_alignment,sent.sentenceNo,l);
+ if (!simple){
+ max_fertility_here = min(WordIndex(m+1), MAX_FERTILITY);
+ for (i = 1; i <= l ; i++) {
+ for ( k = 1; k < max_fertility_here; k++){
+ beta = 0 ;
+ alpha[k][i] = 0 ;
+ for (j = 1 ; j <= m ; j++){
+ temp = temp_mult[i][j];
+ if (temp > 0.95) temp = 0.95; // smooth to prevent under/over flow
+ else if (temp < 0.05) temp = 0.05;
+ beta += pow(temp/(1.0-temp), (double) k);
+ }
+ alpha[k][i] = beta * pow((double) -1, (double) (k+1)) / (double) k ;
+ }
+ }
+ for (i = 1; i <= l ; i++){
+ r = 1;
+ for (j = 1 ; j <= m ; j++)
+ r *= (1 - temp_mult[i][j]);
+ for (k = 0 ; k < max_fertility_here ; k++){
+ sum = get_sum_of_partitions(k, i, alpha);
+ temp = r * sum * count;
+ nCountTable.addValue(es[i], k,temp);
+ } // end of for (k ..)
+ } // end of for (i == ..)
+ } // end of if (!simple)
+ perp.addFactor(cross_entropy, count, l, m,1);
+ trainVPerp.addFactor(log(viterbi_score), count, l, m,1);
+ } // end of while
+ sHandler1.rewind();
+ cerr << "Normalizing t, a, d, n count tables now ... " ;
+ if( dump_files && OutputInAachenFormat==1 )
+ {
+ tfile = Prefix + ".t2to3" ;
+ tTable.printCountTable(tfile.c_str(),Elist.getVocabList(),Flist.getVocabList(),1);
+ }
+ if( updateT )
+ tTable.normalizeTable(Elist, Flist);
+ aCountTable.normalize(aTable);
+ dCountTable.normalize(dTable);
+ if (!simple)
+ nCountTable.normalize(nTable,&Elist.getVocabList());
+ else {
+ for (i = 0 ; i< Elist.uniqTokens() ; i++){
+ if (0 < MAX_FERTILITY){
+ nTable.addValue(i,0,PROB(0.2));
+ if (1 < MAX_FERTILITY){
+ nTable.addValue(i,1,PROB(0.65));
+ if (2 < MAX_FERTILITY){
+ nTable.addValue(i,2,PROB(0.1));
+ if (3 < MAX_FERTILITY)
+ nTable.addValue(i,3,PROB(0.04));
+ PROB val = 0.01/(MAX_FERTILITY-4);
+ for (k = 4 ; k < MAX_FERTILITY ; k++)
+ nTable.addValue(i, k,val);
+ }
+ }
+ }
+ }
+ } // end of else (!simple)
+ p0 = 0.95;
+ p1 = 0.05;
+ if (dump_files){
+ tfile = Prefix + ".t2to3" ;
+ afile = Prefix + ".a2to3" ;
+ nfile = Prefix + ".n2to3" ;
+ dfile = Prefix + ".d2to3" ;
+ p0file = Prefix + ".p0_2to3" ;
+
+ if( OutputInAachenFormat==0 )
+ tTable.printProbTable(tfile.c_str(),Elist.getVocabList(),Flist.getVocabList(),OutputInAachenFormat);
+ aTable.printTable(afile.c_str());
+ dTable.printTable(dfile.c_str());
+ nCountTable.printNTable(Elist.uniqTokens(), nfile.c_str(), Elist.getVocabList(),OutputInAachenFormat);
+ ofstream of(p0file.c_str());
+ of << p0;
+ of.close();
+ }
+ errorReportAL(cerr,"IBM-2");
+ if(simple)
+ {
+ perp.record("T2To3");
+ trainVPerp.record("T2To3");
+ }
+ else
+ {
+ perp.record("ST2To3");
+ trainVPerp.record("ST2To3");
+ }
+}
+
+void model3::transferSimple(/*model1& m1, model2& m2, */ sentenceHandler& sHandler1,
+ bool dump_files, Perplexity& perp, Perplexity& trainVPerp,bool updateT)
+{
+ /*
+ This function performs simple Model 2 -> Model 3 transfer.
+ It sets values for n and p without considering Model 2's ideas.
+ It sets d values based on a.
+ */
+ time_t st, fn;
+ // just inherit these from the previous models, to avoid data duplication
+
+ st = time(NULL);
+ cerr << "==========================================================\n";
+ cerr << "\nTransfer started at: "<< ctime(&st) << '\n';
+
+ cerr << "Simple tranfer of Model2 --> Model3 (i.e. estimating initial parameters of Model3 from Model2 tables)\n";
+
+ estimate_t_a_d(sHandler1, perp, trainVPerp, true, dump_files,updateT) ;
+ fn = time(NULL) ;
+ cerr << "\nTransfer: TRAIN CROSS-ENTROPY " << perp.cross_entropy()
+ << " PERPLEXITY " << perp.perplexity() << '\n';
+ cerr << "\nTransfer took: " << difftime(fn, st) << " seconds\n";
+ cerr << "\nTransfer Finished at: "<< ctime(&fn) << '\n';
+ cerr << "==========================================================\n";
+
+}
+
+
+void model3::transfer(sentenceHandler& sHandler1,bool dump_files, Perplexity& perp, Perplexity& trainVPerp,bool updateT)
+{
+ if (Transfer == TRANSFER_SIMPLE)
+ transferSimple(sHandler1,dump_files,perp, trainVPerp,updateT);
+ {
+ time_t st, fn ;
+
+ st = time(NULL);
+ cerr << "==========================================================\n";
+ cerr << "\nTransfer started at: "<< ctime(&st) << '\n';
+ cerr << "Transfering Model2 --> Model3 (i.e. estimating initial parameters of Model3 from Model2 tables)\n";
+
+ p1_count = p0_count = 0 ;
+
+ estimate_t_a_d(sHandler1, perp, trainVPerp, false, dump_files,updateT);
+
+
+
+ /* Below is a made-up stab at transferring t & a probs to p0/p1.
+ (Method not documented in IBM paper).
+ It seems to give p0 = .96, which may be right for Model 2, or may not.
+ I'm commenting it out for now and hardwiring p0 = .90 as above. -Kevin
+
+ // compute p0, p1 counts
+ Vector<LogProb> nm(Elist.uniqTokens(),0.0);
+
+ for(i=0; i < Elist.uniqTokens(); i++){
+ for(k=1; k < MAX_FERTILITY; k++){
+ nm[i] += nTable.getValue(i, k) * (LogProb) k;
+ }
+ }
+
+ LogProb mprime;
+ // sentenceHandler sHandler1(efFilename.c_str());
+ // sentPair sent ;
+
+ while(sHandler1.getNextSentence(sent)){
+ Vector<WordIndex>& es = sent.eSent;
+ Vector<WordIndex>& fs = sent.fSent;
+ const float count = sent.noOccurrences;
+
+ l = es.size() - 1;
+ m = fs.size() - 1;
+ mprime = 0 ;
+ for (i = 1; i <= l ; i++){
+ mprime += nm[es[i]] ;
+ }
+ mprime = LogProb((int((double) mprime + 0.5))); // round mprime to nearest integer
+ if ((mprime < m) && (2 * mprime >= m)) {
+ // cerr << "updating both p0_count and p1_count, mprime: " << mprime <<
+ // "m = " << m << "\n";
+ p1_count += (m - (double) mprime) * count ;
+ p0_count += (2 * (double) mprime - m) * count ;
+ // cerr << "p0_count = "<<p0_count << " , p1_count = " << p1_count << endl ;
+ }
+ else {
+ // p1_count += 0 ;
+ // cerr << "updating only p0_count, mprime: " << mprime <<
+ // "m = " << m << "\n";
+ p0_count += double(m * count) ;
+ // cerr << "p0_count = "<<p0_count << " , p1_count = " << p1_count << endl ;
+ }
+ }
+
+ // normalize p1, p0
+
+ cerr << "p0_count = "<<p0_count << " , p1_count = " << p1_count << endl ;
+ p1 = p1_count / (p1_count + p0_count ) ;
+ p0 = 1 - p1;
+ cerr << "p0 = "<<p0 << " , p1 = " << p1 << endl ;
+ // Smooth p0 probability to avoid getting zero probability.
+ if (0 == p0){
+ p0 = (LogProb) SMOOTH_THRESHOLD ;
+ p1 = p1 - (LogProb) SMOOTH_THRESHOLD ;
+ }
+ if (0 == p1){
+ p1 = (LogProb) SMOOTH_THRESHOLD ;
+ p0 = p0 - (LogProb) SMOOTH_THRESHOLD ;
+ }
+ */
+
+ fn = time(NULL) ;
+ cerr << "\nTransfer: TRAIN CROSS-ENTROPY " << perp.cross_entropy()
+ << " PERPLEXITY " << perp.perplexity() << '\n';
+ // cerr << "tTable contains " << tTable.getHash().bucket_count()
+ // << " buckets and " << tTable.getHash().size() << " entries." ;
+ cerr << "\nTransfer took: " << difftime(fn, st) << " seconds\n";
+ cerr << "\nTransfer Finished at: "<< ctime(&fn) << endl;
+ cerr << "==========================================================\n";
+
+ }
+
+}