Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-10-13 23:20:14 +0300
committerMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-10-13 23:20:14 +0300
commit2d79ad02bb66d7e0ba264defbf5ff9b47c70ba74 (patch)
treed5901877d639ea139edb9cda0a19f30b5ca3bb0f
parent03fe1758763c99dd55bcf6c1c5e0e1dd60ae4e1a (diff)
Merged PR 20933: beam & batch works for n on-factored models
-rw-r--r--src/layers/output.cpp22
-rw-r--r--src/translator/beam_search.cpp5
-rw-r--r--src/translator/nth_element.cpp2
3 files changed, 21 insertions, 8 deletions
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index 92cccdfb..af72b794 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -313,14 +313,24 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
}
return Logits(std::move(allLogits), factoredVocab_);
} else if(shortlist_) {
- return Logits(affineOrDot(input,
- shortlist_->getCachedShortWt(),
- shortlist_->getCachedShortb(),
+ const Shape &inputShape = input->shape();
+ assert(inputShape[1] == 1); // time dimension always 1 for decoding
+ input = reshape(input, {inputShape[0], inputShape[2], 1, inputShape[3]});
+
+ Expr Wt = shortlist_->getCachedShortWt();
+ Expr b = shortlist_->getCachedShortb();
+ Expr ret = affineShortlist(input,
+ Wt,
+ b,
false,
- /*transB=*/isLegacyUntransposedW ? false : true));
+ /*transB=*/isLegacyUntransposedW ? false : true);
+ const Shape &retShape = ret->shape();
+ assert(retShape[2] == 1); // time dimension always 1 for decoding
+ ret = reshape(ret, {retShape[0], 1, retShape[1], retShape[3]});
+ return Logits(ret);
} else {
- return Logits(
- affineOrDot(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
+ Expr ret = affineOrDot(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true);
+ return Logits(ret);
}
}
diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp
index 2a0d3947..580895f2 100644
--- a/src/translator/beam_search.cpp
+++ b/src/translator/beam_search.cpp
@@ -94,7 +94,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
// For factored decoding, the word is built over multiple decoding steps,
// starting with the lemma, then adding factors one by one.
if (factorGroup == 0) {
- word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
+ word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx) : wordIdx);
std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
//LOG(info, "{} + {} ({}) -> {} -> {}",
// factoredVocab->decode(prevHyp->tracebackWords()),
@@ -115,7 +115,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
}
}
else if (shortlist)
- word = Word::fromWordIndex(shortlist->reverseMap((int) prevBeamHypIdx, (int) origBatchIdx, wordIdx));
+ word = Word::fromWordIndex(shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx));
else
word = Word::fromWordIndex(wordIdx);
@@ -330,6 +330,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
auto prevBatchIdxMap = batchIdxMap; // [origBatchIdx -> currentBatchIdx] but shifted by one time step
// main loop over output time steps
for (size_t t = 0; ; t++) {
+ //std::cerr << "\nstep=" << t << std::endl;
ABORT_IF(origDimBatch != beams.size(), "Lost a batch entry??");
// determine beam size for next output time step, as max over still-active sentences
// E.g. if all batch entries are down from beam 5 to no more than 4 surviving hyps, then
diff --git a/src/translator/nth_element.cpp b/src/translator/nth_element.cpp
index 237d9b9d..dbcceec4 100644
--- a/src/translator/nth_element.cpp
+++ b/src/translator/nth_element.cpp
@@ -3,7 +3,9 @@
* SPDX-License-Identifier: MIT
*/
+#include "common/utils.h"
#include "translator/nth_element.h"
+
#include <algorithm>
#include <iterator>
#include <limits>