From 9b4a845cc7db127c8d29c990128893fc974128a4 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Wed, 16 Jun 2021 11:19:23 -0700 Subject: clean up bias --- src/data/shortlist.cpp | 4 ++-- src/layers/output.cpp | 28 ++++++++++++++++------------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 808ffd7a..a965f249 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -184,9 +184,9 @@ void LSHShortlist::createCachedTensors(Expr weights, cachedShortWt_ = reshape(cachedShortWt_, {currBeamSize, batchSize, k, cachedShortWt_->shape()[1]}); if (b) { - ABORT("Bias not yet tested"); + ABORT("Bias not supported with LSH"); cachedShortb_ = index_select(b, -1, indicesExprFlatten); - cachedShortb_ = reshape(cachedShortb_, {currBeamSize, k, batchSize, cachedShortb_->shape()[1]}); // not tested + cachedShortb_ = reshape(cachedShortb_, {currBeamSize, batchSize, k, cachedShortb_->shape()[0]}); // not tested } if (lemmaEt) { diff --git a/src/layers/output.cpp b/src/layers/output.cpp index 03e77545..055f8cae 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -56,16 +56,28 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { lazyConstruct(input->shape()[-1]); auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { + /* + std::cerr << "affineOrDot.x=" << x->shape() << std::endl; + std::cerr << "affineOrDot.W=" << W->shape() << std::endl; + std::cerr << "affineOrDot.b=" << b->shape() << std::endl; + std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl; + */ if(b) return affine(x, W, b, transA, transB); else return dot(x, W, transA, transB); }; - auto affineShortlist = [](Expr x, Expr W, Expr b, bool , bool ) { - //std::cerr << "x=" << x->shape() << std::endl; - //std::cerr << "W=" << W->shape() << std::endl; - Expr ret = bdot(x, W, false, true); + auto affineShortlist = [](Expr x, Expr W, Expr b, bool transA, bool transB) { + /* + std::cerr << "affineShortlist.x=" << x->shape() << std::endl; + std::cerr << "affineShortlist.W=" << W->shape() << std::endl; + std::cerr << "affineShortlist.b=" << b->shape() << std::endl; + std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl; + */ + ABORT_IF(!(!transA && transB), "Must be transA==0 and transB==1"); + ABORT_IF(b, "affineShortlist not tested with bias"); + Expr ret = bdot(x, W, transA, transB); //std::cerr << "ret.2=" << ret->shape() << std::endl; //std::cerr << std::endl; @@ -171,11 +183,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { // matrix Expr factorLogits; if(g == 0 && shortlist_) { - //std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl; Expr tmp = transpose(input1, {0, 2, 1, 3}); - //std::cerr << "tmp=" << tmp->shape() << std::endl; - //std::cerr << "x=" << x->shape() << std::endl; - //std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl; factorLogits = affineShortlist( tmp, factorWt, @@ -183,18 +191,14 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits factorLogits = transpose(factorLogits, {0, 2, 1, 3}); - //std::cerr << "affineShortlist.factorLogits=" << factorLogits->shape() << std::endl << std::endl; } else { - //std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl; - //std::cerr << "affineOrDot.factorWt=" << factorWt->shape() << std::endl; factorLogits = affineOrDot( input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits - //std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl << std::endl; } // optionally add lemma-dependent bias -- cgit v1.2.3