diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-10-13 23:20:14 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-10-13 23:20:14 +0300 |
commit | 2d79ad02bb66d7e0ba264defbf5ff9b47c70ba74 (patch) | |
tree | d5901877d639ea139edb9cda0a19f30b5ca3bb0f | |
parent | 03fe1758763c99dd55bcf6c1c5e0e1dd60ae4e1a (diff) |
Merged PR 20933: beam & batch works for n on-factored models
-rw-r--r-- | src/layers/output.cpp | 22 | ||||
-rw-r--r-- | src/translator/beam_search.cpp | 5 | ||||
-rw-r--r-- | src/translator/nth_element.cpp | 2 |
3 files changed, 21 insertions, 8 deletions
diff --git a/src/layers/output.cpp b/src/layers/output.cpp index 92cccdfb..af72b794 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -313,14 +313,24 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { } return Logits(std::move(allLogits), factoredVocab_); } else if(shortlist_) { - return Logits(affineOrDot(input, - shortlist_->getCachedShortWt(), - shortlist_->getCachedShortb(), + const Shape &inputShape = input->shape(); + assert(inputShape[1] == 1); // time dimension always 1 for decoding + input = reshape(input, {inputShape[0], inputShape[2], 1, inputShape[3]}); + + Expr Wt = shortlist_->getCachedShortWt(); + Expr b = shortlist_->getCachedShortb(); + Expr ret = affineShortlist(input, + Wt, + b, false, - /*transB=*/isLegacyUntransposedW ? false : true)); + /*transB=*/isLegacyUntransposedW ? false : true); + const Shape &retShape = ret->shape(); + assert(retShape[2] == 1); // time dimension always 1 for decoding + ret = reshape(ret, {retShape[0], 1, retShape[1], retShape[3]}); + return Logits(ret); } else { - return Logits( - affineOrDot(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); + Expr ret = affineOrDot(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true); + return Logits(ret); } } diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp index 2a0d3947..580895f2 100644 --- a/src/translator/beam_search.cpp +++ b/src/translator/beam_search.cpp @@ -94,7 +94,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current // For factored decoding, the word is built over multiple decoding steps, // starting with the lemma, then adding factors one by one. if (factorGroup == 0) { - word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0 + word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx) : wordIdx); std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices); //LOG(info, "{} + {} ({}) -> {} -> {}", // factoredVocab->decode(prevHyp->tracebackWords()), @@ -115,7 +115,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current } } else if (shortlist) - word = Word::fromWordIndex(shortlist->reverseMap((int) prevBeamHypIdx, (int) origBatchIdx, wordIdx)); + word = Word::fromWordIndex(shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx)); else word = Word::fromWordIndex(wordIdx); @@ -330,6 +330,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> auto prevBatchIdxMap = batchIdxMap; // [origBatchIdx -> currentBatchIdx] but shifted by one time step // main loop over output time steps for (size_t t = 0; ; t++) { + //std::cerr << "\nstep=" << t << std::endl; ABORT_IF(origDimBatch != beams.size(), "Lost a batch entry??"); // determine beam size for next output time step, as max over still-active sentences // E.g. if all batch entries are down from beam 5 to no more than 4 surviving hyps, then diff --git a/src/translator/nth_element.cpp b/src/translator/nth_element.cpp index 237d9b9d..dbcceec4 100644 --- a/src/translator/nth_element.cpp +++ b/src/translator/nth_element.cpp @@ -3,7 +3,9 @@ * SPDX-License-Identifier: MIT */ +#include "common/utils.h" #include "translator/nth_element.h" + #include <algorithm> #include <iterator> #include <limits> |