/////////////////////////////////////////////////////////////////////////////// // // // This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. // // // // ModelBlocks is free software: you can redistribute it and/or modify // // it under the terms of the GNU General Public License as published by // // the Free Software Foundation, either version 3 of the License, or // // (at your option) any later version. // // // // ModelBlocks is distributed in the hope that it will be useful, // // but WITHOUT ANY WARRANTY; without even the implied warranty of // // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // // GNU General Public License for more details. // // // // You should have received a copy of the GNU General Public License // // along with ModelBlocks. If not, see . // // // // ModelBlocks developers designate this particular file as subject to // // the "Moses" exception as provided by ModelBlocks developers in // // the LICENSE file that accompanies this code. // // // /////////////////////////////////////////////////////////////////////////////// #ifndef _NL_HMM_ #define _NL_HMM_ #include #include #include #include //#include //#include //#include #include "nl-prob.h" #include "nl-safeids.h" #include "nl-beam.h" typedef int Frame; //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// // // NullBackDat - default empty back-pointer data; can replace with word or sem relation // //////////////////////////////////////////////////////////////////////////////// template class NullBackDat { static const string sDummy; public: NullBackDat () {} NullBackDat (const MY& my) {} void write (FILE*) const {} string getString() const { return sDummy; } }; template const string NullBackDat::sDummy ( "" ); //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// // // Index - pointer to source in previous beam heap // //////////////////////////////////////////////////////////////////////////////// class Index : public Id { public: Index ( ) { } Index (int i) {set(i);} Index& operator++ ( ) {set(toInt()+1); return *this;} }; //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// // // TrellNode - node in viterbi trellis // //////////////////////////////////////////////////////////////////////////////// template class TrellNode { private: // Data members... Index indSource; B backptrData; S sId; LogProb lgprMax; public: // Constructor / destructor methods... TrellNode ( ) { } TrellNode ( const Index& indS, const S& sI, const B& bDat, LogProb lgpr) { indSource=indS; sId=sI; lgprMax=lgpr; backptrData=bDat; /* fo = -1; */ } // Specification methods... const Index& setSource ( ) const { return indSource; } const B& setBackData( ) const { return backptrData; } const S& setId ( ) const { return sId; } LogProb& setScore ( ) { return lgprMax; } // Extraction methods... bool operator== ( const TrellNode& tnsb ) const { return(sId==tnsb.sId); } // size_t getHashKey ( ) const { return sId.getHashKey(); } const Index& getSource ( ) const { return indSource; } const B& getBackData( ) const { return backptrData; } const S& getId ( ) const { return sId; } LogProb getLogProb ( ) const { return lgprMax; } LogProb getScore ( ) const { return lgprMax; } }; //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// // // HMM // //////////////////////////////////////////////////////////////////////////////// template > class HMM { private: typedef std::pair IB; // Data members... const MY& my; const MX& mx; SafeArray2D,Id,TrellNode > aatnTrellis; Frame frameLast; int iNextNode; public: // Static member varaibles... static bool OUTPUT_QUIET; static bool OUTPUT_NOISY; static bool OUTPUT_VERYNOISY; static int BEAM_WIDTH; // Constructor / destructor methods... HMM ( const MY& my1, const MX& mx1 ) : my(my1), mx(mx1) { } // Specification methods... void init ( int, int, const S& ) ; void init ( int, int, SafeArray1D,pair >* ); void updateRanked ( const typename MX::RandVarType&, bool ) ; void updateSerial ( const typename MX::RandVarType& ) ; void updatePara ( const typename MX::RandVarType& ) ; bool unknown ( const typename MX::RandVarType& ) ; void each ( const typename MX::RandVarType&, Beam&, SafeArray1D,std::pair,LogProb> >& ) ; // Extraction methods... const TrellNode& getTrellNode ( int i ) const { return aatnTrellis.get(frameLast,i); } int getBeamUsed ( int ) const ; // Input / output methods... void writeMLS ( FILE* ) const ; void writeMLS ( FILE*, const S& ) const ; void debugPrint() const; double getCurrSum(int) const; //void writeCurr ( FILE*, int ) const ; void writeCurr ( ostream&, int ) const ; void writeCurrSum ( FILE*, int ) const ; void gatherElementsInBeam( SafeArray1D,pair >* result, int f ) const; void writeCurrEntropy ( FILE*, int ) const; //void writeCurrDepths ( FILE*, int ) const; void writeFoll ( FILE*, int, int, const typename MX::RandVarType& ) const ; void writeFollRanked ( FILE*, int, int, const typename MX::RandVarType&, bool ) const ; std::list getMLS() const; std::list > getMLSnodes() const; std::list getMLS(const S&) const; std::list > getMLSnodes(const S&) const; }; template bool HMM::OUTPUT_QUIET = false; template bool HMM::OUTPUT_NOISY = false; template bool HMM::OUTPUT_VERYNOISY = false; template int HMM::BEAM_WIDTH = 1; //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// template void HMM::init ( int numFr, int numS, const S& s ) { // Alloc trellis... BEAM_WIDTH = numS; aatnTrellis.init(numFr,BEAM_WIDTH); frameLast=0; // Set initial element at first time slice... aatnTrellis.set(frameLast,0) = TrellNode ( Index(0), s, B(), 0 ) ; } template void HMM::init ( int numFr, int beamSize, SafeArray1D,pair >* existingBeam ) { // Alloc trellis... // int numToCopy = existingBeam->getSize(); BEAM_WIDTH = beamSize; aatnTrellis.init(numFr,BEAM_WIDTH); frameLast=0; // Set initial beam elements at first time slice... for ( int i=0, n=existingBeam->getSize(); i ( Index(0), existingBeam->get(i).first, B(), existingBeam->get(i).second ) ; } } template void HMM::debugPrint() const{ for (int frame=0, numFrames=aatnTrellis.getxSize(); frame 0) { cerr << "\t" << "aatnTrellis.get(frame=" << frame << ",beamIndex=" << beamIndex << ") is\t" << aatnTrellis.get(frame,beamIndex).getId() << "\tprob=" << aatnTrellis.get(frame,beamIndex).getLogProb().toDouble() << endl; } } } } //////////////////////////////////////////////////////////////////////////////// template bool outRank ( const quad >& a1, const quad >& a2 ) { return (a1.third>a2.third); } template bool HMM::unknown( const typename MX::RandVarType& x ) { return mx.unknown(x); } template void HMM::updateRanked ( const typename MX::RandVarType& x, bool b1 ) { // Increment frame counter... frameLast++; // Init beam for new frame... Beam btn(BEAM_WIDTH); SafeArray1D,std::pair,LogProb> > atnSorted (BEAM_WIDTH); Heap < quad >, outRank > ashpiQueue; typedef quad > SHPI; SHPI shpi, shpiTop; int aCtr; ashpiQueue.clear(); //shpi.first = -1; //shpi.second = HModel::IterVal(); //shpi.third = 1.0; shpi.first = 0; shpi.third = aatnTrellis.get(frameLast-1,shpi.first).getScore(); shpi.third *= my.setIterProb ( shpi.second, aatnTrellis.get(frameLast-1,shpi.first).getId(), x, b1, aCtr=-1 ); //S s; my.setTrellDat(s,shpi.second); shpi.fourth = -1; ////cerr<<"????? "<0; iTrg++ ) { // Iterate A* (best-first) search until a complete path is at the top of the queue... while ( ashpiQueue.getSize() > 0 && ashpiQueue.getTop().fourth < MY::IterVal::NUM_ITERS ) { // Remove top... shpiTop = ashpiQueue.dequeueTop(); // Fork off (try to advance each elementary variable a)... for ( int a=shpiTop.fourth.toInt(); a<=MY::IterVal::NUM_ITERS; a++ ) { // Copy top into new queue element... shpi = shpiTop; // At variable position -1, advance beam element for transition source... if ( a == -1 ) shpi.first++; // Incorporate prob from transition source... shpi.third = aatnTrellis.get(frameLast-1,shpi.first).getScore(); if ( shpi.third > LogProb() ) { // Try to advance variable at position a and return probability (subsequent variables set to first, probability ignored)... shpi.third *= my.setIterProb ( shpi.second, aatnTrellis.get(frameLast-1,shpi.first).getId(), x, b1, aCtr=a ); // At end of variables, incorporate observation probability... if ( a == MY::IterVal::NUM_ITERS && shpi.fourth != MY::IterVal::NUM_ITERS ) { S s; my.setTrellDat(s,shpi.second); shpi.third *= mx.getProb(x,s); } // Record variable position at which this element was forked off... shpi.fourth = a; //cerr<<" from partial: "< LogProb() ) { ////if ( frameLast == 4 ) cerr<<" from partial: "< 0 ) { S s; my.setTrellDat(s,ashpiQueue.getTop().second); bFull |= btn.tryAdd ( s, IB(ashpiQueue.getTop().first,my.setBackDat(ashpiQueue.getTop().second)), ashpiQueue.getTop().third ); ////cerr<,LogProb>* tn1 = &atnSorted.get(i); aatnTrellis.set(frameLast,i)=TrellNode(tn1->first.second.first, tn1->first.first, tn1->first.second.second, tn1->second); } my.update(); } //////////////////////////////////////////////////////////////////////////////// template void HMM::updateSerial ( const typename MX::RandVarType& x ) { // Increment frame counter... frameLast++; // Init beam for new frame... Beam btn(BEAM_WIDTH); SafeArray1D,std::pair,LogProb> > atnSorted (BEAM_WIDTH); // // Copy beam to trellis... // for ( int i=0; i