Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRihards Krišlauks <rihards.krislauks@gmail.com>2018-02-28 16:59:16 +0300
committerRihards Krišlauks <rihards.krislauks@gmail.com>2018-02-28 18:22:51 +0300
commit138f53ec8f7324d9f771d1cc6fae92187d5c5420 (patch)
treee017eafb96fd2a30849c1f93d64e161c8274d596
parent1de1e7cb22c5e953eb309395ed059297f0583a0d (diff)
Switch to LSTM matrix format used in Marina; be backwards compatiblelstm-overhaul
-rw-r--r--src/amun/gpu/dl4mt/lstm.h29
-rw-r--r--src/amun/gpu/dl4mt/model.cu67
-rw-r--r--src/amun/gpu/dl4mt/model.h12
3 files changed, 56 insertions, 52 deletions
diff --git a/src/amun/gpu/dl4mt/lstm.h b/src/amun/gpu/dl4mt/lstm.h
index 2f343db6..be8977c8 100644
--- a/src/amun/gpu/dl4mt/lstm.h
+++ b/src/amun/gpu/dl4mt/lstm.h
@@ -25,33 +25,31 @@ class SlowLSTM: public Cell {
const unsigned cols = GetStateLength().output;
- // transform context for use with gates
- Prod(FIO_, Context, *w_.W_);
- BroadcastVec(_1 + _2, FIO_, *w_.B_); // Broadcasting row-wise
- // transform context for use with computing the input
- Prod(H_, Context, *w_.Wx_);
- BroadcastVec(_1 + _2, H_, *w_.Bx_); // Broadcasting row-wise
-
- // transform previous output for use with gates
+ // transform context for use with gates and for computing the input (the C part)
+ Prod(FIOC_, Context, *w_.W_);
+ BroadcastVec(_1 + _2, FIOC_, *w_.B_); // Broadcasting row-wise
+
+ // transform previous output for use with gates and for computing the input
Prod(Temp1_, *(State.output), *w_.U_);
- // transform previous output for use with computing the input
- Prod(Temp2_, *(State.output), *w_.Ux_);
+ Element(_1 + _2, FIOC_, Temp1_);
// compute the gates
- Element(Logit(_1 + _2), FIO_, Temp1_);
+ Slice(FIO_, FIOC_, 0, cols * 3);
+ Element(Logit(_1), FIO_);
Slice(F_, FIO_, 0, cols);
Slice(I_, FIO_, 1, cols);
Slice(O_, FIO_, 2, cols);
// compute the input
- Element(Tanh(_1 + _2), H_, Temp2_);
+ Slice(C_, FIOC_, 3, cols);
+ Element(Tanh(_1), C_);
// apply the forget gate
Copy(*NextState.cell, *State.cell);
Element(_1 * _2, *NextState.cell, F_);
// apply the input gate
- Element(_1 * _2, H_, I_);
+ Element(_1 * _2, C_, I_);
// update the cell state with the input
- Element(_1 + _2, *NextState.cell, H_);
+ Element(_1 + _2, *NextState.cell, C_);
// apply the output gate
Element(_1 * Tanh(_2), O_, *NextState.cell);
Swap(*(NextState.output), O_);
@@ -67,10 +65,11 @@ class SlowLSTM: public Cell {
// reused to avoid allocation
mutable mblas::Tensor FIO_;
+ mutable mblas::Tensor FIOC_;
mutable mblas::Tensor F_;
mutable mblas::Tensor I_;
mutable mblas::Tensor O_;
- mutable mblas::Tensor H_;
+ mutable mblas::Tensor C_;
mutable mblas::Tensor Temp1_;
mutable mblas::Tensor Temp2_;
diff --git a/src/amun/gpu/dl4mt/model.cu b/src/amun/gpu/dl4mt/model.cu
index 9942d93c..179419a8 100644
--- a/src/amun/gpu/dl4mt/model.cu
+++ b/src/amun/gpu/dl4mt/model.cu
@@ -1,7 +1,24 @@
#include "model.h"
+#include "gpu/mblas/tensor_functions.h"
using namespace std;
+namespace {
+ using namespace amunmt;
+ using namespace GPU;
+ using namespace mblas;
+ shared_ptr<Tensor> merge (shared_ptr<Tensor> m1, shared_ptr<Tensor> m2) {
+ if (m2->size()) {
+ Transpose(*m1);
+ Transpose(*m2);
+ Concat(*m1, *m2);
+ Transpose(*m1);
+ Transpose(*m2);
+ }
+ return m1;
+ }
+}
+
namespace amunmt {
namespace GPU {
@@ -34,24 +51,24 @@ Weights::EncForwardGRU::EncForwardGRU(const NpzConverter& model)
////////////////////////////////////////////////////////////////////////////////////////////////////
Weights::EncForwardLSTM::EncForwardLSTM(const NpzConverter& model)
-: W_(model.get("encoder_W", true)),
- B_(model.get("encoder_b", true, true)),
- U_(model.get("encoder_U", true)),
- Wx_(model.get("encoder_Wx", true)),
- Bx_(model.get("encoder_bx", true, true)),
- Ux_(model.get("encoder_Ux", true)),
+// matrix merging is done to be backwards-compatible with the original LSTM implementation in Amun
+// we now use the same format used in Marian
+// TODO: adapt to support Nematus LSTM models which use a similar format to Amun's original format
+: W_(merge(model.get("encoder_W", true), model.get("encoder_Wx", false))),
+ B_(merge(model.get("encoder_b", true, true), model.get("encoder_bx", false, true))),
+ U_(merge(model.get("encoder_U", true), model.get("encoder_Ux", false))),
Gamma_1_(model.get("encoder_gamma1", false)),
Gamma_2_(model.get("encoder_gamma2", false))
-{ }
+{}
////////////////////////////////////////////////////////////////////////////////////////////////////
Weights::EncBackwardLSTM::EncBackwardLSTM(const NpzConverter& model)
-: W_(model.get("encoder_r_W", true)),
- B_(model.get("encoder_r_b", true, true)),
- U_(model.get("encoder_r_U", true)),
- Wx_(model.get("encoder_r_Wx", true)),
- Bx_(model.get("encoder_r_bx", true, true)),
- Ux_(model.get("encoder_r_Ux", true)),
+// matrix merging is done to be backwards-compatible with the original LSTM implementation in Amun
+// we now use the same format used in Marian
+// TODO: adapt to support Nematus LSTM models which use a similar format to Amun's original format
+: W_(merge(model.get("encoder_r_W", true), model.get("encoder_r_Wx", false))),
+ B_(merge(model.get("encoder_r_b", true, true), model.get("encoder_r_bx", false, true))),
+ U_(merge(model.get("encoder_r_U", true), model.get("encoder_r_Ux", false))),
Gamma_1_(model.get("encoder_r_gamma1", false)),
Gamma_2_(model.get("encoder_r_gamma2", false))
{}
@@ -110,24 +127,24 @@ Weights::DecGRU2::DecGRU2(const NpzConverter& model)
////////////////////////////////////////////////////////////////////////////////////////////////////
Weights::DecLSTM1::DecLSTM1(const NpzConverter& model)
-: W_(model.get("decoder_W", true)),
- B_(model.get("decoder_b", true, true)),
- U_(model.get("decoder_U", true)),
- Wx_(model.get("decoder_Wx", true)),
- Bx_(model.get("decoder_bx", true, true)),
- Ux_(model.get("decoder_Ux", true)),
+// matrix merging is done to be backwards-compatible with the original LSTM implementation in Amun
+// we now use the same format used in Marian
+// TODO: adapt to support Nematus LSTM models which use a similar format to Amun's original format
+: W_(merge(model.get("decoder_W", true), model.get("decoder_Wx", false))),
+ B_(merge(model.get("decoder_b", true, true), model.get("decoder_bx", false, true))),
+ U_(merge(model.get("decoder_U", true), model.get("decoder_Ux", false))),
Gamma_1_(model.get("decoder_cell1_gamma1", false)),
Gamma_2_(model.get("decoder_cell1_gamma2", false))
{}
////////////////////////////////////////////////////////////////////////////////////////////////////
Weights::DecLSTM2::DecLSTM2(const NpzConverter& model)
-: W_(model.get("decoder_Wc", true)),
- B_(model.get("decoder_b_nl", true, true)),
- U_(model.get("decoder_U_nl", true)),
- Wx_(model.get("decoder_Wcx", true)),
- Bx_(model.get("decoder_bx_nl", true, true)),
- Ux_(model.get("decoder_Ux_nl", true)),
+// matrix merging is done to be backwards-compatible with the original LSTM implementation in Amun
+// we now use the same format used in Marian
+// TODO: adapt to support Nematus LSTM models which use a similar format to Amun's original format
+: W_(merge(model.get("decoder_Wc", true), model.get("decoder_Wcx", false))),
+ B_(merge(model.get("decoder_b_nl", true, true), model.get("decoder_bx_nl", false, true))),
+ U_(merge(model.get("decoder_U_nl", true), model.get("decoder_Ux_nl", false))),
Gamma_1_(model.get("decoder_cell2_gamma1", false)),
Gamma_2_(model.get("decoder_cell2_gamma2", false))
{}
diff --git a/src/amun/gpu/dl4mt/model.h b/src/amun/gpu/dl4mt/model.h
index 0829d233..3eaf4878 100644
--- a/src/amun/gpu/dl4mt/model.h
+++ b/src/amun/gpu/dl4mt/model.h
@@ -62,9 +62,6 @@ struct Weights {
const std::shared_ptr<mblas::Tensor> W_;
const std::shared_ptr<mblas::Tensor> B_;
const std::shared_ptr<mblas::Tensor> U_;
- const std::shared_ptr<mblas::Tensor> Wx_;
- const std::shared_ptr<mblas::Tensor> Bx_;
- const std::shared_ptr<mblas::Tensor> Ux_;
const std::shared_ptr<mblas::Tensor> Gamma_1_;
const std::shared_ptr<mblas::Tensor> Gamma_2_;
};
@@ -77,9 +74,6 @@ struct Weights {
const std::shared_ptr<mblas::Tensor> W_;
const std::shared_ptr<mblas::Tensor> B_;
const std::shared_ptr<mblas::Tensor> U_;
- const std::shared_ptr<mblas::Tensor> Wx_;
- const std::shared_ptr<mblas::Tensor> Bx_;
- const std::shared_ptr<mblas::Tensor> Ux_;
const std::shared_ptr<mblas::Tensor> Gamma_1_;
const std::shared_ptr<mblas::Tensor> Gamma_2_;
};
@@ -147,9 +141,6 @@ struct Weights {
const std::shared_ptr<mblas::Tensor> W_;
const std::shared_ptr<mblas::Tensor> B_;
const std::shared_ptr<mblas::Tensor> U_;
- const std::shared_ptr<mblas::Tensor> Wx_;
- const std::shared_ptr<mblas::Tensor> Bx_;
- const std::shared_ptr<mblas::Tensor> Ux_;
const std::shared_ptr<mblas::Tensor> Gamma_1_;
const std::shared_ptr<mblas::Tensor> Gamma_2_;
};
@@ -163,9 +154,6 @@ struct Weights {
const std::shared_ptr<mblas::Tensor> W_;
const std::shared_ptr<mblas::Tensor> B_;
const std::shared_ptr<mblas::Tensor> U_;
- const std::shared_ptr<mblas::Tensor> Wx_;
- const std::shared_ptr<mblas::Tensor> Bx_;
- const std::shared_ptr<mblas::Tensor> Ux_;
const std::shared_ptr<mblas::Tensor> Gamma_1_;
const std::shared_ptr<mblas::Tensor> Gamma_2_;
};