diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-06-19 12:29:32 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-06-19 12:29:32 +0300 |
commit | 49bdf1d91e386c95a10ffcd935be2bc41a64d05b (patch) | |
tree | 37aa315c44f18336d0823b7bf83628ebe0affa69 /src/amun/cpu/nematus | |
parent | 1400139357944b1ed1d6366338dca56a328a2822 (diff) |
Copy nematus model to new directory
Diffstat (limited to 'src/amun/cpu/nematus')
-rw-r--r-- | src/amun/cpu/nematus/decoder.cpp | 7 | ||||
-rw-r--r-- | src/amun/cpu/nematus/decoder.h | 374 | ||||
-rw-r--r-- | src/amun/cpu/nematus/dl4mt.h | 10 | ||||
-rw-r--r-- | src/amun/cpu/nematus/encoder.cpp | 30 | ||||
-rw-r--r-- | src/amun/cpu/nematus/encoder.h | 110 | ||||
-rw-r--r-- | src/amun/cpu/nematus/gru.cpp | 6 | ||||
-rw-r--r-- | src/amun/cpu/nematus/gru.h | 153 | ||||
-rw-r--r-- | src/amun/cpu/nematus/model.cpp | 125 | ||||
-rw-r--r-- | src/amun/cpu/nematus/model.h | 295 | ||||
-rw-r--r-- | src/amun/cpu/nematus/transition.h | 104 |
10 files changed, 1214 insertions, 0 deletions
diff --git a/src/amun/cpu/nematus/decoder.cpp b/src/amun/cpu/nematus/decoder.cpp new file mode 100644 index 00000000..36d9876a --- /dev/null +++ b/src/amun/cpu/nematus/decoder.cpp @@ -0,0 +1,7 @@ +#include "decoder.h" + + +namespace amunmt { + +} + diff --git a/src/amun/cpu/nematus/decoder.h b/src/amun/cpu/nematus/decoder.h new file mode 100644 index 00000000..0e14d6e5 --- /dev/null +++ b/src/amun/cpu/nematus/decoder.h @@ -0,0 +1,374 @@ +#pragma once + +#include "../mblas/matrix.h" +#include "model.h" +#include "gru.h" +#include "transition.h" +#include "common/god.h" + +namespace amunmt { +namespace CPU { + +class Decoder { + private: + template <class Weights> + class Embeddings { + public: + Embeddings(const Weights& model) + : w_(model) + {} + + void Lookup(mblas::Matrix& Rows, const std::vector<size_t>& ids) { + using namespace mblas; + std::vector<size_t> tids = ids; + for (auto&& id : tids) { + if (id >= w_.E_.rows()) { + id = 1; + } + } + Rows = Assemble<byRow, Matrix>(w_.E_, tids); + } + + size_t GetCols() { + return w_.E_.columns(); + } + + size_t GetRows() const { + return w_.E_.rows(); + } + + private: + const Weights& w_; + }; + + ////////////////////////////////////////////////////////////// + template <class Weights1, class Weights2> + class RNNHidden { + public: + RNNHidden(const Weights1& initModel, const Weights2& gruModel) + : w_(initModel), + gru_(gruModel) + {} + + void InitializeState( + mblas::Matrix& State, + const mblas::Matrix& SourceContext, + const size_t batchSize = 1) + { + using namespace mblas; + + // Calculate mean of source context, rowwise + // Repeat mean batchSize times by broadcasting + Temp1_ = Mean<byRow, Matrix>(SourceContext); + + Temp2_.resize(batchSize, SourceContext.columns()); + Temp2_ = 0.0f; + AddBiasVector<byRow>(Temp2_, Temp1_); + + State = Temp2_ * w_.Wi_; + AddBiasVector<byRow>(State, w_.Bi_); + + if (w_.lns_.rows()) { + LayerNormalization(State, w_.lns_, w_.lnb_); + } + State = blaze::forEach(State, Tanh()); + // std::cerr << "INIT: " << std::endl; + // for (int i = 0; i < 5; ++i) std::cerr << State(0, i) << " "; + // std::cerr << std::endl; + } + + void GetNextState(mblas::Matrix& NextState, + const mblas::Matrix& State, + const mblas::Matrix& Context) { + gru_.GetNextState(NextState, State, Context); + } + + private: + const Weights1& w_; + const GRU<Weights2> gru_; + + mblas::Matrix Temp1_; + mblas::Matrix Temp2_; + }; + + ////////////////////////////////////////////////////////////// + template <class WeightsGRU, class WeightsTrans> + class RNNFinal { + public: + RNNFinal(const WeightsGRU& modelGRU, const WeightsTrans& modelTrans) + : gru_(modelGRU), + transition_(modelTrans) + {} + + void GetNextState( + mblas::Matrix& nextState, + const mblas::Matrix& state, + const mblas::Matrix& context) + { + gru_.GetNextState(nextState, state, context); + transition_.GetNextState(nextState); + // std::cerr << "TRANS: " << std::endl; + // for (int i = 0; i < 10; ++i) std::cerr << nextState(0, i) << " "; + // std::cerr << std::endl; + } + + private: + const GRU<WeightsGRU> gru_; + const Transition<WeightsTrans> transition_; + }; + + ////////////////////////////////////////////////////////////// + template <class Weights> + class Attention { + public: + Attention(const Weights& model) + : w_(model) + { + V_ = blaze::trans(blaze::row(w_.V_, 0)); + } + + void Init(const mblas::Matrix& SourceContext) { + using namespace mblas; + SCU_ = SourceContext * w_.U_; + mblas::AddBiasVector<mblas::byRow>(SCU_, w_.B_); + + if (w_.Wc_att_lns_.rows()) { + LayerNormalization(SCU_, w_.Wc_att_lns_, w_.Wc_att_lnb_); + } + } + + void GetAlignedSourceContext( + mblas::Matrix& AlignedSourceContext, + const mblas::Matrix& HiddenState, + const mblas::Matrix& SourceContext) + { + using namespace mblas; + + Temp2_ = HiddenState * w_.W_; + if (w_.W_comb_lns_.rows()) { + LayerNormalization(Temp2_, w_.W_comb_lns_, w_.W_comb_lnb_); + } + + Temp1_ = Broadcast<Matrix>(Tanh(), SCU_, Temp2_); + + A_.resize(Temp1_.rows(), 1); + blaze::column(A_, 0) = Temp1_ * V_; + size_t words = SourceContext.rows(); + // batch size, for batching, divide by numer of sentences + size_t batchSize = HiddenState.rows(); + Reshape(A_, batchSize, words); // due to broadcasting above + + float bias = w_.C_(0,0); + blaze::forEach(A_, [=](float x) { return x + bias; }); + + mblas::SafeSoftmax(A_); + AlignedSourceContext = A_ * SourceContext; + } + + void GetAttention(mblas::Matrix& Attention) { + Attention = A_; + } + + mblas::Matrix& GetAttention() { + return A_; + } + + private: + const Weights& w_; + + mblas::Matrix SCU_; + mblas::Matrix Temp1_; + mblas::Matrix Temp2_; + mblas::Matrix A_; + mblas::ColumnVector V_; + }; + + ////////////////////////////////////////////////////////////// + template <class Weights> + class Softmax { + public: + Softmax(const Weights& model) + : w_(model), + filtered_(false) + {} + + void GetProbs(mblas::ArrayMatrix& Probs, + const mblas::Matrix& State, + const mblas::Matrix& Embedding, + const mblas::Matrix& AlignedSourceContext) { + using namespace mblas; + + T1_ = State * w_.W1_; + AddBiasVector<byRow>(T1_, w_.B1_); + if (w_.lns_1_.rows()) { + LayerNormalization(T1_, w_.lns_1_, w_.lnb_1_); + } + // std::cerr << "State" << std::endl; + // for(int i = 0; i < 5; ++i) std::cerr << T1_(0, i) << " "; + // std::cerr << std::endl; + + T2_ = Embedding * w_.W2_; + AddBiasVector<byRow>(T2_, w_.B2_); + if (w_.lns_2_.rows()) { + LayerNormalization(T2_, w_.lns_2_, w_.lnb_2_); + } + // std::cerr << "emb" << std::endl; + // for(int i = 0; i < 5; ++i) std::cerr << T2_(0, i) << " "; + // std::cerr << std::endl; + + T3_ = AlignedSourceContext * w_.W3_; + AddBiasVector<byRow>(T3_, w_.B3_); + if (w_.lns_3_.rows()) { + LayerNormalization(T3_, w_.lns_3_, w_.lnb_3_); + } + // std::cerr << "CTX" << std::endl; + // for(int i = 0; i < 5; ++i) std::cerr << T3_(0, i) << " "; + // std::cerr << std::endl; + + auto t = blaze::forEach(T1_ + T2_ + T3_, Tanh()); + + if(!filtered_) { + Probs = t * w_.W4_; + AddBiasVector<byRow>(Probs, w_.B4_); + } else { + Probs = t * FilteredW4_; + AddBiasVector<byRow>(Probs, FilteredB4_); + } + // std::cerr << "LOgit" << std::endl; + // for(int i = 0; i < 5; ++i) std::cerr << Probs(0, i) << " "; + // std::cerr << std::endl; + LogSoftmax(Probs); + } + + void Filter(const std::vector<size_t>& ids) { + filtered_ = true; + using namespace mblas; + FilteredW4_ = Assemble<byColumn, Matrix>(w_.W4_, ids); + FilteredB4_ = Assemble<byColumn, Matrix>(w_.B4_, ids); + } + + private: + const Weights& w_; + bool filtered_; + + mblas::Matrix FilteredW4_; + mblas::Matrix FilteredB4_; + + mblas::Matrix T1_; + mblas::Matrix T2_; + mblas::Matrix T3_; + }; + + public: + Decoder(const Weights& model) + : embeddings_(model.decEmbeddings_), + rnn1_(model.decInit_, model.decGru1_), + rnn2_(model.decGru2_, model.decTransition_), + attention_(model.decAttention_), + softmax_(model.decSoftmax_) + {} + + void Decode( + mblas::Matrix& NextState, + const mblas::Matrix& State, + const mblas::Matrix& Embeddings, + const mblas::Matrix& SourceContext) + { + GetHiddenState(HiddenState_, State, Embeddings); + // std::cerr << "HIDDEN: " << std::endl; + // for (int i = 0; i < 5; ++i) std::cerr << HiddenState_(0, i) << " "; + // std::cerr << std::endl; + + GetAlignedSourceContext(AlignedSourceContext_, HiddenState_, SourceContext); + // std::cerr << "ALIGNED SRC: " << std::endl; + // for (int i = 0; i < 5; ++i) std::cerr << AlignedSourceContext_(0, i) << " "; + // std::cerr << std::endl; + + GetNextState(NextState, HiddenState_, AlignedSourceContext_); + // std::cerr << "NEXT: " << std::endl; + // for (int i = 0; i < 5; ++i) std::cerr << NextState(0, i) << " "; + // std::cerr << std::endl; + + GetProbs(NextState, Embeddings, AlignedSourceContext_); + } + + mblas::ArrayMatrix& GetProbs() { + return Probs_; + } + + void EmptyState(mblas::Matrix& State, + const mblas::Matrix& SourceContext, + size_t batchSize = 1) { + rnn1_.InitializeState(State, SourceContext, batchSize); + attention_.Init(SourceContext); + } + + void EmptyEmbedding(mblas::Matrix& Embedding, + size_t batchSize = 1) { + Embedding.resize(batchSize, embeddings_.GetCols()); + Embedding = 0.0f; + } + + void Lookup(mblas::Matrix& Embedding, + const std::vector<size_t>& w) { + embeddings_.Lookup(Embedding, w); + } + + void Filter(const std::vector<size_t>& ids) { + softmax_.Filter(ids); + } + + void GetAttention(mblas::Matrix& attention) { + attention_.GetAttention(attention); + } + + mblas::Matrix& GetAttention() { + return attention_.GetAttention(); + } + + size_t GetVocabSize() const { + return embeddings_.GetRows(); + } + + private: + + void GetHiddenState(mblas::Matrix& HiddenState, + const mblas::Matrix& PrevState, + const mblas::Matrix& Embedding) { + rnn1_.GetNextState(HiddenState, PrevState, Embedding); + } + + void GetAlignedSourceContext(mblas::Matrix& AlignedSourceContext, + const mblas::Matrix& HiddenState, + const mblas::Matrix& SourceContext) { + attention_.GetAlignedSourceContext(AlignedSourceContext, HiddenState, SourceContext); + } + + void GetNextState(mblas::Matrix& State, + const mblas::Matrix& HiddenState, + const mblas::Matrix& AlignedSourceContext) { + rnn2_.GetNextState(State, HiddenState, AlignedSourceContext); + } + + + void GetProbs(const mblas::Matrix& State, + const mblas::Matrix& Embedding, + const mblas::Matrix& AlignedSourceContext) { + softmax_.GetProbs(Probs_, State, Embedding, AlignedSourceContext); + } + + private: + mblas::Matrix HiddenState_; + mblas::Matrix AlignedSourceContext_; + mblas::ArrayMatrix Probs_; + + Embeddings<Weights::Embeddings> embeddings_; + RNNHidden<Weights::DecInit, Weights::GRU> rnn1_; + RNNFinal<Weights::DecGRU2, Weights::Transition> rnn2_; + Attention<Weights::DecAttention> attention_; + Softmax<Weights::DecSoftmax> softmax_; +}; + +} +} + diff --git a/src/amun/cpu/nematus/dl4mt.h b/src/amun/cpu/nematus/dl4mt.h new file mode 100644 index 00000000..ed4b90bf --- /dev/null +++ b/src/amun/cpu/nematus/dl4mt.h @@ -0,0 +1,10 @@ +#pragma once + +#include "model.h" +#include "encoder.h" +#include "decoder.h" + +namespace amunmt { + +} + diff --git a/src/amun/cpu/nematus/encoder.cpp b/src/amun/cpu/nematus/encoder.cpp new file mode 100644 index 00000000..032528fa --- /dev/null +++ b/src/amun/cpu/nematus/encoder.cpp @@ -0,0 +1,30 @@ +#include "encoder.h" + +using namespace std; + +namespace amunmt { +namespace CPU { + +void Encoder::GetContext(const std::vector<size_t>& words, mblas::Matrix& context) { + std::vector<mblas::Matrix> embeddedWords; + + context.resize(words.size(), + forwardRnn_.GetStateLength() + backwardRnn_.GetStateLength()); + + for (auto& w : words) { + embeddedWords.emplace_back(); + mblas::Matrix &embed = embeddedWords.back(); + embeddings_.Lookup(embed, w); + } + + forwardRnn_.GetContext(embeddedWords.cbegin(), + embeddedWords.cend(), + context, false); + backwardRnn_.GetContext(embeddedWords.crbegin(), + embeddedWords.crend(), + context, true); +} + +} // namespace CPU +} // namespace amunmt + diff --git a/src/amun/cpu/nematus/encoder.h b/src/amun/cpu/nematus/encoder.h new file mode 100644 index 00000000..31266e42 --- /dev/null +++ b/src/amun/cpu/nematus/encoder.h @@ -0,0 +1,110 @@ +#pragma once + +#include "../mblas/matrix.h" +#include "model.h" +#include "gru.h" +#include "transition.h" + +namespace amunmt { +namespace CPU { + +class Encoder { + private: + + ///////////////////////////////////////////////////////////////// + template <class Weights> + class Embeddings { + public: + Embeddings(const Weights& model) + : w_(model) + {} + + void Lookup(mblas::Matrix& Row, size_t i) { + size_t len = w_.E_.columns(); + if(i < w_.E_.rows()) + Row = blaze::submatrix(w_.E_, i, 0, 1, len); + else + Row = blaze::submatrix(w_.E_, 1, 0, 1, len); // UNK + } + + const Weights& w_; + private: + }; + + ///////////////////////////////////////////////////////////////// + template <class WeightsGRU, class WeightsTrans> + class EncoderRNN { + public: + EncoderRNN(const WeightsGRU& modelGRU, const WeightsTrans& modelTrans) + : gru_(modelGRU), + transition_(modelTrans) + {} + + void InitializeState(size_t batchSize = 1) { + State_.resize(batchSize, gru_.GetStateLength()); + State_ = 0.0f; + } + + void GetNextState(mblas::Matrix& nextState, + const mblas::Matrix& state, + const mblas::Matrix& embd) { + gru_.GetNextState(nextState, state, embd); + // std::cerr << "GRU: " << std::endl; + // for (int i = 0; i < 10; ++i) std::cerr << nextState(0, i) << " "; + // std::cerr << std::endl; + transition_.GetNextState(nextState); + // std::cerr << "TRANS: " << std::endl; + // for (int i = 0; i < 10; ++i) std::cerr << nextState(0, i) << " "; + // std::cerr << std::endl; + } + + template <class It> + void GetContext(It it, It end, mblas::Matrix& Context, bool invert) { + InitializeState(); + + size_t n = std::distance(it, end); + size_t i = 0; + while(it != end) { + GetNextState(State_, State_, *it++); + + size_t len = gru_.GetStateLength(); + if(invert) + blaze::submatrix(Context, n - i - 1, len, 1, len) = State_; + else + blaze::submatrix(Context, i, 0, 1, len) = State_; + ++i; + } + } + + size_t GetStateLength() const { + return gru_.GetStateLength(); + } + + private: + // Model matrices + const GRU<WeightsGRU> gru_; + const Transition<WeightsTrans> transition_; + + mblas::Matrix State_; + }; + + ///////////////////////////////////////////////////////////////// + public: + Encoder(const Weights& model) + : embeddings_(model.encEmbeddings_), + forwardRnn_(model.encForwardGRU_, model.encForwardTransition_), + backwardRnn_(model.encBackwardGRU_, model.encBackwardTransition_) + {} + + void GetContext(const std::vector<size_t>& words, + mblas::Matrix& context); + + private: + Embeddings<Weights::Embeddings> embeddings_; + EncoderRNN<Weights::GRU, Weights::Transition> forwardRnn_; + EncoderRNN<Weights::GRU, Weights::Transition> backwardRnn_; +}; + +} +} + diff --git a/src/amun/cpu/nematus/gru.cpp b/src/amun/cpu/nematus/gru.cpp new file mode 100644 index 00000000..778659aa --- /dev/null +++ b/src/amun/cpu/nematus/gru.cpp @@ -0,0 +1,6 @@ +#include "gru.h" + +namespace amunmt { + +} + diff --git a/src/amun/cpu/nematus/gru.h b/src/amun/cpu/nematus/gru.h new file mode 100644 index 00000000..33e855fd --- /dev/null +++ b/src/amun/cpu/nematus/gru.h @@ -0,0 +1,153 @@ +#pragma once +#include "cpu/mblas/matrix.h" +#include <iomanip> + +namespace amunmt { +namespace CPU { + +template <class Weights> +class GRU { + public: + GRU(const Weights& model) + : w_(model), + layerNormalization_(w_.W_lns_.rows()) + { + if (!layerNormalization_) { + WWx_ = mblas::Concat<mblas::byColumn, mblas::Matrix>(w_.W_, w_.Wx_); + UUx_ = mblas::Concat<mblas::byColumn, mblas::Matrix>(w_.U_, w_.Ux_); + } + } + + void GetNextState( + mblas::Matrix& nextState, + const mblas::Matrix& state, + const mblas::Matrix& context) const + { + // std::cerr << "Get next state" << std::endl; + if (layerNormalization_) { + RUH_1_ = context * w_.W_; + mblas::AddBiasVector<mblas::byRow>(RUH_1_, w_.B_); + LayerNormalization(RUH_1_, w_.W_lns_, w_.W_lnb_); + + RUH_2_ = context * w_.Wx_; + mblas::AddBiasVector<mblas::byRow>(RUH_2_, w_.Bx1_); + LayerNormalization(RUH_2_, w_.Wx_lns_, w_.Wx_lnb_); + + RUH_ = mblas::Concat<mblas::byColumn, mblas::Matrix>(RUH_1_, RUH_2_); + + Temp_1_ = state * w_.U_; + mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.Bx3_); + LayerNormalization(Temp_1_, w_.U_lns_, w_.U_lnb_); + + Temp_2_ = state * w_.Ux_; + mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx2_); + LayerNormalization(Temp_2_, w_.Ux_lns_, w_.Ux_lnb_); + + Temp_ = mblas::Concat<mblas::byColumn, mblas::Matrix>(Temp_1_, Temp_2_); + + ElementwiseOpsLayerNorm(nextState, state); + + } else { + RUH_ = context * WWx_; + Temp_ = state * UUx_; + ElementwiseOps(nextState, state); + } + } + + void ElementwiseOps(mblas::Matrix& NextState, const mblas::Matrix& State) const { + using namespace mblas; + using namespace blaze; + + const int rowNo = State.rows(); + const int colNo = State.columns(); + NextState.resize(rowNo, colNo); + + for (int j = 0; j < rowNo; ++j) { + auto rowOut = row(NextState, j); + auto rowState = row(State, j); + + auto rowRuh = row(RUH_, j); + auto rowT = row(Temp_, j); + + auto rowH = subvector(rowRuh, 2 * colNo, colNo); + auto rowT2 = subvector(rowT, 2 * colNo, colNo); + + for (int i = 0; i < colNo; ++i) { + float ev1 = expapprox(-(rowRuh[i] + w_.B_(0, i) + rowT[i])); + float r = 1.0f / (1.0f + ev1); + + int k = i + colNo; + float ev2 = expapprox(-(rowRuh[k] + w_.B_(0, k) + rowT[k])); + float u = 1.0f / (1.0f + ev2); + + float hv = rowH[i] + w_.Bx1_(0, i); + float t2v = rowT2[i]; + hv = tanhapprox(hv + r * t2v); + rowOut[i] = (1.0f - u) * hv + u * rowState[i]; + } + } + } + + void ElementwiseOpsLayerNorm(mblas::Matrix& NextState, const mblas::Matrix& State) const { + using namespace mblas; + using namespace blaze; + + const int rowNo = State.rows(); + const int colNo = State.columns(); + NextState.resize(rowNo, colNo); + + for (int j = 0; j < rowNo; ++j) { + auto rowOut = row(NextState, j); + auto rowState = row(State, j); + + auto rowRuh = row(RUH_, j); + auto rowT = row(Temp_, j); + + auto rowH = subvector(rowRuh, 2 * colNo, colNo); + auto rowT2 = subvector(rowT, 2 * colNo, colNo); + + for (int i = 0; i < colNo; ++i) { + float ev1 = expapprox(-(rowRuh[i] + rowT[i])); + float r = 1.0f / (1.0f + ev1); + + int k = i + colNo; + float ev2 = expapprox(-(rowRuh[k] + rowT[k])); + float u = 1.0f / (1.0f + ev2); + + float hv = rowH[i]; + float t2v = rowT2[i] + w_.Bx2_(0, i); + hv = tanhapprox(hv + r * t2v); + rowOut[i] = (1.0f - u) * hv + u * rowState[i]; + } + } + } + size_t GetStateLength() const { + return w_.U_.rows(); + } + + + private: + // Model matrices + const Weights& w_; + mutable mblas::Matrix WWx_; + mutable mblas::Matrix UUx_; + mutable mblas::Matrix Wbbx_; + mutable mblas::Matrix lns_WWx_; + mutable mblas::Matrix lns_UUx_; + mutable mblas::Matrix lnb_WWx_; + mutable mblas::Matrix lnb_UUx_; + + // reused to avoid allocation + mutable mblas::Matrix RUH_; + mutable mblas::Matrix RUH_1_; + mutable mblas::Matrix RUH_2_; + mutable mblas::Matrix Temp_; + mutable mblas::Matrix Temp_1_; + mutable mblas::Matrix Temp_2_; + + bool layerNormalization_; +}; + +} +} + diff --git a/src/amun/cpu/nematus/model.cpp b/src/amun/cpu/nematus/model.cpp new file mode 100644 index 00000000..71b8694a --- /dev/null +++ b/src/amun/cpu/nematus/model.cpp @@ -0,0 +1,125 @@ +#include "model.h" + +using namespace std; + +namespace amunmt { +namespace CPU { + +Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key) + : E_(model[key]) +{} + +Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys) + : E_(model.getFirstOfMany(keys)) +{} + +Weights::GRU::GRU(const NpzConverter& model, std::string prefix, std::vector<std::string> keys) + : W_(model[prefix + keys.at(0)]), + B_(model(prefix + keys.at(1), true)), + U_(model[prefix + keys.at(2)]), + Wx_(model[prefix + keys.at(3)]), + Bx1_(model(prefix + keys.at(4), true)), + Bx2_(Bx1_.rows(), Bx1_.columns()), + Bx3_(B_.rows(), B_.columns()), + Ux_(model[prefix + keys.at(5)]), + W_lns_(model[prefix + keys.at(6)]), + W_lnb_(model[prefix + keys.at(7)]), + Wx_lns_(model[prefix + keys.at(8)]), + Wx_lnb_(model[prefix + keys.at(9)]), + U_lns_(model[prefix + keys.at(10)]), + U_lnb_(model[prefix + keys.at(11)]), + Ux_lns_(model[prefix + keys.at(12)]), + Ux_lnb_(model[prefix + keys.at(13)]) +{ + const_cast<mblas::Matrix&>(Bx2_) = 0.0f; + const_cast<mblas::Matrix&>(Bx3_) = 0.0f; +} + +////////////////////////////////////////////////////////////////////////////// + +Weights::DecInit::DecInit(const NpzConverter& model) + : Wi_(model["ff_state_W"]), + Bi_(model("ff_state_b", true)), + lns_(model["ff_state_ln_s"]), + lnb_(model["ff_state_ln_b"]) +{} + + +Weights::DecGRU2::DecGRU2(const NpzConverter& model, std::string prefix, std::vector<std::string> keys) + : W_(model[prefix + keys.at(0)]), // Wc + B_(1, W_.Cols()), + U_(model[prefix + keys.at(1)]), // U_nl + Bx3_(model(prefix + keys.at(2), true)), // b_nl + Wx_(model[prefix + keys.at(3)]), // Wcx + Bx1_(1, Wx_.Cols()), + Ux_(model[prefix + keys.at(4)]), // Ux_nl + Bx2_(model(prefix + keys.at(5), true)), // bx_nl + W_lns_(model[prefix + keys.at(6)]), // Wc_lns + W_lnb_(model[prefix + keys.at(7)]), // Wc_nlb + Wx_lns_(model[prefix + keys.at(8)]), // Wcx_lns + Wx_lnb_(model[prefix + keys.at(9)]), // Wcx_lnb + U_lns_(model[prefix + keys.at(10)]), // U_nl_lns + U_lnb_(model[prefix + keys.at(11)]), // U_nl_lnb + Ux_lns_(model[prefix + keys.at(12)]), // Ux_nl_lns + Ux_lnb_(model[prefix + keys.at(13)]) // Ux_nl_lnb + +{ + const_cast<mblas::Matrix&>(B_) = 0.0f; + const_cast<mblas::Matrix&>(Bx1_) = 0.0f; +} + +Weights::DecAttention::DecAttention(const NpzConverter& model) + : V_(model("decoder_U_att", true)), + W_(model["decoder_W_comb_att"]), + B_(model("decoder_b_att", true)), + U_(model["decoder_Wc_att"]), + C_(model["decoder_c_tt"]), + Wc_att_lns_(model["decoder_Wc_att_lns"]), + Wc_att_lnb_(model["decoder_Wc_att_lnb"]), + W_comb_lns_(model["decoder_W_comb_att_lns"]), + W_comb_lnb_(model["decoder_W_comb_att_lnb"]) +{} + +Weights::DecSoftmax::DecSoftmax(const NpzConverter& model) + : W1_(model["ff_logit_lstm_W"]), + B1_(model("ff_logit_lstm_b", true)), + W2_(model["ff_logit_prev_W"]), + B2_(model("ff_logit_prev_b", true)), + W3_(model["ff_logit_ctx_W"]), + B3_(model("ff_logit_ctx_b", true)), + W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false), + std::make_pair(std::string("Wemb_dec"), true)})), + B4_(model("ff_logit_b", true)), + lns_1_(model["ff_logit_lstm_ln_s"]), + lns_2_(model["ff_logit_prev_ln_s"]), + lns_3_(model["ff_logit_ctx_ln_s"]), + lnb_1_(model["ff_logit_lstm_ln_b"]), + lnb_2_(model["ff_logit_prev_ln_b"]), + lnb_3_(model["ff_logit_ctx_ln_b"]) +{} + +////////////////////////////////////////////////////////////////////////////// + +Weights::Weights(const NpzConverter& model, size_t) + : encEmbeddings_(model, "Wemb"), + decEmbeddings_(model, {std::make_pair(std::string("Wemb_dec"), false), + std::make_pair(std::string("Wemb"), false)}), + encForwardGRU_(model, "encoder_", {"W", "b", "U", "Wx", "bx", "Ux", "W_lns", "W_lnb", "Wx_lns", + "Wx_lnb", "U_lns", "U_lnb", "Ux_lns", "Ux_lnb" }), + encBackwardGRU_(model, "encoder_r_", {"W", "b", "U", "Wx", "bx", "Ux", "W_lns", "W_lnb", + "Wx_lns", "Wx_lnb", "U_lns", "U_lnb", "Ux_lns", "Ux_lnb" }), + decInit_(model), + decGru1_(model, "decoder_", {"W", "b", "U", "Wx", "bx", "Ux", "W_lns", "W_lnb", "Wx_lns", + "Wx_lnb", "U_lns", "U_lnb", "Ux_lns", "Ux_lnb" }), + decGru2_(model, "decoder_", {"Wc", "U_nl", "b_nl", "Wcx", "Ux_nl", "bx_nl", "Wc_lns", "Wc_lnb", + "Wcx_lns", "Wcx_lnb", "U_nl_lns", "U_nl_lnb", "Ux_nl_lns", + "Ux_nl_lnb"}), + decAttention_(model), + decSoftmax_(model), + encForwardTransition_(model, Weights::Transition::TransitionType::Encoder, "encoder_"), + encBackwardTransition_(model,Weights::Transition::TransitionType::Encoder, "encoder_r_"), + decTransition_(model, Weights::Transition::TransitionType::Decoder, "decoder_", "_nl") +{} + +} // namespace cpu +} // namespace amunmt diff --git a/src/amun/cpu/nematus/model.h b/src/amun/cpu/nematus/model.h new file mode 100644 index 00000000..e4bded82 --- /dev/null +++ b/src/amun/cpu/nematus/model.h @@ -0,0 +1,295 @@ +#pragma once + +#include <iostream> +#include <map> +#include <string> + +#include "../npz_converter.h" + +#include "../mblas/matrix.h" + +namespace amunmt { +namespace CPU { + +struct Weights { + class Transition { + public: + enum class TransitionType {Encoder, Decoder}; + Transition(const NpzConverter& model, TransitionType type, std::string prefix, + std::string infix="") + : depth_(findTransitionDepth(model, prefix, infix)), type_(type) + { + for (int i = 1; i <= depth_; ++i) { + U_.emplace_back(model[name(prefix, "U", infix, i)]); + Ux_.emplace_back(model[name(prefix, "Ux", infix, i)]); + B_.emplace_back(model(name(prefix, "b", infix, i), true)); + U_lns_.emplace_back(model[name(prefix, "U", infix, i, "_lns")]); + U_lnb_.emplace_back(model[name(prefix, "U", infix, i, "_lnb")]); + Ux_lns_.emplace_back(model[name(prefix, "Ux", infix, i, "_lns")]); + Ux_lnb_.emplace_back(model[name(prefix, "Ux", infix, i, "_lnb")]); + // decoder_U_nl_drt_4_lnb + switch(type) { + case TransitionType::Encoder: + Bx1_.emplace_back(1, Ux_.back().Cols()); + const_cast<mblas::Matrix&>(Bx1_.back()) = 0.0f; + Bx2_.emplace_back(model(name(prefix, "bx", infix, i), true)); + break; + case TransitionType::Decoder: + Bx1_.emplace_back(model(name(prefix, "bx", infix, i), true)); + Bx2_.emplace_back(1, Ux_.back().Cols()); + const_cast<mblas::Matrix&>(Bx2_.back()) = 0.0f; + break; + } + } + } + + static int findTransitionDepth(const NpzConverter& model, std::string prefix, std::string infix) { + int currentDepth = 0; + while (true) { + if (model.has(prefix + "b" + infix + "_drt_" + std::to_string(currentDepth + 1))) { + ++currentDepth; + } else { + break; + } + } + std::cerr << "Found transition depth: " << currentDepth << std::endl; + return currentDepth; + } + + int size() const { + return depth_; + } + + TransitionType type() const { + return type_; + } + + protected: + std::string name(const std::string& prefix, std::string name, std::string infix, int index, + std::string suffix = "") + { + return prefix + name + infix + "_drt_" + std::to_string(index) + suffix; + } + + private: + int depth_; + TransitionType type_; + + public: + std::vector<mblas::Matrix> B_; + std::vector<mblas::Matrix> Bx1_; + std::vector<mblas::Matrix> Bx2_; + std::vector<mblas::Matrix> U_; + std::vector<mblas::Matrix> Ux_; + + std::vector<mblas::Matrix> U_lns_; + std::vector<mblas::Matrix> U_lnb_; + std::vector<mblas::Matrix> Ux_lns_; + std::vector<mblas::Matrix> Ux_lnb_; + + }; + + struct Embeddings { + Embeddings(const NpzConverter& model, const std::string &key); + Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys); + + const mblas::Matrix E_; + }; + + struct GRU { + GRU(const NpzConverter& model, std::string prefix, std::vector<std::string> keys); + + const mblas::Matrix W_; + const mblas::Matrix B_; + const mblas::Matrix U_; + const mblas::Matrix Wx_; + const mblas::Matrix Bx1_; + const mblas::Matrix Bx2_; + const mblas::Matrix Bx3_; + const mblas::Matrix Ux_; + + const mblas::Matrix W_lns_; + const mblas::Matrix W_lnb_; + const mblas::Matrix Wx_lns_; + const mblas::Matrix Wx_lnb_; + const mblas::Matrix U_lns_; + const mblas::Matrix U_lnb_; + const mblas::Matrix Ux_lns_; + const mblas::Matrix Ux_lnb_; + }; + + struct DecInit { + DecInit(const NpzConverter& model); + + const mblas::Matrix Wi_; + const mblas::Matrix Bi_; + const mblas::Matrix lns_; + const mblas::Matrix lnb_; + }; + + struct DecGRU2 { + DecGRU2(const NpzConverter& model, std::string prefix, std::vector<std::string> keys); + + const mblas::Matrix W_; + const mblas::Matrix B_; + const mblas::Matrix U_; + const mblas::Matrix Wx_; + const mblas::Matrix Bx3_; + const mblas::Matrix Bx2_; + const mblas::Matrix Bx1_; + const mblas::Matrix Ux_; + + const mblas::Matrix W_lns_; + const mblas::Matrix W_lnb_; + const mblas::Matrix Wx_lns_; + const mblas::Matrix Wx_lnb_; + const mblas::Matrix U_lns_; + const mblas::Matrix U_lnb_; + const mblas::Matrix Ux_lns_; + const mblas::Matrix Ux_lnb_; + }; + + struct DecAttention { + DecAttention(const NpzConverter& model); + + const mblas::Matrix V_; + const mblas::Matrix W_; + const mblas::Matrix B_; + const mblas::Matrix U_; + const mblas::Matrix C_; + const mblas::Matrix Wc_att_lns_; + const mblas::Matrix Wc_att_lnb_; + const mblas::Matrix W_comb_lns_; + const mblas::Matrix W_comb_lnb_; + }; + + struct DecSoftmax { + DecSoftmax(const NpzConverter& model); + + const mblas::Matrix W1_; + const mblas::Matrix B1_; + const mblas::Matrix W2_; + const mblas::Matrix B2_; + const mblas::Matrix W3_; + const mblas::Matrix B3_; + const mblas::Matrix W4_; + const mblas::Matrix B4_; + const mblas::Matrix lns_1_; + const mblas::Matrix lns_2_; + const mblas::Matrix lns_3_; + const mblas::Matrix lnb_1_; + const mblas::Matrix lnb_2_; + const mblas::Matrix lnb_3_; + }; + + + Weights(const std::string& npzFile, size_t device = 0) + : Weights(NpzConverter(npzFile), device) + {} + + Weights(const NpzConverter& model, size_t device = 0); + + size_t GetDevice() { + return std::numeric_limits<size_t>::max(); + } + + const Embeddings encEmbeddings_; + const Embeddings decEmbeddings_; + const GRU encForwardGRU_; + const GRU encBackwardGRU_; + const DecInit decInit_; + const GRU decGru1_; + const DecGRU2 decGru2_; + const DecAttention decAttention_; + const DecSoftmax decSoftmax_; + const Transition encForwardTransition_; + const Transition encBackwardTransition_; + const Transition decTransition_; +}; + +inline std::ostream& operator<<(std::ostream &out, const Weights::Embeddings &obj) +{ + out << "E_ \t" << obj.E_; + return out; +} + +inline std::ostream& operator<<(std::ostream &out, const Weights::GRU &obj) +{ + out << "W_ \t" << obj.W_ << std::endl; + out << "B_ \t" << obj.B_ << std::endl; + out << "U_ \t" << obj.U_ << std::endl; + out << "Wx_ \t" << obj.Wx_ << std::endl; + out << "Bx1_ \t" << obj.Bx1_ << std::endl; + out << "Bx2_ \t" << obj.Bx2_ << std::endl; + out << "Ux_ \t" << obj.Ux_; + return out; +} + +inline std::ostream& operator<<(std::ostream &out, const Weights::DecGRU2 &obj) +{ + out << "W_ \t" << obj.W_ << std::endl; + out << "B_ \t" << obj.B_ << std::endl; + out << "U_ \t" << obj.U_ << std::endl; + out << "Wx_ \t" << obj.Wx_ << std::endl; + out << "Bx1_ \t" << obj.Bx1_ << std::endl; + out << "Bx2_ \t" << obj.Bx2_ << std::endl; + out << "Ux_ \t" << obj.Ux_; + return out; +} + +inline std::ostream& operator<<(std::ostream &out, const Weights::DecInit &obj) +{ + out << "Wi_ \t" << obj.Wi_ << std::endl; + out << "Bi_ \t" << obj.Bi_ ; + return out; +} + +inline std::ostream& operator<<(std::ostream &out, const Weights::DecAttention &obj) +{ + out << "V_ \t" << obj.V_ << std::endl; + out << "W_ \t" << obj.W_ << std::endl; + out << "B_ \t" << obj.B_ << std::endl; + out << "U_ \t" << obj.U_ << std::endl; + out << "C_ \t" << obj.C_ ; + return out; +} + +inline std::ostream& operator<<(std::ostream &out, const Weights::DecSoftmax &obj) +{ + out << "W1_ \t" << obj.W1_ << std::endl; + out << "B1_ \t" << obj.B1_ << std::endl; + out << "W2_ \t" << obj.W2_ << std::endl; + out << "B2_ \t" << obj.B2_ << std::endl; + out << "W3_ \t" << obj.W3_ << std::endl; + out << "B3_ \t" << obj.B3_ << std::endl; + out << "W4_ \t" << obj.W4_ << std::endl; + out << "B4_ \t" << obj.B4_ ; + + return out; +} + +inline std::ostream& operator<<(std::ostream &out, const Weights &obj) +{ + out << "\n encEmbeddings_ \n" << obj.encEmbeddings_ << std::endl; + out << "\n decEmbeddings_ \n" << obj.decEmbeddings_ << std::endl; + + out << "\n encForwardGRU_ \n" << obj.encForwardGRU_ << std::endl; + out << "\n encBackwardGRU_ \n" << obj.encBackwardGRU_ << std::endl; + + out << "\n decInit_ \n" << obj.decInit_ << std::endl; + + out << "\n decGru1_ \n" << obj.decGru1_ << std::endl; + out << "\n decGru2_ \n" << obj.decGru2_ << std::endl; + + out << "\n decAttention_ \n" << obj.decAttention_ << std::endl; + + out << "\n decSoftmax_ \n" << obj.decSoftmax_ << std::endl; + + //Debug2(obj.encEmbeddings_.E_); + + return out; +} + +} +} + diff --git a/src/amun/cpu/nematus/transition.h b/src/amun/cpu/nematus/transition.h new file mode 100644 index 00000000..f4d342d7 --- /dev/null +++ b/src/amun/cpu/nematus/transition.h @@ -0,0 +1,104 @@ +#pragma once +#include "cpu/mblas/matrix.h" +#include <iomanip> + +namespace amunmt { +namespace CPU { + +template <class Weights> +class Transition { + public: + Transition(const Weights& model) + : w_(model), + layerNormalization_(false) + { + if (w_.U_lns_.size() > 1 && w_.U_lns_[0].rows() > 1) { + layerNormalization_ = true; + } + } + + void GetNextState(mblas::Matrix& state) const + { + if (layerNormalization_) { + for (int i = 0; i < w_.size(); ++i) { + Temp_1_ = state * w_.U_[i]; + Temp_2_ = state * w_.Ux_[i]; + + switch(w_.type()) { + case Weights::Transition::TransitionType::Encoder: + LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]); + mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]); + + LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]); + break; + + case Weights::Transition::TransitionType::Decoder: + mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]); + LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]); + + mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]); + LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]); + break; + } + ElementwiseOps(state, i); + } + } else { + for (int i = 0; i < w_.size(); ++i) { + Temp_1_ = state * w_.U_[i]; + Temp_2_ = state * w_.Ux_[i]; + mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]); + mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]); + ElementwiseOps(state, i); + } + } + } + + void ElementwiseOps(mblas::Matrix& state, int idx) const { + using namespace mblas; + using namespace blaze; + + for (int j = 0; j < (int)state.Rows(); ++j) { + auto rowState = row(state, j); + auto rowT = row(Temp_1_, j); + auto rowT2 = row(Temp_2_, j); + + for (int i = 0; i < (int)state.Cols(); ++i) { + float ev1 = expapprox(-(rowT[i])); // + w_.B_[idx](0, i))); + float r = 1.0f / (1.0f + ev1); + + int k = i + state.Cols(); + float ev2 = expapprox(-(rowT[k])); // + w_.B_[idx](0, k))); + float u = 1.0f / (1.0f + ev2); + + float hv = w_.Bx2_[idx](0, i); + float t2v = rowT2[i]; // + w_.Bx1_[idx](0, i); + hv = tanhapprox(hv + r * t2v); + rowState[i] = (1.0f - u) * hv + u * rowState[i]; + } + } + } + + size_t GetStateLength() const { + return w_.U_.rows(); + } + + + private: + // Model matrices + const Weights& w_; + + // reused to avoid allocation + mutable mblas::Matrix UUx_; + mutable mblas::Matrix RUH_; + mutable mblas::Matrix RUH_1_; + mutable mblas::Matrix RUH_2_; + mutable mblas::Matrix Temp_; + mutable mblas::Matrix Temp_1_; + mutable mblas::Matrix Temp_2_; + + bool layerNormalization_; +}; + +} +} + |