Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-06-05 16:14:47 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-06-05 17:10:33 +0300
commit9b91163b169ba0b7d9a771b9e7b33b6345ccf1a1 (patch)
treed13c068765ed11c1f544f7d0b3ee03be82358e83 /src/amun/cpu/npz_converter.h
parente7bc69b077e214f7fc47b246d565a13d6836e6c5 (diff)
Support tie_embeddings from Nematus: CPU
Diffstat (limited to 'src/amun/cpu/npz_converter.h')
-rw-r--r--src/amun/cpu/npz_converter.h23
1 files changed, 23 insertions, 0 deletions
diff --git a/src/amun/cpu/npz_converter.h b/src/amun/cpu/npz_converter.h
index 7488ead7..ffb7c692 100644
--- a/src/amun/cpu/npz_converter.h
+++ b/src/amun/cpu/npz_converter.h
@@ -79,6 +79,29 @@ class NpzConverter {
return std::move(ret);
}
+ mblas::Matrix operator[](const std::vector<std::pair<std::string, bool>> keys) const {
+ BlazeWrapper matrix;
+ for (auto key : keys) {
+ auto it = model_.find(key.first);
+ if(it != model_.end()) {
+ NpyMatrixWrapper np(it->second);
+ matrix = BlazeWrapper(np.data(), np.size1(), np.size2());
+ mblas::Matrix ret;
+ if (key.second) {
+ const auto matrix2 = blaze::trans(matrix);
+ ret = matrix2;
+ } else {
+ ret = matrix;
+ }
+ return std::move(ret);
+ }
+ }
+ std::cerr << "Matrix not found: " << keys[0].first << "\n";
+
+ mblas::Matrix ret;
+ return std::move(ret);
+ }
+
mblas::Matrix operator()(const std::string& key,
bool transpose) const {
BlazeWrapper matrix;