Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-06-19 12:29:32 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-06-19 12:29:32 +0300
commit49bdf1d91e386c95a10ffcd935be2bc41a64d05b (patch)
tree37aa315c44f18336d0823b7bf83628ebe0affa69 /src/amun/cpu/nematus
parent1400139357944b1ed1d6366338dca56a328a2822 (diff)
Copy nematus model to new directory
Diffstat (limited to 'src/amun/cpu/nematus')
-rw-r--r--src/amun/cpu/nematus/decoder.cpp7
-rw-r--r--src/amun/cpu/nematus/decoder.h374
-rw-r--r--src/amun/cpu/nematus/dl4mt.h10
-rw-r--r--src/amun/cpu/nematus/encoder.cpp30
-rw-r--r--src/amun/cpu/nematus/encoder.h110
-rw-r--r--src/amun/cpu/nematus/gru.cpp6
-rw-r--r--src/amun/cpu/nematus/gru.h153
-rw-r--r--src/amun/cpu/nematus/model.cpp125
-rw-r--r--src/amun/cpu/nematus/model.h295
-rw-r--r--src/amun/cpu/nematus/transition.h104
10 files changed, 1214 insertions, 0 deletions
diff --git a/src/amun/cpu/nematus/decoder.cpp b/src/amun/cpu/nematus/decoder.cpp
new file mode 100644
index 00000000..36d9876a
--- /dev/null
+++ b/src/amun/cpu/nematus/decoder.cpp
@@ -0,0 +1,7 @@
+#include "decoder.h"
+
+
+namespace amunmt {
+
+}
+
diff --git a/src/amun/cpu/nematus/decoder.h b/src/amun/cpu/nematus/decoder.h
new file mode 100644
index 00000000..0e14d6e5
--- /dev/null
+++ b/src/amun/cpu/nematus/decoder.h
@@ -0,0 +1,374 @@
+#pragma once
+
+#include "../mblas/matrix.h"
+#include "model.h"
+#include "gru.h"
+#include "transition.h"
+#include "common/god.h"
+
+namespace amunmt {
+namespace CPU {
+
+class Decoder {
+ private:
+ template <class Weights>
+ class Embeddings {
+ public:
+ Embeddings(const Weights& model)
+ : w_(model)
+ {}
+
+ void Lookup(mblas::Matrix& Rows, const std::vector<size_t>& ids) {
+ using namespace mblas;
+ std::vector<size_t> tids = ids;
+ for (auto&& id : tids) {
+ if (id >= w_.E_.rows()) {
+ id = 1;
+ }
+ }
+ Rows = Assemble<byRow, Matrix>(w_.E_, tids);
+ }
+
+ size_t GetCols() {
+ return w_.E_.columns();
+ }
+
+ size_t GetRows() const {
+ return w_.E_.rows();
+ }
+
+ private:
+ const Weights& w_;
+ };
+
+ //////////////////////////////////////////////////////////////
+ template <class Weights1, class Weights2>
+ class RNNHidden {
+ public:
+ RNNHidden(const Weights1& initModel, const Weights2& gruModel)
+ : w_(initModel),
+ gru_(gruModel)
+ {}
+
+ void InitializeState(
+ mblas::Matrix& State,
+ const mblas::Matrix& SourceContext,
+ const size_t batchSize = 1)
+ {
+ using namespace mblas;
+
+ // Calculate mean of source context, rowwise
+ // Repeat mean batchSize times by broadcasting
+ Temp1_ = Mean<byRow, Matrix>(SourceContext);
+
+ Temp2_.resize(batchSize, SourceContext.columns());
+ Temp2_ = 0.0f;
+ AddBiasVector<byRow>(Temp2_, Temp1_);
+
+ State = Temp2_ * w_.Wi_;
+ AddBiasVector<byRow>(State, w_.Bi_);
+
+ if (w_.lns_.rows()) {
+ LayerNormalization(State, w_.lns_, w_.lnb_);
+ }
+ State = blaze::forEach(State, Tanh());
+ // std::cerr << "INIT: " << std::endl;
+ // for (int i = 0; i < 5; ++i) std::cerr << State(0, i) << " ";
+ // std::cerr << std::endl;
+ }
+
+ void GetNextState(mblas::Matrix& NextState,
+ const mblas::Matrix& State,
+ const mblas::Matrix& Context) {
+ gru_.GetNextState(NextState, State, Context);
+ }
+
+ private:
+ const Weights1& w_;
+ const GRU<Weights2> gru_;
+
+ mblas::Matrix Temp1_;
+ mblas::Matrix Temp2_;
+ };
+
+ //////////////////////////////////////////////////////////////
+ template <class WeightsGRU, class WeightsTrans>
+ class RNNFinal {
+ public:
+ RNNFinal(const WeightsGRU& modelGRU, const WeightsTrans& modelTrans)
+ : gru_(modelGRU),
+ transition_(modelTrans)
+ {}
+
+ void GetNextState(
+ mblas::Matrix& nextState,
+ const mblas::Matrix& state,
+ const mblas::Matrix& context)
+ {
+ gru_.GetNextState(nextState, state, context);
+ transition_.GetNextState(nextState);
+ // std::cerr << "TRANS: " << std::endl;
+ // for (int i = 0; i < 10; ++i) std::cerr << nextState(0, i) << " ";
+ // std::cerr << std::endl;
+ }
+
+ private:
+ const GRU<WeightsGRU> gru_;
+ const Transition<WeightsTrans> transition_;
+ };
+
+ //////////////////////////////////////////////////////////////
+ template <class Weights>
+ class Attention {
+ public:
+ Attention(const Weights& model)
+ : w_(model)
+ {
+ V_ = blaze::trans(blaze::row(w_.V_, 0));
+ }
+
+ void Init(const mblas::Matrix& SourceContext) {
+ using namespace mblas;
+ SCU_ = SourceContext * w_.U_;
+ mblas::AddBiasVector<mblas::byRow>(SCU_, w_.B_);
+
+ if (w_.Wc_att_lns_.rows()) {
+ LayerNormalization(SCU_, w_.Wc_att_lns_, w_.Wc_att_lnb_);
+ }
+ }
+
+ void GetAlignedSourceContext(
+ mblas::Matrix& AlignedSourceContext,
+ const mblas::Matrix& HiddenState,
+ const mblas::Matrix& SourceContext)
+ {
+ using namespace mblas;
+
+ Temp2_ = HiddenState * w_.W_;
+ if (w_.W_comb_lns_.rows()) {
+ LayerNormalization(Temp2_, w_.W_comb_lns_, w_.W_comb_lnb_);
+ }
+
+ Temp1_ = Broadcast<Matrix>(Tanh(), SCU_, Temp2_);
+
+ A_.resize(Temp1_.rows(), 1);
+ blaze::column(A_, 0) = Temp1_ * V_;
+ size_t words = SourceContext.rows();
+ // batch size, for batching, divide by numer of sentences
+ size_t batchSize = HiddenState.rows();
+ Reshape(A_, batchSize, words); // due to broadcasting above
+
+ float bias = w_.C_(0,0);
+ blaze::forEach(A_, [=](float x) { return x + bias; });
+
+ mblas::SafeSoftmax(A_);
+ AlignedSourceContext = A_ * SourceContext;
+ }
+
+ void GetAttention(mblas::Matrix& Attention) {
+ Attention = A_;
+ }
+
+ mblas::Matrix& GetAttention() {
+ return A_;
+ }
+
+ private:
+ const Weights& w_;
+
+ mblas::Matrix SCU_;
+ mblas::Matrix Temp1_;
+ mblas::Matrix Temp2_;
+ mblas::Matrix A_;
+ mblas::ColumnVector V_;
+ };
+
+ //////////////////////////////////////////////////////////////
+ template <class Weights>
+ class Softmax {
+ public:
+ Softmax(const Weights& model)
+ : w_(model),
+ filtered_(false)
+ {}
+
+ void GetProbs(mblas::ArrayMatrix& Probs,
+ const mblas::Matrix& State,
+ const mblas::Matrix& Embedding,
+ const mblas::Matrix& AlignedSourceContext) {
+ using namespace mblas;
+
+ T1_ = State * w_.W1_;
+ AddBiasVector<byRow>(T1_, w_.B1_);
+ if (w_.lns_1_.rows()) {
+ LayerNormalization(T1_, w_.lns_1_, w_.lnb_1_);
+ }
+ // std::cerr << "State" << std::endl;
+ // for(int i = 0; i < 5; ++i) std::cerr << T1_(0, i) << " ";
+ // std::cerr << std::endl;
+
+ T2_ = Embedding * w_.W2_;
+ AddBiasVector<byRow>(T2_, w_.B2_);
+ if (w_.lns_2_.rows()) {
+ LayerNormalization(T2_, w_.lns_2_, w_.lnb_2_);
+ }
+ // std::cerr << "emb" << std::endl;
+ // for(int i = 0; i < 5; ++i) std::cerr << T2_(0, i) << " ";
+ // std::cerr << std::endl;
+
+ T3_ = AlignedSourceContext * w_.W3_;
+ AddBiasVector<byRow>(T3_, w_.B3_);
+ if (w_.lns_3_.rows()) {
+ LayerNormalization(T3_, w_.lns_3_, w_.lnb_3_);
+ }
+ // std::cerr << "CTX" << std::endl;
+ // for(int i = 0; i < 5; ++i) std::cerr << T3_(0, i) << " ";
+ // std::cerr << std::endl;
+
+ auto t = blaze::forEach(T1_ + T2_ + T3_, Tanh());
+
+ if(!filtered_) {
+ Probs = t * w_.W4_;
+ AddBiasVector<byRow>(Probs, w_.B4_);
+ } else {
+ Probs = t * FilteredW4_;
+ AddBiasVector<byRow>(Probs, FilteredB4_);
+ }
+ // std::cerr << "LOgit" << std::endl;
+ // for(int i = 0; i < 5; ++i) std::cerr << Probs(0, i) << " ";
+ // std::cerr << std::endl;
+ LogSoftmax(Probs);
+ }
+
+ void Filter(const std::vector<size_t>& ids) {
+ filtered_ = true;
+ using namespace mblas;
+ FilteredW4_ = Assemble<byColumn, Matrix>(w_.W4_, ids);
+ FilteredB4_ = Assemble<byColumn, Matrix>(w_.B4_, ids);
+ }
+
+ private:
+ const Weights& w_;
+ bool filtered_;
+
+ mblas::Matrix FilteredW4_;
+ mblas::Matrix FilteredB4_;
+
+ mblas::Matrix T1_;
+ mblas::Matrix T2_;
+ mblas::Matrix T3_;
+ };
+
+ public:
+ Decoder(const Weights& model)
+ : embeddings_(model.decEmbeddings_),
+ rnn1_(model.decInit_, model.decGru1_),
+ rnn2_(model.decGru2_, model.decTransition_),
+ attention_(model.decAttention_),
+ softmax_(model.decSoftmax_)
+ {}
+
+ void Decode(
+ mblas::Matrix& NextState,
+ const mblas::Matrix& State,
+ const mblas::Matrix& Embeddings,
+ const mblas::Matrix& SourceContext)
+ {
+ GetHiddenState(HiddenState_, State, Embeddings);
+ // std::cerr << "HIDDEN: " << std::endl;
+ // for (int i = 0; i < 5; ++i) std::cerr << HiddenState_(0, i) << " ";
+ // std::cerr << std::endl;
+
+ GetAlignedSourceContext(AlignedSourceContext_, HiddenState_, SourceContext);
+ // std::cerr << "ALIGNED SRC: " << std::endl;
+ // for (int i = 0; i < 5; ++i) std::cerr << AlignedSourceContext_(0, i) << " ";
+ // std::cerr << std::endl;
+
+ GetNextState(NextState, HiddenState_, AlignedSourceContext_);
+ // std::cerr << "NEXT: " << std::endl;
+ // for (int i = 0; i < 5; ++i) std::cerr << NextState(0, i) << " ";
+ // std::cerr << std::endl;
+
+ GetProbs(NextState, Embeddings, AlignedSourceContext_);
+ }
+
+ mblas::ArrayMatrix& GetProbs() {
+ return Probs_;
+ }
+
+ void EmptyState(mblas::Matrix& State,
+ const mblas::Matrix& SourceContext,
+ size_t batchSize = 1) {
+ rnn1_.InitializeState(State, SourceContext, batchSize);
+ attention_.Init(SourceContext);
+ }
+
+ void EmptyEmbedding(mblas::Matrix& Embedding,
+ size_t batchSize = 1) {
+ Embedding.resize(batchSize, embeddings_.GetCols());
+ Embedding = 0.0f;
+ }
+
+ void Lookup(mblas::Matrix& Embedding,
+ const std::vector<size_t>& w) {
+ embeddings_.Lookup(Embedding, w);
+ }
+
+ void Filter(const std::vector<size_t>& ids) {
+ softmax_.Filter(ids);
+ }
+
+ void GetAttention(mblas::Matrix& attention) {
+ attention_.GetAttention(attention);
+ }
+
+ mblas::Matrix& GetAttention() {
+ return attention_.GetAttention();
+ }
+
+ size_t GetVocabSize() const {
+ return embeddings_.GetRows();
+ }
+
+ private:
+
+ void GetHiddenState(mblas::Matrix& HiddenState,
+ const mblas::Matrix& PrevState,
+ const mblas::Matrix& Embedding) {
+ rnn1_.GetNextState(HiddenState, PrevState, Embedding);
+ }
+
+ void GetAlignedSourceContext(mblas::Matrix& AlignedSourceContext,
+ const mblas::Matrix& HiddenState,
+ const mblas::Matrix& SourceContext) {
+ attention_.GetAlignedSourceContext(AlignedSourceContext, HiddenState, SourceContext);
+ }
+
+ void GetNextState(mblas::Matrix& State,
+ const mblas::Matrix& HiddenState,
+ const mblas::Matrix& AlignedSourceContext) {
+ rnn2_.GetNextState(State, HiddenState, AlignedSourceContext);
+ }
+
+
+ void GetProbs(const mblas::Matrix& State,
+ const mblas::Matrix& Embedding,
+ const mblas::Matrix& AlignedSourceContext) {
+ softmax_.GetProbs(Probs_, State, Embedding, AlignedSourceContext);
+ }
+
+ private:
+ mblas::Matrix HiddenState_;
+ mblas::Matrix AlignedSourceContext_;
+ mblas::ArrayMatrix Probs_;
+
+ Embeddings<Weights::Embeddings> embeddings_;
+ RNNHidden<Weights::DecInit, Weights::GRU> rnn1_;
+ RNNFinal<Weights::DecGRU2, Weights::Transition> rnn2_;
+ Attention<Weights::DecAttention> attention_;
+ Softmax<Weights::DecSoftmax> softmax_;
+};
+
+}
+}
+
diff --git a/src/amun/cpu/nematus/dl4mt.h b/src/amun/cpu/nematus/dl4mt.h
new file mode 100644
index 00000000..ed4b90bf
--- /dev/null
+++ b/src/amun/cpu/nematus/dl4mt.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include "model.h"
+#include "encoder.h"
+#include "decoder.h"
+
+namespace amunmt {
+
+}
+
diff --git a/src/amun/cpu/nematus/encoder.cpp b/src/amun/cpu/nematus/encoder.cpp
new file mode 100644
index 00000000..032528fa
--- /dev/null
+++ b/src/amun/cpu/nematus/encoder.cpp
@@ -0,0 +1,30 @@
+#include "encoder.h"
+
+using namespace std;
+
+namespace amunmt {
+namespace CPU {
+
+void Encoder::GetContext(const std::vector<size_t>& words, mblas::Matrix& context) {
+ std::vector<mblas::Matrix> embeddedWords;
+
+ context.resize(words.size(),
+ forwardRnn_.GetStateLength() + backwardRnn_.GetStateLength());
+
+ for (auto& w : words) {
+ embeddedWords.emplace_back();
+ mblas::Matrix &embed = embeddedWords.back();
+ embeddings_.Lookup(embed, w);
+ }
+
+ forwardRnn_.GetContext(embeddedWords.cbegin(),
+ embeddedWords.cend(),
+ context, false);
+ backwardRnn_.GetContext(embeddedWords.crbegin(),
+ embeddedWords.crend(),
+ context, true);
+}
+
+} // namespace CPU
+} // namespace amunmt
+
diff --git a/src/amun/cpu/nematus/encoder.h b/src/amun/cpu/nematus/encoder.h
new file mode 100644
index 00000000..31266e42
--- /dev/null
+++ b/src/amun/cpu/nematus/encoder.h
@@ -0,0 +1,110 @@
+#pragma once
+
+#include "../mblas/matrix.h"
+#include "model.h"
+#include "gru.h"
+#include "transition.h"
+
+namespace amunmt {
+namespace CPU {
+
+class Encoder {
+ private:
+
+ /////////////////////////////////////////////////////////////////
+ template <class Weights>
+ class Embeddings {
+ public:
+ Embeddings(const Weights& model)
+ : w_(model)
+ {}
+
+ void Lookup(mblas::Matrix& Row, size_t i) {
+ size_t len = w_.E_.columns();
+ if(i < w_.E_.rows())
+ Row = blaze::submatrix(w_.E_, i, 0, 1, len);
+ else
+ Row = blaze::submatrix(w_.E_, 1, 0, 1, len); // UNK
+ }
+
+ const Weights& w_;
+ private:
+ };
+
+ /////////////////////////////////////////////////////////////////
+ template <class WeightsGRU, class WeightsTrans>
+ class EncoderRNN {
+ public:
+ EncoderRNN(const WeightsGRU& modelGRU, const WeightsTrans& modelTrans)
+ : gru_(modelGRU),
+ transition_(modelTrans)
+ {}
+
+ void InitializeState(size_t batchSize = 1) {
+ State_.resize(batchSize, gru_.GetStateLength());
+ State_ = 0.0f;
+ }
+
+ void GetNextState(mblas::Matrix& nextState,
+ const mblas::Matrix& state,
+ const mblas::Matrix& embd) {
+ gru_.GetNextState(nextState, state, embd);
+ // std::cerr << "GRU: " << std::endl;
+ // for (int i = 0; i < 10; ++i) std::cerr << nextState(0, i) << " ";
+ // std::cerr << std::endl;
+ transition_.GetNextState(nextState);
+ // std::cerr << "TRANS: " << std::endl;
+ // for (int i = 0; i < 10; ++i) std::cerr << nextState(0, i) << " ";
+ // std::cerr << std::endl;
+ }
+
+ template <class It>
+ void GetContext(It it, It end, mblas::Matrix& Context, bool invert) {
+ InitializeState();
+
+ size_t n = std::distance(it, end);
+ size_t i = 0;
+ while(it != end) {
+ GetNextState(State_, State_, *it++);
+
+ size_t len = gru_.GetStateLength();
+ if(invert)
+ blaze::submatrix(Context, n - i - 1, len, 1, len) = State_;
+ else
+ blaze::submatrix(Context, i, 0, 1, len) = State_;
+ ++i;
+ }
+ }
+
+ size_t GetStateLength() const {
+ return gru_.GetStateLength();
+ }
+
+ private:
+ // Model matrices
+ const GRU<WeightsGRU> gru_;
+ const Transition<WeightsTrans> transition_;
+
+ mblas::Matrix State_;
+ };
+
+ /////////////////////////////////////////////////////////////////
+ public:
+ Encoder(const Weights& model)
+ : embeddings_(model.encEmbeddings_),
+ forwardRnn_(model.encForwardGRU_, model.encForwardTransition_),
+ backwardRnn_(model.encBackwardGRU_, model.encBackwardTransition_)
+ {}
+
+ void GetContext(const std::vector<size_t>& words,
+ mblas::Matrix& context);
+
+ private:
+ Embeddings<Weights::Embeddings> embeddings_;
+ EncoderRNN<Weights::GRU, Weights::Transition> forwardRnn_;
+ EncoderRNN<Weights::GRU, Weights::Transition> backwardRnn_;
+};
+
+}
+}
+
diff --git a/src/amun/cpu/nematus/gru.cpp b/src/amun/cpu/nematus/gru.cpp
new file mode 100644
index 00000000..778659aa
--- /dev/null
+++ b/src/amun/cpu/nematus/gru.cpp
@@ -0,0 +1,6 @@
+#include "gru.h"
+
+namespace amunmt {
+
+}
+
diff --git a/src/amun/cpu/nematus/gru.h b/src/amun/cpu/nematus/gru.h
new file mode 100644
index 00000000..33e855fd
--- /dev/null
+++ b/src/amun/cpu/nematus/gru.h
@@ -0,0 +1,153 @@
+#pragma once
+#include "cpu/mblas/matrix.h"
+#include <iomanip>
+
+namespace amunmt {
+namespace CPU {
+
+template <class Weights>
+class GRU {
+ public:
+ GRU(const Weights& model)
+ : w_(model),
+ layerNormalization_(w_.W_lns_.rows())
+ {
+ if (!layerNormalization_) {
+ WWx_ = mblas::Concat<mblas::byColumn, mblas::Matrix>(w_.W_, w_.Wx_);
+ UUx_ = mblas::Concat<mblas::byColumn, mblas::Matrix>(w_.U_, w_.Ux_);
+ }
+ }
+
+ void GetNextState(
+ mblas::Matrix& nextState,
+ const mblas::Matrix& state,
+ const mblas::Matrix& context) const
+ {
+ // std::cerr << "Get next state" << std::endl;
+ if (layerNormalization_) {
+ RUH_1_ = context * w_.W_;
+ mblas::AddBiasVector<mblas::byRow>(RUH_1_, w_.B_);
+ LayerNormalization(RUH_1_, w_.W_lns_, w_.W_lnb_);
+
+ RUH_2_ = context * w_.Wx_;
+ mblas::AddBiasVector<mblas::byRow>(RUH_2_, w_.Bx1_);
+ LayerNormalization(RUH_2_, w_.Wx_lns_, w_.Wx_lnb_);
+
+ RUH_ = mblas::Concat<mblas::byColumn, mblas::Matrix>(RUH_1_, RUH_2_);
+
+ Temp_1_ = state * w_.U_;
+ mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.Bx3_);
+ LayerNormalization(Temp_1_, w_.U_lns_, w_.U_lnb_);
+
+ Temp_2_ = state * w_.Ux_;
+ mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx2_);
+ LayerNormalization(Temp_2_, w_.Ux_lns_, w_.Ux_lnb_);
+
+ Temp_ = mblas::Concat<mblas::byColumn, mblas::Matrix>(Temp_1_, Temp_2_);
+
+ ElementwiseOpsLayerNorm(nextState, state);
+
+ } else {
+ RUH_ = context * WWx_;
+ Temp_ = state * UUx_;
+ ElementwiseOps(nextState, state);
+ }
+ }
+
+ void ElementwiseOps(mblas::Matrix& NextState, const mblas::Matrix& State) const {
+ using namespace mblas;
+ using namespace blaze;
+
+ const int rowNo = State.rows();
+ const int colNo = State.columns();
+ NextState.resize(rowNo, colNo);
+
+ for (int j = 0; j < rowNo; ++j) {
+ auto rowOut = row(NextState, j);
+ auto rowState = row(State, j);
+
+ auto rowRuh = row(RUH_, j);
+ auto rowT = row(Temp_, j);
+
+ auto rowH = subvector(rowRuh, 2 * colNo, colNo);
+ auto rowT2 = subvector(rowT, 2 * colNo, colNo);
+
+ for (int i = 0; i < colNo; ++i) {
+ float ev1 = expapprox(-(rowRuh[i] + w_.B_(0, i) + rowT[i]));
+ float r = 1.0f / (1.0f + ev1);
+
+ int k = i + colNo;
+ float ev2 = expapprox(-(rowRuh[k] + w_.B_(0, k) + rowT[k]));
+ float u = 1.0f / (1.0f + ev2);
+
+ float hv = rowH[i] + w_.Bx1_(0, i);
+ float t2v = rowT2[i];
+ hv = tanhapprox(hv + r * t2v);
+ rowOut[i] = (1.0f - u) * hv + u * rowState[i];
+ }
+ }
+ }
+
+ void ElementwiseOpsLayerNorm(mblas::Matrix& NextState, const mblas::Matrix& State) const {
+ using namespace mblas;
+ using namespace blaze;
+
+ const int rowNo = State.rows();
+ const int colNo = State.columns();
+ NextState.resize(rowNo, colNo);
+
+ for (int j = 0; j < rowNo; ++j) {
+ auto rowOut = row(NextState, j);
+ auto rowState = row(State, j);
+
+ auto rowRuh = row(RUH_, j);
+ auto rowT = row(Temp_, j);
+
+ auto rowH = subvector(rowRuh, 2 * colNo, colNo);
+ auto rowT2 = subvector(rowT, 2 * colNo, colNo);
+
+ for (int i = 0; i < colNo; ++i) {
+ float ev1 = expapprox(-(rowRuh[i] + rowT[i]));
+ float r = 1.0f / (1.0f + ev1);
+
+ int k = i + colNo;
+ float ev2 = expapprox(-(rowRuh[k] + rowT[k]));
+ float u = 1.0f / (1.0f + ev2);
+
+ float hv = rowH[i];
+ float t2v = rowT2[i] + w_.Bx2_(0, i);
+ hv = tanhapprox(hv + r * t2v);
+ rowOut[i] = (1.0f - u) * hv + u * rowState[i];
+ }
+ }
+ }
+ size_t GetStateLength() const {
+ return w_.U_.rows();
+ }
+
+
+ private:
+ // Model matrices
+ const Weights& w_;
+ mutable mblas::Matrix WWx_;
+ mutable mblas::Matrix UUx_;
+ mutable mblas::Matrix Wbbx_;
+ mutable mblas::Matrix lns_WWx_;
+ mutable mblas::Matrix lns_UUx_;
+ mutable mblas::Matrix lnb_WWx_;
+ mutable mblas::Matrix lnb_UUx_;
+
+ // reused to avoid allocation
+ mutable mblas::Matrix RUH_;
+ mutable mblas::Matrix RUH_1_;
+ mutable mblas::Matrix RUH_2_;
+ mutable mblas::Matrix Temp_;
+ mutable mblas::Matrix Temp_1_;
+ mutable mblas::Matrix Temp_2_;
+
+ bool layerNormalization_;
+};
+
+}
+}
+
diff --git a/src/amun/cpu/nematus/model.cpp b/src/amun/cpu/nematus/model.cpp
new file mode 100644
index 00000000..71b8694a
--- /dev/null
+++ b/src/amun/cpu/nematus/model.cpp
@@ -0,0 +1,125 @@
+#include "model.h"
+
+using namespace std;
+
+namespace amunmt {
+namespace CPU {
+
+Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key)
+ : E_(model[key])
+{}
+
+Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
+ : E_(model.getFirstOfMany(keys))
+{}
+
+Weights::GRU::GRU(const NpzConverter& model, std::string prefix, std::vector<std::string> keys)
+ : W_(model[prefix + keys.at(0)]),
+ B_(model(prefix + keys.at(1), true)),
+ U_(model[prefix + keys.at(2)]),
+ Wx_(model[prefix + keys.at(3)]),
+ Bx1_(model(prefix + keys.at(4), true)),
+ Bx2_(Bx1_.rows(), Bx1_.columns()),
+ Bx3_(B_.rows(), B_.columns()),
+ Ux_(model[prefix + keys.at(5)]),
+ W_lns_(model[prefix + keys.at(6)]),
+ W_lnb_(model[prefix + keys.at(7)]),
+ Wx_lns_(model[prefix + keys.at(8)]),
+ Wx_lnb_(model[prefix + keys.at(9)]),
+ U_lns_(model[prefix + keys.at(10)]),
+ U_lnb_(model[prefix + keys.at(11)]),
+ Ux_lns_(model[prefix + keys.at(12)]),
+ Ux_lnb_(model[prefix + keys.at(13)])
+{
+ const_cast<mblas::Matrix&>(Bx2_) = 0.0f;
+ const_cast<mblas::Matrix&>(Bx3_) = 0.0f;
+}
+
+//////////////////////////////////////////////////////////////////////////////
+
+Weights::DecInit::DecInit(const NpzConverter& model)
+ : Wi_(model["ff_state_W"]),
+ Bi_(model("ff_state_b", true)),
+ lns_(model["ff_state_ln_s"]),
+ lnb_(model["ff_state_ln_b"])
+{}
+
+
+Weights::DecGRU2::DecGRU2(const NpzConverter& model, std::string prefix, std::vector<std::string> keys)
+ : W_(model[prefix + keys.at(0)]), // Wc
+ B_(1, W_.Cols()),
+ U_(model[prefix + keys.at(1)]), // U_nl
+ Bx3_(model(prefix + keys.at(2), true)), // b_nl
+ Wx_(model[prefix + keys.at(3)]), // Wcx
+ Bx1_(1, Wx_.Cols()),
+ Ux_(model[prefix + keys.at(4)]), // Ux_nl
+ Bx2_(model(prefix + keys.at(5), true)), // bx_nl
+ W_lns_(model[prefix + keys.at(6)]), // Wc_lns
+ W_lnb_(model[prefix + keys.at(7)]), // Wc_nlb
+ Wx_lns_(model[prefix + keys.at(8)]), // Wcx_lns
+ Wx_lnb_(model[prefix + keys.at(9)]), // Wcx_lnb
+ U_lns_(model[prefix + keys.at(10)]), // U_nl_lns
+ U_lnb_(model[prefix + keys.at(11)]), // U_nl_lnb
+ Ux_lns_(model[prefix + keys.at(12)]), // Ux_nl_lns
+ Ux_lnb_(model[prefix + keys.at(13)]) // Ux_nl_lnb
+
+{
+ const_cast<mblas::Matrix&>(B_) = 0.0f;
+ const_cast<mblas::Matrix&>(Bx1_) = 0.0f;
+}
+
+Weights::DecAttention::DecAttention(const NpzConverter& model)
+ : V_(model("decoder_U_att", true)),
+ W_(model["decoder_W_comb_att"]),
+ B_(model("decoder_b_att", true)),
+ U_(model["decoder_Wc_att"]),
+ C_(model["decoder_c_tt"]),
+ Wc_att_lns_(model["decoder_Wc_att_lns"]),
+ Wc_att_lnb_(model["decoder_Wc_att_lnb"]),
+ W_comb_lns_(model["decoder_W_comb_att_lns"]),
+ W_comb_lnb_(model["decoder_W_comb_att_lnb"])
+{}
+
+Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
+ : W1_(model["ff_logit_lstm_W"]),
+ B1_(model("ff_logit_lstm_b", true)),
+ W2_(model["ff_logit_prev_W"]),
+ B2_(model("ff_logit_prev_b", true)),
+ W3_(model["ff_logit_ctx_W"]),
+ B3_(model("ff_logit_ctx_b", true)),
+ W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
+ std::make_pair(std::string("Wemb_dec"), true)})),
+ B4_(model("ff_logit_b", true)),
+ lns_1_(model["ff_logit_lstm_ln_s"]),
+ lns_2_(model["ff_logit_prev_ln_s"]),
+ lns_3_(model["ff_logit_ctx_ln_s"]),
+ lnb_1_(model["ff_logit_lstm_ln_b"]),
+ lnb_2_(model["ff_logit_prev_ln_b"]),
+ lnb_3_(model["ff_logit_ctx_ln_b"])
+{}
+
+//////////////////////////////////////////////////////////////////////////////
+
+Weights::Weights(const NpzConverter& model, size_t)
+ : encEmbeddings_(model, "Wemb"),
+ decEmbeddings_(model, {std::make_pair(std::string("Wemb_dec"), false),
+ std::make_pair(std::string("Wemb"), false)}),
+ encForwardGRU_(model, "encoder_", {"W", "b", "U", "Wx", "bx", "Ux", "W_lns", "W_lnb", "Wx_lns",
+ "Wx_lnb", "U_lns", "U_lnb", "Ux_lns", "Ux_lnb" }),
+ encBackwardGRU_(model, "encoder_r_", {"W", "b", "U", "Wx", "bx", "Ux", "W_lns", "W_lnb",
+ "Wx_lns", "Wx_lnb", "U_lns", "U_lnb", "Ux_lns", "Ux_lnb" }),
+ decInit_(model),
+ decGru1_(model, "decoder_", {"W", "b", "U", "Wx", "bx", "Ux", "W_lns", "W_lnb", "Wx_lns",
+ "Wx_lnb", "U_lns", "U_lnb", "Ux_lns", "Ux_lnb" }),
+ decGru2_(model, "decoder_", {"Wc", "U_nl", "b_nl", "Wcx", "Ux_nl", "bx_nl", "Wc_lns", "Wc_lnb",
+ "Wcx_lns", "Wcx_lnb", "U_nl_lns", "U_nl_lnb", "Ux_nl_lns",
+ "Ux_nl_lnb"}),
+ decAttention_(model),
+ decSoftmax_(model),
+ encForwardTransition_(model, Weights::Transition::TransitionType::Encoder, "encoder_"),
+ encBackwardTransition_(model,Weights::Transition::TransitionType::Encoder, "encoder_r_"),
+ decTransition_(model, Weights::Transition::TransitionType::Decoder, "decoder_", "_nl")
+{}
+
+} // namespace cpu
+} // namespace amunmt
diff --git a/src/amun/cpu/nematus/model.h b/src/amun/cpu/nematus/model.h
new file mode 100644
index 00000000..e4bded82
--- /dev/null
+++ b/src/amun/cpu/nematus/model.h
@@ -0,0 +1,295 @@
+#pragma once
+
+#include <iostream>
+#include <map>
+#include <string>
+
+#include "../npz_converter.h"
+
+#include "../mblas/matrix.h"
+
+namespace amunmt {
+namespace CPU {
+
+struct Weights {
+ class Transition {
+ public:
+ enum class TransitionType {Encoder, Decoder};
+ Transition(const NpzConverter& model, TransitionType type, std::string prefix,
+ std::string infix="")
+ : depth_(findTransitionDepth(model, prefix, infix)), type_(type)
+ {
+ for (int i = 1; i <= depth_; ++i) {
+ U_.emplace_back(model[name(prefix, "U", infix, i)]);
+ Ux_.emplace_back(model[name(prefix, "Ux", infix, i)]);
+ B_.emplace_back(model(name(prefix, "b", infix, i), true));
+ U_lns_.emplace_back(model[name(prefix, "U", infix, i, "_lns")]);
+ U_lnb_.emplace_back(model[name(prefix, "U", infix, i, "_lnb")]);
+ Ux_lns_.emplace_back(model[name(prefix, "Ux", infix, i, "_lns")]);
+ Ux_lnb_.emplace_back(model[name(prefix, "Ux", infix, i, "_lnb")]);
+ // decoder_U_nl_drt_4_lnb
+ switch(type) {
+ case TransitionType::Encoder:
+ Bx1_.emplace_back(1, Ux_.back().Cols());
+ const_cast<mblas::Matrix&>(Bx1_.back()) = 0.0f;
+ Bx2_.emplace_back(model(name(prefix, "bx", infix, i), true));
+ break;
+ case TransitionType::Decoder:
+ Bx1_.emplace_back(model(name(prefix, "bx", infix, i), true));
+ Bx2_.emplace_back(1, Ux_.back().Cols());
+ const_cast<mblas::Matrix&>(Bx2_.back()) = 0.0f;
+ break;
+ }
+ }
+ }
+
+ static int findTransitionDepth(const NpzConverter& model, std::string prefix, std::string infix) {
+ int currentDepth = 0;
+ while (true) {
+ if (model.has(prefix + "b" + infix + "_drt_" + std::to_string(currentDepth + 1))) {
+ ++currentDepth;
+ } else {
+ break;
+ }
+ }
+ std::cerr << "Found transition depth: " << currentDepth << std::endl;
+ return currentDepth;
+ }
+
+ int size() const {
+ return depth_;
+ }
+
+ TransitionType type() const {
+ return type_;
+ }
+
+ protected:
+ std::string name(const std::string& prefix, std::string name, std::string infix, int index,
+ std::string suffix = "")
+ {
+ return prefix + name + infix + "_drt_" + std::to_string(index) + suffix;
+ }
+
+ private:
+ int depth_;
+ TransitionType type_;
+
+ public:
+ std::vector<mblas::Matrix> B_;
+ std::vector<mblas::Matrix> Bx1_;
+ std::vector<mblas::Matrix> Bx2_;
+ std::vector<mblas::Matrix> U_;
+ std::vector<mblas::Matrix> Ux_;
+
+ std::vector<mblas::Matrix> U_lns_;
+ std::vector<mblas::Matrix> U_lnb_;
+ std::vector<mblas::Matrix> Ux_lns_;
+ std::vector<mblas::Matrix> Ux_lnb_;
+
+ };
+
+ struct Embeddings {
+ Embeddings(const NpzConverter& model, const std::string &key);
+ Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
+
+ const mblas::Matrix E_;
+ };
+
+ struct GRU {
+ GRU(const NpzConverter& model, std::string prefix, std::vector<std::string> keys);
+
+ const mblas::Matrix W_;
+ const mblas::Matrix B_;
+ const mblas::Matrix U_;
+ const mblas::Matrix Wx_;
+ const mblas::Matrix Bx1_;
+ const mblas::Matrix Bx2_;
+ const mblas::Matrix Bx3_;
+ const mblas::Matrix Ux_;
+
+ const mblas::Matrix W_lns_;
+ const mblas::Matrix W_lnb_;
+ const mblas::Matrix Wx_lns_;
+ const mblas::Matrix Wx_lnb_;
+ const mblas::Matrix U_lns_;
+ const mblas::Matrix U_lnb_;
+ const mblas::Matrix Ux_lns_;
+ const mblas::Matrix Ux_lnb_;
+ };
+
+ struct DecInit {
+ DecInit(const NpzConverter& model);
+
+ const mblas::Matrix Wi_;
+ const mblas::Matrix Bi_;
+ const mblas::Matrix lns_;
+ const mblas::Matrix lnb_;
+ };
+
+ struct DecGRU2 {
+ DecGRU2(const NpzConverter& model, std::string prefix, std::vector<std::string> keys);
+
+ const mblas::Matrix W_;
+ const mblas::Matrix B_;
+ const mblas::Matrix U_;
+ const mblas::Matrix Wx_;
+ const mblas::Matrix Bx3_;
+ const mblas::Matrix Bx2_;
+ const mblas::Matrix Bx1_;
+ const mblas::Matrix Ux_;
+
+ const mblas::Matrix W_lns_;
+ const mblas::Matrix W_lnb_;
+ const mblas::Matrix Wx_lns_;
+ const mblas::Matrix Wx_lnb_;
+ const mblas::Matrix U_lns_;
+ const mblas::Matrix U_lnb_;
+ const mblas::Matrix Ux_lns_;
+ const mblas::Matrix Ux_lnb_;
+ };
+
+ struct DecAttention {
+ DecAttention(const NpzConverter& model);
+
+ const mblas::Matrix V_;
+ const mblas::Matrix W_;
+ const mblas::Matrix B_;
+ const mblas::Matrix U_;
+ const mblas::Matrix C_;
+ const mblas::Matrix Wc_att_lns_;
+ const mblas::Matrix Wc_att_lnb_;
+ const mblas::Matrix W_comb_lns_;
+ const mblas::Matrix W_comb_lnb_;
+ };
+
+ struct DecSoftmax {
+ DecSoftmax(const NpzConverter& model);
+
+ const mblas::Matrix W1_;
+ const mblas::Matrix B1_;
+ const mblas::Matrix W2_;
+ const mblas::Matrix B2_;
+ const mblas::Matrix W3_;
+ const mblas::Matrix B3_;
+ const mblas::Matrix W4_;
+ const mblas::Matrix B4_;
+ const mblas::Matrix lns_1_;
+ const mblas::Matrix lns_2_;
+ const mblas::Matrix lns_3_;
+ const mblas::Matrix lnb_1_;
+ const mblas::Matrix lnb_2_;
+ const mblas::Matrix lnb_3_;
+ };
+
+
+ Weights(const std::string& npzFile, size_t device = 0)
+ : Weights(NpzConverter(npzFile), device)
+ {}
+
+ Weights(const NpzConverter& model, size_t device = 0);
+
+ size_t GetDevice() {
+ return std::numeric_limits<size_t>::max();
+ }
+
+ const Embeddings encEmbeddings_;
+ const Embeddings decEmbeddings_;
+ const GRU encForwardGRU_;
+ const GRU encBackwardGRU_;
+ const DecInit decInit_;
+ const GRU decGru1_;
+ const DecGRU2 decGru2_;
+ const DecAttention decAttention_;
+ const DecSoftmax decSoftmax_;
+ const Transition encForwardTransition_;
+ const Transition encBackwardTransition_;
+ const Transition decTransition_;
+};
+
+inline std::ostream& operator<<(std::ostream &out, const Weights::Embeddings &obj)
+{
+ out << "E_ \t" << obj.E_;
+ return out;
+}
+
+inline std::ostream& operator<<(std::ostream &out, const Weights::GRU &obj)
+{
+ out << "W_ \t" << obj.W_ << std::endl;
+ out << "B_ \t" << obj.B_ << std::endl;
+ out << "U_ \t" << obj.U_ << std::endl;
+ out << "Wx_ \t" << obj.Wx_ << std::endl;
+ out << "Bx1_ \t" << obj.Bx1_ << std::endl;
+ out << "Bx2_ \t" << obj.Bx2_ << std::endl;
+ out << "Ux_ \t" << obj.Ux_;
+ return out;
+}
+
+inline std::ostream& operator<<(std::ostream &out, const Weights::DecGRU2 &obj)
+{
+ out << "W_ \t" << obj.W_ << std::endl;
+ out << "B_ \t" << obj.B_ << std::endl;
+ out << "U_ \t" << obj.U_ << std::endl;
+ out << "Wx_ \t" << obj.Wx_ << std::endl;
+ out << "Bx1_ \t" << obj.Bx1_ << std::endl;
+ out << "Bx2_ \t" << obj.Bx2_ << std::endl;
+ out << "Ux_ \t" << obj.Ux_;
+ return out;
+}
+
+inline std::ostream& operator<<(std::ostream &out, const Weights::DecInit &obj)
+{
+ out << "Wi_ \t" << obj.Wi_ << std::endl;
+ out << "Bi_ \t" << obj.Bi_ ;
+ return out;
+}
+
+inline std::ostream& operator<<(std::ostream &out, const Weights::DecAttention &obj)
+{
+ out << "V_ \t" << obj.V_ << std::endl;
+ out << "W_ \t" << obj.W_ << std::endl;
+ out << "B_ \t" << obj.B_ << std::endl;
+ out << "U_ \t" << obj.U_ << std::endl;
+ out << "C_ \t" << obj.C_ ;
+ return out;
+}
+
+inline std::ostream& operator<<(std::ostream &out, const Weights::DecSoftmax &obj)
+{
+ out << "W1_ \t" << obj.W1_ << std::endl;
+ out << "B1_ \t" << obj.B1_ << std::endl;
+ out << "W2_ \t" << obj.W2_ << std::endl;
+ out << "B2_ \t" << obj.B2_ << std::endl;
+ out << "W3_ \t" << obj.W3_ << std::endl;
+ out << "B3_ \t" << obj.B3_ << std::endl;
+ out << "W4_ \t" << obj.W4_ << std::endl;
+ out << "B4_ \t" << obj.B4_ ;
+
+ return out;
+}
+
+inline std::ostream& operator<<(std::ostream &out, const Weights &obj)
+{
+ out << "\n encEmbeddings_ \n" << obj.encEmbeddings_ << std::endl;
+ out << "\n decEmbeddings_ \n" << obj.decEmbeddings_ << std::endl;
+
+ out << "\n encForwardGRU_ \n" << obj.encForwardGRU_ << std::endl;
+ out << "\n encBackwardGRU_ \n" << obj.encBackwardGRU_ << std::endl;
+
+ out << "\n decInit_ \n" << obj.decInit_ << std::endl;
+
+ out << "\n decGru1_ \n" << obj.decGru1_ << std::endl;
+ out << "\n decGru2_ \n" << obj.decGru2_ << std::endl;
+
+ out << "\n decAttention_ \n" << obj.decAttention_ << std::endl;
+
+ out << "\n decSoftmax_ \n" << obj.decSoftmax_ << std::endl;
+
+ //Debug2(obj.encEmbeddings_.E_);
+
+ return out;
+}
+
+}
+}
+
diff --git a/src/amun/cpu/nematus/transition.h b/src/amun/cpu/nematus/transition.h
new file mode 100644
index 00000000..f4d342d7
--- /dev/null
+++ b/src/amun/cpu/nematus/transition.h
@@ -0,0 +1,104 @@
+#pragma once
+#include "cpu/mblas/matrix.h"
+#include <iomanip>
+
+namespace amunmt {
+namespace CPU {
+
+template <class Weights>
+class Transition {
+ public:
+ Transition(const Weights& model)
+ : w_(model),
+ layerNormalization_(false)
+ {
+ if (w_.U_lns_.size() > 1 && w_.U_lns_[0].rows() > 1) {
+ layerNormalization_ = true;
+ }
+ }
+
+ void GetNextState(mblas::Matrix& state) const
+ {
+ if (layerNormalization_) {
+ for (int i = 0; i < w_.size(); ++i) {
+ Temp_1_ = state * w_.U_[i];
+ Temp_2_ = state * w_.Ux_[i];
+
+ switch(w_.type()) {
+ case Weights::Transition::TransitionType::Encoder:
+ LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]);
+ mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]);
+
+ LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]);
+ break;
+
+ case Weights::Transition::TransitionType::Decoder:
+ mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]);
+ LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]);
+
+ mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]);
+ LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]);
+ break;
+ }
+ ElementwiseOps(state, i);
+ }
+ } else {
+ for (int i = 0; i < w_.size(); ++i) {
+ Temp_1_ = state * w_.U_[i];
+ Temp_2_ = state * w_.Ux_[i];
+ mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]);
+ mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]);
+ ElementwiseOps(state, i);
+ }
+ }
+ }
+
+ void ElementwiseOps(mblas::Matrix& state, int idx) const {
+ using namespace mblas;
+ using namespace blaze;
+
+ for (int j = 0; j < (int)state.Rows(); ++j) {
+ auto rowState = row(state, j);
+ auto rowT = row(Temp_1_, j);
+ auto rowT2 = row(Temp_2_, j);
+
+ for (int i = 0; i < (int)state.Cols(); ++i) {
+ float ev1 = expapprox(-(rowT[i])); // + w_.B_[idx](0, i)));
+ float r = 1.0f / (1.0f + ev1);
+
+ int k = i + state.Cols();
+ float ev2 = expapprox(-(rowT[k])); // + w_.B_[idx](0, k)));
+ float u = 1.0f / (1.0f + ev2);
+
+ float hv = w_.Bx2_[idx](0, i);
+ float t2v = rowT2[i]; // + w_.Bx1_[idx](0, i);
+ hv = tanhapprox(hv + r * t2v);
+ rowState[i] = (1.0f - u) * hv + u * rowState[i];
+ }
+ }
+ }
+
+ size_t GetStateLength() const {
+ return w_.U_.rows();
+ }
+
+
+ private:
+ // Model matrices
+ const Weights& w_;
+
+ // reused to avoid allocation
+ mutable mblas::Matrix UUx_;
+ mutable mblas::Matrix RUH_;
+ mutable mblas::Matrix RUH_1_;
+ mutable mblas::Matrix RUH_2_;
+ mutable mblas::Matrix Temp_;
+ mutable mblas::Matrix Temp_1_;
+ mutable mblas::Matrix Temp_2_;
+
+ bool layerNormalization_;
+};
+
+}
+}
+