diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-07-21 03:12:02 +0300 |
---|---|---|
committer | Martin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-07-21 03:12:02 +0300 |
commit | f6cb1b5c6aa7b35d80454a7fd01301e097945a3a (patch) | |
tree | ba0f15f85a46f56577602108080c24fce5a0b078 | |
parent | 056c4bef5b99d266f8984fd20b14ab578cd55ee3 (diff) |
Merged PR 19864: add bias if it exists
Fixes backcompat with shortlist and bias.
-rw-r--r-- | src/data/shortlist.h | 3 | ||||
-rw-r--r-- | src/layers/output.cpp | 32 |
2 files changed, 26 insertions, 9 deletions
diff --git a/src/data/shortlist.h b/src/data/shortlist.h index 6cfb650d..82b0df69 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -43,6 +43,7 @@ public: Shortlist(const std::vector<WordIndex>& indices); virtual ~Shortlist(); + virtual bool isDynamic() const { return false; } virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const; virtual WordIndex tryForwardMap(WordIndex wIdx) const; @@ -87,6 +88,8 @@ private: public: LSHShortlist(int k, int nbits, size_t lemmaSize, bool abortIfDynamic = false); + + virtual bool isDynamic() const override { return true; } virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override; virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override; diff --git a/src/layers/output.cpp b/src/layers/output.cpp index d7ba4490..8fe5096a 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -59,7 +59,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { /* std::cerr << "affineOrDot.x=" << x->shape() << std::endl; std::cerr << "affineOrDot.W=" << W->shape() << std::endl; - std::cerr << "affineOrDot.b=" << b->shape() << std::endl; + if (b) std::cerr << "affineShortlist.b=" << b->shape() << std::endl; std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl; */ if(b) @@ -68,18 +68,32 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { return dot(x, W, transA, transB); }; - auto affineShortlist = [](Expr x, Expr W, Expr b, bool transA, bool transB) { - /* + auto affineShortlist = [this](Expr x, Expr W, Expr b, bool transA, bool transB) { + /* std::cerr << "affineShortlist.x=" << x->shape() << std::endl; std::cerr << "affineShortlist.W=" << W->shape() << std::endl; - std::cerr << "affineShortlist.b=" << b->shape() << std::endl; + if (b) std::cerr << "affineShortlist.b=" << b->shape() << std::endl; std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl; */ - ABORT_IF(!(!transA && transB), "affineShortlist. Must be transA==0 and transB==1"); - ABORT_IF(b, "affineShortlist not tested with bias"); - Expr ret = bdot(x, W, transA, transB); - //std::cerr << "ret=" << ret->shape() << std::endl; - //std::cerr << std::endl; + + Expr ret; + + if (b) { + // original shortlist. W always has 1 for beam & batch + ABORT_UNLESS(!shortlist_->isDynamic(), "affineShortlist. Bias not supported with LSH/dynamic shortlist"); // todo rename ABORT_UNLESS to ASSERT + ret = affine(x, W, b, transA, transB); + } + else if (shortlist_->isDynamic()) { + // LSH produces W entry for each beam and batch => need bdot() + ABORT_IF(!(!transA && transB), "affineShortlist. Only tested with transA==0 and transB==1"); + ret = bdot(x, W, transA, transB); + } + else { + // original shortlist. W always has 1 for beam & batch + ret = dot(x, W, transA, transB); + } + + //std::cerr << "ret.x=" << ret->shape() << std::endl; return ret; }; |