diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-06-16 21:19:23 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-06-16 21:19:23 +0300 |
commit | 9b4a845cc7db127c8d29c990128893fc974128a4 (patch) | |
tree | 98bff6afeaf0f958943926a094aeb376b30e3641 | |
parent | 892554129e2a1700755ed0e97ccb1cd5fa6b76f5 (diff) |
clean up bias
-rw-r--r-- | src/data/shortlist.cpp | 4 | ||||
-rw-r--r-- | src/layers/output.cpp | 28 |
2 files changed, 18 insertions, 14 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 808ffd7a..a965f249 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -184,9 +184,9 @@ void LSHShortlist::createCachedTensors(Expr weights, cachedShortWt_ = reshape(cachedShortWt_, {currBeamSize, batchSize, k, cachedShortWt_->shape()[1]}); if (b) { - ABORT("Bias not yet tested"); + ABORT("Bias not supported with LSH"); cachedShortb_ = index_select(b, -1, indicesExprFlatten); - cachedShortb_ = reshape(cachedShortb_, {currBeamSize, k, batchSize, cachedShortb_->shape()[1]}); // not tested + cachedShortb_ = reshape(cachedShortb_, {currBeamSize, batchSize, k, cachedShortb_->shape()[0]}); // not tested } if (lemmaEt) { diff --git a/src/layers/output.cpp b/src/layers/output.cpp index 03e77545..055f8cae 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -56,16 +56,28 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { lazyConstruct(input->shape()[-1]); auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { + /* + std::cerr << "affineOrDot.x=" << x->shape() << std::endl; + std::cerr << "affineOrDot.W=" << W->shape() << std::endl; + std::cerr << "affineOrDot.b=" << b->shape() << std::endl; + std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl; + */ if(b) return affine(x, W, b, transA, transB); else return dot(x, W, transA, transB); }; - auto affineShortlist = [](Expr x, Expr W, Expr b, bool , bool ) { - //std::cerr << "x=" << x->shape() << std::endl; - //std::cerr << "W=" << W->shape() << std::endl; - Expr ret = bdot(x, W, false, true); + auto affineShortlist = [](Expr x, Expr W, Expr b, bool transA, bool transB) { + /* + std::cerr << "affineShortlist.x=" << x->shape() << std::endl; + std::cerr << "affineShortlist.W=" << W->shape() << std::endl; + std::cerr << "affineShortlist.b=" << b->shape() << std::endl; + std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl; + */ + ABORT_IF(!(!transA && transB), "Must be transA==0 and transB==1"); + ABORT_IF(b, "affineShortlist not tested with bias"); + Expr ret = bdot(x, W, transA, transB); //std::cerr << "ret.2=" << ret->shape() << std::endl; //std::cerr << std::endl; @@ -171,11 +183,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { // matrix Expr factorLogits; if(g == 0 && shortlist_) { - //std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl; Expr tmp = transpose(input1, {0, 2, 1, 3}); - //std::cerr << "tmp=" << tmp->shape() << std::endl; - //std::cerr << "x=" << x->shape() << std::endl; - //std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl; factorLogits = affineShortlist( tmp, factorWt, @@ -183,18 +191,14 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits factorLogits = transpose(factorLogits, {0, 2, 1, 3}); - //std::cerr << "affineShortlist.factorLogits=" << factorLogits->shape() << std::endl << std::endl; } else { - //std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl; - //std::cerr << "affineOrDot.factorWt=" << factorWt->shape() << std::endl; factorLogits = affineOrDot( input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits - //std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl << std::endl; } // optionally add lemma-dependent bias |