diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-06-16 02:54:57 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-06-16 02:54:57 +0300 |
commit | 6981b21f4e14d8be10a7e3011bc4d93afa00add0 (patch) | |
tree | d6f7fd7347f3b1a19e7d85437cf6c4e1da64debb | |
parent | 488a532bdf85b71276de662ae160172171ca97fc (diff) | |
parent | 7e6ec58507a946ee9e00aa80ea0835aff068b319 (diff) |
Merge branch 'hihoan/lsh7' of vs-ssh.visualstudio.com:v3/machinetranslation/Marian/marian-dev into hihoan/lsh7
-rw-r--r-- | src/data/shortlist.cpp | 27 | ||||
-rw-r--r-- | src/data/shortlist.h | 4 | ||||
-rw-r--r-- | src/layers/logits.cpp | 18 |
3 files changed, 10 insertions, 39 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 832d575b..36e2d22f 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -52,7 +52,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp done_ = true; } -Expr Shortlist::getIndicesExpr(int batchSize, int beamSize) const { +Expr Shortlist::getIndicesExpr() const { int k = indicesExpr_->shape()[0]; Expr out = reshape(indicesExpr_, {1, 1, k}); return out; @@ -63,13 +63,8 @@ void Shortlist::createCachedTensors(Expr weights, Expr b, Expr lemmaEt, int k) { - //std::cerr << "isLegacyUntransposedW=" << isLegacyUntransposedW << std::endl; ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested"); - - //std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl; - //std::cerr << "weights=" << weights->shape() << std::endl; cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExpr_); - //std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl; cachedShortWt_ = reshape(cachedShortWt_, {1, 1, cachedShortWt_->shape()[0], cachedShortWt_->shape()[1]}); if (b) { @@ -78,11 +73,8 @@ void Shortlist::createCachedTensors(Expr weights, cachedShortb_ = reshape(cachedShortb_, {1, k, 1, cachedShortb_->shape()[1]}); // not tested } - //std::cerr << "lemmaEt.1_=" << lemmaEt->shape() << std::endl; cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_); - //std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl; cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k}); - //std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl; } /////////////////////////////////////////////////////////////////////////////////// @@ -110,7 +102,6 @@ WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const { } WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const { - //utils::Debug(indices_, "LSHShortlist::tryForwardMap indices_"); auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx); bool found = first != indices_.end(); if(found && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx @@ -119,16 +110,10 @@ WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const { return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17? } -Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const { - assert(indicesExpr_->shape()[0] == currBeamSize); - assert(indicesExpr_->shape()[1] == batchSize); +Expr LSHShortlist::getIndicesExpr() const { return indicesExpr_; - //Expr ret = transpose(indicesExpr_, {1, 0, 2}); - //return ret; } -#define BLAS_FOUND 1 - void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) { #if BLAS_FOUND ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu, @@ -175,11 +160,7 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, }; Shape kShape({currBeamSize, batchSize, k_}); - - //std::cerr << "input=" << input->shape() << std::endl; - //std::cerr << "weights=" << weights->shape() << std::endl; indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward); - //std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl; createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k_); @@ -196,10 +177,6 @@ void LSHShortlist::createCachedTensors(Expr weights, int k) { int currBeamSize = indicesExpr_->shape()[0]; int batchSize = indicesExpr_->shape()[1]; - //int numHypos = batchSize * currBeamSize; - //std::cerr << "batchSize=" << batchSize << std::endl; - //std::cerr << "currBeamSize=" << currBeamSize << std::endl; - //std::cerr << "isLegacyUntransposedW=" << isLegacyUntransposedW << std::endl; ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested"); Expr indicesExprFlatten = reshape(indicesExpr_, {indicesExpr_->shape().elements()}); diff --git a/src/data/shortlist.h b/src/data/shortlist.h index 315fdbcd..1d8903e6 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -46,7 +46,7 @@ public: virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const; virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt); - virtual Expr getIndicesExpr(int batchSize, int currBeamSize) const; + virtual Expr getIndicesExpr() const; virtual Expr getCachedShortWt() const { return cachedShortWt_; } virtual Expr getCachedShortb() const { return cachedShortb_; } virtual Expr getCachedShortLemmaEt() const { return cachedShortLemmaEt_; } @@ -85,7 +85,7 @@ public: virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override; virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override; - virtual Expr getIndicesExpr(int batchSize,int currBeamSize) const override; + virtual Expr getIndicesExpr() const override; }; diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp index 73169f21..1830741e 100644 --- a/src/layers/logits.cpp +++ b/src/layers/logits.cpp @@ -117,28 +117,22 @@ Expr Logits::getFactoredLogits(size_t groupIndex, out->val()->set(masks); }; - int currBeamSize = sel->shape()[0]; - int batchSize = sel->shape()[2]; - Expr lastIndices = shortlist->getIndicesExpr(batchSize, currBeamSize); - //std::cerr << "lastIndices=" << lastIndices->shape() << std::endl; + //int currBeamSize = sel->shape()[0]; + //int batchSize = sel->shape()[2]; + Expr lastIndices = shortlist->getIndicesExpr(); + //assert(lastIndices->shape()[0] == currBeamSize || lastIndices->shape()[0] == 1); + //assert(lastIndices->shape()[1] == batchSize || lastIndices->shape()[1] == 1); + factorMasks = lambda({lastIndices}, lastIndices->shape(), Type::float32, forward); - //std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl; const Shape &s = factorMasks->shape(); factorMasks = reshape(factorMasks, {s[0], 1, s[1], s[2]}); - //std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl; } factorMaxima = cast(factorMaxima, sel->value_type()); factorMasks = cast(factorMasks, sel->value_type()); - //std::cerr << "factorMaxima=" << factorMaxima->shape() << std::endl; - //std::cerr << "factorMasks.4=" << factorMasks->shape() << std::endl; - //std::cerr << "sel.1=" << sel->shape() << std::endl; Expr tmp = factorMaxima * factorMasks; - //std::cerr << "tmp=" << tmp->shape() << std::endl; sel = sel + tmp; // those lemmas that don't have a factor - //std::cerr << "sel.2=" << sel->shape() << std::endl; - //std::cerr << std::endl; } } |