diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-06-29 07:26:02 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-06-29 07:26:02 +0300 |
commit | 24c644bae0312a5b640b2cda43020fd4b74cdb07 (patch) | |
tree | feb87ca8be1f22efd6e237cfff7fff53aa94e2d8 | |
parent | cd292d3b32428b6c1cf57e9eb6ad06b1db1e5452 (diff) |
pass shortlist regression tests
-rw-r--r-- | src/data/shortlist.cpp | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index b7c03436..ad2525dc 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -68,13 +68,13 @@ void Shortlist::createCachedTensors(Expr weights, cachedShortWt_ = reshape(cachedShortWt_, {1, 1, cachedShortWt_->shape()[0], cachedShortWt_->shape()[1]}); if (b) { - ABORT("Bias not yet tested"); cachedShortb_ = index_select(b, -1, indicesExpr_); - cachedShortb_ = reshape(cachedShortb_, {1, k, 1, cachedShortb_->shape()[1]}); // not tested } - cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_); - cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k}); + if (lemmaEt) { + cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_); + cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k}); + } } /////////////////////////////////////////////////////////////////////////////////// |