Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-07-21 03:12:02 +0300
committerMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-07-21 03:12:02 +0300
commitf6cb1b5c6aa7b35d80454a7fd01301e097945a3a (patch)
treeba0f15f85a46f56577602108080c24fce5a0b078
parent056c4bef5b99d266f8984fd20b14ab578cd55ee3 (diff)
Merged PR 19864: add bias if it exists
Fixes backcompat with shortlist and bias.
-rw-r--r--src/data/shortlist.h3
-rw-r--r--src/layers/output.cpp32
2 files changed, 26 insertions, 9 deletions
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index 6cfb650d..82b0df69 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -43,6 +43,7 @@ public:
Shortlist(const std::vector<WordIndex>& indices);
virtual ~Shortlist();
+ virtual bool isDynamic() const { return false; }
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const;
virtual WordIndex tryForwardMap(WordIndex wIdx) const;
@@ -87,6 +88,8 @@ private:
public:
LSHShortlist(int k, int nbits, size_t lemmaSize, bool abortIfDynamic = false);
+
+ virtual bool isDynamic() const override { return true; }
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index d7ba4490..8fe5096a 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -59,7 +59,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
/*
std::cerr << "affineOrDot.x=" << x->shape() << std::endl;
std::cerr << "affineOrDot.W=" << W->shape() << std::endl;
- std::cerr << "affineOrDot.b=" << b->shape() << std::endl;
+ if (b) std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl;
*/
if(b)
@@ -68,18 +68,32 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
return dot(x, W, transA, transB);
};
- auto affineShortlist = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
- /*
+ auto affineShortlist = [this](Expr x, Expr W, Expr b, bool transA, bool transB) {
+ /*
std::cerr << "affineShortlist.x=" << x->shape() << std::endl;
std::cerr << "affineShortlist.W=" << W->shape() << std::endl;
- std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
+ if (b) std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
*/
- ABORT_IF(!(!transA && transB), "affineShortlist. Must be transA==0 and transB==1");
- ABORT_IF(b, "affineShortlist not tested with bias");
- Expr ret = bdot(x, W, transA, transB);
- //std::cerr << "ret=" << ret->shape() << std::endl;
- //std::cerr << std::endl;
+
+ Expr ret;
+
+ if (b) {
+ // original shortlist. W always has 1 for beam & batch
+ ABORT_UNLESS(!shortlist_->isDynamic(), "affineShortlist. Bias not supported with LSH/dynamic shortlist"); // todo rename ABORT_UNLESS to ASSERT
+ ret = affine(x, W, b, transA, transB);
+ }
+ else if (shortlist_->isDynamic()) {
+ // LSH produces W entry for each beam and batch => need bdot()
+ ABORT_IF(!(!transA && transB), "affineShortlist. Only tested with transA==0 and transB==1");
+ ret = bdot(x, W, transA, transB);
+ }
+ else {
+ // original shortlist. W always has 1 for beam & batch
+ ret = dot(x, W, transA, transB);
+ }
+
+ //std::cerr << "ret.x=" << ret->shape() << std::endl;
return ret;
};