From b653db0a9b5fc37b87d60572465724af80717805 Mon Sep 17 00:00:00 2001 From: Martin Junczys-Dowmunt Date: Thu, 22 Jul 2021 21:00:44 +0000 Subject: Merged PR 19910: Fix training/scoring error with FSM Fixes a dimension mismatch during training and scoring introduced in the decoding-only shortlist changes. Related work items: #122643 --- src/layers/output.cpp | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/layers/output.cpp b/src/layers/output.cpp index 8fe5096a..92cccdfb 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -273,27 +273,27 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { } #endif // re-embedding lookup, soft-indexed by softmax - Expr cachedShortLemmaEt; + Expr e; if(shortlist_) { // short-listed version of re-embedding matrix - cachedShortLemmaEt = shortlist_->getCachedShortLemmaEt(); + Expr cachedShortLemmaEt = shortlist_->getCachedShortLemmaEt(); + // std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl; + // std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl; + const Shape &fShape = factorSoftmax->shape(); + ABORT_IF(fShape[1] != 1, "We are decoding with a shortlist but time step size {} != 1??", fShape[1]); + factorSoftmax = reshape(factorSoftmax, {fShape[0], fShape[2], 1, fShape[3]}); // we can switch dims because time step is of size 1 + // std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl; + e = bdot(factorSoftmax, cachedShortLemmaEt, false, true); + // std::cerr << "e.1=" << e->shape() << std::endl; + const Shape &eShape = e->shape(); + e = reshape(e, {eShape[0], 1, eShape[1], eShape[3]}); // switch dims back, again possible because time step is of size 1 + // std::cerr << "e.2=" << e->shape() << std::endl; + // std::cerr << std::endl; + } else { // for scoring, training and decoding without a shortlist we use a simple dot operation + e = dot(factorSoftmax, + lemmaEt_, + false, + true); // [B... x L] } - else { - const Shape &s = lemmaEt_->shape(); - //std::cerr << "lemmaEt_=" << lemmaEt_->shape() << std::endl; - cachedShortLemmaEt = reshape(lemmaEt_, {1, 1, s[0], s[1]}); - } - //std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl; - //std::cerr << "cachedShortLemmaEt.2=" << cachedShortLemmaEt->shape() << std::endl; - factorSoftmax = transpose(factorSoftmax, {0, 2, 1, 3}); - //std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl; - //std::cerr << "cachedShortLemmaEt.2=" << cachedShortLemmaEt->shape() << std::endl; - - Expr e = bdot(factorSoftmax, cachedShortLemmaEt, false, true); - //std::cerr << "e.1=" << e->shape() << std::endl; - const Shape &eShape = e->shape(); - e = reshape(e, {eShape[0], 1, eShape[1], eShape[3]}); - //std::cerr << "e.3=" << e->shape() << std::endl; - //std::cerr << std::endl; // project it back to regular hidden dim int inputDim = input1->shape()[-1]; -- cgit v1.2.3