Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-06-29 07:26:02 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-06-29 07:26:02 +0300
commit24c644bae0312a5b640b2cda43020fd4b74cdb07 (patch)
treefeb87ca8be1f22efd6e237cfff7fff53aa94e2d8
parentcd292d3b32428b6c1cf57e9eb6ad06b1db1e5452 (diff)
pass shortlist regression tests
-rw-r--r--src/data/shortlist.cpp8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index b7c03436..ad2525dc 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -68,13 +68,13 @@ void Shortlist::createCachedTensors(Expr weights,
cachedShortWt_ = reshape(cachedShortWt_, {1, 1, cachedShortWt_->shape()[0], cachedShortWt_->shape()[1]});
if (b) {
- ABORT("Bias not yet tested");
cachedShortb_ = index_select(b, -1, indicesExpr_);
- cachedShortb_ = reshape(cachedShortb_, {1, k, 1, cachedShortb_->shape()[1]}); // not tested
}
- cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_);
- cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k});
+ if (lemmaEt) {
+ cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_);
+ cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k});
+ }
}
///////////////////////////////////////////////////////////////////////////////////