diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-28 19:25:45 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-11-28 19:25:45 +0300 |
commit | 3cce98c3569dbaad886dc1e51f16cbfd33d2adfc (patch) | |
tree | 3d1aedf9fea048cc8a7f19e7c041483e30633384 | |
parent | 9aee648fa1d913925bdc153f26d4835a7e773f08 (diff) |
Add marian style tie-embstied-amun
-rw-r--r-- | src/amun/cpu/dl4mt/model.cpp | 17 | ||||
-rw-r--r-- | src/amun/cpu/dl4mt/model.h | 1 | ||||
-rw-r--r-- | src/amun/cpu/nematus/model.cpp | 11 | ||||
-rw-r--r-- | src/amun/cpu/nematus/model.h | 1 | ||||
-rw-r--r-- | src/amun/gpu/dl4mt/model.cu | 6 | ||||
m--------- | src/marian | 12 |
6 files changed, 24 insertions, 24 deletions
diff --git a/src/amun/cpu/dl4mt/model.cpp b/src/amun/cpu/dl4mt/model.cpp index ab8a3e90..9c1b5972 100644 --- a/src/amun/cpu/dl4mt/model.cpp +++ b/src/amun/cpu/dl4mt/model.cpp @@ -6,10 +6,6 @@ namespace amunmt { namespace CPU { namespace dl4mt { -Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key) - : E_(model[key]) -{} - Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys) : E_(model.getFirstOfMany(keys)) {} @@ -67,8 +63,10 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model) B2_(model("ff_logit_prev_b", true)), W3_(model["ff_logit_ctx_W"]), B3_(model("ff_logit_ctx_b", true)), - W4_(model.getFirstOfMany({std::pair<std::string, bool>(std::string("ff_logit_W"), false), - std::make_pair(std::string("Wemb_dec"), true)})), + W4_(model.getFirstOfMany({std::pair<std::string, bool>( + std::string("ff_logit_W"), false), + std::make_pair(std::string("Wemb_dec"), true), + std::make_pair(std::string("Wemb"), true)})), B4_(model("ff_logit_b", true)), Gamma_0_(model["ff_logit_l1_gamma0"]), Gamma_1_(model["ff_logit_l1_gamma1"]), @@ -78,12 +76,15 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model) ////////////////////////////////////////////////////////////////////////////// Weights::Weights(const NpzConverter& model, size_t) -: encEmbeddings_(model, "Wemb"), +: encEmbeddings_(model, std::vector<std::pair<std::string, bool>>({ + std::make_pair(std::string("Wemb"), false), + std::make_pair(std::string("Wemb_dec"), false)})), encForwardGRU_(model, {"encoder_W", "encoder_b", "encoder_U", "encoder_Wx", "encoder_bx", "encoder_Ux", "encoder_gamma1", "encoder_gamma2"}), encBackwardGRU_(model, {"encoder_r_W", "encoder_r_b", "encoder_r_U", "encoder_r_Wx", "encoder_r_bx", "encoder_r_Ux", "encoder_r_gamma1", "encoder_r_gamma2"}), - decEmbeddings_(model, std::vector<std::pair<std::string, bool>>({std::make_pair(std::string("Wemb_dec"), false), + decEmbeddings_(model, std::vector<std::pair<std::string, bool>>({ + std::make_pair(std::string("Wemb_dec"), false), std::make_pair(std::string("Wemb"), false)})), decInit_(model), decGru1_(model, {"decoder_W", "decoder_b", "decoder_U", "decoder_Wx", "decoder_bx", "decoder_Ux", diff --git a/src/amun/cpu/dl4mt/model.h b/src/amun/cpu/dl4mt/model.h index 87299b87..5fa8e0de 100644 --- a/src/amun/cpu/dl4mt/model.h +++ b/src/amun/cpu/dl4mt/model.h @@ -16,7 +16,6 @@ struct Weights { ////////////////////////////////////////////////////////////////////////////// struct Embeddings { - Embeddings(const NpzConverter& model, const std::string &key); Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys); const mblas::Matrix E_; diff --git a/src/amun/cpu/nematus/model.cpp b/src/amun/cpu/nematus/model.cpp index be0b6df9..c6f426e8 100644 --- a/src/amun/cpu/nematus/model.cpp +++ b/src/amun/cpu/nematus/model.cpp @@ -60,10 +60,6 @@ std::string Weights::Transition::name(const std::string& prefix, std::string nam return prefix + name + infix + "_drt_" + std::to_string(index) + suffix; } -Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key) - : E_(model[key]) -{} - Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys) : E_(model.getFirstOfMany(keys)) {} @@ -143,7 +139,8 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model) W3_(model["ff_logit_ctx_W"]), B3_(model("ff_logit_ctx_b", true)), W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false), - std::make_pair(std::string("Wemb_dec"), true)})), + std::make_pair(std::string("Wemb_dec"), true), + std::make_pair(std::string("Wemb"), true)})), B4_(model("ff_logit_b", true)), lns_1_(model["ff_logit_lstm_ln_s"]), lns_2_(model["ff_logit_prev_ln_s"]), @@ -156,7 +153,9 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model) ////////////////////////////////////////////////////////////////////////////// Weights::Weights(const NpzConverter& model, size_t) - : encEmbeddings_(model, "Wemb"), + : encEmbeddings_(model, std::vector<std::pair<std::string, bool>>( + {std::make_pair(std::string("Wemb"), false), + std::make_pair(std::string("Wemb_dec"), false)})), decEmbeddings_(model, std::vector<std::pair<std::string, bool>>( {std::make_pair(std::string("Wemb_dec"), false), std::make_pair(std::string("Wemb"), false)})), diff --git a/src/amun/cpu/nematus/model.h b/src/amun/cpu/nematus/model.h index 82fe09f3..9a68df20 100644 --- a/src/amun/cpu/nematus/model.h +++ b/src/amun/cpu/nematus/model.h @@ -49,7 +49,6 @@ struct Weights { }; struct Embeddings { - Embeddings(const NpzConverter& model, const std::string &key); Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys); const mblas::Matrix E_; diff --git a/src/amun/gpu/dl4mt/model.cu b/src/amun/gpu/dl4mt/model.cu index 326844e1..e1b1c3e7 100644 --- a/src/amun/gpu/dl4mt/model.cu +++ b/src/amun/gpu/dl4mt/model.cu @@ -5,7 +5,8 @@ namespace GPU { //////////////////////////////////////////////////////////////////////////////////////////////////// Weights::EncEmbeddings::EncEmbeddings(const NpzConverter& model) -: E_(model.get("Wemb", true)) +: E_(model.getFirstOfMany({std::make_pair("Wemb", false), + std::make_pair("Wemb_dec", false)}, true)) {} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -141,7 +142,8 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model) W3_(model.get("ff_logit_ctx_W", true)), B3_(model.get("ff_logit_ctx_b", true, true)), W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false), - std::make_pair(std::string("Wemb_dec"), true)}, true)), + std::make_pair(std::string("Wemb_dec"), true), + std::make_pair(std::string("Wemb"), true)}, true)), B4_(model.get("ff_logit_b", true, true)), Gamma_0_(model.get("ff_logit_l1_gamma0", false)), Gamma_1_(model.get("ff_logit_l1_gamma1", false)), diff --git a/src/marian b/src/marian -Subproject f0f06468e1633af503b81eaa8b1fbca881b9525 +Subproject 7f3802e97f08e91382c6609cd772c4e80488b4d |