Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-06-08 01:43:54 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-06-08 01:43:54 +0300
commit92c6c077868a48f45deeaf901c74f633e27319d0 (patch)
tree7305721f98d1b0b6e979c0e29bb48a869715910e
parenteb3f540d4260361968ce63b1fdc758121c618382 (diff)
reshape cachedShortWt_
-rw-r--r--src/data/shortlist.cpp2
-rw-r--r--src/layers/output.cpp1
2 files changed, 1 insertions, 2 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index 496b9ecb..7db84cb9 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -113,7 +113,7 @@ void Shortlist::broadcast(Expr weights,
//std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
cachedShortWt_ = reshape(cachedShortWt_, {batchSize, currBeamSize, k, cachedShortWt_->shape()[1]});
//std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;
- cachedShortWt_ = transpose(cachedShortWt_, {1, 2, 0, 3});
+ cachedShortWt_ = transpose(cachedShortWt_, {1, 0, 2, 3});
//std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl;
if (b) {
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index 9cdda430..4f413272 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -74,7 +74,6 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
ret = transpose(ret, {0, 3, 2, 1});
*/
x = transpose(x, {0, 2, 1, 3});
- W = transpose(W, {0, 2, 1, 3});
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "W=" << W->shape() << std::endl;
Expr ret = bdot(x, W, false, true);