diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-06-08 01:43:54 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-06-08 01:43:54 +0300 |
commit | 92c6c077868a48f45deeaf901c74f633e27319d0 (patch) | |
tree | 7305721f98d1b0b6e979c0e29bb48a869715910e | |
parent | eb3f540d4260361968ce63b1fdc758121c618382 (diff) |
reshape cachedShortWt_
-rw-r--r-- | src/data/shortlist.cpp | 2 | ||||
-rw-r--r-- | src/layers/output.cpp | 1 |
2 files changed, 1 insertions, 2 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 496b9ecb..7db84cb9 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -113,7 +113,7 @@ void Shortlist::broadcast(Expr weights, //std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl; cachedShortWt_ = reshape(cachedShortWt_, {batchSize, currBeamSize, k, cachedShortWt_->shape()[1]}); //std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl; - cachedShortWt_ = transpose(cachedShortWt_, {1, 2, 0, 3}); + cachedShortWt_ = transpose(cachedShortWt_, {1, 0, 2, 3}); //std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl; if (b) { diff --git a/src/layers/output.cpp b/src/layers/output.cpp index 9cdda430..4f413272 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -74,7 +74,6 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { ret = transpose(ret, {0, 3, 2, 1}); */ x = transpose(x, {0, 2, 1, 3}); - W = transpose(W, {0, 2, 1, 3}); //std::cerr << "x=" << x->shape() << std::endl; //std::cerr << "W=" << W->shape() << std::endl; Expr ret = bdot(x, W, false, true); |