Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2016-09-29 18:21:49 +0300
committerHieu Hoang <hieuhoang@gmail.com>2016-09-29 18:21:49 +0300
commit230d4dd13a3fb0543e0541310a52b345beb1e018 (patch)
tree108f0a9a2abf50e2cc5f61c6c16cede24a6aba66
parentfc6679178c67ee91376bd7e7148af046e746ec24 (diff)
add PhraseDecoder
-rw-r--r--contrib/moses2/TranslationModel/CompactPT/PhraseDecoder.cpp469
-rw-r--r--contrib/moses2/TranslationModel/CompactPT/PhraseDecoder.h147
2 files changed, 616 insertions, 0 deletions
diff --git a/contrib/moses2/TranslationModel/CompactPT/PhraseDecoder.cpp b/contrib/moses2/TranslationModel/CompactPT/PhraseDecoder.cpp
new file mode 100644
index 000000000..81983c011
--- /dev/null
+++ b/contrib/moses2/TranslationModel/CompactPT/PhraseDecoder.cpp
@@ -0,0 +1,469 @@
+// $Id$
+// vim:tabstop=2
+/***********************************************************************
+Moses - factored phrase-based language decoder
+Copyright (C) 2006 University of Edinburgh
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+***********************************************************************/
+
+#include <deque>
+
+#include "PhraseDecoder.h"
+#include "moses/StaticData.h"
+
+#include "../../ManagerBase.h"
+
+using namespace std;
+
+namespace Moses2
+{
+
+PhraseDecoder::PhraseDecoder(
+ PhraseTableCompact &phraseDictionary,
+ const std::vector<FactorType>* input,
+ const std::vector<FactorType>* output,
+ size_t numScoreComponent
+ // , const std::vector<float>* weight
+)
+ : m_coding(None), m_numScoreComponent(numScoreComponent),
+ m_containsAlignmentInfo(true), m_maxRank(0),
+ m_symbolTree(0), m_multipleScoreTrees(false),
+ m_scoreTrees(1), m_alignTree(0),
+ m_phraseDictionary(phraseDictionary), m_input(input), m_output(output),
+ // m_weight(weight),
+ m_separator(" ||| ")
+{ }
+
+PhraseDecoder::~PhraseDecoder()
+{
+ if(m_symbolTree)
+ delete m_symbolTree;
+
+ for(size_t i = 0; i < m_scoreTrees.size(); i++)
+ if(m_scoreTrees[i])
+ delete m_scoreTrees[i];
+
+ if(m_alignTree)
+ delete m_alignTree;
+}
+
+inline unsigned PhraseDecoder::GetSourceSymbolId(std::string& symbol)
+{
+ boost::unordered_map<std::string, unsigned>::iterator it
+ = m_sourceSymbolsMap.find(symbol);
+ if(it != m_sourceSymbolsMap.end())
+ return it->second;
+
+ size_t idx = m_sourceSymbols.find(symbol);
+ m_sourceSymbolsMap[symbol] = idx;
+ return idx;
+}
+
+inline std::string PhraseDecoder::GetTargetSymbol(unsigned idx) const
+{
+ if(idx < m_targetSymbols.size())
+ return m_targetSymbols[idx];
+ return std::string("##ERROR##");
+}
+
+inline size_t PhraseDecoder::GetREncType(unsigned encodedSymbol)
+{
+ return (encodedSymbol >> 30) + 1;
+}
+
+inline size_t PhraseDecoder::GetPREncType(unsigned encodedSymbol)
+{
+ return (encodedSymbol >> 31) + 1;
+}
+
+inline unsigned PhraseDecoder::GetTranslation(unsigned srcIdx, size_t rank)
+{
+ size_t srcTrgIdx = m_lexicalTableIndex[srcIdx];
+ return m_lexicalTable[srcTrgIdx + rank].second;
+}
+
+size_t PhraseDecoder::GetMaxSourcePhraseLength()
+{
+ return m_maxPhraseLength;
+}
+
+inline unsigned PhraseDecoder::DecodeREncSymbol1(unsigned encodedSymbol)
+{
+ return encodedSymbol &= ~(3 << 30);
+}
+
+inline unsigned PhraseDecoder::DecodeREncSymbol2Rank(unsigned encodedSymbol)
+{
+ return encodedSymbol &= ~(255 << 24);
+}
+
+inline unsigned PhraseDecoder::DecodeREncSymbol2Position(unsigned encodedSymbol)
+{
+ encodedSymbol &= ~(3 << 30);
+ encodedSymbol >>= 24;
+ return encodedSymbol;
+}
+
+inline unsigned PhraseDecoder::DecodeREncSymbol3(unsigned encodedSymbol)
+{
+ return encodedSymbol &= ~(3 << 30);
+}
+
+inline unsigned PhraseDecoder::DecodePREncSymbol1(unsigned encodedSymbol)
+{
+ return encodedSymbol &= ~(1 << 31);
+}
+
+inline int PhraseDecoder::DecodePREncSymbol2Left(unsigned encodedSymbol)
+{
+ return ((encodedSymbol >> 25) & 63) - 32;
+}
+
+inline int PhraseDecoder::DecodePREncSymbol2Right(unsigned encodedSymbol)
+{
+ return ((encodedSymbol >> 19) & 63) - 32;
+}
+
+inline unsigned PhraseDecoder::DecodePREncSymbol2Rank(unsigned encodedSymbol)
+{
+ return (encodedSymbol & 524287);
+}
+
+size_t PhraseDecoder::Load(std::FILE* in)
+{
+ size_t start = std::ftell(in);
+ size_t read = 0;
+
+ read += std::fread(&m_coding, sizeof(m_coding), 1, in);
+ read += std::fread(&m_numScoreComponent, sizeof(m_numScoreComponent), 1, in);
+ read += std::fread(&m_containsAlignmentInfo, sizeof(m_containsAlignmentInfo), 1, in);
+ read += std::fread(&m_maxRank, sizeof(m_maxRank), 1, in);
+ read += std::fread(&m_maxPhraseLength, sizeof(m_maxPhraseLength), 1, in);
+
+ if(m_coding == REnc) {
+ m_sourceSymbols.load(in);
+
+ size_t size;
+ read += std::fread(&size, sizeof(size_t), 1, in);
+ m_lexicalTableIndex.resize(size);
+ read += std::fread(&m_lexicalTableIndex[0], sizeof(size_t), size, in);
+
+ read += std::fread(&size, sizeof(size_t), 1, in);
+ m_lexicalTable.resize(size);
+ read += std::fread(&m_lexicalTable[0], sizeof(SrcTrg), size, in);
+ }
+
+ m_targetSymbols.load(in);
+
+ m_symbolTree = new CanonicalHuffman<unsigned>(in);
+
+ read += std::fread(&m_multipleScoreTrees, sizeof(m_multipleScoreTrees), 1, in);
+ if(m_multipleScoreTrees) {
+ m_scoreTrees.resize(m_numScoreComponent);
+ for(size_t i = 0; i < m_numScoreComponent; i++)
+ m_scoreTrees[i] = new CanonicalHuffman<float>(in);
+ } else {
+ m_scoreTrees.resize(1);
+ m_scoreTrees[0] = new CanonicalHuffman<float>(in);
+ }
+
+ if(m_containsAlignmentInfo)
+ m_alignTree = new CanonicalHuffman<AlignPoint>(in);
+
+ size_t end = std::ftell(in);
+ return end - start;
+}
+
+std::string PhraseDecoder::MakeSourceKey(std::string &source)
+{
+ return source + m_separator;
+}
+
+TargetPhraseVectorPtr PhraseDecoder::CreateTargetPhraseCollection(
+ const ManagerBase &mgr,
+ const Phrase<Word> &sourcePhrase,
+ bool topLevel,
+ bool eval)
+{
+
+ // Not using TargetPhraseCollection avoiding "new" operator
+ // which can introduce heavy locking with multiple threads
+ TargetPhraseVectorPtr tpv(new TargetPhraseVector());
+ size_t bitsLeft = 0;
+
+ if(m_coding == PREnc) {
+ std::pair<TargetPhraseVectorPtr, size_t> cachedPhraseColl
+ = m_decodingCache.Retrieve(sourcePhrase);
+
+ // Has been cached and is complete or does not need to be completed
+ if(cachedPhraseColl.first != NULL && (!topLevel || cachedPhraseColl.second == 0))
+ return cachedPhraseColl.first;
+
+ // Has been cached, but is incomplete
+ else if(cachedPhraseColl.first != NULL) {
+ bitsLeft = cachedPhraseColl.second;
+ tpv->resize(cachedPhraseColl.first->size());
+ std::copy(cachedPhraseColl.first->begin(),
+ cachedPhraseColl.first->end(),
+ tpv->begin());
+ }
+ }
+
+ // Retrieve source phrase identifier
+ std::string sourcePhraseString = sourcePhrase.GetString(*m_input);
+ size_t sourcePhraseId = m_phraseDictionary.m_hash[MakeSourceKey(sourcePhraseString)];
+
+ if(sourcePhraseId != m_phraseDictionary.m_hash.GetSize()) {
+ // Retrieve compressed and encoded target phrase collection
+ std::string encodedPhraseCollection;
+ if(m_phraseDictionary.m_inMemory)
+ encodedPhraseCollection = m_phraseDictionary.m_targetPhrasesMemory[sourcePhraseId].str();
+ else
+ encodedPhraseCollection = m_phraseDictionary.m_targetPhrasesMapped[sourcePhraseId].str();
+
+ BitWrapper<> encodedBitStream(encodedPhraseCollection);
+ if(m_coding == PREnc && bitsLeft)
+ encodedBitStream.SeekFromEnd(bitsLeft);
+
+ // Decompress and decode target phrase collection
+ TargetPhraseVectorPtr decodedPhraseColl =
+ DecodeCollection(mgr, tpv, encodedBitStream, sourcePhrase, topLevel, eval);
+
+ return decodedPhraseColl;
+ } else
+ return TargetPhraseVectorPtr();
+}
+
+TargetPhraseVectorPtr PhraseDecoder::DecodeCollection(
+ const ManagerBase &mgr,
+ TargetPhraseVectorPtr tpv,
+ BitWrapper<> &encodedBitStream,
+ const Phrase<Word> &sourcePhrase,
+ bool topLevel,
+ bool eval)
+{
+ const System &system = mgr.system;
+ FactorCollection &vocab = system.GetVocab();
+
+ bool extending = tpv->size();
+ size_t bitsLeft = encodedBitStream.TellFromEnd();
+
+ typedef std::pair<size_t, size_t> AlignPointSizeT;
+
+ std::vector<int> sourceWords;
+ if(m_coding == REnc) {
+ for(size_t i = 0; i < sourcePhrase.GetSize(); i++) {
+ std::string sourceWord
+ = sourcePhrase[i].GetString(*m_input);
+ unsigned idx = GetSourceSymbolId(sourceWord);
+ sourceWords.push_back(idx);
+ }
+ }
+
+ unsigned phraseStopSymbol = 0;
+ AlignPoint alignStopSymbol(-1, -1);
+
+ std::vector<Word> words;
+ std::vector<float> scores;
+ std::set<AlignPointSizeT> alignment;
+
+ enum DecodeState { New, Symbol, Score, Alignment, Add } state = New;
+
+ size_t srcSize = sourcePhrase.GetSize();
+
+ while(encodedBitStream.TellFromEnd()) {
+
+ if(state == New) {
+ // Creating new TargetPhrase on the heap
+ words.clear();
+ alignment.clear();
+ scores.clear();
+
+ state = Symbol;
+ }
+
+ if(state == Symbol) {
+ unsigned symbol = m_symbolTree->Read(encodedBitStream);
+ if(symbol == phraseStopSymbol) {
+ state = Score;
+ } else {
+ if(m_coding == REnc) {
+ std::string wordString;
+ size_t type = GetREncType(symbol);
+
+ if(type == 1) {
+ unsigned decodedSymbol = DecodeREncSymbol1(symbol);
+ wordString = GetTargetSymbol(decodedSymbol);
+ } else if (type == 2) {
+ size_t rank = DecodeREncSymbol2Rank(symbol);
+ size_t srcPos = DecodeREncSymbol2Position(symbol);
+
+ if(srcPos >= sourceWords.size())
+ return TargetPhraseVectorPtr();
+
+ wordString = GetTargetSymbol(GetTranslation(sourceWords[srcPos], rank));
+ if(m_phraseDictionary.m_useAlignmentInfo) {
+ size_t trgPos = words.size();
+ alignment.insert(AlignPoint(srcPos, trgPos));
+ }
+ } else if(type == 3) {
+ size_t rank = DecodeREncSymbol3(symbol);
+ size_t srcPos = words.size();
+
+ if(srcPos >= sourceWords.size())
+ return TargetPhraseVectorPtr();
+
+ wordString = GetTargetSymbol(GetTranslation(sourceWords[srcPos], rank));
+ if(m_phraseDictionary.m_useAlignmentInfo) {
+ size_t trgPos = srcPos;
+ alignment.insert(AlignPoint(srcPos, trgPos));
+ }
+ }
+
+ Word word;
+ word.CreateFromString(vocab, system, wordString);
+ words.push_back(word);
+ } else if(m_coding == PREnc) {
+ // if the symbol is just a word
+ if(GetPREncType(symbol) == 1) {
+ unsigned decodedSymbol = DecodePREncSymbol1(symbol);
+
+ Word word;
+ word.CreateFromString(vocab, system, GetTargetSymbol(decodedSymbol));
+ words.push_back(word);
+ }
+ // if the symbol is a subphrase pointer
+ else {
+ int left = DecodePREncSymbol2Left(symbol);
+ int right = DecodePREncSymbol2Right(symbol);
+ unsigned rank = DecodePREncSymbol2Rank(symbol);
+
+ int srcStart = left + words.size();
+ int srcEnd = srcSize - right - 1;
+
+ // false positive consistency check
+ if(0 > srcStart || srcStart > srcEnd || unsigned(srcEnd) >= srcSize)
+ return TargetPhraseVectorPtr();
+
+ // false positive consistency check
+ if(m_maxRank && rank > m_maxRank)
+ return TargetPhraseVectorPtr();
+
+ // set subphrase by default to itself
+ TargetPhraseVectorPtr subTpv = tpv;
+
+ // if range smaller than source phrase retrieve subphrase
+ if(unsigned(srcEnd - srcStart + 1) != srcSize) {
+ SubPhrase<Word> subPhrase = sourcePhrase.GetSubPhrase(srcStart, srcEnd - srcStart + 1);
+ subTpv = CreateTargetPhraseCollection(mgr, subPhrase, false);
+ } else {
+ // false positive consistency check
+ if(rank >= tpv->size()-1)
+ return TargetPhraseVectorPtr();
+ }
+
+ // false positive consistency check
+ if(subTpv != NULL && rank < subTpv->size()) {
+ // insert the subphrase into the main target phrase
+ const TargetPhraseImpl& subTp = *subTpv->at(rank);
+ if(m_phraseDictionary.m_useAlignmentInfo) {
+ // reconstruct the alignment data based on the alignment of the subphrase
+ for(AlignmentInfo::const_iterator it = subTp.GetAlignTerm().begin();
+ it != subTp.GetAlignTerm().end(); it++) {
+ alignment.insert(AlignPointSizeT(srcStart + it->first,
+ words.size() + it->second));
+ }
+ }
+
+ for (size_t i = 0; i < subTp.GetSize(); ++i) {
+ words.push_back(subTp[i]);
+ }
+ } else
+ return TargetPhraseVectorPtr();
+ }
+ } else {
+ Word word;
+ word.CreateFromString(vocab, system, GetTargetSymbol(symbol));
+ words.push_back(word);
+ }
+ }
+ } else if(state == Score) {
+ size_t idx = m_multipleScoreTrees ? scores.size() : 0;
+ float score = m_scoreTrees[idx]->Read(encodedBitStream);
+ scores.push_back(score);
+
+ if(scores.size() == m_numScoreComponent) {
+ if(m_containsAlignmentInfo)
+ state = Alignment;
+ else
+ state = Add;
+ }
+ } else if(state == Alignment) {
+ AlignPoint alignPoint = m_alignTree->Read(encodedBitStream);
+ if(alignPoint == alignStopSymbol) {
+ state = Add;
+ } else {
+ if(m_phraseDictionary.m_useAlignmentInfo)
+ alignment.insert(AlignPointSizeT(alignPoint));
+ }
+ }
+
+ if(state == Add) {
+ size_t targetSize = words.size();
+ TargetPhraseImpl *targetPhrase = new TargetPhraseImpl(mgr.GetPool(), m_phraseDictionary, system, targetSize);
+
+ if(m_phraseDictionary.m_useAlignmentInfo) {
+ size_t sourceSize = sourcePhrase.GetSize();
+ for(std::set<AlignPointSizeT>::iterator it = alignment.begin(); it != alignment.end(); it++) {
+ if(it->first >= sourceSize || it->second >= targetSize)
+ return TargetPhraseVectorPtr();
+ }
+ targetPhrase->SetAlignTerm(alignment);
+ }
+
+ if(eval) {
+ mgr.system.featureFunctions.EvaluateInIsolation(mgr.GetPool(), mgr.system, sourcePhrase, *targetPhrase);
+ }
+
+ if(m_coding == PREnc) {
+ if(!m_maxRank || tpv->size() <= m_maxRank)
+ bitsLeft = encodedBitStream.TellFromEnd();
+
+ if(!topLevel && m_maxRank && tpv->size() >= m_maxRank)
+ break;
+ }
+
+ if(encodedBitStream.TellFromEnd() <= 8)
+ break;
+
+ state = New;
+ }
+ }
+
+ if(m_coding == PREnc && !extending) {
+ bitsLeft = bitsLeft > 8 ? bitsLeft : 0;
+ m_decodingCache.Cache(sourcePhrase, tpv, bitsLeft, m_maxRank);
+ }
+
+ return tpv;
+}
+
+void PhraseDecoder::PruneCache()
+{
+ m_decodingCache.Prune();
+}
+
+}
diff --git a/contrib/moses2/TranslationModel/CompactPT/PhraseDecoder.h b/contrib/moses2/TranslationModel/CompactPT/PhraseDecoder.h
new file mode 100644
index 000000000..01a7c23c5
--- /dev/null
+++ b/contrib/moses2/TranslationModel/CompactPT/PhraseDecoder.h
@@ -0,0 +1,147 @@
+// $Id$
+// vim:tabstop=2
+/***********************************************************************
+Moses - factored phrase-based language decoder
+Copyright (C) 2006 University of Edinburgh
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+***********************************************************************/
+
+#pragma once
+
+#include <sstream>
+#include <vector>
+#include <boost/unordered_map.hpp>
+#include <boost/unordered_set.hpp>
+#include <string>
+#include <iterator>
+#include <algorithm>
+#include <sys/stat.h>
+
+#include "moses/TypeDef.h"
+#include "moses/FactorCollection.h"
+#include "moses/Word.h"
+#include "moses/Util.h"
+#include "moses/InputFileStream.h"
+#include "moses/StaticData.h"
+#include "moses/Range.h"
+
+#include "PhraseTableCompact.h"
+#include "StringVector.h"
+#include "CanonicalHuffman.h"
+#include "TargetPhraseCollectionCache.h"
+
+namespace Moses2
+{
+
+class PhraseTableCompact;
+
+class PhraseDecoder
+{
+protected:
+
+ friend class PhraseTableCompact;
+
+ typedef std::pair<unsigned char, unsigned char> AlignPoint;
+ typedef std::pair<unsigned, unsigned> SrcTrg;
+
+ enum Coding { None, REnc, PREnc } m_coding;
+
+ size_t m_numScoreComponent;
+ bool m_containsAlignmentInfo;
+ size_t m_maxRank;
+ size_t m_maxPhraseLength;
+
+ boost::unordered_map<std::string, unsigned> m_sourceSymbolsMap;
+ StringVector<unsigned char, unsigned, std::allocator> m_sourceSymbols;
+ StringVector<unsigned char, unsigned, std::allocator> m_targetSymbols;
+
+ std::vector<size_t> m_lexicalTableIndex;
+ std::vector<SrcTrg> m_lexicalTable;
+
+ CanonicalHuffman<unsigned>* m_symbolTree;
+
+ bool m_multipleScoreTrees;
+ std::vector<CanonicalHuffman<float>*> m_scoreTrees;
+
+ CanonicalHuffman<AlignPoint>* m_alignTree;
+
+ TargetPhraseCollectionCache m_decodingCache;
+
+ PhraseTableCompact& m_phraseDictionary;
+
+ // ***********************************************
+
+ const std::vector<FactorType>* m_input;
+ const std::vector<FactorType>* m_output;
+
+ std::string m_separator;
+
+ // ***********************************************
+
+ unsigned GetSourceSymbolId(std::string& s);
+ std::string GetTargetSymbol(unsigned id) const;
+
+ size_t GetREncType(unsigned encodedSymbol);
+ size_t GetPREncType(unsigned encodedSymbol);
+
+ unsigned GetTranslation(unsigned srcIdx, size_t rank);
+
+ size_t GetMaxSourcePhraseLength();
+
+ unsigned DecodeREncSymbol1(unsigned encodedSymbol);
+ unsigned DecodeREncSymbol2Rank(unsigned encodedSymbol);
+ unsigned DecodeREncSymbol2Position(unsigned encodedSymbol);
+ unsigned DecodeREncSymbol3(unsigned encodedSymbol);
+
+ unsigned DecodePREncSymbol1(unsigned encodedSymbol);
+ int DecodePREncSymbol2Left(unsigned encodedSymbol);
+ int DecodePREncSymbol2Right(unsigned encodedSymbol);
+ unsigned DecodePREncSymbol2Rank(unsigned encodedSymbol);
+
+ std::string MakeSourceKey(std::string &);
+
+public:
+
+ PhraseDecoder(
+ PhraseTableCompact &phraseDictionary,
+ const std::vector<FactorType>* input,
+ const std::vector<FactorType>* output,
+ size_t numScoreComponent
+ );
+
+ ~PhraseDecoder();
+
+ size_t Load(std::FILE* in);
+
+ TargetPhraseVectorPtr CreateTargetPhraseCollection(
+ const ManagerBase &mgr,
+ const Phrase<Word> &sourcePhrase,
+ bool topLevel = false,
+ bool eval = true);
+
+ TargetPhraseVectorPtr DecodeCollection(
+ const ManagerBase &mgr,
+ TargetPhraseVectorPtr tpv,
+ BitWrapper<> &encodedBitStream,
+ const Phrase<Word> &sourcePhrase,
+ bool topLevel,
+ bool eval);
+
+ void PruneCache();
+};
+
+}
+