diff options
Diffstat (limited to 'moses/TranslationModel/ProbingPT/ProbingPT.cpp')
-rw-r--r-- | moses/TranslationModel/ProbingPT/ProbingPT.cpp | 315 |
1 files changed, 218 insertions, 97 deletions
diff --git a/moses/TranslationModel/ProbingPT/ProbingPT.cpp b/moses/TranslationModel/ProbingPT/ProbingPT.cpp index cbfd2c1a4..1ae0c67c3 100644 --- a/moses/TranslationModel/ProbingPT/ProbingPT.cpp +++ b/moses/TranslationModel/ProbingPT/ProbingPT.cpp @@ -3,8 +3,9 @@ #include "moses/StaticData.h" #include "moses/FactorCollection.h" #include "moses/TargetPhraseCollection.h" +#include "moses/InputFileStream.h" #include "moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerSkeleton.h" -#include "quering.hh" +#include "querying.hh" using namespace std; @@ -34,44 +35,94 @@ void ProbingPT::Load(AllOptions::ptr const& opts) m_unkId = 456456546456; + FactorCollection &vocab = FactorCollection::Instance(); + // source vocab - const std::map<uint64_t, std::string> &sourceVocab = m_engine->getSourceVocab(); + const std::map<uint64_t, std::string> &sourceVocab = + m_engine->getSourceVocab(); std::map<uint64_t, std::string>::const_iterator iterSource; - for (iterSource = sourceVocab.begin(); iterSource != sourceVocab.end(); ++iterSource) { - const string &wordStr = iterSource->second; - const Factor *factor = FactorCollection::Instance().AddFactor(wordStr); + for (iterSource = sourceVocab.begin(); iterSource != sourceVocab.end(); + ++iterSource) { + string wordStr = iterSource->second; + //cerr << "wordStr=" << wordStr << endl; - uint64_t probingId = iterSource->first; + const Factor *factor = vocab.AddFactor(wordStr); - SourceVocabMap::value_type entry(factor, probingId); - m_sourceVocabMap.insert(entry); + uint64_t probingId = iterSource->first; + size_t factorId = factor->GetId(); + if (factorId >= m_sourceVocab.size()) { + m_sourceVocab.resize(factorId + 1, m_unkId); + } + m_sourceVocab[factorId] = probingId; } // target vocab - const std::map<unsigned int, std::string> &probingVocab = m_engine->getVocab(); - std::map<unsigned int, std::string>::const_iterator iter; - for (iter = probingVocab.begin(); iter != probingVocab.end(); ++iter) { - const string &wordStr = iter->second; - const Factor *factor = FactorCollection::Instance().AddFactor(wordStr); + InputFileStream targetVocabStrme(m_filePath + "/TargetVocab.dat"); + string line; + while (getline(targetVocabStrme, line)) { + vector<string> toks = Tokenize(line, "\t"); + UTIL_THROW_IF2(toks.size() != 2, string("Incorrect format:") + line + "\n"); + + //cerr << "wordStr=" << toks[0] << endl; + + const Factor *factor = vocab.AddFactor(toks[0]); + uint32_t probingId = Scan<uint32_t>(toks[1]); + + if (probingId >= m_targetVocab.size()) { + m_targetVocab.resize(probingId + 1); + } + + m_targetVocab[probingId] = factor; + } + + // alignments + CreateAlignmentMap(m_filePath + "/Alignments.dat"); - unsigned int probingId = iter->first; + // memory mapped file to tps + string filePath = m_filePath + "/TargetColl.dat"; + file.open(filePath.c_str()); + if (!file.is_open()) { + throw "Couldn't open file "; + } + + data = file.data(); + //size_t size = file.size(); + + // cache + //CreateCache(system); - TargetVocabMap::value_type entry(factor, probingId); - m_vocabMap.insert(entry); +} +void ProbingPT::CreateAlignmentMap(const std::string path) +{ + const std::vector< std::vector<unsigned char> > &probingAlignColl = m_engine->getAlignments(); + m_aligns.resize(probingAlignColl.size(), NULL); + + for (size_t i = 0; i < probingAlignColl.size(); ++i) { + AlignmentInfo::CollType aligns; + + const std::vector<unsigned char> &probingAligns = probingAlignColl[i]; + for (size_t j = 0; j < probingAligns.size(); j += 2) { + size_t startPos = probingAligns[j]; + size_t endPos = probingAligns[j+1]; + //cerr << "startPos=" << startPos << " " << endPos << endl; + aligns.insert(std::pair<size_t,size_t>(startPos, endPos)); + } + + const AlignmentInfo *align = AlignmentInfoCollection::Instance().Add(aligns); + m_aligns[i] = align; + //cerr << "align=" << align->Debug(system) << endl; } } void ProbingPT::InitializeForInput(ttasksptr const& ttask) { - ReduceCache(); + } void ProbingPT::GetTargetPhraseCollectionBatch(const InputPathList &inputPathQueue) const { - CacheColl &cache = GetCache(); - InputPathList::const_iterator iter; for (iter = inputPathQueue.begin(); iter != inputPathQueue.end(); ++iter) { InputPath &inputPath = **iter; @@ -82,132 +133,202 @@ void ProbingPT::GetTargetPhraseCollectionBatch(const InputPathList &inputPathQue } TargetPhraseCollection::shared_ptr tpColl = CreateTargetPhrase(sourcePhrase); + inputPath.SetTargetPhrases(*this, tpColl, NULL); + } +} - // add target phrase to phrase-table cache - size_t hash = hash_value(sourcePhrase); - std::pair<TargetPhraseCollection::shared_ptr , clock_t> value(tpColl, clock()); - cache[hash] = value; +TargetPhraseCollection::shared_ptr ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase) const +{ + // create a target phrase from the 1st word of the source, prefix with 'ProbingPT:' + assert(sourcePhrase.GetSize()); - inputPath.SetTargetPhrases(*this, tpColl, NULL); + std::pair<bool, uint64_t> keyStruct = GetKey(sourcePhrase); + if (!keyStruct.first) { + return TargetPhraseCollection::shared_ptr(); + } + + // check in cache + CachePb::const_iterator iter = m_cachePb.find(keyStruct.second); + if (iter != m_cachePb.end()) { + //cerr << "FOUND IN CACHE " << keyStruct.second << " " << sourcePhrase.Debug(mgr.system) << endl; + TargetPhraseCollection *tps = iter->second; + return TargetPhraseCollection::shared_ptr(tps); + } + + // query pt + TargetPhraseCollection *tps = CreateTargetPhrases(sourcePhrase, + keyStruct.second); + return TargetPhraseCollection::shared_ptr(tps); +} + +std::pair<bool, uint64_t> ProbingPT::GetKey(const Phrase &sourcePhrase) const +{ + std::pair<bool, uint64_t> ret; + + // create a target phrase from the 1st word of the source, prefix with 'ProbingPT:' + size_t sourceSize = sourcePhrase.GetSize(); + assert(sourceSize); + + uint64_t probingSource[sourceSize]; + GetSourceProbingIds(sourcePhrase, ret.first, probingSource); + if (!ret.first) { + // source phrase contains a word unknown in the pt. + // We know immediately there's no translation for it + } else { + ret.second = m_engine->getKey(probingSource, sourceSize); } + + return ret; + } -std::vector<uint64_t> ProbingPT::ConvertToProbingSourcePhrase(const Phrase &sourcePhrase, bool &ok) const +void ProbingPT::GetSourceProbingIds(const Phrase &sourcePhrase, + bool &ok, uint64_t probingSource[]) const { + size_t size = sourcePhrase.GetSize(); - std::vector<uint64_t> ret(size); for (size_t i = 0; i < size; ++i) { - const Factor *factor = sourcePhrase.GetFactor(i, m_input[0]); - uint64_t probingId = GetSourceProbingId(factor); + const Word &word = sourcePhrase.GetWord(i); + uint64_t probingId = GetSourceProbingId(word); if (probingId == m_unkId) { ok = false; - return ret; + return; } else { - ret[i] = probingId; + probingSource[i] = probingId; } } ok = true; - return ret; } -TargetPhraseCollection::shared_ptr ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase) const +uint64_t ProbingPT::GetSourceProbingId(const Word &word) const { - // create a target phrase from the 1st word of the source, prefix with 'ProbingPT:' - assert(sourcePhrase.GetSize()); + uint64_t ret = 0; - TargetPhraseCollection::shared_ptr tpColl; - bool ok; - vector<uint64_t> probingSource = ConvertToProbingSourcePhrase(sourcePhrase, ok); - if (!ok) { - // source phrase contains a word unknown in the pt. - // We know immediately there's no translation for it - return tpColl; + for (size_t i = 0; i < m_input.size(); ++i) { + FactorType factorType = m_input[i]; + const Factor *factor = word[factorType]; + + size_t factorId = factor->GetId(); + if (factorId >= m_sourceVocab.size()) { + return m_unkId; + } + ret += m_sourceVocab[factorId]; } - std::pair<bool, std::vector<target_text> > query_result; + return ret; +} + +TargetPhraseCollection *ProbingPT::CreateTargetPhrases( + const Phrase &sourcePhrase, uint64_t key) const +{ + TargetPhraseCollection *tps = NULL; //Actual lookup - query_result = m_engine->query(probingSource); + std::pair<bool, uint64_t> query_result; // 1st=found, 2nd=target file offset + query_result = m_engine->query(key); + //cerr << "key2=" << query_result.second << endl; if (query_result.first) { - //m_engine->printTargetInfo(query_result.second); - tpColl.reset(new TargetPhraseCollection()); + const char *offset = data + query_result.second; + uint64_t *numTP = (uint64_t*) offset; + + tps = new TargetPhraseCollection(); - const std::vector<target_text> &probingTargetPhrases = query_result.second; - for (size_t i = 0; i < probingTargetPhrases.size(); ++i) { - const target_text &probingTargetPhrase = probingTargetPhrases[i]; - TargetPhrase *tp = CreateTargetPhrase(sourcePhrase, probingTargetPhrase); + offset += sizeof(uint64_t); + for (size_t i = 0; i < *numTP; ++i) { + TargetPhrase *tp = CreateTargetPhrase(offset); + assert(tp); + tp->EvaluateInIsolation(sourcePhrase, GetFeaturesToApply()); + + tps->Add(tp); - tpColl->Add(tp); } - tpColl->Prune(true, m_tableLimit); + tps->Prune(true, m_tableLimit); + //cerr << *tps << endl; } - return tpColl; + return tps; + } -TargetPhrase *ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase, const target_text &probingTargetPhrase) const +TargetPhrase *ProbingPT::CreateTargetPhrase( + const char *&offset) const { - const std::vector<unsigned int> &probingPhrase = probingTargetPhrase.target_phrase; - size_t size = probingPhrase.size(); + TargetPhraseInfo *tpInfo = (TargetPhraseInfo*) offset; + size_t numRealWords = tpInfo->numWords / m_output.size(); TargetPhrase *tp = new TargetPhrase(this); - // words - for (size_t i = 0; i < size; ++i) { - uint64_t probingId = probingPhrase[i]; - const Factor *factor = GetTargetFactor(probingId); - assert(factor); + offset += sizeof(TargetPhraseInfo); - Word &word = tp->AddWord(); - word.SetFactor(m_output[0], factor); + // scores + float *scores = (float*) offset; + + size_t totalNumScores = m_engine->num_scores + m_engine->num_lex_scores; + + if (m_engine->logProb) { + // set pt score for rule + tp->GetScoreBreakdown().PlusEquals(this, scores); + + // save scores for other FF, eg. lex RO. Just give the offset + /* + if (m_engine->num_lex_scores) { + tp->scoreProperties = scores + m_engine->num_scores; + } + */ + } else { + // log score 1st + float logScores[totalNumScores]; + for (size_t i = 0; i < totalNumScores; ++i) { + logScores[i] = FloorScore(TransformScore(scores[i])); + } + + // set pt score for rule + tp->GetScoreBreakdown().PlusEquals(this, logScores); + + // save scores for other FF, eg. lex RO. + /* + tp->scoreProperties = pool.Allocate<SCORE>(m_engine->num_lex_scores); + for (size_t i = 0; i < m_engine->num_lex_scores; ++i) { + tp->scoreProperties[i] = logScores[i + m_engine->num_scores]; + } + */ } - // score for this phrase table - vector<float> scores = probingTargetPhrase.prob; - std::transform(scores.begin(), scores.end(), scores.begin(),TransformScore); - tp->GetScoreBreakdown().PlusEquals(this, scores); + offset += sizeof(float) * totalNumScores; + + // words + for (size_t targetPos = 0; targetPos < numRealWords; ++targetPos) { + Word &word = tp->AddWord(); + for (size_t i = 0; i < m_output.size(); ++i) { + FactorType factorType = m_output[i]; + + uint32_t *probingId = (uint32_t*) offset; + + const Factor *factor = GetTargetFactor(*probingId); + assert(factor); - // alignment - /* - const std::vector<unsigned char> &alignments = probingTargetPhrase.word_all1; + word[factorType] = factor; - AlignmentInfo &aligns = tp->GetAlignTerm(); - for (size_t i = 0; i < alignS.size(); i += 2 ) { - aligns.Add((size_t) alignments[i], (size_t) alignments[i+1]); + offset += sizeof(uint32_t); + } } - */ - // score of all other ff when this rule is being loaded - tp->EvaluateInIsolation(sourcePhrase, GetFeaturesToApply()); + // align + uint32_t alignTerm = tpInfo->alignTerm; + //cerr << "alignTerm=" << alignTerm << endl; + UTIL_THROW_IF2(alignTerm >= m_aligns.size(), "Unknown alignInd"); + tp->SetAlignTerm(m_aligns[alignTerm]); + + // properties TODO + return tp; } -const Factor *ProbingPT::GetTargetFactor(uint64_t probingId) const -{ - TargetVocabMap::right_map::const_iterator iter; - iter = m_vocabMap.right.find(probingId); - if (iter != m_vocabMap.right.end()) { - return iter->second; - } else { - // not in mapping. Must be UNK - return NULL; - } -} +////////////////////////////////////////////////////////////////// -uint64_t ProbingPT::GetSourceProbingId(const Factor *factor) const -{ - SourceVocabMap::left_map::const_iterator iter; - iter = m_sourceVocabMap.left.find(factor); - if (iter != m_sourceVocabMap.left.end()) { - return iter->second; - } else { - // not in mapping. Must be UNK - return m_unkId; - } -} ChartRuleLookupManager *ProbingPT::CreateRuleLookupManager( const ChartParser &, |