#include "cpu/nematus/encoder_decoder.h" #include #include #include "common/sentence.h" #include "common/sentences.h" #include "cpu/decoder/encoder_decoder_loader.h" #include "cpu/mblas/matrix.h" using namespace std; namespace amunmt { namespace CPU { namespace Nematus { using EDState = EncoderDecoderState; EncoderDecoder::EncoderDecoder(const God &god, const std::string& name, const YAML::Node& config, unsigned tab, const Nematus::Weights& model) : CPUEncoderDecoderBase(god, name, config, tab), model_(model), encoder_(new CPU::Nematus::Encoder(model_)), decoder_(new CPU::Nematus::Decoder(model_)) {} void EncoderDecoder::Decode(const State& in, State& out, const std::vector&) { const EDState& edIn = in.get(); EDState& edOut = out.get(); decoder_->Decode(edOut.GetStates(), edIn.GetStates(), edIn.GetEmbeddings(), SourceContext_); } void EncoderDecoder::BeginSentenceState(State& state, unsigned batchSize) { EDState& edState = state.get(); decoder_->EmptyState(edState.GetStates(), SourceContext_, batchSize); decoder_->EmptyEmbedding(edState.GetEmbeddings(), batchSize); } void EncoderDecoder::Encode(const Sentences& sources) { encoder_->GetContext(sources.Get(0).GetWords(tab_), SourceContext_); } void EncoderDecoder::AssembleBeamState(const State& in, const Beam& beam, State& out) { std::vector beamWords; std::vector beamStateIds; for(auto h : beam) { beamWords.push_back(h->GetWord()); beamStateIds.push_back(h->GetPrevStateIndex()); } const EDState& edIn = in.get(); EDState& edOut = out.get(); edOut.GetStates() = mblas::Assemble(edIn.GetStates(), beamStateIds); decoder_->Lookup(edOut.GetEmbeddings(), beamWords); } void EncoderDecoder::GetAttention(mblas::Tensor& Attention) { decoder_->GetAttention(Attention); } mblas::Tensor& EncoderDecoder::GetAttention() { return decoder_->GetAttention(); } unsigned EncoderDecoder::GetVocabSize() const { return decoder_->GetVocabSize(); } void EncoderDecoder::Filter(const std::vector& filterIds) { decoder_->Filter(filterIds); } BaseTensor& EncoderDecoder::GetProbs() { return decoder_->GetProbs(); } } } }