diff options
Diffstat (limited to 'src/layers/output.cpp')
-rw-r--r-- | src/layers/output.cpp | 44 |
1 files changed, 20 insertions, 24 deletions
diff --git a/src/layers/output.cpp b/src/layers/output.cpp index 92cccdfb..4d6e488a 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -36,12 +36,12 @@ namespace mlp { b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros()); /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); + std::string lemmaDependency = options_->get<std::string>("lemma-dependency", ""); ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary"); - if(lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix -#define HARDMAX_HACK -#ifdef HARDMAX_HACK - lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number -#endif + if(lemmaDependency == "re-embedding") { // embed the (expected) word with a different embedding matrix + ABORT_IF( + lemmaDimEmb <= 0, + "In order to predict factors by re-embedding them, a lemma-dim-emb must be specified."); auto range = factoredVocab_->getGroupRange(0); auto lemmaVocabDim = (int)(range.second - range.first); auto initFunc = inits::glorotUniform( @@ -109,8 +109,12 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { std::vector<Ptr<RationalLoss>> allLogits(numGroups, nullptr); // (note: null entries for absent factors) Expr input1 = input; // [B... x D] - Expr Plemma = nullptr; // used for lemmaDimEmb=-1 - Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3 + Expr Plemma = nullptr; // used for lemmaDependency = lemma-dependent-bias + Expr inputLemma = nullptr; // used for lemmaDependency = hard-transformer-layer and soft-transformer-layer + + std::string factorsCombine = options_->get<std::string>("factors-combine", ""); + ABORT_IF(factorsCombine == "concat", "Combining lemma and factors embeddings with concatenation on the target side is currently not supported"); + for(size_t g = 0; g < numGroups; g++) { auto range = factoredVocab_->getGroupRange(g); if(g > 0 && range.first == range.second) // empty entry @@ -130,9 +134,8 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { factorB = slice(b_, -1, Slice((int)range.first, (int)range.second)); } /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); - if((lemmaDimEmb == -2 || lemmaDimEmb == -3) - && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max) - LOG_ONCE(info, "[embedding] using lemma conditioning with gate"); + std::string lemmaDependency = options_->get<std::string>("lemma-dependency", ""); + if((lemmaDependency == "soft-transformer-layer" || lemmaDependency == "hard-transformer-layer") && g > 0) { // this mimics one transformer layer // - attention over two inputs: // - e = current lemma. We use the original embedding vector; specifically, expectation @@ -229,7 +232,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { allLogits[g] = New<RationalLoss>(factorLogits, nullptr); // optionally add a soft embedding of lemma back to create some lemma dependency // @TODO: if this works, move it into lazyConstruct - if(lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure + if(lemmaDependency == "soft-transformer-layer" && g == 0) { LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version"); // get expected lemma embedding vector auto factorLogSoftmax = logsoftmax( @@ -239,7 +242,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] - } else if(lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max + } else if(lemmaDependency == "hard-transformer-layer" && g == 0) { LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version"); // get max-lemma embedding vector auto maxVal = max(factorLogits, @@ -249,29 +252,22 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] - } else if(lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias + } else if(lemmaDependency == "lemma-dependent-bias" && g == 0) { ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented"); LOG_ONCE(info, "[embedding] using lemma-dependent bias"); auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in) auto z = /*stopGradient*/ (factorLogSoftmax); Plemma = exp(z); // [B... x U] - } else if(lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix + } else if(lemmaDependency == "re-embedding" && g == 0) { + ABORT_IF( + lemmaDimEmb <= 0, + "In order to predict factors by re-embedding them, a lemma-dim-emb must be specified."); LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb); // compute softmax. We compute logsoftmax() separately because this way, computation will be // reused later via CSE auto factorLogSoftmax = logsoftmax(factorLogits); auto factorSoftmax = exp(factorLogSoftmax); -#ifdef HARDMAX_HACK - bool hardmax = (lemmaDimEmb & 1) - != 0; // odd value triggers hardmax for now (for quick experimentation) - if(hardmax) { - lemmaDimEmb = lemmaDimEmb & 0xfffffffe; - LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb); - auto maxVal = max(factorSoftmax, -1); - factorSoftmax = eq(factorSoftmax, maxVal); - } -#endif // re-embedding lookup, soft-indexed by softmax Expr e; if(shortlist_) { // short-listed version of re-embedding matrix |