Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-04-29 10:44:30 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-04-29 10:44:30 +0300
commit6b2b7d11880013c9574e6bcd8d67bef4f28be97c (patch)
tree69499bc94bf2cce6ddc5d482e4f8007f061fa015
parentf41acb1aa86da3b7f357e63218550026835564da (diff)
factor mask
-rw-r--r--src/layers/logits.cpp44
-rw-r--r--src/layers/logits.h1
2 files changed, 44 insertions, 1 deletions
diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp
index 4f0ad815..c327bd0d 100644
--- a/src/layers/logits.cpp
+++ b/src/layers/logits.cpp
@@ -101,7 +101,28 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
factorMasks = constant(getFactorMasks(g, std::vector<WordIndex>()));
}
else {
- factorMasks = constant(getFactorMasks(g, shortlist->indices()));
+ //std::cerr << "sel=" << sel->shape() << std::endl;
+ int currBeamSize = sel->shape()[0];
+ int batchSize = sel->shape()[2];
+
+ auto forward = [this, g, currBeamSize, batchSize](Expr out, const std::vector<Expr>& inputs) {
+ std::vector<WordIndex> indices;
+ Expr lastIndices = inputs[0];
+ lastIndices->val()->get(indices);
+ std::vector<float> masks = getFactorMasks2(batchSize, currBeamSize, g, indices);
+ out->val()->set(masks);
+ };
+
+ Expr lastIndices = shortlist->getIndicesExpr(batchSize, currBeamSize);
+ //std::cerr << "lastIndices=" << lastIndices->shape() << std::endl;
+ factorMasks = lambda({lastIndices}, lastIndices->shape(), Type::float32, forward);
+ //std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl;
+ factorMasks = transpose(factorMasks, {1, 0, 2});
+ //std::cerr << "factorMasks.2=" << factorMasks->shape() << std::endl;
+
+ const Shape &s = factorMasks->shape();
+ factorMasks = reshape(factorMasks, {s[0], 1, s[1], s[2]});
+ //std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl;
}
factorMaxima = cast(factorMaxima, sel->value_type());
factorMasks = cast(factorMasks, sel->value_type());
@@ -219,6 +240,27 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<
return res;
}
+std::vector<float> Logits::getFactorMasks2(int batchSize, int currBeamSize, size_t factorGroup, const std::vector<WordIndex>& indices)
+ const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
+ size_t n
+ = indices.empty()
+ ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first)
+ : indices.size() / currBeamSize;
+ std::vector<float> res;
+ res.reserve(currBeamSize * n);
+
+ // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this
+ // into FactoredVocab
+ for (size_t currBeam = 0; currBeam < currBeamSize; ++currBeam) {
+ for(size_t i = 0; i < n; i++) {
+ size_t idx = currBeam * n + i;
+ size_t lemma = indices.empty() ? i : (indices[idx] - factoredVocab_->getGroupRange(0).first);
+ res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
+ }
+ }
+ return res;
+}
+
Logits Logits::applyUnaryFunction(
const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
std::vector<Ptr<RationalLoss>> newLogits;
diff --git a/src/layers/logits.h b/src/layers/logits.h
index c61a9e74..1c93926d 100644
--- a/src/layers/logits.h
+++ b/src/layers/logits.h
@@ -80,6 +80,7 @@ private:
} // actually the same as constant(data) for this data type
std::vector<float> getFactorMasks(size_t factorGroup,
const std::vector<WordIndex>& indices) const;
+ std::vector<float> getFactorMasks2(int batchSize, int currBeamSize, size_t factorGroup, const std::vector<WordIndex>& indices) const;
private:
// members