Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-28 19:25:45 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-11-28 19:25:45 +0300
commit3cce98c3569dbaad886dc1e51f16cbfd33d2adfc (patch)
tree3d1aedf9fea048cc8a7f19e7c041483e30633384
parent9aee648fa1d913925bdc153f26d4835a7e773f08 (diff)
Add marian style tie-embstied-amun
-rw-r--r--src/amun/cpu/dl4mt/model.cpp17
-rw-r--r--src/amun/cpu/dl4mt/model.h1
-rw-r--r--src/amun/cpu/nematus/model.cpp11
-rw-r--r--src/amun/cpu/nematus/model.h1
-rw-r--r--src/amun/gpu/dl4mt/model.cu6
m---------src/marian12
6 files changed, 24 insertions, 24 deletions
diff --git a/src/amun/cpu/dl4mt/model.cpp b/src/amun/cpu/dl4mt/model.cpp
index ab8a3e90..9c1b5972 100644
--- a/src/amun/cpu/dl4mt/model.cpp
+++ b/src/amun/cpu/dl4mt/model.cpp
@@ -6,10 +6,6 @@ namespace amunmt {
namespace CPU {
namespace dl4mt {
-Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key)
- : E_(model[key])
-{}
-
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
: E_(model.getFirstOfMany(keys))
{}
@@ -67,8 +63,10 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
B2_(model("ff_logit_prev_b", true)),
W3_(model["ff_logit_ctx_W"]),
B3_(model("ff_logit_ctx_b", true)),
- W4_(model.getFirstOfMany({std::pair<std::string, bool>(std::string("ff_logit_W"), false),
- std::make_pair(std::string("Wemb_dec"), true)})),
+ W4_(model.getFirstOfMany({std::pair<std::string, bool>(
+ std::string("ff_logit_W"), false),
+ std::make_pair(std::string("Wemb_dec"), true),
+ std::make_pair(std::string("Wemb"), true)})),
B4_(model("ff_logit_b", true)),
Gamma_0_(model["ff_logit_l1_gamma0"]),
Gamma_1_(model["ff_logit_l1_gamma1"]),
@@ -78,12 +76,15 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
//////////////////////////////////////////////////////////////////////////////
Weights::Weights(const NpzConverter& model, size_t)
-: encEmbeddings_(model, "Wemb"),
+: encEmbeddings_(model, std::vector<std::pair<std::string, bool>>({
+ std::make_pair(std::string("Wemb"), false),
+ std::make_pair(std::string("Wemb_dec"), false)})),
encForwardGRU_(model, {"encoder_W", "encoder_b", "encoder_U", "encoder_Wx", "encoder_bx",
"encoder_Ux", "encoder_gamma1", "encoder_gamma2"}),
encBackwardGRU_(model, {"encoder_r_W", "encoder_r_b", "encoder_r_U", "encoder_r_Wx",
"encoder_r_bx", "encoder_r_Ux", "encoder_r_gamma1", "encoder_r_gamma2"}),
- decEmbeddings_(model, std::vector<std::pair<std::string, bool>>({std::make_pair(std::string("Wemb_dec"), false),
+ decEmbeddings_(model, std::vector<std::pair<std::string, bool>>({
+ std::make_pair(std::string("Wemb_dec"), false),
std::make_pair(std::string("Wemb"), false)})),
decInit_(model),
decGru1_(model, {"decoder_W", "decoder_b", "decoder_U", "decoder_Wx", "decoder_bx", "decoder_Ux",
diff --git a/src/amun/cpu/dl4mt/model.h b/src/amun/cpu/dl4mt/model.h
index 87299b87..5fa8e0de 100644
--- a/src/amun/cpu/dl4mt/model.h
+++ b/src/amun/cpu/dl4mt/model.h
@@ -16,7 +16,6 @@ struct Weights {
//////////////////////////////////////////////////////////////////////////////
struct Embeddings {
- Embeddings(const NpzConverter& model, const std::string &key);
Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
const mblas::Matrix E_;
diff --git a/src/amun/cpu/nematus/model.cpp b/src/amun/cpu/nematus/model.cpp
index be0b6df9..c6f426e8 100644
--- a/src/amun/cpu/nematus/model.cpp
+++ b/src/amun/cpu/nematus/model.cpp
@@ -60,10 +60,6 @@ std::string Weights::Transition::name(const std::string& prefix, std::string nam
return prefix + name + infix + "_drt_" + std::to_string(index) + suffix;
}
-Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key)
- : E_(model[key])
-{}
-
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
: E_(model.getFirstOfMany(keys))
{}
@@ -143,7 +139,8 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
W3_(model["ff_logit_ctx_W"]),
B3_(model("ff_logit_ctx_b", true)),
W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
- std::make_pair(std::string("Wemb_dec"), true)})),
+ std::make_pair(std::string("Wemb_dec"), true),
+ std::make_pair(std::string("Wemb"), true)})),
B4_(model("ff_logit_b", true)),
lns_1_(model["ff_logit_lstm_ln_s"]),
lns_2_(model["ff_logit_prev_ln_s"]),
@@ -156,7 +153,9 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
//////////////////////////////////////////////////////////////////////////////
Weights::Weights(const NpzConverter& model, size_t)
- : encEmbeddings_(model, "Wemb"),
+ : encEmbeddings_(model, std::vector<std::pair<std::string, bool>>(
+ {std::make_pair(std::string("Wemb"), false),
+ std::make_pair(std::string("Wemb_dec"), false)})),
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>(
{std::make_pair(std::string("Wemb_dec"), false),
std::make_pair(std::string("Wemb"), false)})),
diff --git a/src/amun/cpu/nematus/model.h b/src/amun/cpu/nematus/model.h
index 82fe09f3..9a68df20 100644
--- a/src/amun/cpu/nematus/model.h
+++ b/src/amun/cpu/nematus/model.h
@@ -49,7 +49,6 @@ struct Weights {
};
struct Embeddings {
- Embeddings(const NpzConverter& model, const std::string &key);
Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
const mblas::Matrix E_;
diff --git a/src/amun/gpu/dl4mt/model.cu b/src/amun/gpu/dl4mt/model.cu
index 326844e1..e1b1c3e7 100644
--- a/src/amun/gpu/dl4mt/model.cu
+++ b/src/amun/gpu/dl4mt/model.cu
@@ -5,7 +5,8 @@ namespace GPU {
////////////////////////////////////////////////////////////////////////////////////////////////////
Weights::EncEmbeddings::EncEmbeddings(const NpzConverter& model)
-: E_(model.get("Wemb", true))
+: E_(model.getFirstOfMany({std::make_pair("Wemb", false),
+ std::make_pair("Wemb_dec", false)}, true))
{}
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -141,7 +142,8 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
W3_(model.get("ff_logit_ctx_W", true)),
B3_(model.get("ff_logit_ctx_b", true, true)),
W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
- std::make_pair(std::string("Wemb_dec"), true)}, true)),
+ std::make_pair(std::string("Wemb_dec"), true),
+ std::make_pair(std::string("Wemb"), true)}, true)),
B4_(model.get("ff_logit_b", true, true)),
Gamma_0_(model.get("ff_logit_l1_gamma0", false)),
Gamma_1_(model.get("ff_logit_l1_gamma1", false)),
diff --git a/src/marian b/src/marian
-Subproject f0f06468e1633af503b81eaa8b1fbca881b9525
+Subproject 7f3802e97f08e91382c6609cd772c4e80488b4d