diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2016-10-05 17:15:47 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2016-10-05 17:32:56 +0300 |
commit | 041b13eb19f364b79809a7efa08c4552d41d4e75 (patch) | |
tree | 44d8c9f2bd7a182ccd2591bf0deeba91fd07ff80 | |
parent | 2eea4dd5e0e369a43300298190c4b860c17d19ad (diff) |
compiles but segfault
-rw-r--r-- | moses/ScoreComponentCollection.h | 9 | ||||
-rw-r--r-- | moses/TranslationModel/ProbingPT/ProbingPT.cpp | 122 | ||||
-rw-r--r-- | moses/TranslationModel/ProbingPT/ProbingPT.h | 14 |
3 files changed, 123 insertions, 22 deletions
diff --git a/moses/ScoreComponentCollection.h b/moses/ScoreComponentCollection.h index 1305e9c16..0ab57a73a 100644 --- a/moses/ScoreComponentCollection.h +++ b/moses/ScoreComponentCollection.h @@ -247,6 +247,15 @@ public: } } + void PlusEquals(const FeatureFunction* sp, float scores[]) + { + size_t numScores = sp->GetNumScoreComponents(); + size_t offset = sp->GetIndex(); + for (size_t i = 0; i < numScores; ++i) { + m_scores[i + offset] += scores[i]; + } + } + //! Special version PlusEquals(ScoreProducer, vector<float>) //! to add the score from a single ScoreProducer that produces //! a single value diff --git a/moses/TranslationModel/ProbingPT/ProbingPT.cpp b/moses/TranslationModel/ProbingPT/ProbingPT.cpp index 1298f8149..1fd982f0e 100644 --- a/moses/TranslationModel/ProbingPT/ProbingPT.cpp +++ b/moses/TranslationModel/ProbingPT/ProbingPT.cpp @@ -137,25 +137,6 @@ void ProbingPT::GetTargetPhraseCollectionBatch(const InputPathList &inputPathQue } } -std::vector<uint64_t> ProbingPT::ConvertToProbingSourcePhrase(const Phrase &sourcePhrase, bool &ok) const -{ - size_t size = sourcePhrase.GetSize(); - std::vector<uint64_t> ret(size); - for (size_t i = 0; i < size; ++i) { - const Factor *factor = sourcePhrase.GetFactor(i, m_input[0]); - uint64_t probingId = GetSourceProbingId(factor); - if (probingId == m_unkId) { - ok = false; - return ret; - } else { - ret[i] = probingId; - } - } - - ok = true; - return ret; -} - TargetPhraseCollection::shared_ptr ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase) const { // create a target phrase from the 1st word of the source, prefix with 'ProbingPT:' @@ -243,7 +224,110 @@ uint64_t ProbingPT::GetSourceProbingId(const Word &word) const TargetPhraseCollection *ProbingPT::CreateTargetPhrases( const Phrase &sourcePhrase, uint64_t key) const { + TargetPhraseCollection *tps = NULL; + + //Actual lookup + std::pair<bool, uint64_t> query_result; // 1st=found, 2nd=target file offset + query_result = m_engine->query(key); + //cerr << "key2=" << query_result.second << endl; + + if (query_result.first) { + const char *offset = data + query_result.second; + uint64_t *numTP = (uint64_t*) offset; + + tps = new TargetPhraseCollection(); + + offset += sizeof(uint64_t); + for (size_t i = 0; i < *numTP; ++i) { + TargetPhrase *tp = CreateTargetPhrase(offset); + assert(tp); + tp->EvaluateInIsolation(sourcePhrase, GetFeaturesToApply()); + + tps->Add(tp); + + } + + tps->Prune(true, m_tableLimit); + //cerr << *tps << endl; + } + + return tps; + +} + +TargetPhrase *ProbingPT::CreateTargetPhrase( + const char *&offset) const +{ + TargetPhraseInfo *tpInfo = (TargetPhraseInfo*) offset; + size_t numRealWords = tpInfo->numWords / m_output.size(); + + TargetPhrase *tp = new TargetPhrase(this); + + offset += sizeof(TargetPhraseInfo); + + // scores + float *scores = (float*) offset; + + size_t totalNumScores = m_engine->num_scores + m_engine->num_lex_scores; + + if (m_engine->logProb) { + // set pt score for rule + tp->GetScoreBreakdown().PlusEquals(this, scores); + + // save scores for other FF, eg. lex RO. Just give the offset + /* + if (m_engine->num_lex_scores) { + tp->scoreProperties = scores + m_engine->num_scores; + } + */ + } + else { + // log score 1st + float logScores[totalNumScores]; + for (size_t i = 0; i < totalNumScores; ++i) { + logScores[i] = FloorScore(TransformScore(scores[i])); + } + + // set pt score for rule + tp->GetScoreBreakdown().PlusEquals(this, logScores); + + // save scores for other FF, eg. lex RO. + /* + tp->scoreProperties = pool.Allocate<SCORE>(m_engine->num_lex_scores); + for (size_t i = 0; i < m_engine->num_lex_scores; ++i) { + tp->scoreProperties[i] = logScores[i + m_engine->num_scores]; + } + */ + } + + offset += sizeof(float) * totalNumScores; + + // words + for (size_t targetPos = 0; targetPos < numRealWords; ++targetPos) { + for (size_t i = 0; i < m_output.size(); ++i) { + FactorType factorType = m_output[i]; + + uint32_t *probingId = (uint32_t*) offset; + + const Factor *factor = GetTargetFactor(*probingId); + assert(factor); + + Word &word = tp->GetWord(targetPos); + word[factorType] = factor; + + offset += sizeof(uint32_t); + } + } + + // align + uint32_t alignTerm = tpInfo->alignTerm; + //cerr << "alignTerm=" << alignTerm << endl; + UTIL_THROW_IF2(alignTerm >= m_aligns.size(), "Unknown alignInd"); + tp->SetAlignTerm(m_aligns[alignTerm]); + + // properties TODO + return tp; } ////////////////////////////////////////////////////////////////// diff --git a/moses/TranslationModel/ProbingPT/ProbingPT.h b/moses/TranslationModel/ProbingPT/ProbingPT.h index 98d052e07..21c01df28 100644 --- a/moses/TranslationModel/ProbingPT/ProbingPT.h +++ b/moses/TranslationModel/ProbingPT/ProbingPT.h @@ -56,17 +56,25 @@ protected: void CreateAlignmentMap(const std::string path); TargetPhraseCollection::shared_ptr CreateTargetPhrase(const Phrase &sourcePhrase) const; - uint64_t GetSourceProbingId(const Factor *factor) const; - - std::vector<uint64_t> ConvertToProbingSourcePhrase(const Phrase &sourcePhrase, bool &ok) const; std::pair<bool, uint64_t> GetKey(const Phrase &sourcePhrase) const; void GetSourceProbingIds(const Phrase &sourcePhrase, bool &ok, uint64_t probingSource[]) const; uint64_t GetSourceProbingId(const Word &word) const; + uint64_t GetSourceProbingId(const Factor *factor) const; TargetPhraseCollection *CreateTargetPhrases( const Phrase &sourcePhrase, uint64_t key) const; + TargetPhrase *CreateTargetPhrase( + const char *&offset) const; + + inline const Factor *GetTargetFactor(uint32_t probingId) const + { + if (probingId >= m_targetVocab.size()) { + return NULL; + } + return m_targetVocab[probingId]; + } }; |