diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-04-29 09:56:25 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-04-29 09:56:25 +0300 |
commit | 592854f571e5c114c2e1f9d0469b07f0652381ce (patch) | |
tree | 824cf7e5b2852118b8915e8aa0ecfc138660f2a8 | |
parent | 909df372d10803395684a60d6d6fe0cb7de83637 (diff) |
move cache variables into shortlist class
-rw-r--r-- | src/data/shortlist.h | 3 | ||||
-rw-r--r-- | src/layers/output.cpp | 23 | ||||
-rw-r--r-- | src/layers/output.h | 8 |
3 files changed, 15 insertions, 19 deletions
diff --git a/src/data/shortlist.h b/src/data/shortlist.h index dd7d0589..44da6faa 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -42,6 +42,9 @@ public: WordIndex tryForwardMap(WordIndex wIdx); virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt); + virtual Expr getCachedShortWt() const { return cachedShortWt_; } + virtual Expr getCachedShortb() const { return cachedShortb_; } + virtual Expr getCachedShortLemmaEt() const { return cachedShortLemmaEt_; } }; class ShortlistGenerator { diff --git a/src/layers/output.cpp b/src/layers/output.cpp index e9bffac4..0d46583a 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -66,11 +66,9 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { return affineOrDot(x, W, b, transA, transB); }; - if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one + if(shortlist_ && !shortlist_->getCachedShortWt()) { // shortlisted versions of parameters are cached within one // batch, then clear()ed - cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices()); - if(hasBias_) - cachedShortb_ = index_select(b_, -1, shortlist_->indices()); + shortlist_->filter(input, Wt_, isLegacyUntransposedW, b_, lemmaEt_); } if(factoredVocab_) { @@ -93,8 +91,8 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { // slice this group's section out of W_ Expr factorWt, factorB; if(g == 0 && shortlist_) { - factorWt = cachedShortWt_; - factorB = cachedShortb_; + factorWt = shortlist_->getCachedShortWt(); + factorB = shortlist_->getCachedShortb(); } else { factorWt = slice( Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second)); @@ -240,10 +238,13 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { } #endif // re-embedding lookup, soft-indexed by softmax - if(shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix - cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices()); + Expr cachedShortLemmaEt; + if(shortlist_) // short-listed version of re-embedding matrix + cachedShortLemmaEt = shortlist_->getCachedShortLemmaEt(); + else + cachedShortLemmaEt = lemmaEt_; auto e = dot(factorSoftmax, - cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, + cachedShortLemmaEt, false, true); // [B... x L] // project it back to regular hidden dim @@ -265,8 +266,8 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { return Logits(std::move(allLogits), factoredVocab_); } else if(shortlist_) { return Logits(affineOrLSH(input, - cachedShortWt_, - cachedShortb_, + shortlist_->getCachedShortWt(), + shortlist_->getCachedShortb(), false, /*transB=*/isLegacyUntransposedW ? false : true)); } else { diff --git a/src/layers/output.h b/src/layers/output.h index bf8a580a..d3afdead 100644 --- a/src/layers/output.h +++ b/src/layers/output.h @@ -19,9 +19,6 @@ private: bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form bool hasBias_{true}; - Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) - Expr cachedShortb_; // these match the current value of shortlist_ - Expr cachedShortLemmaEt_; Ptr<FactoredVocab> factoredVocab_; // optional parameters set/updated after construction @@ -49,8 +46,6 @@ public: ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()"); else { - ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, - "No shortlist but cached parameters??"); shortlist_ = shortlist; } // cachedShortWt_ and cachedShortb_ will be created lazily inside apply() @@ -60,9 +55,6 @@ public: // cachedShortWt_ etc. in the graph's short-term cache void clear() override final { shortlist_ = nullptr; - cachedShortWt_ = nullptr; - cachedShortb_ = nullptr; - cachedShortLemmaEt_ = nullptr; } Logits applyAsLogits(Expr input) override final; |