diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2017-06-20 17:27:41 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2017-06-20 17:27:41 +0300 |
commit | 1604516232133ffcc5d7cc47f2c9cb48fe771609 (patch) | |
tree | 8a171ccb86a60f16f9761a1a66159e7944e9ac3b /src/amun/cpu/nematus/encoder_decoder.cpp | |
parent | 4a58c88d3373583f15664a69b94f809a4b1e77e4 (diff) | |
parent | b357f082f3d550197bef448ae71f6df02bbbd2bf (diff) |
merge
Diffstat (limited to 'src/amun/cpu/nematus/encoder_decoder.cpp')
-rw-r--r-- | src/amun/cpu/nematus/encoder_decoder.cpp | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/src/amun/cpu/nematus/encoder_decoder.cpp b/src/amun/cpu/nematus/encoder_decoder.cpp new file mode 100644 index 00000000..dd542dc0 --- /dev/null +++ b/src/amun/cpu/nematus/encoder_decoder.cpp @@ -0,0 +1,98 @@ +#include "cpu/nematus/encoder_decoder.h" + +#include <vector> +#include <yaml-cpp/yaml.h> + +#include "common/sentence.h" +#include "common/sentences.h" + +#include "cpu/decoder/encoder_decoder_loader.h" +#include "cpu/mblas/matrix.h" + +using namespace std; + +namespace amunmt { +namespace CPU { +namespace Nematus { + +using EDState = EncoderDecoderState; + +EncoderDecoder::EncoderDecoder(const God &god, + const std::string& name, + const YAML::Node& config, + size_t tab, + const Nematus::Weights& model) + : CPUEncoderDecoderBase(god, name, config, tab), + model_(model), + encoder_(new CPU::Nematus::Encoder(model_)), + decoder_(new CPU::Nematus::Decoder(model_)) +{} + + +void EncoderDecoder::Decode(const State& in, State& out, const std::vector<uint>&) { + const EDState& edIn = in.get<EDState>(); + EDState& edOut = out.get<EDState>(); + + decoder_->Decode(edOut.GetStates(), edIn.GetStates(), + edIn.GetEmbeddings(), SourceContext_); +} + + +void EncoderDecoder::BeginSentenceState(State& state, size_t batchSize) { + EDState& edState = state.get<EDState>(); + decoder_->EmptyState(edState.GetStates(), SourceContext_, batchSize); + decoder_->EmptyEmbedding(edState.GetEmbeddings(), batchSize); +} + + +void EncoderDecoder::SetSource(const Sentences& sources) { + encoder_->GetContext(sources.at(0)->GetWords(tab_), + SourceContext_); +} + + +void EncoderDecoder::AssembleBeamState(const State& in, + const Beam& beam, + State& out) { + std::vector<size_t> beamWords; + std::vector<size_t> beamStateIds; + for(auto h : beam) { + beamWords.push_back(h->GetWord()); + beamStateIds.push_back(h->GetPrevStateIndex()); + } + + const EDState& edIn = in.get<EDState>(); + EDState& edOut = out.get<EDState>(); + + edOut.GetStates() = mblas::Assemble<mblas::byRow, mblas::Matrix>(edIn.GetStates(), beamStateIds); + decoder_->Lookup(edOut.GetEmbeddings(), beamWords); +} + + +void EncoderDecoder::GetAttention(mblas::Matrix& Attention) { + decoder_->GetAttention(Attention); +} + + +mblas::Matrix& EncoderDecoder::GetAttention() { + return decoder_->GetAttention(); +} + + +size_t EncoderDecoder::GetVocabSize() const { + return decoder_->GetVocabSize(); +} + + +void EncoderDecoder::Filter(const std::vector<size_t>& filterIds) { + decoder_->Filter(filterIds); +} + + +BaseMatrix& EncoderDecoder::GetProbs() { + return decoder_->GetProbs(); +} + +} +} +} |