diff options
author | Rihards Krišlauks <rihards.krislauks@gmail.com> | 2018-02-28 16:59:16 +0300 |
---|---|---|
committer | Rihards Krišlauks <rihards.krislauks@gmail.com> | 2018-02-28 18:22:51 +0300 |
commit | 138f53ec8f7324d9f771d1cc6fae92187d5c5420 (patch) | |
tree | e017eafb96fd2a30849c1f93d64e161c8274d596 | |
parent | 1de1e7cb22c5e953eb309395ed059297f0583a0d (diff) |
Switch to LSTM matrix format used in Marina; be backwards compatiblelstm-overhaul
-rw-r--r-- | src/amun/gpu/dl4mt/lstm.h | 29 | ||||
-rw-r--r-- | src/amun/gpu/dl4mt/model.cu | 67 | ||||
-rw-r--r-- | src/amun/gpu/dl4mt/model.h | 12 |
3 files changed, 56 insertions, 52 deletions
diff --git a/src/amun/gpu/dl4mt/lstm.h b/src/amun/gpu/dl4mt/lstm.h index 2f343db6..be8977c8 100644 --- a/src/amun/gpu/dl4mt/lstm.h +++ b/src/amun/gpu/dl4mt/lstm.h @@ -25,33 +25,31 @@ class SlowLSTM: public Cell { const unsigned cols = GetStateLength().output; - // transform context for use with gates - Prod(FIO_, Context, *w_.W_); - BroadcastVec(_1 + _2, FIO_, *w_.B_); // Broadcasting row-wise - // transform context for use with computing the input - Prod(H_, Context, *w_.Wx_); - BroadcastVec(_1 + _2, H_, *w_.Bx_); // Broadcasting row-wise - - // transform previous output for use with gates + // transform context for use with gates and for computing the input (the C part) + Prod(FIOC_, Context, *w_.W_); + BroadcastVec(_1 + _2, FIOC_, *w_.B_); // Broadcasting row-wise + + // transform previous output for use with gates and for computing the input Prod(Temp1_, *(State.output), *w_.U_); - // transform previous output for use with computing the input - Prod(Temp2_, *(State.output), *w_.Ux_); + Element(_1 + _2, FIOC_, Temp1_); // compute the gates - Element(Logit(_1 + _2), FIO_, Temp1_); + Slice(FIO_, FIOC_, 0, cols * 3); + Element(Logit(_1), FIO_); Slice(F_, FIO_, 0, cols); Slice(I_, FIO_, 1, cols); Slice(O_, FIO_, 2, cols); // compute the input - Element(Tanh(_1 + _2), H_, Temp2_); + Slice(C_, FIOC_, 3, cols); + Element(Tanh(_1), C_); // apply the forget gate Copy(*NextState.cell, *State.cell); Element(_1 * _2, *NextState.cell, F_); // apply the input gate - Element(_1 * _2, H_, I_); + Element(_1 * _2, C_, I_); // update the cell state with the input - Element(_1 + _2, *NextState.cell, H_); + Element(_1 + _2, *NextState.cell, C_); // apply the output gate Element(_1 * Tanh(_2), O_, *NextState.cell); Swap(*(NextState.output), O_); @@ -67,10 +65,11 @@ class SlowLSTM: public Cell { // reused to avoid allocation mutable mblas::Tensor FIO_; + mutable mblas::Tensor FIOC_; mutable mblas::Tensor F_; mutable mblas::Tensor I_; mutable mblas::Tensor O_; - mutable mblas::Tensor H_; + mutable mblas::Tensor C_; mutable mblas::Tensor Temp1_; mutable mblas::Tensor Temp2_; diff --git a/src/amun/gpu/dl4mt/model.cu b/src/amun/gpu/dl4mt/model.cu index 9942d93c..179419a8 100644 --- a/src/amun/gpu/dl4mt/model.cu +++ b/src/amun/gpu/dl4mt/model.cu @@ -1,7 +1,24 @@ #include "model.h" +#include "gpu/mblas/tensor_functions.h" using namespace std; +namespace { + using namespace amunmt; + using namespace GPU; + using namespace mblas; + shared_ptr<Tensor> merge (shared_ptr<Tensor> m1, shared_ptr<Tensor> m2) { + if (m2->size()) { + Transpose(*m1); + Transpose(*m2); + Concat(*m1, *m2); + Transpose(*m1); + Transpose(*m2); + } + return m1; + } +} + namespace amunmt { namespace GPU { @@ -34,24 +51,24 @@ Weights::EncForwardGRU::EncForwardGRU(const NpzConverter& model) //////////////////////////////////////////////////////////////////////////////////////////////////// Weights::EncForwardLSTM::EncForwardLSTM(const NpzConverter& model) -: W_(model.get("encoder_W", true)), - B_(model.get("encoder_b", true, true)), - U_(model.get("encoder_U", true)), - Wx_(model.get("encoder_Wx", true)), - Bx_(model.get("encoder_bx", true, true)), - Ux_(model.get("encoder_Ux", true)), +// matrix merging is done to be backwards-compatible with the original LSTM implementation in Amun +// we now use the same format used in Marian +// TODO: adapt to support Nematus LSTM models which use a similar format to Amun's original format +: W_(merge(model.get("encoder_W", true), model.get("encoder_Wx", false))), + B_(merge(model.get("encoder_b", true, true), model.get("encoder_bx", false, true))), + U_(merge(model.get("encoder_U", true), model.get("encoder_Ux", false))), Gamma_1_(model.get("encoder_gamma1", false)), Gamma_2_(model.get("encoder_gamma2", false)) -{ } +{} //////////////////////////////////////////////////////////////////////////////////////////////////// Weights::EncBackwardLSTM::EncBackwardLSTM(const NpzConverter& model) -: W_(model.get("encoder_r_W", true)), - B_(model.get("encoder_r_b", true, true)), - U_(model.get("encoder_r_U", true)), - Wx_(model.get("encoder_r_Wx", true)), - Bx_(model.get("encoder_r_bx", true, true)), - Ux_(model.get("encoder_r_Ux", true)), +// matrix merging is done to be backwards-compatible with the original LSTM implementation in Amun +// we now use the same format used in Marian +// TODO: adapt to support Nematus LSTM models which use a similar format to Amun's original format +: W_(merge(model.get("encoder_r_W", true), model.get("encoder_r_Wx", false))), + B_(merge(model.get("encoder_r_b", true, true), model.get("encoder_r_bx", false, true))), + U_(merge(model.get("encoder_r_U", true), model.get("encoder_r_Ux", false))), Gamma_1_(model.get("encoder_r_gamma1", false)), Gamma_2_(model.get("encoder_r_gamma2", false)) {} @@ -110,24 +127,24 @@ Weights::DecGRU2::DecGRU2(const NpzConverter& model) //////////////////////////////////////////////////////////////////////////////////////////////////// Weights::DecLSTM1::DecLSTM1(const NpzConverter& model) -: W_(model.get("decoder_W", true)), - B_(model.get("decoder_b", true, true)), - U_(model.get("decoder_U", true)), - Wx_(model.get("decoder_Wx", true)), - Bx_(model.get("decoder_bx", true, true)), - Ux_(model.get("decoder_Ux", true)), +// matrix merging is done to be backwards-compatible with the original LSTM implementation in Amun +// we now use the same format used in Marian +// TODO: adapt to support Nematus LSTM models which use a similar format to Amun's original format +: W_(merge(model.get("decoder_W", true), model.get("decoder_Wx", false))), + B_(merge(model.get("decoder_b", true, true), model.get("decoder_bx", false, true))), + U_(merge(model.get("decoder_U", true), model.get("decoder_Ux", false))), Gamma_1_(model.get("decoder_cell1_gamma1", false)), Gamma_2_(model.get("decoder_cell1_gamma2", false)) {} //////////////////////////////////////////////////////////////////////////////////////////////////// Weights::DecLSTM2::DecLSTM2(const NpzConverter& model) -: W_(model.get("decoder_Wc", true)), - B_(model.get("decoder_b_nl", true, true)), - U_(model.get("decoder_U_nl", true)), - Wx_(model.get("decoder_Wcx", true)), - Bx_(model.get("decoder_bx_nl", true, true)), - Ux_(model.get("decoder_Ux_nl", true)), +// matrix merging is done to be backwards-compatible with the original LSTM implementation in Amun +// we now use the same format used in Marian +// TODO: adapt to support Nematus LSTM models which use a similar format to Amun's original format +: W_(merge(model.get("decoder_Wc", true), model.get("decoder_Wcx", false))), + B_(merge(model.get("decoder_b_nl", true, true), model.get("decoder_bx_nl", false, true))), + U_(merge(model.get("decoder_U_nl", true), model.get("decoder_Ux_nl", false))), Gamma_1_(model.get("decoder_cell2_gamma1", false)), Gamma_2_(model.get("decoder_cell2_gamma2", false)) {} diff --git a/src/amun/gpu/dl4mt/model.h b/src/amun/gpu/dl4mt/model.h index 0829d233..3eaf4878 100644 --- a/src/amun/gpu/dl4mt/model.h +++ b/src/amun/gpu/dl4mt/model.h @@ -62,9 +62,6 @@ struct Weights { const std::shared_ptr<mblas::Tensor> W_; const std::shared_ptr<mblas::Tensor> B_; const std::shared_ptr<mblas::Tensor> U_; - const std::shared_ptr<mblas::Tensor> Wx_; - const std::shared_ptr<mblas::Tensor> Bx_; - const std::shared_ptr<mblas::Tensor> Ux_; const std::shared_ptr<mblas::Tensor> Gamma_1_; const std::shared_ptr<mblas::Tensor> Gamma_2_; }; @@ -77,9 +74,6 @@ struct Weights { const std::shared_ptr<mblas::Tensor> W_; const std::shared_ptr<mblas::Tensor> B_; const std::shared_ptr<mblas::Tensor> U_; - const std::shared_ptr<mblas::Tensor> Wx_; - const std::shared_ptr<mblas::Tensor> Bx_; - const std::shared_ptr<mblas::Tensor> Ux_; const std::shared_ptr<mblas::Tensor> Gamma_1_; const std::shared_ptr<mblas::Tensor> Gamma_2_; }; @@ -147,9 +141,6 @@ struct Weights { const std::shared_ptr<mblas::Tensor> W_; const std::shared_ptr<mblas::Tensor> B_; const std::shared_ptr<mblas::Tensor> U_; - const std::shared_ptr<mblas::Tensor> Wx_; - const std::shared_ptr<mblas::Tensor> Bx_; - const std::shared_ptr<mblas::Tensor> Ux_; const std::shared_ptr<mblas::Tensor> Gamma_1_; const std::shared_ptr<mblas::Tensor> Gamma_2_; }; @@ -163,9 +154,6 @@ struct Weights { const std::shared_ptr<mblas::Tensor> W_; const std::shared_ptr<mblas::Tensor> B_; const std::shared_ptr<mblas::Tensor> U_; - const std::shared_ptr<mblas::Tensor> Wx_; - const std::shared_ptr<mblas::Tensor> Bx_; - const std::shared_ptr<mblas::Tensor> Ux_; const std::shared_ptr<mblas::Tensor> Gamma_1_; const std::shared_ptr<mblas::Tensor> Gamma_2_; }; |