diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2016-10-04 18:48:52 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2016-10-04 18:48:52 +0300 |
commit | 2eea4dd5e0e369a43300298190c4b860c17d19ad (patch) | |
tree | 2eb6668dd0c8bed3ebf6086fdd6134b6f52486ca | |
parent | 3a72b4958a3fc468b6bd6102e67e24007c9b2d9b (diff) |
compiles
-rw-r--r-- | moses/TranslationModel/ProbingPT/ProbingPT.cpp | 131 | ||||
-rw-r--r-- | moses/TranslationModel/ProbingPT/ProbingPT.h | 15 |
2 files changed, 76 insertions, 70 deletions
diff --git a/moses/TranslationModel/ProbingPT/ProbingPT.cpp b/moses/TranslationModel/ProbingPT/ProbingPT.cpp index bb3f26e22..1298f8149 100644 --- a/moses/TranslationModel/ProbingPT/ProbingPT.cpp +++ b/moses/TranslationModel/ProbingPT/ProbingPT.cpp @@ -161,99 +161,94 @@ TargetPhraseCollection::shared_ptr ProbingPT::CreateTargetPhrase(const Phrase &s // create a target phrase from the 1st word of the source, prefix with 'ProbingPT:' assert(sourcePhrase.GetSize()); - TargetPhraseCollection::shared_ptr tpColl; - bool ok; - vector<uint64_t> probingSource = ConvertToProbingSourcePhrase(sourcePhrase, ok); - if (!ok) { - // source phrase contains a word unknown in the pt. - // We know immediately there's no translation for it - return tpColl; + std::pair<bool, uint64_t> keyStruct = GetKey(sourcePhrase); + if (!keyStruct.first) { + return TargetPhraseCollection::shared_ptr(); } - std::pair<bool, std::vector<target_text> > query_result; - - //Actual lookup - query_result = m_engine->query(probingSource); + // check in cache + CachePb::const_iterator iter = m_cachePb.find(keyStruct.second); + if (iter != m_cachePb.end()) { + //cerr << "FOUND IN CACHE " << keyStruct.second << " " << sourcePhrase.Debug(mgr.system) << endl; + TargetPhraseCollection *tps = iter->second; + return TargetPhraseCollection::shared_ptr(tps); + } - if (query_result.first) { - //m_engine->printTargetInfo(query_result.second); - tpColl.reset(new TargetPhraseCollection()); + // query pt + TargetPhraseCollection *tps = CreateTargetPhrases(sourcePhrase, + keyStruct.second); + return TargetPhraseCollection::shared_ptr(tps); +} - const std::vector<target_text> &probingTargetPhrases = query_result.second; - for (size_t i = 0; i < probingTargetPhrases.size(); ++i) { - const target_text &probingTargetPhrase = probingTargetPhrases[i]; - TargetPhrase *tp = CreateTargetPhrase(sourcePhrase, probingTargetPhrase); +std::pair<bool, uint64_t> ProbingPT::GetKey(const Phrase &sourcePhrase) const +{ + std::pair<bool, uint64_t> ret; - tpColl->Add(tp); - } + // create a target phrase from the 1st word of the source, prefix with 'ProbingPT:' + size_t sourceSize = sourcePhrase.GetSize(); + assert(sourceSize); - tpColl->Prune(true, m_tableLimit); + uint64_t probingSource[sourceSize]; + GetSourceProbingIds(sourcePhrase, ret.first, probingSource); + if (!ret.first) { + // source phrase contains a word unknown in the pt. + // We know immediately there's no translation for it + } + else { + ret.second = m_engine->getKey(probingSource, sourceSize); } - return tpColl; + return ret; + } -TargetPhrase *ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase, const target_text &probingTargetPhrase) const +void ProbingPT::GetSourceProbingIds(const Phrase &sourcePhrase, + bool &ok, uint64_t probingSource[]) const { - const std::vector<unsigned int> &probingPhrase = probingTargetPhrase.target_phrase; - size_t size = probingPhrase.size(); - - TargetPhrase *tp = new TargetPhrase(this); - // words + size_t size = sourcePhrase.GetSize(); for (size_t i = 0; i < size; ++i) { - uint64_t probingId = probingPhrase[i]; - const Factor *factor = GetTargetFactor(probingId); - assert(factor); - - Word &word = tp->AddWord(); - word.SetFactor(m_output[0], factor); + const Word &word = sourcePhrase.GetWord(i); + uint64_t probingId = GetSourceProbingId(word); + if (probingId == m_unkId) { + ok = false; + return; + } + else { + probingSource[i] = probingId; + } } - // score for this phrase table - vector<float> scores = probingTargetPhrase.prob; - std::transform(scores.begin(), scores.end(), scores.begin(),TransformScore); - tp->GetScoreBreakdown().PlusEquals(this, scores); + ok = true; +} - // alignment - /* - const std::vector<unsigned char> &alignments = probingTargetPhrase.word_all1; +uint64_t ProbingPT::GetSourceProbingId(const Word &word) const +{ + uint64_t ret = 0; - AlignmentInfo &aligns = tp->GetAlignTerm(); - for (size_t i = 0; i < alignS.size(); i += 2 ) { - aligns.Add((size_t) alignments[i], (size_t) alignments[i+1]); + for (size_t i = 0; i < m_input.size(); ++i) { + FactorType factorType = m_input[i]; + const Factor *factor = word[factorType]; + + size_t factorId = factor->GetId(); + if (factorId >= m_sourceVocab.size()) { + return m_unkId; + } + ret += m_sourceVocab[factorId]; } - */ - // score of all other ff when this rule is being loaded - tp->EvaluateInIsolation(sourcePhrase, GetFeaturesToApply()); - return tp; + return ret; } -const Factor *ProbingPT::GetTargetFactor(uint64_t probingId) const +TargetPhraseCollection *ProbingPT::CreateTargetPhrases( + const Phrase &sourcePhrase, uint64_t key) const { - TargetVocabMap::right_map::const_iterator iter; - iter = m_vocabMap.right.find(probingId); - if (iter != m_vocabMap.right.end()) { - return iter->second; - } else { - // not in mapping. Must be UNK - return NULL; - } -} -uint64_t ProbingPT::GetSourceProbingId(const Factor *factor) const -{ - SourceVocabMap::left_map::const_iterator iter; - iter = m_sourceVocabMap.left.find(factor); - if (iter != m_sourceVocabMap.left.end()) { - return iter->second; - } else { - // not in mapping. Must be UNK - return m_unkId; - } } +////////////////////////////////////////////////////////////////// + + ChartRuleLookupManager *ProbingPT::CreateRuleLookupManager( const ChartParser &, const ChartCellCollectionBase &, diff --git a/moses/TranslationModel/ProbingPT/ProbingPT.h b/moses/TranslationModel/ProbingPT/ProbingPT.h index 3b5dfc895..98d052e07 100644 --- a/moses/TranslationModel/ProbingPT/ProbingPT.h +++ b/moses/TranslationModel/ProbingPT/ProbingPT.h @@ -2,6 +2,7 @@ #pragma once #include <boost/iostreams/device/mapped_file.hpp> #include <boost/bimap.hpp> +#include <boost/unordered_map.hpp> #include "../PhraseDictionary.h" @@ -48,15 +49,25 @@ protected: boost::iostreams::mapped_file_source file; const char *data; + // caching + typedef boost::unordered_map<uint64_t, TargetPhraseCollection*> CachePb; + CachePb m_cachePb; + void CreateAlignmentMap(const std::string path); TargetPhraseCollection::shared_ptr CreateTargetPhrase(const Phrase &sourcePhrase) const; - TargetPhrase *CreateTargetPhrase(const Phrase &sourcePhrase, const target_text &probingTargetPhrase) const; - const Factor *GetTargetFactor(uint64_t probingId) const; uint64_t GetSourceProbingId(const Factor *factor) const; std::vector<uint64_t> ConvertToProbingSourcePhrase(const Phrase &sourcePhrase, bool &ok) const; + std::pair<bool, uint64_t> GetKey(const Phrase &sourcePhrase) const; + void GetSourceProbingIds(const Phrase &sourcePhrase, bool &ok, + uint64_t probingSource[]) const; + uint64_t GetSourceProbingId(const Word &word) const; + + TargetPhraseCollection *CreateTargetPhrases( + const Phrase &sourcePhrase, uint64_t key) const; + }; } // namespace Moses |