Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-06-20 17:27:41 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-06-20 17:27:41 +0300
commit1604516232133ffcc5d7cc47f2c9cb48fe771609 (patch)
tree8a171ccb86a60f16f9761a1a66159e7944e9ac3b /src/amun/cpu/nematus/encoder_decoder.cpp
parent4a58c88d3373583f15664a69b94f809a4b1e77e4 (diff)
parentb357f082f3d550197bef448ae71f6df02bbbd2bf (diff)
merge
Diffstat (limited to 'src/amun/cpu/nematus/encoder_decoder.cpp')
-rw-r--r--src/amun/cpu/nematus/encoder_decoder.cpp98
1 files changed, 98 insertions, 0 deletions
diff --git a/src/amun/cpu/nematus/encoder_decoder.cpp b/src/amun/cpu/nematus/encoder_decoder.cpp
new file mode 100644
index 00000000..dd542dc0
--- /dev/null
+++ b/src/amun/cpu/nematus/encoder_decoder.cpp
@@ -0,0 +1,98 @@
+#include "cpu/nematus/encoder_decoder.h"
+
+#include <vector>
+#include <yaml-cpp/yaml.h>
+
+#include "common/sentence.h"
+#include "common/sentences.h"
+
+#include "cpu/decoder/encoder_decoder_loader.h"
+#include "cpu/mblas/matrix.h"
+
+using namespace std;
+
+namespace amunmt {
+namespace CPU {
+namespace Nematus {
+
+using EDState = EncoderDecoderState;
+
+EncoderDecoder::EncoderDecoder(const God &god,
+ const std::string& name,
+ const YAML::Node& config,
+ size_t tab,
+ const Nematus::Weights& model)
+ : CPUEncoderDecoderBase(god, name, config, tab),
+ model_(model),
+ encoder_(new CPU::Nematus::Encoder(model_)),
+ decoder_(new CPU::Nematus::Decoder(model_))
+{}
+
+
+void EncoderDecoder::Decode(const State& in, State& out, const std::vector<uint>&) {
+ const EDState& edIn = in.get<EDState>();
+ EDState& edOut = out.get<EDState>();
+
+ decoder_->Decode(edOut.GetStates(), edIn.GetStates(),
+ edIn.GetEmbeddings(), SourceContext_);
+}
+
+
+void EncoderDecoder::BeginSentenceState(State& state, size_t batchSize) {
+ EDState& edState = state.get<EDState>();
+ decoder_->EmptyState(edState.GetStates(), SourceContext_, batchSize);
+ decoder_->EmptyEmbedding(edState.GetEmbeddings(), batchSize);
+}
+
+
+void EncoderDecoder::SetSource(const Sentences& sources) {
+ encoder_->GetContext(sources.at(0)->GetWords(tab_),
+ SourceContext_);
+}
+
+
+void EncoderDecoder::AssembleBeamState(const State& in,
+ const Beam& beam,
+ State& out) {
+ std::vector<size_t> beamWords;
+ std::vector<size_t> beamStateIds;
+ for(auto h : beam) {
+ beamWords.push_back(h->GetWord());
+ beamStateIds.push_back(h->GetPrevStateIndex());
+ }
+
+ const EDState& edIn = in.get<EDState>();
+ EDState& edOut = out.get<EDState>();
+
+ edOut.GetStates() = mblas::Assemble<mblas::byRow, mblas::Matrix>(edIn.GetStates(), beamStateIds);
+ decoder_->Lookup(edOut.GetEmbeddings(), beamWords);
+}
+
+
+void EncoderDecoder::GetAttention(mblas::Matrix& Attention) {
+ decoder_->GetAttention(Attention);
+}
+
+
+mblas::Matrix& EncoderDecoder::GetAttention() {
+ return decoder_->GetAttention();
+}
+
+
+size_t EncoderDecoder::GetVocabSize() const {
+ return decoder_->GetVocabSize();
+}
+
+
+void EncoderDecoder::Filter(const std::vector<size_t>& filterIds) {
+ decoder_->Filter(filterIds);
+}
+
+
+BaseMatrix& EncoderDecoder::GetProbs() {
+ return decoder_->GetProbs();
+}
+
+}
+}
+}