// $Id$ #include "PhraseDictionaryTreeAdaptor.h" #include #include "PhraseDictionaryTree.h" #include "Phrase.h" #include "FactorCollection.h" #include "InputFileStream.h" #include "Input.h" #include "ConfusionNet.h" inline bool existsFile(const char* filename) { struct stat mystat; return (stat(filename,&mystat)==0); } struct PDTAimp { std::vector m_weights; LMList const* m_languageModels; float m_weightWP; std::vector m_input,m_output; FactorCollection *m_factorCollection; PhraseDictionaryTree *m_dict; mutable std::vector m_tgtColls; typedef std::map MapSrc2Tgt; mutable MapSrc2Tgt m_cache; PhraseDictionaryTreeAdaptor *m_obj; int useCache; typedef std::vector vTPC; std::vector m_rangeCache; PDTAimp(PhraseDictionaryTreeAdaptor *p) : m_languageModels(0),m_weightWP(0.0),m_factorCollection(0),m_dict(0), m_obj(p),useCache(1) {} void Factors2String(FactorArray const& w,std::string& s) const { for(size_t j=0;jToString(); } } void CleanUp() { assert(m_dict); m_dict->FreeMemory(); for(size_t i=0;i p =m_cache.insert(std::make_pair(source,static_cast(0))); if(p.second || p.first->second==0) { TargetPhraseCollection *ptr=new TargetPhraseCollection; ptr->push_back(targetPhrase); p.first->second=ptr; m_tgtColls.push_back(ptr); } else std::cerr<<"WARNING: you added an already existing phrase!\n"; } TargetPhraseCollection const* GetTargetPhraseCollection(Phrase const &src) const { assert(m_dict); if(src.GetSize()==0) return 0; std::pair piter; if(useCache) { piter=m_cache.insert(std::make_pair(src,static_cast(0))); if(!piter.second) return piter.first->second; } else if (m_cache.size()) { MapSrc2Tgt::const_iterator i=m_cache.find(src); return (i!=m_cache.end() ? i->second : 0); } std::vector srcString(src.GetSize()); // convert source Phrase into vector of strings for(size_t i=0;i cands; m_dict->GetTargetCandidates(srcString,cands); if(cands.empty()) return 0; std::vector tCands;tCands.reserve(cands.size()); std::vector > costs;costs.reserve(cands.size()); // convert into TargetPhrases for(size_t i=0;iempty()) { delete rv; return 0; } else { if(useCache) piter.first->second=rv; m_tgtColls.push_back(rv); return rv; } } void Create(const std::vector &input , const std::vector &output , FactorCollection &factorCollection , const std::string &filePath , const std::vector &weight , const LMList &languageModels , float weightWP ) { // set my members m_factorCollection=&factorCollection; m_dict=new PhraseDictionaryTree(weight.size()); m_input=input; m_output=output; m_languageModels=&languageModels; m_weightWP=weightWP; m_weights=weight; std::string binFname=filePath+".binphr.idx"; if(!existsFile(binFname.c_str())) { TRACE_ERR("bin ttable does not exist -> create it\n"); InputFileStream in(filePath); m_dict->Create(in,filePath); } TRACE_ERR("reading bin ttable\n"); m_dict->Read(filePath); } typedef PhraseDictionaryTree::PrefixPtr PPtr; typedef std::pair Range; struct State { PPtr ptr; Range range; float score; State() : range(0,0),score(0.0) {} State(size_t b,size_t e,const PPtr& v,float sc=0.0) : ptr(v),range(b,e),score(sc) {} State(Range const& r,const PPtr& v,float sc=0.0) : ptr(v),range(r),score(sc) {} size_t begin() const {return range.first;} size_t end() const {return range.second;} float GetScore() const {return score;} }; void CreateTargetPhrase(TargetPhrase& targetPhrase, StringTgtCand::first_type const& factorStrings, StringTgtCand::second_type const& scoreVector) const { for(size_t k=0;k factors=Tokenize(*factorStrings[k],"|"); FactorArray& fa=targetPhrase.AddWord(); for(size_t l=0;lAddFactor(Output, m_output[l], factors[l]); } targetPhrase.SetScore(scoreVector, m_weights, *m_languageModels, m_weightWP); } TargetPhraseCollection* PruneTargetCandidates(std::vector const & tCands,std::vector >& costs) const { // prune target candidates and sort according to score std::vector >::iterator nth=costs.end(); if(m_obj->m_maxTargetPhrase>0 && costs.size()>m_obj->m_maxTargetPhrase) { nth=costs.begin()+m_obj->m_maxTargetPhrase; std::nth_element(costs.begin(),nth,costs.end(),std::greater >()); } std::sort(costs.begin(),nth,std::greater >()); // convert into TargerPhraseCollection TargetPhraseCollection *rv=new TargetPhraseCollection; for(std::vector >::iterator it=costs.begin();it!=nth;++it) rv->push_back(tCands[it->second]); return rv; } void CacheSource(ConfusionNet const& src) { assert(m_dict); std::vector stack; for(size_t i=0;iGetRoot())); typedef StringTgtCand::first_type sPhrase; typedef std::map > E2Costs; std::map cov2cand; while(!stack.empty()) { State curr(stack.back()); stack.pop_back(); assert(curr.end()Extend(curr.ptr,s); if(nextP) { Range newRange(curr.begin(),curr.end()+1); if(newRange.second tcands; m_dict->GetTargetCandidates(nextP,tcands); if(tcands.size()) { E2Costs& e2costs=cov2cand[newRange]; for(size_t i=0;i p=e2costs.insert(std::make_pair(tcands[i].first,std::make_pair(costs,tcands[i].second))); if(!p.second) { if(p.first->second.first>costs) p.first->second=std::make_pair(costs,tcands[i].second); } } } } } } // end while(!stack.empty()) m_rangeCache.resize(src.GetSize(),vTPC(src.GetSize(),0)); for(std::map::const_iterator i=cov2cand.begin();i!=cov2cand.end();++i) { assert(i->first.firstfirst.second>0); assert(i->first.second-1first.first].size()); assert(m_rangeCache[i->first.first][i->first.second-1]==0); std::vector tCands;tCands.reserve(i->second.size()); std::vector > costs;costs.reserve(i->second.size()); for(E2Costs::const_iterator j=i->second.begin();j!=i->second.end();++j) { TargetPhrase targetPhrase(Output, m_obj); CreateTargetPhrase(targetPhrase,j->first,j->second.second); costs.push_back(std::make_pair(targetPhrase.GetFutureScore(),tCands.size())); tCands.push_back(targetPhrase); } TargetPhraseCollection *rv=PruneTargetCandidates(tCands,costs); if(rv->empty()) delete rv; else { m_rangeCache[i->first.first][i->first.second-1]=rv; m_tgtColls.push_back(rv); } } } }; /************************************************************* function definitions of the interface class virtually everything is forwarded to the implementation class *************************************************************/ PhraseDictionaryTreeAdaptor:: PhraseDictionaryTreeAdaptor(size_t noScoreComponent) : MyBase(noScoreComponent),imp(new PDTAimp(this)) {} PhraseDictionaryTreeAdaptor::~PhraseDictionaryTreeAdaptor() { imp->CleanUp(); } void PhraseDictionaryTreeAdaptor::CleanUp() { imp->CleanUp(); MyBase::CleanUp(); } void PhraseDictionaryTreeAdaptor::InitializeForInput(InputType const& source) { // only required for confusion net if(ConfusionNet const* cn=dynamic_cast(&source)) imp->CacheSource(*cn); } void PhraseDictionaryTreeAdaptor::Create(const std::vector &input , const std::vector &output , FactorCollection &factorCollection , const std::string &filePath , const std::vector &weight , size_t maxTargetPhrase , const LMList &languageModels , float weightWP ) { if(m_noScoreComponent!=weight.size()) { std::cerr<<"ERROR: mismatch of number of scaling factors: "<Create(input,output,factorCollection,filePath, weight,languageModels,weightWP); } TargetPhraseCollection const* PhraseDictionaryTreeAdaptor::GetTargetPhraseCollection(Phrase const &src) const { return imp->GetTargetPhraseCollection(src); } TargetPhraseCollection const* PhraseDictionaryTreeAdaptor::GetTargetPhraseCollection(InputType const& src,WordsRange const &range) const { if(imp->m_rangeCache.empty()) return imp->GetTargetPhraseCollection(src.GetSubString(range)); else return imp->m_rangeCache[range.GetStartPos()][range.GetEndPos()]; } void PhraseDictionaryTreeAdaptor:: SetWeightTransModel(const std::vector &weightT) { CleanUp(); imp->m_weights=weightT; } void PhraseDictionaryTreeAdaptor:: AddEquivPhrase(const Phrase &source, const TargetPhrase &targetPhrase) { imp->AddEquivPhrase(source,targetPhrase); } void PhraseDictionaryTreeAdaptor::EnableCache() { imp->useCache=1; } void PhraseDictionaryTreeAdaptor::DisableCache() { imp->useCache=0; }