diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-06-12 01:47:15 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-06-12 01:47:15 +0300 |
commit | 49998217d9de2c5bfbd65f6ba9b2f4058c0e88f5 (patch) | |
tree | 3ed4505bcfa6fa43a77903f40dc1025fa97897fc | |
parent | f0251889f2a22cfb641fa2a0b287df2464eff3b8 (diff) |
don't transpose lastIndices. Works for lsh
-rw-r--r-- | src/data/shortlist.cpp | 5 | ||||
-rw-r--r-- | src/layers/logits.cpp | 6 | ||||
-rw-r--r-- | src/layers/output.cpp | 10 |
3 files changed, 14 insertions, 7 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 7fba4b67..8efd70d4 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -149,8 +149,9 @@ WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const { Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const { assert(indicesExpr_->shape()[0] == currBeamSize); assert(indicesExpr_->shape()[1] == batchSize); - Expr ret = transpose(indicesExpr_, {1, 0, 2}); - return ret; + return indicesExpr_; + //Expr ret = transpose(indicesExpr_, {1, 0, 2}); + //return ret; } #define BLAS_FOUND 1 diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp index 06bafb1c..5005f601 100644 --- a/src/layers/logits.cpp +++ b/src/layers/logits.cpp @@ -120,11 +120,13 @@ Expr Logits::getFactoredLogits(size_t groupIndex, int currBeamSize = sel->shape()[0]; int batchSize = sel->shape()[2]; Expr lastIndices = shortlist->getIndicesExpr(batchSize, currBeamSize); + std::cerr << "lastIndices=" << lastIndices->shape() << std::endl; factorMasks = lambda({lastIndices}, lastIndices->shape(), Type::float32, forward); - factorMasks = transpose(factorMasks, {1, 0, 2}); - + std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl; + const Shape &s = factorMasks->shape(); factorMasks = reshape(factorMasks, {s[0], 1, s[1], s[2]}); + std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl; } factorMaxima = cast(factorMaxima, sel->value_type()); factorMasks = cast(factorMasks, sel->value_type()); diff --git a/src/layers/output.cpp b/src/layers/output.cpp index 8b3d1af0..eab81124 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -171,9 +171,11 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { // matrix Expr factorLogits; if(g == 0 && shortlist_) { + std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl; Expr tmp = transpose(input1, {0, 2, 1, 3}); + std::cerr << "tmp=" << tmp->shape() << std::endl; //std::cerr << "x=" << x->shape() << std::endl; - //std::cerr << "W=" << W->shape() << std::endl; + std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl; factorLogits = affineShortlist( tmp, factorWt, @@ -181,16 +183,18 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits factorLogits = transpose(factorLogits, {0, 2, 1, 3}); + std::cerr << "affineShortlist.factorLogits=" << factorLogits->shape() << std::endl << std::endl; } else { - //factorWt = transpose(factorWt, {1, 0, 2, 3}); - //std::cerr << "affineOrDot.factorWt.2=" << factorWt->shape() << std::endl; + std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl; + std::cerr << "affineOrDot.factorWt=" << factorWt->shape() << std::endl; factorLogits = affineOrDot( input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl << std::endl; } // optionally add lemma-dependent bias |