Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/layers/output.cpp')
-rw-r--r--src/layers/output.cpp44
1 files changed, 20 insertions, 24 deletions
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index 92cccdfb..4d6e488a 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -36,12 +36,12 @@ namespace mlp {
b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros());
/*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
+ std::string lemmaDependency = options_->get<std::string>("lemma-dependency", "");
ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
- if(lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix
-#define HARDMAX_HACK
-#ifdef HARDMAX_HACK
- lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number
-#endif
+ if(lemmaDependency == "re-embedding") { // embed the (expected) word with a different embedding matrix
+ ABORT_IF(
+ lemmaDimEmb <= 0,
+ "In order to predict factors by re-embedding them, a lemma-dim-emb must be specified.");
auto range = factoredVocab_->getGroupRange(0);
auto lemmaVocabDim = (int)(range.second - range.first);
auto initFunc = inits::glorotUniform(
@@ -109,8 +109,12 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
std::vector<Ptr<RationalLoss>> allLogits(numGroups,
nullptr); // (note: null entries for absent factors)
Expr input1 = input; // [B... x D]
- Expr Plemma = nullptr; // used for lemmaDimEmb=-1
- Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3
+ Expr Plemma = nullptr; // used for lemmaDependency = lemma-dependent-bias
+ Expr inputLemma = nullptr; // used for lemmaDependency = hard-transformer-layer and soft-transformer-layer
+
+ std::string factorsCombine = options_->get<std::string>("factors-combine", "");
+ ABORT_IF(factorsCombine == "concat", "Combining lemma and factors embeddings with concatenation on the target side is currently not supported");
+
for(size_t g = 0; g < numGroups; g++) {
auto range = factoredVocab_->getGroupRange(g);
if(g > 0 && range.first == range.second) // empty entry
@@ -130,9 +134,8 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
}
/*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
- if((lemmaDimEmb == -2 || lemmaDimEmb == -3)
- && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max)
- LOG_ONCE(info, "[embedding] using lemma conditioning with gate");
+ std::string lemmaDependency = options_->get<std::string>("lemma-dependency", "");
+ if((lemmaDependency == "soft-transformer-layer" || lemmaDependency == "hard-transformer-layer") && g > 0) {
// this mimics one transformer layer
// - attention over two inputs:
// - e = current lemma. We use the original embedding vector; specifically, expectation
@@ -229,7 +232,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
// optionally add a soft embedding of lemma back to create some lemma dependency
// @TODO: if this works, move it into lazyConstruct
- if(lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure
+ if(lemmaDependency == "soft-transformer-layer" && g == 0) {
LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version");
// get expected lemma embedding vector
auto factorLogSoftmax = logsoftmax(
@@ -239,7 +242,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
factorWt,
false,
/*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
- } else if(lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max
+ } else if(lemmaDependency == "hard-transformer-layer" && g == 0) {
LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version");
// get max-lemma embedding vector
auto maxVal = max(factorLogits,
@@ -249,29 +252,22 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
factorWt,
false,
/*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
- } else if(lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias
+ } else if(lemmaDependency == "lemma-dependent-bias" && g == 0) {
ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented");
LOG_ONCE(info, "[embedding] using lemma-dependent bias");
auto factorLogSoftmax
= logsoftmax(factorLogits); // (we do that again later, CSE will kick in)
auto z = /*stopGradient*/ (factorLogSoftmax);
Plemma = exp(z); // [B... x U]
- } else if(lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix
+ } else if(lemmaDependency == "re-embedding" && g == 0) {
+ ABORT_IF(
+ lemmaDimEmb <= 0,
+ "In order to predict factors by re-embedding them, a lemma-dim-emb must be specified.");
LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb);
// compute softmax. We compute logsoftmax() separately because this way, computation will be
// reused later via CSE
auto factorLogSoftmax = logsoftmax(factorLogits);
auto factorSoftmax = exp(factorLogSoftmax);
-#ifdef HARDMAX_HACK
- bool hardmax = (lemmaDimEmb & 1)
- != 0; // odd value triggers hardmax for now (for quick experimentation)
- if(hardmax) {
- lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
- LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
- auto maxVal = max(factorSoftmax, -1);
- factorSoftmax = eq(factorSoftmax, maxVal);
- }
-#endif
// re-embedding lookup, soft-indexed by softmax
Expr e;
if(shortlist_) { // short-listed version of re-embedding matrix