diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-04-29 10:44:30 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-04-29 10:44:30 +0300 |
commit | 6b2b7d11880013c9574e6bcd8d67bef4f28be97c (patch) | |
tree | 69499bc94bf2cce6ddc5d482e4f8007f061fa015 | |
parent | f41acb1aa86da3b7f357e63218550026835564da (diff) |
factor mask
-rw-r--r-- | src/layers/logits.cpp | 44 | ||||
-rw-r--r-- | src/layers/logits.h | 1 |
2 files changed, 44 insertions, 1 deletions
diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp index 4f0ad815..c327bd0d 100644 --- a/src/layers/logits.cpp +++ b/src/layers/logits.cpp @@ -101,7 +101,28 @@ Expr Logits::getFactoredLogits(size_t groupIndex, factorMasks = constant(getFactorMasks(g, std::vector<WordIndex>())); } else { - factorMasks = constant(getFactorMasks(g, shortlist->indices())); + //std::cerr << "sel=" << sel->shape() << std::endl; + int currBeamSize = sel->shape()[0]; + int batchSize = sel->shape()[2]; + + auto forward = [this, g, currBeamSize, batchSize](Expr out, const std::vector<Expr>& inputs) { + std::vector<WordIndex> indices; + Expr lastIndices = inputs[0]; + lastIndices->val()->get(indices); + std::vector<float> masks = getFactorMasks2(batchSize, currBeamSize, g, indices); + out->val()->set(masks); + }; + + Expr lastIndices = shortlist->getIndicesExpr(batchSize, currBeamSize); + //std::cerr << "lastIndices=" << lastIndices->shape() << std::endl; + factorMasks = lambda({lastIndices}, lastIndices->shape(), Type::float32, forward); + //std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl; + factorMasks = transpose(factorMasks, {1, 0, 2}); + //std::cerr << "factorMasks.2=" << factorMasks->shape() << std::endl; + + const Shape &s = factorMasks->shape(); + factorMasks = reshape(factorMasks, {s[0], 1, s[1], s[2]}); + //std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl; } factorMaxima = cast(factorMaxima, sel->value_type()); factorMasks = cast(factorMasks, sel->value_type()); @@ -219,6 +240,27 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector< return res; } +std::vector<float> Logits::getFactorMasks2(int batchSize, int currBeamSize, size_t factorGroup, const std::vector<WordIndex>& indices) + const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 + size_t n + = indices.empty() + ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) + : indices.size() / currBeamSize; + std::vector<float> res; + res.reserve(currBeamSize * n); + + // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this + // into FactoredVocab + for (size_t currBeam = 0; currBeam < currBeamSize; ++currBeam) { + for(size_t i = 0; i < n; i++) { + size_t idx = currBeam * n + i; + size_t lemma = indices.empty() ? i : (indices[idx] - factoredVocab_->getGroupRange(0).first); + res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); + } + } + return res; +} + Logits Logits::applyUnaryFunction( const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values std::vector<Ptr<RationalLoss>> newLogits; diff --git a/src/layers/logits.h b/src/layers/logits.h index c61a9e74..1c93926d 100644 --- a/src/layers/logits.h +++ b/src/layers/logits.h @@ -80,6 +80,7 @@ private: } // actually the same as constant(data) for this data type std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const; + std::vector<float> getFactorMasks2(int batchSize, int currBeamSize, size_t factorGroup, const std::vector<WordIndex>& indices) const; private: // members |