Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-06-16 20:19:24 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-06-16 20:19:24 +0300
commit892554129e2a1700755ed0e97ccb1cd5fa6b76f5 (patch)
tree3ae47ad09adfcfd449ea93119fbcbf092a3e36f6 /src
parent395a4f94d0d3182600d22203625bd1be1f8a042b (diff)
lemma Et is optional
Diffstat (limited to 'src')
-rw-r--r--src/data/shortlist.cpp11
1 files changed, 6 insertions, 5 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index 9943198b..808ffd7a 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -83,7 +83,6 @@ Ptr<faiss::IndexLSH> LSHShortlist::index_;
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize)
: Shortlist(std::vector<WordIndex>())
, k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
- std::cerr << "LSHShortlist lemmaSize_=" << lemmaSize_ << std::endl;
/*
for (int i = 0; i < k_; ++i) {
indices_.push_back(i);
@@ -190,10 +189,12 @@ void LSHShortlist::createCachedTensors(Expr weights,
cachedShortb_ = reshape(cachedShortb_, {currBeamSize, k, batchSize, cachedShortb_->shape()[1]}); // not tested
}
- int dim = lemmaEt->shape()[0];
- cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprFlatten);
- cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {dim, currBeamSize, batchSize, k});
- cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {1, 2, 0, 3});
+ if (lemmaEt) {
+ int dim = lemmaEt->shape()[0];
+ cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprFlatten);
+ cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {dim, currBeamSize, batchSize, k});
+ cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {1, 2, 0, 3});
+ }
}
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits, size_t lemmaSize)