diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2017-11-26 21:24:07 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2017-11-26 21:24:07 +0300 |
commit | 284c2389ce12d060bd55788a1d6168470c31f068 (patch) | |
tree | f33bbb7bf65ca974732a8008dc74479dadb003fd | |
parent | 69f02ac398988e2285258cd584101e81c4e6c530 (diff) |
move implementations into .cu file
-rw-r--r-- | contrib/other-builds/amunmt/.project | 5 | ||||
m--------- | examples | 0 | ||||
-rwxr-xr-x | src/amun/CMakeLists.txt | 3 | ||||
-rw-r--r-- | src/amun/gpu/dl4mt/model.cu | 220 | ||||
-rw-r--r-- | src/amun/gpu/dl4mt/model.h | 238 | ||||
m--------- | src/marian | 11 |
6 files changed, 277 insertions, 200 deletions
diff --git a/contrib/other-builds/amunmt/.project b/contrib/other-builds/amunmt/.project index 4cfd4037..9fc6aff9 100644 --- a/contrib/other-builds/amunmt/.project +++ b/contrib/other-builds/amunmt/.project @@ -1526,6 +1526,11 @@ <locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/dl4mt/lstm.h</locationURI> </link> <link> + <name>src/amun/gpu/dl4mt/model.cu</name> + <type>1</type> + <locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/dl4mt/model.cu</locationURI> + </link> + <link> <name>src/amun/gpu/dl4mt/model.h</name> <type>1</type> <locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/dl4mt/model.h</locationURI> diff --git a/examples b/examples -Subproject 050843dd917043bb5588ea431d02dcab68b567e +Subproject 280481dc29167dc5b9fe63d85bbb6bacdf918c7 diff --git a/src/amun/CMakeLists.txt b/src/amun/CMakeLists.txt index b6662951..3e8899d4 100755 --- a/src/amun/CMakeLists.txt +++ b/src/amun/CMakeLists.txt @@ -84,6 +84,7 @@ cuda_add_executable( gpu/decoder/encoder_decoder_state.cu gpu/dl4mt/encoder.cu gpu/dl4mt/gru.cu + gpu/dl4mt/model.cu gpu/mblas/matrix.cu gpu/mblas/matrix_functions.cu gpu/mblas/nth_element.cu @@ -112,6 +113,7 @@ cuda_add_library(python SHARED gpu/mblas/nth_element_kernels.cu gpu/dl4mt/encoder.cu gpu/dl4mt/gru.cu + gpu/dl4mt/model.cu gpu/npz_converter.cu gpu/types-gpu.cu common/loader_factory.cpp @@ -139,6 +141,7 @@ cuda_add_library(mosesplugin STATIC gpu/mblas/nth_element_kernels.cu gpu/dl4mt/encoder.cu gpu/dl4mt/gru.cu + gpu/dl4mt/model.cu gpu/npz_converter.cu gpu/types-gpu.cu common/loader_factory.cpp diff --git a/src/amun/gpu/dl4mt/model.cu b/src/amun/gpu/dl4mt/model.cu new file mode 100644 index 00000000..326844e1 --- /dev/null +++ b/src/amun/gpu/dl4mt/model.cu @@ -0,0 +1,220 @@ +#include "model.h" + +namespace amunmt { +namespace GPU { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::EncEmbeddings::EncEmbeddings(const NpzConverter& model) +: E_(model.get("Wemb", true)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::EncForwardGRU::EncForwardGRU(const NpzConverter& model) +: W_(model.get("encoder_W", true)), + B_(model.get("encoder_b", true, true)), + U_(model.get("encoder_U", true)), + Wx_(model.get("encoder_Wx", true)), + Bx1_(model.get("encoder_bx", true, true)), + Bx2_(new mblas::Matrix(Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)), + Ux_(model.get("encoder_Ux", true)), + Gamma_1_(model.get("encoder_gamma1", false)), + Gamma_2_(model.get("encoder_gamma2", false)) +{ } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::EncForwardLSTM::EncForwardLSTM(const NpzConverter& model) +: W_(model.get("encoder_W", true)), + B_(model.get("encoder_b", true, true)), + U_(model.get("encoder_U", true)), + Wx_(model.get("encoder_Wx", true)), + Bx_(model.get("encoder_bx", true, true)), + Ux_(model.get("encoder_Ux", true)), + Gamma_1_(model.get("encoder_gamma1", false)), + Gamma_2_(model.get("encoder_gamma2", false)) +{ } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::EncBackwardLSTM::EncBackwardLSTM(const NpzConverter& model) +: W_(model.get("encoder_r_W", true)), + B_(model.get("encoder_r_b", true, true)), + U_(model.get("encoder_r_U", true)), + Wx_(model.get("encoder_r_Wx", true)), + Bx_(model.get("encoder_r_bx", true, true)), + Ux_(model.get("encoder_r_Ux", true)), + Gamma_1_(model.get("encoder_r_gamma1", false)), + Gamma_2_(model.get("encoder_r_gamma2", false)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::EncBackwardGRU::EncBackwardGRU(const NpzConverter& model) +: W_(model.get("encoder_r_W", true)), + B_(model.get("encoder_r_b", true, true)), + U_(model.get("encoder_r_U", true)), + Wx_(model.get("encoder_r_Wx", true)), + Bx1_(model.get("encoder_r_bx", true, true)), + Bx2_(new mblas::Matrix( Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)), + Ux_(model.get("encoder_r_Ux", true)), + Gamma_1_(model.get("encoder_r_gamma1", false)), + Gamma_2_(model.get("encoder_r_gamma2", false)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::DecEmbeddings::DecEmbeddings(const NpzConverter& model) +: E_(model.getFirstOfMany({std::make_pair("Wemb_dec", false), + std::make_pair("Wemb", false)}, true)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::DecInit::DecInit(const NpzConverter& model) +: Wi_(model.get("ff_state_W", true)), + Bi_(model.get("ff_state_b", true, true)), + Gamma_(model.get("ff_state_gamma", false)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::DecGRU1::DecGRU1(const NpzConverter& model) +: W_(model.get("decoder_W", true)), + B_(model.get("decoder_b", true, true)), + U_(model.get("decoder_U", true)), + Wx_(model.get("decoder_Wx", true)), + Bx1_(model.get("decoder_bx", true, true)), + Bx2_(new mblas::Matrix(Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)), + Ux_(model.get("decoder_Ux", true)), + Gamma_1_(model.get("decoder_cell1_gamma1", false)), + Gamma_2_(model.get("decoder_cell1_gamma2", false)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::DecGRU2::DecGRU2(const NpzConverter& model) +: W_(model.get("decoder_Wc", true)), + B_(model.get("decoder_b_nl", true, true)), + U_(model.get("decoder_U_nl", true)), + Wx_(model.get("decoder_Wcx", true)), + Bx2_(model.get("decoder_bx_nl", true, true)), + Bx1_(new mblas::Matrix(Bx2_->dim(0), Bx2_->dim(1), Bx2_->dim(2), Bx2_->dim(3), true)), + Ux_(model.get("decoder_Ux_nl", true)), + Gamma_1_(model.get("decoder_cell2_gamma1", false)), + Gamma_2_(model.get("decoder_cell2_gamma2", false)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::DecLSTM1::DecLSTM1(const NpzConverter& model) +: W_(model.get("decoder_W", true)), + B_(model.get("decoder_b", true, true)), + U_(model.get("decoder_U", true)), + Wx_(model.get("decoder_Wx", true)), + Bx_(model.get("decoder_bx", true, true)), + Ux_(model.get("decoder_Ux", true)), + Gamma_1_(model.get("decoder_cell1_gamma1", false)), + Gamma_2_(model.get("decoder_cell1_gamma2", false)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::DecLSTM2::DecLSTM2(const NpzConverter& model) +: W_(model.get("decoder_Wc", true)), + B_(model.get("decoder_b_nl", true, true)), + U_(model.get("decoder_U_nl", true)), + Wx_(model.get("decoder_Wcx", true)), + Bx_(model.get("decoder_bx_nl", true, true)), + Ux_(model.get("decoder_Ux_nl", true)), + Gamma_1_(model.get("decoder_cell2_gamma1", false)), + Gamma_2_(model.get("decoder_cell2_gamma2", false)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::DecAlignment::DecAlignment(const NpzConverter& model) +: V_(model.get("decoder_U_att", true, true)), + W_(model.get("decoder_W_comb_att", true)), + B_(model.get("decoder_b_att", true, true)), + U_(model.get("decoder_Wc_att", true)), + C_(model.get("decoder_c_tt", true)), // scalar? + Gamma_1_(model.get("decoder_att_gamma1", false)), + Gamma_2_(model.get("decoder_att_gamma2", false)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::DecSoftmax::DecSoftmax(const NpzConverter& model) +: W1_(model.get("ff_logit_lstm_W", true)), + B1_(model.get("ff_logit_lstm_b", true, true)), + W2_(model.get("ff_logit_prev_W", true)), + B2_(model.get("ff_logit_prev_b", true, true)), + W3_(model.get("ff_logit_ctx_W", true)), + B3_(model.get("ff_logit_ctx_b", true, true)), + W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false), + std::make_pair(std::string("Wemb_dec"), true)}, true)), + B4_(model.get("ff_logit_b", true, true)), + Gamma_0_(model.get("ff_logit_l1_gamma0", false)), + Gamma_1_(model.get("ff_logit_l1_gamma1", false)), + Gamma_2_(model.get("ff_logit_l1_gamma2", false)) +{} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +Weights::Weights(const std::string& npzFile, const YAML::Node& config, size_t device) +: Weights(NpzConverter(npzFile), config, device) +{} + +Weights::Weights(const NpzConverter& model, const YAML::Node& config, size_t device) +: encEmbeddings_(model), + decEmbeddings_(model), + decInit_(model), + decAlignment_(model), + decSoftmax_(model), + device_(device) +{ + + std::string encCell = config["enc-cell"] ? config["enc-cell"].as<std::string>() : "gru"; + std::string encCell_r = config["enc-cell-r"] ? config["enc-cell-r"].as<std::string>() : encCell; + initEncForward(model, encCell); + initEncBackward(model, encCell_r); + + std::string decCell = config["dec-cell"] ? config["dec-cell"].as<std::string>() : "gru"; + std::string decCell2 = config["dec-cell-2"] ? config["dec-cell-2"].as<std::string>() : decCell; + initDec1(model, decCell); + initDec2(model, decCell2); +} + +void Weights::initEncForward(const NpzConverter& model,std::string celltype){ + if(celltype == "lstm"){ + encForwardLSTM_ = std::shared_ptr<EncForwardLSTM>(new EncForwardLSTM(model)); + } else if (celltype == "mlstm") { + encForwardMLSTM_ = std::shared_ptr<MultWeights<EncForwardLSTM>> + (new MultWeights<EncForwardLSTM>(model, "encoder")); + } else if (celltype == "gru"){ + encForwardGRU_ = std::shared_ptr<EncForwardGRU>(new EncForwardGRU(model)); + } +} + +void Weights::initEncBackward(const NpzConverter& model,std::string celltype) { + if(celltype == "lstm"){ + encBackwardLSTM_ = std::shared_ptr<EncBackwardLSTM>(new EncBackwardLSTM(model)); + } else if (celltype == "mlstm") { + encBackwardMLSTM_ = std::shared_ptr<MultWeights<EncBackwardLSTM>> + (new MultWeights<EncBackwardLSTM>(model, "encoder_r")); + } else if (celltype == "gru"){ + encBackwardGRU_ = std::shared_ptr<EncBackwardGRU>(new EncBackwardGRU(model)); + } +} + +void Weights::initDec1(const NpzConverter& model,std::string celltype){ + if (celltype == "lstm"){ + decLSTM1_ = std::shared_ptr<DecLSTM1>(new DecLSTM1(model)); + } else if (celltype == "mlstm") { + decMLSTM1_ = std::shared_ptr<MultWeights<DecLSTM1>>(new MultWeights<DecLSTM1>(model, "decoder")); + } else if (celltype == "gru") { + decGru1_ = std::shared_ptr<DecGRU1>(new DecGRU1(model)); + } +} + +void Weights::initDec2(const NpzConverter& model,std::string celltype){ + if (celltype == "lstm"){ + decLSTM2_ = std::shared_ptr<DecLSTM2>(new DecLSTM2(model)); + } else if (celltype == "mlstm") { + decMLSTM2_ = std::shared_ptr<MultWeights<DecLSTM2>>(new MultWeights<DecLSTM2>(model, "decoder_2")); + } else if (celltype == "gru") { + decGru2_ = std::shared_ptr<DecGRU2>(new DecGRU2(model)); + } +} + +} // namespace +} + diff --git a/src/amun/gpu/dl4mt/model.h b/src/amun/gpu/dl4mt/model.h index 283140e5..9a29c344 100644 --- a/src/amun/gpu/dl4mt/model.h +++ b/src/amun/gpu/dl4mt/model.h @@ -13,31 +13,17 @@ namespace GPU { struct Weights { ////////////////////////////////////////////////////////////////////////////// - struct EncEmbeddings { EncEmbeddings(const EncEmbeddings&) = delete; - - EncEmbeddings(const NpzConverter& model) - : E_(model.get("Wemb", true)) - {} + EncEmbeddings(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> E_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// struct EncForwardGRU { EncForwardGRU(const EncForwardGRU&) = delete; - - EncForwardGRU(const NpzConverter& model) - : W_(model.get("encoder_W", true)), - B_(model.get("encoder_b", true, true)), - U_(model.get("encoder_U", true)), - Wx_(model.get("encoder_Wx", true)), - Bx1_(model.get("encoder_bx", true, true)), - Bx2_(new mblas::Matrix(Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)), - Ux_(model.get("encoder_Ux", true)), - Gamma_1_(model.get("encoder_gamma1", false)), - Gamma_2_(model.get("encoder_gamma2", false)) - { } + EncForwardGRU(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> W_; const std::shared_ptr<mblas::Matrix> B_; @@ -50,20 +36,10 @@ struct Weights { const std::shared_ptr<mblas::Matrix> Gamma_2_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// struct EncBackwardGRU { EncBackwardGRU(const EncBackwardGRU&) = delete; - - EncBackwardGRU(const NpzConverter& model) - : W_(model.get("encoder_r_W", true)), - B_(model.get("encoder_r_b", true, true)), - U_(model.get("encoder_r_U", true)), - Wx_(model.get("encoder_r_Wx", true)), - Bx1_(model.get("encoder_r_bx", true, true)), - Bx2_(new mblas::Matrix( Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)), - Ux_(model.get("encoder_r_Ux", true)), - Gamma_1_(model.get("encoder_r_gamma1", false)), - Gamma_2_(model.get("encoder_r_gamma2", false)) - {} + EncBackwardGRU(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> W_; const std::shared_ptr<mblas::Matrix> B_; @@ -76,19 +52,10 @@ struct Weights { const std::shared_ptr<mblas::Matrix> Gamma_2_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// struct EncForwardLSTM { EncForwardLSTM(const EncForwardLSTM&) = delete; - - EncForwardLSTM(const NpzConverter& model) - : W_(model.get("encoder_W", true)), - B_(model.get("encoder_b", true, true)), - U_(model.get("encoder_U", true)), - Wx_(model.get("encoder_Wx", true)), - Bx_(model.get("encoder_bx", true, true)), - Ux_(model.get("encoder_Ux", true)), - Gamma_1_(model.get("encoder_gamma1", false)), - Gamma_2_(model.get("encoder_gamma2", false)) - { } + EncForwardLSTM(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> W_; const std::shared_ptr<mblas::Matrix> B_; @@ -100,19 +67,10 @@ struct Weights { const std::shared_ptr<mblas::Matrix> Gamma_2_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// struct EncBackwardLSTM { EncBackwardLSTM(const EncBackwardLSTM&) = delete; - - EncBackwardLSTM(const NpzConverter& model) - : W_(model.get("encoder_r_W", true)), - B_(model.get("encoder_r_b", true, true)), - U_(model.get("encoder_r_U", true)), - Wx_(model.get("encoder_r_Wx", true)), - Bx_(model.get("encoder_r_bx", true, true)), - Ux_(model.get("encoder_r_Ux", true)), - Gamma_1_(model.get("encoder_r_gamma1", false)), - Gamma_2_(model.get("encoder_r_gamma2", false)) - {} + EncBackwardLSTM(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> W_; const std::shared_ptr<mblas::Matrix> B_; @@ -124,47 +82,31 @@ struct Weights { const std::shared_ptr<mblas::Matrix> Gamma_2_; }; - ////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////////// struct DecEmbeddings { DecEmbeddings(const DecEmbeddings&) = delete; - - DecEmbeddings(const NpzConverter& model) - : E_(model.getFirstOfMany({std::make_pair("Wemb_dec", false), - std::make_pair("Wemb", false)}, true)) - {} + DecEmbeddings(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> E_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// + struct DecInit { DecInit(const DecInit&) = delete; - - DecInit(const NpzConverter& model) - : Wi_(model.get("ff_state_W", true)), - Bi_(model.get("ff_state_b", true, true)), - Gamma_(model.get("ff_state_gamma", false)) - {} + DecInit(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> Wi_; const std::shared_ptr<mblas::Matrix> Bi_; const std::shared_ptr<mblas::Matrix> Gamma_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// + struct DecGRU1 { DecGRU1(const DecGRU1&) = delete; - - DecGRU1(const NpzConverter& model) - : W_(model.get("decoder_W", true)), - B_(model.get("decoder_b", true, true)), - U_(model.get("decoder_U", true)), - Wx_(model.get("decoder_Wx", true)), - Bx1_(model.get("decoder_bx", true, true)), - Bx2_(new mblas::Matrix(Bx1_->dim(0), Bx1_->dim(1), Bx1_->dim(2), Bx1_->dim(3), true)), - Ux_(model.get("decoder_Ux", true)), - Gamma_1_(model.get("decoder_cell1_gamma1", false)), - Gamma_2_(model.get("decoder_cell1_gamma2", false)) - {} + DecGRU1(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> W_; const std::shared_ptr<mblas::Matrix> B_; @@ -177,20 +119,11 @@ struct Weights { const std::shared_ptr<mblas::Matrix> Gamma_2_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// + struct DecGRU2 { DecGRU2(const DecGRU2&) = delete; - - DecGRU2(const NpzConverter& model) - : W_(model.get("decoder_Wc", true)), - B_(model.get("decoder_b_nl", true, true)), - U_(model.get("decoder_U_nl", true)), - Wx_(model.get("decoder_Wcx", true)), - Bx2_(model.get("decoder_bx_nl", true, true)), - Bx1_(new mblas::Matrix(Bx2_->dim(0), Bx2_->dim(1), Bx2_->dim(2), Bx2_->dim(3), true)), - Ux_(model.get("decoder_Ux_nl", true)), - Gamma_1_(model.get("decoder_cell2_gamma1", false)), - Gamma_2_(model.get("decoder_cell2_gamma2", false)) - {} + DecGRU2(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> W_; const std::shared_ptr<mblas::Matrix> B_; @@ -203,19 +136,11 @@ struct Weights { const std::shared_ptr<mblas::Matrix> Gamma_2_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// + struct DecLSTM1 { DecLSTM1(const DecLSTM1&) = delete; - - DecLSTM1(const NpzConverter& model) - : W_(model.get("decoder_W", true)), - B_(model.get("decoder_b", true, true)), - U_(model.get("decoder_U", true)), - Wx_(model.get("decoder_Wx", true)), - Bx_(model.get("decoder_bx", true, true)), - Ux_(model.get("decoder_Ux", true)), - Gamma_1_(model.get("decoder_cell1_gamma1", false)), - Gamma_2_(model.get("decoder_cell1_gamma2", false)) - {} + DecLSTM1(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> W_; const std::shared_ptr<mblas::Matrix> B_; @@ -227,19 +152,11 @@ struct Weights { const std::shared_ptr<mblas::Matrix> Gamma_2_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// + struct DecLSTM2 { DecLSTM2(const DecLSTM2&) = delete; - - DecLSTM2(const NpzConverter& model) - : W_(model.get("decoder_Wc", true)), - B_(model.get("decoder_b_nl", true, true)), - U_(model.get("decoder_U_nl", true)), - Wx_(model.get("decoder_Wcx", true)), - Bx_(model.get("decoder_bx_nl", true, true)), - Ux_(model.get("decoder_Ux_nl", true)), - Gamma_1_(model.get("decoder_cell2_gamma1", false)), - Gamma_2_(model.get("decoder_cell2_gamma2", false)) - {} + DecLSTM2(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> W_; const std::shared_ptr<mblas::Matrix> B_; @@ -251,6 +168,8 @@ struct Weights { const std::shared_ptr<mblas::Matrix> Gamma_2_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// + // A wrapper class to deserialize weights for multiplicative-LSTM, // multiplicative-GRU and such template<class BaseWeights> @@ -273,18 +192,11 @@ struct Weights { } }; + //////////////////////////////////////////////////////////////////////////////////////////////////// + struct DecAlignment { DecAlignment(const DecAlignment&) = delete; - - DecAlignment(const NpzConverter& model) - : V_(model.get("decoder_U_att", true, true)), - W_(model.get("decoder_W_comb_att", true)), - B_(model.get("decoder_b_att", true, true)), - U_(model.get("decoder_Wc_att", true)), - C_(model.get("decoder_c_tt", true)), // scalar? - Gamma_1_(model.get("decoder_att_gamma1", false)), - Gamma_2_(model.get("decoder_att_gamma2", false)) - {} + DecAlignment(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> V_; const std::shared_ptr<mblas::Matrix> W_; @@ -295,23 +207,11 @@ struct Weights { const std::shared_ptr<mblas::Matrix> Gamma_2_; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// + struct DecSoftmax { DecSoftmax(const DecSoftmax&) = delete; - - DecSoftmax(const NpzConverter& model) - : W1_(model.get("ff_logit_lstm_W", true)), - B1_(model.get("ff_logit_lstm_b", true, true)), - W2_(model.get("ff_logit_prev_W", true)), - B2_(model.get("ff_logit_prev_b", true, true)), - W3_(model.get("ff_logit_ctx_W", true)), - B3_(model.get("ff_logit_ctx_b", true, true)), - W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false), - std::make_pair(std::string("Wemb_dec"), true)}, true)), - B4_(model.get("ff_logit_b", true, true)), - Gamma_0_(model.get("ff_logit_l1_gamma0", false)), - Gamma_1_(model.get("ff_logit_l1_gamma1", false)), - Gamma_2_(model.get("ff_logit_l1_gamma2", false)) - {} + DecSoftmax(const NpzConverter& model); const std::shared_ptr<mblas::Matrix> W1_; const std::shared_ptr<mblas::Matrix> B1_; @@ -326,29 +226,11 @@ struct Weights { const std::shared_ptr<mblas::Matrix> Gamma_2_; }; - Weights(const std::string& npzFile, const YAML::Node& config, size_t device) - : Weights(NpzConverter(npzFile), config, device) - {} - - Weights(const NpzConverter& model, const YAML::Node& config, size_t device) - : encEmbeddings_(model), - decEmbeddings_(model), - decInit_(model), - decAlignment_(model), - decSoftmax_(model), - device_(device) - { - - std::string encCell = config["enc-cell"] ? config["enc-cell"].as<std::string>() : "gru"; - std::string encCell_r = config["enc-cell-r"] ? config["enc-cell-r"].as<std::string>() : encCell; - initEncForward(model, encCell); - initEncBackward(model, encCell_r); - - std::string decCell = config["dec-cell"] ? config["dec-cell"].as<std::string>() : "gru"; - std::string decCell2 = config["dec-cell-2"] ? config["dec-cell-2"].as<std::string>() : decCell; - initDec1(model, decCell); - initDec2(model, decCell2); - } + //////////////////////////////////////////////////////////////////////////////////////////////////// + + Weights(const std::string& npzFile, const YAML::Node& config, size_t device); + + Weights(const NpzConverter& model, const YAML::Node& config, size_t device); Weights(const Weights&) = delete; @@ -357,44 +239,10 @@ struct Weights { } private: - void initEncForward(const NpzConverter& model,std::string celltype){ - if(celltype == "lstm"){ - encForwardLSTM_ = std::shared_ptr<EncForwardLSTM>(new EncForwardLSTM(model)); - } else if (celltype == "mlstm") { - encForwardMLSTM_ = std::shared_ptr<MultWeights<EncForwardLSTM>> - (new MultWeights<EncForwardLSTM>(model, "encoder")); - } else if (celltype == "gru"){ - encForwardGRU_ = std::shared_ptr<EncForwardGRU>(new EncForwardGRU(model)); - } - } - void initEncBackward(const NpzConverter& model,std::string celltype) { - if(celltype == "lstm"){ - encBackwardLSTM_ = std::shared_ptr<EncBackwardLSTM>(new EncBackwardLSTM(model)); - } else if (celltype == "mlstm") { - encBackwardMLSTM_ = std::shared_ptr<MultWeights<EncBackwardLSTM>> - (new MultWeights<EncBackwardLSTM>(model, "encoder_r")); - } else if (celltype == "gru"){ - encBackwardGRU_ = std::shared_ptr<EncBackwardGRU>(new EncBackwardGRU(model)); - } - } - void initDec1(const NpzConverter& model,std::string celltype){ - if (celltype == "lstm"){ - decLSTM1_ = std::shared_ptr<DecLSTM1>(new DecLSTM1(model)); - } else if (celltype == "mlstm") { - decMLSTM1_ = std::shared_ptr<MultWeights<DecLSTM1>>(new MultWeights<DecLSTM1>(model, "decoder")); - } else if (celltype == "gru") { - decGru1_ = std::shared_ptr<DecGRU1>(new DecGRU1(model)); - } - } - void initDec2(const NpzConverter& model,std::string celltype){ - if (celltype == "lstm"){ - decLSTM2_ = std::shared_ptr<DecLSTM2>(new DecLSTM2(model)); - } else if (celltype == "mlstm") { - decMLSTM2_ = std::shared_ptr<MultWeights<DecLSTM2>>(new MultWeights<DecLSTM2>(model, "decoder_2")); - } else if (celltype == "gru") { - decGru2_ = std::shared_ptr<DecGRU2>(new DecGRU2(model)); - } - } + void initEncForward(const NpzConverter& model,std::string celltype); + void initEncBackward(const NpzConverter& model,std::string celltype); + void initDec1(const NpzConverter& model,std::string celltype); + void initDec2(const NpzConverter& model,std::string celltype); public: const EncEmbeddings encEmbeddings_; diff --git a/src/marian b/src/marian -Subproject fed955ae472e5394b06f65a280ad2254c2cca91 +Subproject 73e7737582aea9e241dc9e4b73c3d38b1181e3d |