Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'moses/TranslationModel/ProbingPT/ProbingPT.cpp')
-rw-r--r--moses/TranslationModel/ProbingPT/ProbingPT.cpp315
1 files changed, 218 insertions, 97 deletions
diff --git a/moses/TranslationModel/ProbingPT/ProbingPT.cpp b/moses/TranslationModel/ProbingPT/ProbingPT.cpp
index cbfd2c1a4..1ae0c67c3 100644
--- a/moses/TranslationModel/ProbingPT/ProbingPT.cpp
+++ b/moses/TranslationModel/ProbingPT/ProbingPT.cpp
@@ -3,8 +3,9 @@
#include "moses/StaticData.h"
#include "moses/FactorCollection.h"
#include "moses/TargetPhraseCollection.h"
+#include "moses/InputFileStream.h"
#include "moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerSkeleton.h"
-#include "quering.hh"
+#include "querying.hh"
using namespace std;
@@ -34,44 +35,94 @@ void ProbingPT::Load(AllOptions::ptr const& opts)
m_unkId = 456456546456;
+ FactorCollection &vocab = FactorCollection::Instance();
+
// source vocab
- const std::map<uint64_t, std::string> &sourceVocab = m_engine->getSourceVocab();
+ const std::map<uint64_t, std::string> &sourceVocab =
+ m_engine->getSourceVocab();
std::map<uint64_t, std::string>::const_iterator iterSource;
- for (iterSource = sourceVocab.begin(); iterSource != sourceVocab.end(); ++iterSource) {
- const string &wordStr = iterSource->second;
- const Factor *factor = FactorCollection::Instance().AddFactor(wordStr);
+ for (iterSource = sourceVocab.begin(); iterSource != sourceVocab.end();
+ ++iterSource) {
+ string wordStr = iterSource->second;
+ //cerr << "wordStr=" << wordStr << endl;
- uint64_t probingId = iterSource->first;
+ const Factor *factor = vocab.AddFactor(wordStr);
- SourceVocabMap::value_type entry(factor, probingId);
- m_sourceVocabMap.insert(entry);
+ uint64_t probingId = iterSource->first;
+ size_t factorId = factor->GetId();
+ if (factorId >= m_sourceVocab.size()) {
+ m_sourceVocab.resize(factorId + 1, m_unkId);
+ }
+ m_sourceVocab[factorId] = probingId;
}
// target vocab
- const std::map<unsigned int, std::string> &probingVocab = m_engine->getVocab();
- std::map<unsigned int, std::string>::const_iterator iter;
- for (iter = probingVocab.begin(); iter != probingVocab.end(); ++iter) {
- const string &wordStr = iter->second;
- const Factor *factor = FactorCollection::Instance().AddFactor(wordStr);
+ InputFileStream targetVocabStrme(m_filePath + "/TargetVocab.dat");
+ string line;
+ while (getline(targetVocabStrme, line)) {
+ vector<string> toks = Tokenize(line, "\t");
+ UTIL_THROW_IF2(toks.size() != 2, string("Incorrect format:") + line + "\n");
+
+ //cerr << "wordStr=" << toks[0] << endl;
+
+ const Factor *factor = vocab.AddFactor(toks[0]);
+ uint32_t probingId = Scan<uint32_t>(toks[1]);
+
+ if (probingId >= m_targetVocab.size()) {
+ m_targetVocab.resize(probingId + 1);
+ }
+
+ m_targetVocab[probingId] = factor;
+ }
+
+ // alignments
+ CreateAlignmentMap(m_filePath + "/Alignments.dat");
- unsigned int probingId = iter->first;
+ // memory mapped file to tps
+ string filePath = m_filePath + "/TargetColl.dat";
+ file.open(filePath.c_str());
+ if (!file.is_open()) {
+ throw "Couldn't open file ";
+ }
+
+ data = file.data();
+ //size_t size = file.size();
+
+ // cache
+ //CreateCache(system);
- TargetVocabMap::value_type entry(factor, probingId);
- m_vocabMap.insert(entry);
+}
+void ProbingPT::CreateAlignmentMap(const std::string path)
+{
+ const std::vector< std::vector<unsigned char> > &probingAlignColl = m_engine->getAlignments();
+ m_aligns.resize(probingAlignColl.size(), NULL);
+
+ for (size_t i = 0; i < probingAlignColl.size(); ++i) {
+ AlignmentInfo::CollType aligns;
+
+ const std::vector<unsigned char> &probingAligns = probingAlignColl[i];
+ for (size_t j = 0; j < probingAligns.size(); j += 2) {
+ size_t startPos = probingAligns[j];
+ size_t endPos = probingAligns[j+1];
+ //cerr << "startPos=" << startPos << " " << endPos << endl;
+ aligns.insert(std::pair<size_t,size_t>(startPos, endPos));
+ }
+
+ const AlignmentInfo *align = AlignmentInfoCollection::Instance().Add(aligns);
+ m_aligns[i] = align;
+ //cerr << "align=" << align->Debug(system) << endl;
}
}
void ProbingPT::InitializeForInput(ttasksptr const& ttask)
{
- ReduceCache();
+
}
void ProbingPT::GetTargetPhraseCollectionBatch(const InputPathList &inputPathQueue) const
{
- CacheColl &cache = GetCache();
-
InputPathList::const_iterator iter;
for (iter = inputPathQueue.begin(); iter != inputPathQueue.end(); ++iter) {
InputPath &inputPath = **iter;
@@ -82,132 +133,202 @@ void ProbingPT::GetTargetPhraseCollectionBatch(const InputPathList &inputPathQue
}
TargetPhraseCollection::shared_ptr tpColl = CreateTargetPhrase(sourcePhrase);
+ inputPath.SetTargetPhrases(*this, tpColl, NULL);
+ }
+}
- // add target phrase to phrase-table cache
- size_t hash = hash_value(sourcePhrase);
- std::pair<TargetPhraseCollection::shared_ptr , clock_t> value(tpColl, clock());
- cache[hash] = value;
+TargetPhraseCollection::shared_ptr ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase) const
+{
+ // create a target phrase from the 1st word of the source, prefix with 'ProbingPT:'
+ assert(sourcePhrase.GetSize());
- inputPath.SetTargetPhrases(*this, tpColl, NULL);
+ std::pair<bool, uint64_t> keyStruct = GetKey(sourcePhrase);
+ if (!keyStruct.first) {
+ return TargetPhraseCollection::shared_ptr();
+ }
+
+ // check in cache
+ CachePb::const_iterator iter = m_cachePb.find(keyStruct.second);
+ if (iter != m_cachePb.end()) {
+ //cerr << "FOUND IN CACHE " << keyStruct.second << " " << sourcePhrase.Debug(mgr.system) << endl;
+ TargetPhraseCollection *tps = iter->second;
+ return TargetPhraseCollection::shared_ptr(tps);
+ }
+
+ // query pt
+ TargetPhraseCollection *tps = CreateTargetPhrases(sourcePhrase,
+ keyStruct.second);
+ return TargetPhraseCollection::shared_ptr(tps);
+}
+
+std::pair<bool, uint64_t> ProbingPT::GetKey(const Phrase &sourcePhrase) const
+{
+ std::pair<bool, uint64_t> ret;
+
+ // create a target phrase from the 1st word of the source, prefix with 'ProbingPT:'
+ size_t sourceSize = sourcePhrase.GetSize();
+ assert(sourceSize);
+
+ uint64_t probingSource[sourceSize];
+ GetSourceProbingIds(sourcePhrase, ret.first, probingSource);
+ if (!ret.first) {
+ // source phrase contains a word unknown in the pt.
+ // We know immediately there's no translation for it
+ } else {
+ ret.second = m_engine->getKey(probingSource, sourceSize);
}
+
+ return ret;
+
}
-std::vector<uint64_t> ProbingPT::ConvertToProbingSourcePhrase(const Phrase &sourcePhrase, bool &ok) const
+void ProbingPT::GetSourceProbingIds(const Phrase &sourcePhrase,
+ bool &ok, uint64_t probingSource[]) const
{
+
size_t size = sourcePhrase.GetSize();
- std::vector<uint64_t> ret(size);
for (size_t i = 0; i < size; ++i) {
- const Factor *factor = sourcePhrase.GetFactor(i, m_input[0]);
- uint64_t probingId = GetSourceProbingId(factor);
+ const Word &word = sourcePhrase.GetWord(i);
+ uint64_t probingId = GetSourceProbingId(word);
if (probingId == m_unkId) {
ok = false;
- return ret;
+ return;
} else {
- ret[i] = probingId;
+ probingSource[i] = probingId;
}
}
ok = true;
- return ret;
}
-TargetPhraseCollection::shared_ptr ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase) const
+uint64_t ProbingPT::GetSourceProbingId(const Word &word) const
{
- // create a target phrase from the 1st word of the source, prefix with 'ProbingPT:'
- assert(sourcePhrase.GetSize());
+ uint64_t ret = 0;
- TargetPhraseCollection::shared_ptr tpColl;
- bool ok;
- vector<uint64_t> probingSource = ConvertToProbingSourcePhrase(sourcePhrase, ok);
- if (!ok) {
- // source phrase contains a word unknown in the pt.
- // We know immediately there's no translation for it
- return tpColl;
+ for (size_t i = 0; i < m_input.size(); ++i) {
+ FactorType factorType = m_input[i];
+ const Factor *factor = word[factorType];
+
+ size_t factorId = factor->GetId();
+ if (factorId >= m_sourceVocab.size()) {
+ return m_unkId;
+ }
+ ret += m_sourceVocab[factorId];
}
- std::pair<bool, std::vector<target_text> > query_result;
+ return ret;
+}
+
+TargetPhraseCollection *ProbingPT::CreateTargetPhrases(
+ const Phrase &sourcePhrase, uint64_t key) const
+{
+ TargetPhraseCollection *tps = NULL;
//Actual lookup
- query_result = m_engine->query(probingSource);
+ std::pair<bool, uint64_t> query_result; // 1st=found, 2nd=target file offset
+ query_result = m_engine->query(key);
+ //cerr << "key2=" << query_result.second << endl;
if (query_result.first) {
- //m_engine->printTargetInfo(query_result.second);
- tpColl.reset(new TargetPhraseCollection());
+ const char *offset = data + query_result.second;
+ uint64_t *numTP = (uint64_t*) offset;
+
+ tps = new TargetPhraseCollection();
- const std::vector<target_text> &probingTargetPhrases = query_result.second;
- for (size_t i = 0; i < probingTargetPhrases.size(); ++i) {
- const target_text &probingTargetPhrase = probingTargetPhrases[i];
- TargetPhrase *tp = CreateTargetPhrase(sourcePhrase, probingTargetPhrase);
+ offset += sizeof(uint64_t);
+ for (size_t i = 0; i < *numTP; ++i) {
+ TargetPhrase *tp = CreateTargetPhrase(offset);
+ assert(tp);
+ tp->EvaluateInIsolation(sourcePhrase, GetFeaturesToApply());
+
+ tps->Add(tp);
- tpColl->Add(tp);
}
- tpColl->Prune(true, m_tableLimit);
+ tps->Prune(true, m_tableLimit);
+ //cerr << *tps << endl;
}
- return tpColl;
+ return tps;
+
}
-TargetPhrase *ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase, const target_text &probingTargetPhrase) const
+TargetPhrase *ProbingPT::CreateTargetPhrase(
+ const char *&offset) const
{
- const std::vector<unsigned int> &probingPhrase = probingTargetPhrase.target_phrase;
- size_t size = probingPhrase.size();
+ TargetPhraseInfo *tpInfo = (TargetPhraseInfo*) offset;
+ size_t numRealWords = tpInfo->numWords / m_output.size();
TargetPhrase *tp = new TargetPhrase(this);
- // words
- for (size_t i = 0; i < size; ++i) {
- uint64_t probingId = probingPhrase[i];
- const Factor *factor = GetTargetFactor(probingId);
- assert(factor);
+ offset += sizeof(TargetPhraseInfo);
- Word &word = tp->AddWord();
- word.SetFactor(m_output[0], factor);
+ // scores
+ float *scores = (float*) offset;
+
+ size_t totalNumScores = m_engine->num_scores + m_engine->num_lex_scores;
+
+ if (m_engine->logProb) {
+ // set pt score for rule
+ tp->GetScoreBreakdown().PlusEquals(this, scores);
+
+ // save scores for other FF, eg. lex RO. Just give the offset
+ /*
+ if (m_engine->num_lex_scores) {
+ tp->scoreProperties = scores + m_engine->num_scores;
+ }
+ */
+ } else {
+ // log score 1st
+ float logScores[totalNumScores];
+ for (size_t i = 0; i < totalNumScores; ++i) {
+ logScores[i] = FloorScore(TransformScore(scores[i]));
+ }
+
+ // set pt score for rule
+ tp->GetScoreBreakdown().PlusEquals(this, logScores);
+
+ // save scores for other FF, eg. lex RO.
+ /*
+ tp->scoreProperties = pool.Allocate<SCORE>(m_engine->num_lex_scores);
+ for (size_t i = 0; i < m_engine->num_lex_scores; ++i) {
+ tp->scoreProperties[i] = logScores[i + m_engine->num_scores];
+ }
+ */
}
- // score for this phrase table
- vector<float> scores = probingTargetPhrase.prob;
- std::transform(scores.begin(), scores.end(), scores.begin(),TransformScore);
- tp->GetScoreBreakdown().PlusEquals(this, scores);
+ offset += sizeof(float) * totalNumScores;
+
+ // words
+ for (size_t targetPos = 0; targetPos < numRealWords; ++targetPos) {
+ Word &word = tp->AddWord();
+ for (size_t i = 0; i < m_output.size(); ++i) {
+ FactorType factorType = m_output[i];
+
+ uint32_t *probingId = (uint32_t*) offset;
+
+ const Factor *factor = GetTargetFactor(*probingId);
+ assert(factor);
- // alignment
- /*
- const std::vector<unsigned char> &alignments = probingTargetPhrase.word_all1;
+ word[factorType] = factor;
- AlignmentInfo &aligns = tp->GetAlignTerm();
- for (size_t i = 0; i < alignS.size(); i += 2 ) {
- aligns.Add((size_t) alignments[i], (size_t) alignments[i+1]);
+ offset += sizeof(uint32_t);
+ }
}
- */
- // score of all other ff when this rule is being loaded
- tp->EvaluateInIsolation(sourcePhrase, GetFeaturesToApply());
+ // align
+ uint32_t alignTerm = tpInfo->alignTerm;
+ //cerr << "alignTerm=" << alignTerm << endl;
+ UTIL_THROW_IF2(alignTerm >= m_aligns.size(), "Unknown alignInd");
+ tp->SetAlignTerm(m_aligns[alignTerm]);
+
+ // properties TODO
+
return tp;
}
-const Factor *ProbingPT::GetTargetFactor(uint64_t probingId) const
-{
- TargetVocabMap::right_map::const_iterator iter;
- iter = m_vocabMap.right.find(probingId);
- if (iter != m_vocabMap.right.end()) {
- return iter->second;
- } else {
- // not in mapping. Must be UNK
- return NULL;
- }
-}
+//////////////////////////////////////////////////////////////////
-uint64_t ProbingPT::GetSourceProbingId(const Factor *factor) const
-{
- SourceVocabMap::left_map::const_iterator iter;
- iter = m_sourceVocabMap.left.find(factor);
- if (iter != m_sourceVocabMap.left.end()) {
- return iter->second;
- } else {
- // not in mapping. Must be UNK
- return m_unkId;
- }
-}
ChartRuleLookupManager *ProbingPT::CreateRuleLookupManager(
const ChartParser &,