Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-06-16 21:19:23 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-06-16 21:19:23 +0300
commit9b4a845cc7db127c8d29c990128893fc974128a4 (patch)
tree98bff6afeaf0f958943926a094aeb376b30e3641
parent892554129e2a1700755ed0e97ccb1cd5fa6b76f5 (diff)
clean up bias
-rw-r--r--src/data/shortlist.cpp4
-rw-r--r--src/layers/output.cpp28
2 files changed, 18 insertions, 14 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index 808ffd7a..a965f249 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -184,9 +184,9 @@ void LSHShortlist::createCachedTensors(Expr weights,
cachedShortWt_ = reshape(cachedShortWt_, {currBeamSize, batchSize, k, cachedShortWt_->shape()[1]});
if (b) {
- ABORT("Bias not yet tested");
+ ABORT("Bias not supported with LSH");
cachedShortb_ = index_select(b, -1, indicesExprFlatten);
- cachedShortb_ = reshape(cachedShortb_, {currBeamSize, k, batchSize, cachedShortb_->shape()[1]}); // not tested
+ cachedShortb_ = reshape(cachedShortb_, {currBeamSize, batchSize, k, cachedShortb_->shape()[0]}); // not tested
}
if (lemmaEt) {
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index 03e77545..055f8cae 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -56,16 +56,28 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
lazyConstruct(input->shape()[-1]);
auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
+ /*
+ std::cerr << "affineOrDot.x=" << x->shape() << std::endl;
+ std::cerr << "affineOrDot.W=" << W->shape() << std::endl;
+ std::cerr << "affineOrDot.b=" << b->shape() << std::endl;
+ std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl;
+ */
if(b)
return affine(x, W, b, transA, transB);
else
return dot(x, W, transA, transB);
};
- auto affineShortlist = [](Expr x, Expr W, Expr b, bool , bool ) {
- //std::cerr << "x=" << x->shape() << std::endl;
- //std::cerr << "W=" << W->shape() << std::endl;
- Expr ret = bdot(x, W, false, true);
+ auto affineShortlist = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
+ /*
+ std::cerr << "affineShortlist.x=" << x->shape() << std::endl;
+ std::cerr << "affineShortlist.W=" << W->shape() << std::endl;
+ std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
+ std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
+ */
+ ABORT_IF(!(!transA && transB), "Must be transA==0 and transB==1");
+ ABORT_IF(b, "affineShortlist not tested with bias");
+ Expr ret = bdot(x, W, transA, transB);
//std::cerr << "ret.2=" << ret->shape() << std::endl;
//std::cerr << std::endl;
@@ -171,11 +183,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
// matrix
Expr factorLogits;
if(g == 0 && shortlist_) {
- //std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl;
Expr tmp = transpose(input1, {0, 2, 1, 3});
- //std::cerr << "tmp=" << tmp->shape() << std::endl;
- //std::cerr << "x=" << x->shape() << std::endl;
- //std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl;
factorLogits = affineShortlist(
tmp,
factorWt,
@@ -183,18 +191,14 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
factorLogits = transpose(factorLogits, {0, 2, 1, 3});
- //std::cerr << "affineShortlist.factorLogits=" << factorLogits->shape() << std::endl << std::endl;
}
else {
- //std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl;
- //std::cerr << "affineOrDot.factorWt=" << factorWt->shape() << std::endl;
factorLogits = affineOrDot(
input1,
factorWt,
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
- //std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl << std::endl;
}
// optionally add lemma-dependent bias