diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-06-05 16:14:47 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-06-05 17:10:33 +0300 |
commit | 9b91163b169ba0b7d9a771b9e7b33b6345ccf1a1 (patch) | |
tree | d13c068765ed11c1f544f7d0b3ee03be82358e83 /src/amun/cpu/npz_converter.h | |
parent | e7bc69b077e214f7fc47b246d565a13d6836e6c5 (diff) |
Support tie_embeddings from Nematus: CPU
Diffstat (limited to 'src/amun/cpu/npz_converter.h')
-rw-r--r-- | src/amun/cpu/npz_converter.h | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/src/amun/cpu/npz_converter.h b/src/amun/cpu/npz_converter.h index 7488ead7..ffb7c692 100644 --- a/src/amun/cpu/npz_converter.h +++ b/src/amun/cpu/npz_converter.h @@ -79,6 +79,29 @@ class NpzConverter { return std::move(ret); } + mblas::Matrix operator[](const std::vector<std::pair<std::string, bool>> keys) const { + BlazeWrapper matrix; + for (auto key : keys) { + auto it = model_.find(key.first); + if(it != model_.end()) { + NpyMatrixWrapper np(it->second); + matrix = BlazeWrapper(np.data(), np.size1(), np.size2()); + mblas::Matrix ret; + if (key.second) { + const auto matrix2 = blaze::trans(matrix); + ret = matrix2; + } else { + ret = matrix; + } + return std::move(ret); + } + } + std::cerr << "Matrix not found: " << keys[0].first << "\n"; + + mblas::Matrix ret; + return std::move(ret); + } + mblas::Matrix operator()(const std::string& key, bool transpose) const { BlazeWrapper matrix; |