Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-04-29 09:56:25 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-04-29 09:56:25 +0300
commit592854f571e5c114c2e1f9d0469b07f0652381ce (patch)
tree824cf7e5b2852118b8915e8aa0ecfc138660f2a8
parent909df372d10803395684a60d6d6fe0cb7de83637 (diff)
move cache variables into shortlist class
-rw-r--r--src/data/shortlist.h3
-rw-r--r--src/layers/output.cpp23
-rw-r--r--src/layers/output.h8
3 files changed, 15 insertions, 19 deletions
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index dd7d0589..44da6faa 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -42,6 +42,9 @@ public:
WordIndex tryForwardMap(WordIndex wIdx);
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
+ virtual Expr getCachedShortWt() const { return cachedShortWt_; }
+ virtual Expr getCachedShortb() const { return cachedShortb_; }
+ virtual Expr getCachedShortLemmaEt() const { return cachedShortLemmaEt_; }
};
class ShortlistGenerator {
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index e9bffac4..0d46583a 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -66,11 +66,9 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
return affineOrDot(x, W, b, transA, transB);
};
- if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one
+ if(shortlist_ && !shortlist_->getCachedShortWt()) { // shortlisted versions of parameters are cached within one
// batch, then clear()ed
- cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
- if(hasBias_)
- cachedShortb_ = index_select(b_, -1, shortlist_->indices());
+ shortlist_->filter(input, Wt_, isLegacyUntransposedW, b_, lemmaEt_);
}
if(factoredVocab_) {
@@ -93,8 +91,8 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
// slice this group's section out of W_
Expr factorWt, factorB;
if(g == 0 && shortlist_) {
- factorWt = cachedShortWt_;
- factorB = cachedShortb_;
+ factorWt = shortlist_->getCachedShortWt();
+ factorB = shortlist_->getCachedShortb();
} else {
factorWt = slice(
Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
@@ -240,10 +238,13 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
}
#endif
// re-embedding lookup, soft-indexed by softmax
- if(shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
- cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
+ Expr cachedShortLemmaEt;
+ if(shortlist_) // short-listed version of re-embedding matrix
+ cachedShortLemmaEt = shortlist_->getCachedShortLemmaEt();
+ else
+ cachedShortLemmaEt = lemmaEt_;
auto e = dot(factorSoftmax,
- cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_,
+ cachedShortLemmaEt,
false,
true); // [B... x L]
// project it back to regular hidden dim
@@ -265,8 +266,8 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
return Logits(std::move(allLogits), factoredVocab_);
} else if(shortlist_) {
return Logits(affineOrLSH(input,
- cachedShortWt_,
- cachedShortb_,
+ shortlist_->getCachedShortWt(),
+ shortlist_->getCachedShortb(),
false,
/*transB=*/isLegacyUntransposedW ? false : true));
} else {
diff --git a/src/layers/output.h b/src/layers/output.h
index bf8a580a..d3afdead 100644
--- a/src/layers/output.h
+++ b/src/layers/output.h
@@ -19,9 +19,6 @@ private:
bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
bool hasBias_{true};
- Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
- Expr cachedShortb_; // these match the current value of shortlist_
- Expr cachedShortLemmaEt_;
Ptr<FactoredVocab> factoredVocab_;
// optional parameters set/updated after construction
@@ -49,8 +46,6 @@ public:
ABORT_IF(shortlist.get() != shortlist_.get(),
"Output shortlist cannot be changed except after clear()");
else {
- ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_,
- "No shortlist but cached parameters??");
shortlist_ = shortlist;
}
// cachedShortWt_ and cachedShortb_ will be created lazily inside apply()
@@ -60,9 +55,6 @@ public:
// cachedShortWt_ etc. in the graph's short-term cache
void clear() override final {
shortlist_ = nullptr;
- cachedShortWt_ = nullptr;
- cachedShortb_ = nullptr;
- cachedShortLemmaEt_ = nullptr;
}
Logits applyAsLogits(Expr input) override final;