Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-11-26 21:24:07 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-11-26 21:24:07 +0300
commit284c2389ce12d060bd55788a1d6168470c31f068 (patch)
treef33bbb7bf65ca974732a8008dc74479dadb003fd
parent69f02ac398988e2285258cd584101e81c4e6c530 (diff)
move implementations into .cu file
-rw-r--r--contrib/other-builds/amunmt/.project5
m---------examples0
-rwxr-xr-xsrc/amun/CMakeLists.txt3
-rw-r--r--src/amun/gpu/dl4mt/model.cu220
-rw-r--r--src/amun/gpu/dl4mt/model.h238
m---------src/marian11
6 files changed, 277 insertions, 200 deletions
diff --git a/contrib/other-builds/amunmt/.project b/contrib/other-builds/amunmt/.project
index 4cfd4037..9fc6aff9 100644
--- a/contrib/other-builds/amunmt/.project
+++ b/contrib/other-builds/amunmt/.project
@@ -1526,6 +1526,11 @@
<locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/dl4mt/lstm.h</locationURI>
</link>
<link>
+ <name>src/amun/gpu/dl4mt/model.cu</name>
+ <type>1</type>
+ <locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/dl4mt/model.cu</locationURI>
+ </link>
+ <link>
<name>src/amun/gpu/dl4mt/model.h</name>
<type>1</type>
<locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/dl4mt/model.h</locationURI>
diff --git a/examples b/examples
-Subproject 050843dd917043bb5588ea431d02dcab68b567e
+Subproject 280481dc29167dc5b9fe63d85bbb6bacdf918c7
diff --git a/src/amun/CMakeLists.txt b/src/amun/CMakeLists.txt
index b6662951..3e8899d4 100755
--- a/src/amun/CMakeLists.txt
+++ b/src/amun/CMakeLists.txt
@@ -84,6 +84,7 @@ cuda_add_executable(
gpu/decoder/encoder_decoder_state.cu
gpu/dl4mt/encoder.cu
gpu/dl4mt/gru.cu
+ gpu/dl4mt/model.cu
gpu/mblas/matrix.cu
gpu/mblas/matrix_functions.cu
gpu/mblas/nth_element.cu
@@ -112,6 +113,7 @@ cuda_add_library(python SHARED
gpu/mblas/nth_element_kernels.cu
gpu/dl4mt/encoder.cu
gpu/dl4mt/gru.cu
+ gpu/dl4mt/model.cu
gpu/npz_converter.cu
gpu/types-gpu.cu
common/loader_factory.cpp
@@ -139,6 +141,7 @@ cuda_add_library(mosesplugin STATIC
gpu/mblas/nth_element_kernels.cu
gpu/dl4mt/encoder.cu
gpu/dl4mt/gru.cu
+ gpu/dl4mt/model.cu
gpu/npz_converter.cu
gpu/types-gpu.cu
common/loader_factory.cpp
diff --git a/src/amun/gpu/dl4mt/model.cu b/src/amun/gpu/dl4mt/model.cu
new file mode 100644
index 00000000..326844e1
--- /dev/null
+++ b/src/amun/gpu/dl4mt/model.cu
@@ -0,0 +1,220 @@
+#include "model.h"
+
+namespace amunmt {
+namespace GPU {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::EncEmbeddings::EncEmbeddings(const NpzConverter& model)
+: E_(model.get("Wemb", true))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::EncForwardGRU::EncForwardGRU(const NpzConverter& model)
+: W_(model.get("encoder_W", true)),
+ B_(model.get("encoder_b", true, true)),
+ U_(model.get("encoder_U", true)),
+ Wx_(model.get("encoder_Wx", true)),
+ Bx1_(model.get("encoder_bx", true, true)),
+ Bx2_(new mblas::Matrix(Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)),
+ Ux_(model.get("encoder_Ux", true)),
+ Gamma_1_(model.get("encoder_gamma1", false)),
+ Gamma_2_(model.get("encoder_gamma2", false))
+{ }
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::EncForwardLSTM::EncForwardLSTM(const NpzConverter& model)
+: W_(model.get("encoder_W", true)),
+ B_(model.get("encoder_b", true, true)),
+ U_(model.get("encoder_U", true)),
+ Wx_(model.get("encoder_Wx", true)),
+ Bx_(model.get("encoder_bx", true, true)),
+ Ux_(model.get("encoder_Ux", true)),
+ Gamma_1_(model.get("encoder_gamma1", false)),
+ Gamma_2_(model.get("encoder_gamma2", false))
+{ }
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::EncBackwardLSTM::EncBackwardLSTM(const NpzConverter& model)
+: W_(model.get("encoder_r_W", true)),
+ B_(model.get("encoder_r_b", true, true)),
+ U_(model.get("encoder_r_U", true)),
+ Wx_(model.get("encoder_r_Wx", true)),
+ Bx_(model.get("encoder_r_bx", true, true)),
+ Ux_(model.get("encoder_r_Ux", true)),
+ Gamma_1_(model.get("encoder_r_gamma1", false)),
+ Gamma_2_(model.get("encoder_r_gamma2", false))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::EncBackwardGRU::EncBackwardGRU(const NpzConverter& model)
+: W_(model.get("encoder_r_W", true)),
+ B_(model.get("encoder_r_b", true, true)),
+ U_(model.get("encoder_r_U", true)),
+ Wx_(model.get("encoder_r_Wx", true)),
+ Bx1_(model.get("encoder_r_bx", true, true)),
+ Bx2_(new mblas::Matrix( Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)),
+ Ux_(model.get("encoder_r_Ux", true)),
+ Gamma_1_(model.get("encoder_r_gamma1", false)),
+ Gamma_2_(model.get("encoder_r_gamma2", false))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::DecEmbeddings::DecEmbeddings(const NpzConverter& model)
+: E_(model.getFirstOfMany({std::make_pair("Wemb_dec", false),
+ std::make_pair("Wemb", false)}, true))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::DecInit::DecInit(const NpzConverter& model)
+: Wi_(model.get("ff_state_W", true)),
+ Bi_(model.get("ff_state_b", true, true)),
+ Gamma_(model.get("ff_state_gamma", false))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::DecGRU1::DecGRU1(const NpzConverter& model)
+: W_(model.get("decoder_W", true)),
+ B_(model.get("decoder_b", true, true)),
+ U_(model.get("decoder_U", true)),
+ Wx_(model.get("decoder_Wx", true)),
+ Bx1_(model.get("decoder_bx", true, true)),
+ Bx2_(new mblas::Matrix(Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)),
+ Ux_(model.get("decoder_Ux", true)),
+ Gamma_1_(model.get("decoder_cell1_gamma1", false)),
+ Gamma_2_(model.get("decoder_cell1_gamma2", false))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::DecGRU2::DecGRU2(const NpzConverter& model)
+: W_(model.get("decoder_Wc", true)),
+ B_(model.get("decoder_b_nl", true, true)),
+ U_(model.get("decoder_U_nl", true)),
+ Wx_(model.get("decoder_Wcx", true)),
+ Bx2_(model.get("decoder_bx_nl", true, true)),
+ Bx1_(new mblas::Matrix(Bx2_->dim(0), Bx2_->dim(1), Bx2_->dim(2), Bx2_->dim(3), true)),
+ Ux_(model.get("decoder_Ux_nl", true)),
+ Gamma_1_(model.get("decoder_cell2_gamma1", false)),
+ Gamma_2_(model.get("decoder_cell2_gamma2", false))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::DecLSTM1::DecLSTM1(const NpzConverter& model)
+: W_(model.get("decoder_W", true)),
+ B_(model.get("decoder_b", true, true)),
+ U_(model.get("decoder_U", true)),
+ Wx_(model.get("decoder_Wx", true)),
+ Bx_(model.get("decoder_bx", true, true)),
+ Ux_(model.get("decoder_Ux", true)),
+ Gamma_1_(model.get("decoder_cell1_gamma1", false)),
+ Gamma_2_(model.get("decoder_cell1_gamma2", false))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::DecLSTM2::DecLSTM2(const NpzConverter& model)
+: W_(model.get("decoder_Wc", true)),
+ B_(model.get("decoder_b_nl", true, true)),
+ U_(model.get("decoder_U_nl", true)),
+ Wx_(model.get("decoder_Wcx", true)),
+ Bx_(model.get("decoder_bx_nl", true, true)),
+ Ux_(model.get("decoder_Ux_nl", true)),
+ Gamma_1_(model.get("decoder_cell2_gamma1", false)),
+ Gamma_2_(model.get("decoder_cell2_gamma2", false))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::DecAlignment::DecAlignment(const NpzConverter& model)
+: V_(model.get("decoder_U_att", true, true)),
+ W_(model.get("decoder_W_comb_att", true)),
+ B_(model.get("decoder_b_att", true, true)),
+ U_(model.get("decoder_Wc_att", true)),
+ C_(model.get("decoder_c_tt", true)), // scalar?
+ Gamma_1_(model.get("decoder_att_gamma1", false)),
+ Gamma_2_(model.get("decoder_att_gamma2", false))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
+: W1_(model.get("ff_logit_lstm_W", true)),
+ B1_(model.get("ff_logit_lstm_b", true, true)),
+ W2_(model.get("ff_logit_prev_W", true)),
+ B2_(model.get("ff_logit_prev_b", true, true)),
+ W3_(model.get("ff_logit_ctx_W", true)),
+ B3_(model.get("ff_logit_ctx_b", true, true)),
+ W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
+ std::make_pair(std::string("Wemb_dec"), true)}, true)),
+ B4_(model.get("ff_logit_b", true, true)),
+ Gamma_0_(model.get("ff_logit_l1_gamma0", false)),
+ Gamma_1_(model.get("ff_logit_l1_gamma1", false)),
+ Gamma_2_(model.get("ff_logit_l1_gamma2", false))
+{}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+Weights::Weights(const std::string& npzFile, const YAML::Node& config, size_t device)
+: Weights(NpzConverter(npzFile), config, device)
+{}
+
+Weights::Weights(const NpzConverter& model, const YAML::Node& config, size_t device)
+: encEmbeddings_(model),
+ decEmbeddings_(model),
+ decInit_(model),
+ decAlignment_(model),
+ decSoftmax_(model),
+ device_(device)
+{
+
+ std::string encCell = config["enc-cell"] ? config["enc-cell"].as<std::string>() : "gru";
+ std::string encCell_r = config["enc-cell-r"] ? config["enc-cell-r"].as<std::string>() : encCell;
+ initEncForward(model, encCell);
+ initEncBackward(model, encCell_r);
+
+ std::string decCell = config["dec-cell"] ? config["dec-cell"].as<std::string>() : "gru";
+ std::string decCell2 = config["dec-cell-2"] ? config["dec-cell-2"].as<std::string>() : decCell;
+ initDec1(model, decCell);
+ initDec2(model, decCell2);
+}
+
+void Weights::initEncForward(const NpzConverter& model,std::string celltype){
+ if(celltype == "lstm"){
+ encForwardLSTM_ = std::shared_ptr<EncForwardLSTM>(new EncForwardLSTM(model));
+ } else if (celltype == "mlstm") {
+ encForwardMLSTM_ = std::shared_ptr<MultWeights<EncForwardLSTM>>
+ (new MultWeights<EncForwardLSTM>(model, "encoder"));
+ } else if (celltype == "gru"){
+ encForwardGRU_ = std::shared_ptr<EncForwardGRU>(new EncForwardGRU(model));
+ }
+}
+
+void Weights::initEncBackward(const NpzConverter& model,std::string celltype) {
+ if(celltype == "lstm"){
+ encBackwardLSTM_ = std::shared_ptr<EncBackwardLSTM>(new EncBackwardLSTM(model));
+ } else if (celltype == "mlstm") {
+ encBackwardMLSTM_ = std::shared_ptr<MultWeights<EncBackwardLSTM>>
+ (new MultWeights<EncBackwardLSTM>(model, "encoder_r"));
+ } else if (celltype == "gru"){
+ encBackwardGRU_ = std::shared_ptr<EncBackwardGRU>(new EncBackwardGRU(model));
+ }
+}
+
+void Weights::initDec1(const NpzConverter& model,std::string celltype){
+ if (celltype == "lstm"){
+ decLSTM1_ = std::shared_ptr<DecLSTM1>(new DecLSTM1(model));
+ } else if (celltype == "mlstm") {
+ decMLSTM1_ = std::shared_ptr<MultWeights<DecLSTM1>>(new MultWeights<DecLSTM1>(model, "decoder"));
+ } else if (celltype == "gru") {
+ decGru1_ = std::shared_ptr<DecGRU1>(new DecGRU1(model));
+ }
+}
+
+void Weights::initDec2(const NpzConverter& model,std::string celltype){
+ if (celltype == "lstm"){
+ decLSTM2_ = std::shared_ptr<DecLSTM2>(new DecLSTM2(model));
+ } else if (celltype == "mlstm") {
+ decMLSTM2_ = std::shared_ptr<MultWeights<DecLSTM2>>(new MultWeights<DecLSTM2>(model, "decoder_2"));
+ } else if (celltype == "gru") {
+ decGru2_ = std::shared_ptr<DecGRU2>(new DecGRU2(model));
+ }
+}
+
+} // namespace
+}
+
diff --git a/src/amun/gpu/dl4mt/model.h b/src/amun/gpu/dl4mt/model.h
index 283140e5..9a29c344 100644
--- a/src/amun/gpu/dl4mt/model.h
+++ b/src/amun/gpu/dl4mt/model.h
@@ -13,31 +13,17 @@ namespace GPU {
struct Weights {
//////////////////////////////////////////////////////////////////////////////
-
struct EncEmbeddings {
EncEmbeddings(const EncEmbeddings&) = delete;
-
- EncEmbeddings(const NpzConverter& model)
- : E_(model.get("Wemb", true))
- {}
+ EncEmbeddings(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> E_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
struct EncForwardGRU {
EncForwardGRU(const EncForwardGRU&) = delete;
-
- EncForwardGRU(const NpzConverter& model)
- : W_(model.get("encoder_W", true)),
- B_(model.get("encoder_b", true, true)),
- U_(model.get("encoder_U", true)),
- Wx_(model.get("encoder_Wx", true)),
- Bx1_(model.get("encoder_bx", true, true)),
- Bx2_(new mblas::Matrix(Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)),
- Ux_(model.get("encoder_Ux", true)),
- Gamma_1_(model.get("encoder_gamma1", false)),
- Gamma_2_(model.get("encoder_gamma2", false))
- { }
+ EncForwardGRU(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> W_;
const std::shared_ptr<mblas::Matrix> B_;
@@ -50,20 +36,10 @@ struct Weights {
const std::shared_ptr<mblas::Matrix> Gamma_2_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
struct EncBackwardGRU {
EncBackwardGRU(const EncBackwardGRU&) = delete;
-
- EncBackwardGRU(const NpzConverter& model)
- : W_(model.get("encoder_r_W", true)),
- B_(model.get("encoder_r_b", true, true)),
- U_(model.get("encoder_r_U", true)),
- Wx_(model.get("encoder_r_Wx", true)),
- Bx1_(model.get("encoder_r_bx", true, true)),
- Bx2_(new mblas::Matrix( Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)),
- Ux_(model.get("encoder_r_Ux", true)),
- Gamma_1_(model.get("encoder_r_gamma1", false)),
- Gamma_2_(model.get("encoder_r_gamma2", false))
- {}
+ EncBackwardGRU(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> W_;
const std::shared_ptr<mblas::Matrix> B_;
@@ -76,19 +52,10 @@ struct Weights {
const std::shared_ptr<mblas::Matrix> Gamma_2_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
struct EncForwardLSTM {
EncForwardLSTM(const EncForwardLSTM&) = delete;
-
- EncForwardLSTM(const NpzConverter& model)
- : W_(model.get("encoder_W", true)),
- B_(model.get("encoder_b", true, true)),
- U_(model.get("encoder_U", true)),
- Wx_(model.get("encoder_Wx", true)),
- Bx_(model.get("encoder_bx", true, true)),
- Ux_(model.get("encoder_Ux", true)),
- Gamma_1_(model.get("encoder_gamma1", false)),
- Gamma_2_(model.get("encoder_gamma2", false))
- { }
+ EncForwardLSTM(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> W_;
const std::shared_ptr<mblas::Matrix> B_;
@@ -100,19 +67,10 @@ struct Weights {
const std::shared_ptr<mblas::Matrix> Gamma_2_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
struct EncBackwardLSTM {
EncBackwardLSTM(const EncBackwardLSTM&) = delete;
-
- EncBackwardLSTM(const NpzConverter& model)
- : W_(model.get("encoder_r_W", true)),
- B_(model.get("encoder_r_b", true, true)),
- U_(model.get("encoder_r_U", true)),
- Wx_(model.get("encoder_r_Wx", true)),
- Bx_(model.get("encoder_r_bx", true, true)),
- Ux_(model.get("encoder_r_Ux", true)),
- Gamma_1_(model.get("encoder_r_gamma1", false)),
- Gamma_2_(model.get("encoder_r_gamma2", false))
- {}
+ EncBackwardLSTM(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> W_;
const std::shared_ptr<mblas::Matrix> B_;
@@ -124,47 +82,31 @@ struct Weights {
const std::shared_ptr<mblas::Matrix> Gamma_2_;
};
- //////////////////////////////////////////////////////////////////////////////
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
struct DecEmbeddings {
DecEmbeddings(const DecEmbeddings&) = delete;
-
- DecEmbeddings(const NpzConverter& model)
- : E_(model.getFirstOfMany({std::make_pair("Wemb_dec", false),
- std::make_pair("Wemb", false)}, true))
- {}
+ DecEmbeddings(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> E_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
+
struct DecInit {
DecInit(const DecInit&) = delete;
-
- DecInit(const NpzConverter& model)
- : Wi_(model.get("ff_state_W", true)),
- Bi_(model.get("ff_state_b", true, true)),
- Gamma_(model.get("ff_state_gamma", false))
- {}
+ DecInit(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> Wi_;
const std::shared_ptr<mblas::Matrix> Bi_;
const std::shared_ptr<mblas::Matrix> Gamma_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
+
struct DecGRU1 {
DecGRU1(const DecGRU1&) = delete;
-
- DecGRU1(const NpzConverter& model)
- : W_(model.get("decoder_W", true)),
- B_(model.get("decoder_b", true, true)),
- U_(model.get("decoder_U", true)),
- Wx_(model.get("decoder_Wx", true)),
- Bx1_(model.get("decoder_bx", true, true)),
- Bx2_(new mblas::Matrix(Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)),
- Ux_(model.get("decoder_Ux", true)),
- Gamma_1_(model.get("decoder_cell1_gamma1", false)),
- Gamma_2_(model.get("decoder_cell1_gamma2", false))
- {}
+ DecGRU1(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> W_;
const std::shared_ptr<mblas::Matrix> B_;
@@ -177,20 +119,11 @@ struct Weights {
const std::shared_ptr<mblas::Matrix> Gamma_2_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
+
struct DecGRU2 {
DecGRU2(const DecGRU2&) = delete;
-
- DecGRU2(const NpzConverter& model)
- : W_(model.get("decoder_Wc", true)),
- B_(model.get("decoder_b_nl", true, true)),
- U_(model.get("decoder_U_nl", true)),
- Wx_(model.get("decoder_Wcx", true)),
- Bx2_(model.get("decoder_bx_nl", true, true)),
- Bx1_(new mblas::Matrix(Bx2_->dim(0), Bx2_->dim(1), Bx2_->dim(2), Bx2_->dim(3), true)),
- Ux_(model.get("decoder_Ux_nl", true)),
- Gamma_1_(model.get("decoder_cell2_gamma1", false)),
- Gamma_2_(model.get("decoder_cell2_gamma2", false))
- {}
+ DecGRU2(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> W_;
const std::shared_ptr<mblas::Matrix> B_;
@@ -203,19 +136,11 @@ struct Weights {
const std::shared_ptr<mblas::Matrix> Gamma_2_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
+
struct DecLSTM1 {
DecLSTM1(const DecLSTM1&) = delete;
-
- DecLSTM1(const NpzConverter& model)
- : W_(model.get("decoder_W", true)),
- B_(model.get("decoder_b", true, true)),
- U_(model.get("decoder_U", true)),
- Wx_(model.get("decoder_Wx", true)),
- Bx_(model.get("decoder_bx", true, true)),
- Ux_(model.get("decoder_Ux", true)),
- Gamma_1_(model.get("decoder_cell1_gamma1", false)),
- Gamma_2_(model.get("decoder_cell1_gamma2", false))
- {}
+ DecLSTM1(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> W_;
const std::shared_ptr<mblas::Matrix> B_;
@@ -227,19 +152,11 @@ struct Weights {
const std::shared_ptr<mblas::Matrix> Gamma_2_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
+
struct DecLSTM2 {
DecLSTM2(const DecLSTM2&) = delete;
-
- DecLSTM2(const NpzConverter& model)
- : W_(model.get("decoder_Wc", true)),
- B_(model.get("decoder_b_nl", true, true)),
- U_(model.get("decoder_U_nl", true)),
- Wx_(model.get("decoder_Wcx", true)),
- Bx_(model.get("decoder_bx_nl", true, true)),
- Ux_(model.get("decoder_Ux_nl", true)),
- Gamma_1_(model.get("decoder_cell2_gamma1", false)),
- Gamma_2_(model.get("decoder_cell2_gamma2", false))
- {}
+ DecLSTM2(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> W_;
const std::shared_ptr<mblas::Matrix> B_;
@@ -251,6 +168,8 @@ struct Weights {
const std::shared_ptr<mblas::Matrix> Gamma_2_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
+
// A wrapper class to deserialize weights for multiplicative-LSTM,
// multiplicative-GRU and such
template<class BaseWeights>
@@ -273,18 +192,11 @@ struct Weights {
}
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
+
struct DecAlignment {
DecAlignment(const DecAlignment&) = delete;
-
- DecAlignment(const NpzConverter& model)
- : V_(model.get("decoder_U_att", true, true)),
- W_(model.get("decoder_W_comb_att", true)),
- B_(model.get("decoder_b_att", true, true)),
- U_(model.get("decoder_Wc_att", true)),
- C_(model.get("decoder_c_tt", true)), // scalar?
- Gamma_1_(model.get("decoder_att_gamma1", false)),
- Gamma_2_(model.get("decoder_att_gamma2", false))
- {}
+ DecAlignment(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> V_;
const std::shared_ptr<mblas::Matrix> W_;
@@ -295,23 +207,11 @@ struct Weights {
const std::shared_ptr<mblas::Matrix> Gamma_2_;
};
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
+
struct DecSoftmax {
DecSoftmax(const DecSoftmax&) = delete;
-
- DecSoftmax(const NpzConverter& model)
- : W1_(model.get("ff_logit_lstm_W", true)),
- B1_(model.get("ff_logit_lstm_b", true, true)),
- W2_(model.get("ff_logit_prev_W", true)),
- B2_(model.get("ff_logit_prev_b", true, true)),
- W3_(model.get("ff_logit_ctx_W", true)),
- B3_(model.get("ff_logit_ctx_b", true, true)),
- W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
- std::make_pair(std::string("Wemb_dec"), true)}, true)),
- B4_(model.get("ff_logit_b", true, true)),
- Gamma_0_(model.get("ff_logit_l1_gamma0", false)),
- Gamma_1_(model.get("ff_logit_l1_gamma1", false)),
- Gamma_2_(model.get("ff_logit_l1_gamma2", false))
- {}
+ DecSoftmax(const NpzConverter& model);
const std::shared_ptr<mblas::Matrix> W1_;
const std::shared_ptr<mblas::Matrix> B1_;
@@ -326,29 +226,11 @@ struct Weights {
const std::shared_ptr<mblas::Matrix> Gamma_2_;
};
- Weights(const std::string& npzFile, const YAML::Node& config, size_t device)
- : Weights(NpzConverter(npzFile), config, device)
- {}
-
- Weights(const NpzConverter& model, const YAML::Node& config, size_t device)
- : encEmbeddings_(model),
- decEmbeddings_(model),
- decInit_(model),
- decAlignment_(model),
- decSoftmax_(model),
- device_(device)
- {
-
- std::string encCell = config["enc-cell"] ? config["enc-cell"].as<std::string>() : "gru";
- std::string encCell_r = config["enc-cell-r"] ? config["enc-cell-r"].as<std::string>() : encCell;
- initEncForward(model, encCell);
- initEncBackward(model, encCell_r);
-
- std::string decCell = config["dec-cell"] ? config["dec-cell"].as<std::string>() : "gru";
- std::string decCell2 = config["dec-cell-2"] ? config["dec-cell-2"].as<std::string>() : decCell;
- initDec1(model, decCell);
- initDec2(model, decCell2);
- }
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
+
+ Weights(const std::string& npzFile, const YAML::Node& config, size_t device);
+
+ Weights(const NpzConverter& model, const YAML::Node& config, size_t device);
Weights(const Weights&) = delete;
@@ -357,44 +239,10 @@ struct Weights {
}
private:
- void initEncForward(const NpzConverter& model,std::string celltype){
- if(celltype == "lstm"){
- encForwardLSTM_ = std::shared_ptr<EncForwardLSTM>(new EncForwardLSTM(model));
- } else if (celltype == "mlstm") {
- encForwardMLSTM_ = std::shared_ptr<MultWeights<EncForwardLSTM>>
- (new MultWeights<EncForwardLSTM>(model, "encoder"));
- } else if (celltype == "gru"){
- encForwardGRU_ = std::shared_ptr<EncForwardGRU>(new EncForwardGRU(model));
- }
- }
- void initEncBackward(const NpzConverter& model,std::string celltype) {
- if(celltype == "lstm"){
- encBackwardLSTM_ = std::shared_ptr<EncBackwardLSTM>(new EncBackwardLSTM(model));
- } else if (celltype == "mlstm") {
- encBackwardMLSTM_ = std::shared_ptr<MultWeights<EncBackwardLSTM>>
- (new MultWeights<EncBackwardLSTM>(model, "encoder_r"));
- } else if (celltype == "gru"){
- encBackwardGRU_ = std::shared_ptr<EncBackwardGRU>(new EncBackwardGRU(model));
- }
- }
- void initDec1(const NpzConverter& model,std::string celltype){
- if (celltype == "lstm"){
- decLSTM1_ = std::shared_ptr<DecLSTM1>(new DecLSTM1(model));
- } else if (celltype == "mlstm") {
- decMLSTM1_ = std::shared_ptr<MultWeights<DecLSTM1>>(new MultWeights<DecLSTM1>(model, "decoder"));
- } else if (celltype == "gru") {
- decGru1_ = std::shared_ptr<DecGRU1>(new DecGRU1(model));
- }
- }
- void initDec2(const NpzConverter& model,std::string celltype){
- if (celltype == "lstm"){
- decLSTM2_ = std::shared_ptr<DecLSTM2>(new DecLSTM2(model));
- } else if (celltype == "mlstm") {
- decMLSTM2_ = std::shared_ptr<MultWeights<DecLSTM2>>(new MultWeights<DecLSTM2>(model, "decoder_2"));
- } else if (celltype == "gru") {
- decGru2_ = std::shared_ptr<DecGRU2>(new DecGRU2(model));
- }
- }
+ void initEncForward(const NpzConverter& model,std::string celltype);
+ void initEncBackward(const NpzConverter& model,std::string celltype);
+ void initDec1(const NpzConverter& model,std::string celltype);
+ void initDec2(const NpzConverter& model,std::string celltype);
public:
const EncEmbeddings encEmbeddings_;
diff --git a/src/marian b/src/marian
-Subproject fed955ae472e5394b06f65a280ad2254c2cca91
+Subproject 73e7737582aea9e241dc9e4b73c3d38b1181e3d