diff options
Diffstat (limited to 'src/layers/embedding.cpp')
-rw-r--r-- | src/layers/embedding.cpp | 74 |
1 files changed, 54 insertions, 20 deletions
diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp index 92c4ad6d..26d6b7fe 100644 --- a/src/layers/embedding.cpp +++ b/src/layers/embedding.cpp @@ -8,19 +8,31 @@ Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) std::string name = opt<std::string>("prefix"); int dimVoc = opt<int>("dimVocab"); int dimEmb = opt<int>("dimEmb"); + int dimFactorEmb = opt<int>("dimFactorEmb"); bool fixed = opt<bool>("fixed", false); + // Embedding layer initialization should depend only on embedding size, hence fanIn=false + auto initFunc = inits::glorotUniform( + /*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length + factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", "")); if(factoredVocab_) { dimVoc = (int)factoredVocab_->factorVocabSize(); LOG_ONCE(info, "[embedding] Factored embeddings enabled"); + if(opt<std::string>("factorsCombine") == "concat") { + ABORT_IF(dimFactorEmb == 0, + "Embedding: If concatenation is chosen to combine the factor embeddings, a factor " + "embedding size must be specified."); + int numberOfFactors = (int)factoredVocab_->getTotalFactorCount(); + dimVoc -= numberOfFactors; + FactorEmbMatrix_ + = graph_->param("factor_" + name, {numberOfFactors, dimFactorEmb}, initFunc, fixed); + LOG_ONCE(info, + "[embedding] Combining lemma and factors embeddings with concatenation enabled"); + } } - // Embedding layer initialization should depend only on embedding size, hence fanIn=false - auto initFunc = inits::glorotUniform( - /*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length - if(options_->has("embFile")) { std::string file = opt<std::string>("embFile"); if(!file.empty()) { @@ -32,6 +44,26 @@ Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); } +/** + * Embeds a sequence of words (given as indices), where they have factor information. The matrices are concatenated + * @param words vector of words + * @returns Expression that is the concatenation of the lemma and factor embeddings + */ +/*private*/ Expr Embedding::embedWithConcat(const Words& data) const { + auto graph = E_->graph(); + std::vector<IndexType> lemmaIndices; + std::vector<float> factorIndices; + factoredVocab_->lemmaAndFactorsIndexes(data, lemmaIndices, factorIndices); + auto lemmaEmbs = rows(E_, lemmaIndices); + int dimFactors = FactorEmbMatrix_->shape()[0]; + auto factEmbs + = dot(graph->constant( + {(int)data.size(), dimFactors}, inits::fromVector(factorIndices), Type::float32), + FactorEmbMatrix_); + + return concatenate({lemmaEmbs, factEmbs}, -1); +} + // helper to embed a sequence of words (given as indices) via factored embeddings Expr Embedding::multiRows(const Words& data, float dropProb) const { auto graph = E_->graph(); @@ -61,7 +93,9 @@ std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBat /*override final*/ { auto graph = E_->graph(); int dimBatch = (int)subBatch->batchSize(); - int dimEmb = E_->shape()[-1]; + int dimEmb = (factoredVocab_ && opt<std::string>("factorsCombine") == "concat") + ? E_->shape()[-1] + FactorEmbMatrix_->shape()[-1] + : E_->shape()[-1]; int dimWidth = (int)subBatch->batchWidth(); // factored embeddings: @@ -96,14 +130,8 @@ std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBat // more slowly auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); -#if 1 + auto batchMask = graph->constant({dimWidth, dimBatch, 1}, inits::fromVector(subBatch->mask())); -#else // @TODO: this is dead code now, get rid of it - // experimental: hide inline-fix source tokens from cross attention - auto batchMask - = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); -#endif // give the graph inputs readable names for debugging and ONNX batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); @@ -112,8 +140,12 @@ std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBat Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { if(factoredVocab_) { - Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + Expr selectedEmbs; + if(opt<std::string>("factorsCombine") == "concat") + selectedEmbs = embedWithConcat(words); // [(B*W) x E] + else + selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] // selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { // selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout return selectedEmbs; @@ -141,13 +173,15 @@ Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const { // clang-format off auto options = New<Options>( - "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_], - "dimEmb", opt<int>("dim-emb"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" + "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_], + "dimEmb", opt<int>("dim-emb"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb", - "fixed", embeddingFix_, + "fixed", embeddingFix_, + "dimFactorEmb", opt<int>("factors-dim-emb"), // for factored embeddings + "factorsCombine", opt<std::string>("factors-combine"), // for factored embeddings "vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings // clang-format on if(options_->hasAndNotEmpty("embedding-vectors")) { |