Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-07-23 00:00:44 +0300
committerMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-07-23 00:00:44 +0300
commitb653db0a9b5fc37b87d60572465724af80717805 (patch)
tree1bdf3b98b527cf4f67aea9ac5f0bf7a5e908cf3f
parent6b568f4afa44b5bd7c9e335856b977fc054f343c (diff)
Merged PR 19910: Fix training/scoring error with FSM
Fixes a dimension mismatch during training and scoring introduced in the decoding-only shortlist changes. Related work items: #122643
-rw-r--r--src/layers/output.cpp38
1 files changed, 19 insertions, 19 deletions
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index 8fe5096a..92cccdfb 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -273,27 +273,27 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
}
#endif
// re-embedding lookup, soft-indexed by softmax
- Expr cachedShortLemmaEt;
+ Expr e;
if(shortlist_) { // short-listed version of re-embedding matrix
- cachedShortLemmaEt = shortlist_->getCachedShortLemmaEt();
+ Expr cachedShortLemmaEt = shortlist_->getCachedShortLemmaEt();
+ // std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
+ // std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
+ const Shape &fShape = factorSoftmax->shape();
+ ABORT_IF(fShape[1] != 1, "We are decoding with a shortlist but time step size {} != 1??", fShape[1]);
+ factorSoftmax = reshape(factorSoftmax, {fShape[0], fShape[2], 1, fShape[3]}); // we can switch dims because time step is of size 1
+ // std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
+ e = bdot(factorSoftmax, cachedShortLemmaEt, false, true);
+ // std::cerr << "e.1=" << e->shape() << std::endl;
+ const Shape &eShape = e->shape();
+ e = reshape(e, {eShape[0], 1, eShape[1], eShape[3]}); // switch dims back, again possible because time step is of size 1
+ // std::cerr << "e.2=" << e->shape() << std::endl;
+ // std::cerr << std::endl;
+ } else { // for scoring, training and decoding without a shortlist we use a simple dot operation
+ e = dot(factorSoftmax,
+ lemmaEt_,
+ false,
+ true); // [B... x L]
}
- else {
- const Shape &s = lemmaEt_->shape();
- //std::cerr << "lemmaEt_=" << lemmaEt_->shape() << std::endl;
- cachedShortLemmaEt = reshape(lemmaEt_, {1, 1, s[0], s[1]});
- }
- //std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
- //std::cerr << "cachedShortLemmaEt.2=" << cachedShortLemmaEt->shape() << std::endl;
- factorSoftmax = transpose(factorSoftmax, {0, 2, 1, 3});
- //std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
- //std::cerr << "cachedShortLemmaEt.2=" << cachedShortLemmaEt->shape() << std::endl;
-
- Expr e = bdot(factorSoftmax, cachedShortLemmaEt, false, true);
- //std::cerr << "e.1=" << e->shape() << std::endl;
- const Shape &eShape = e->shape();
- e = reshape(e, {eShape[0], 1, eShape[1], eShape[3]});
- //std::cerr << "e.3=" << e->shape() << std::endl;
- //std::cerr << std::endl;
// project it back to regular hidden dim
int inputDim = input1->shape()[-1];