Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/layers/embedding.cpp')
-rw-r--r--src/layers/embedding.cpp74
1 files changed, 54 insertions, 20 deletions
diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp
index 92c4ad6d..26d6b7fe 100644
--- a/src/layers/embedding.cpp
+++ b/src/layers/embedding.cpp
@@ -8,19 +8,31 @@ Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
std::string name = opt<std::string>("prefix");
int dimVoc = opt<int>("dimVocab");
int dimEmb = opt<int>("dimEmb");
+ int dimFactorEmb = opt<int>("dimFactorEmb");
bool fixed = opt<bool>("fixed", false);
+ // Embedding layer initialization should depend only on embedding size, hence fanIn=false
+ auto initFunc = inits::glorotUniform(
+ /*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
+
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
if(factoredVocab_) {
dimVoc = (int)factoredVocab_->factorVocabSize();
LOG_ONCE(info, "[embedding] Factored embeddings enabled");
+ if(opt<std::string>("factorsCombine") == "concat") {
+ ABORT_IF(dimFactorEmb == 0,
+ "Embedding: If concatenation is chosen to combine the factor embeddings, a factor "
+ "embedding size must be specified.");
+ int numberOfFactors = (int)factoredVocab_->getTotalFactorCount();
+ dimVoc -= numberOfFactors;
+ FactorEmbMatrix_
+ = graph_->param("factor_" + name, {numberOfFactors, dimFactorEmb}, initFunc, fixed);
+ LOG_ONCE(info,
+ "[embedding] Combining lemma and factors embeddings with concatenation enabled");
+ }
}
- // Embedding layer initialization should depend only on embedding size, hence fanIn=false
- auto initFunc = inits::glorotUniform(
- /*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
-
if(options_->has("embFile")) {
std::string file = opt<std::string>("embFile");
if(!file.empty()) {
@@ -32,6 +44,26 @@ Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
}
+/**
+ * Embeds a sequence of words (given as indices), where they have factor information. The matrices are concatenated
+ * @param words vector of words
+ * @returns Expression that is the concatenation of the lemma and factor embeddings
+ */
+/*private*/ Expr Embedding::embedWithConcat(const Words& data) const {
+ auto graph = E_->graph();
+ std::vector<IndexType> lemmaIndices;
+ std::vector<float> factorIndices;
+ factoredVocab_->lemmaAndFactorsIndexes(data, lemmaIndices, factorIndices);
+ auto lemmaEmbs = rows(E_, lemmaIndices);
+ int dimFactors = FactorEmbMatrix_->shape()[0];
+ auto factEmbs
+ = dot(graph->constant(
+ {(int)data.size(), dimFactors}, inits::fromVector(factorIndices), Type::float32),
+ FactorEmbMatrix_);
+
+ return concatenate({lemmaEmbs, factEmbs}, -1);
+}
+
// helper to embed a sequence of words (given as indices) via factored embeddings
Expr Embedding::multiRows(const Words& data, float dropProb) const {
auto graph = E_->graph();
@@ -61,7 +93,9 @@ std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBat
/*override final*/ {
auto graph = E_->graph();
int dimBatch = (int)subBatch->batchSize();
- int dimEmb = E_->shape()[-1];
+ int dimEmb = (factoredVocab_ && opt<std::string>("factorsCombine") == "concat")
+ ? E_->shape()[-1] + FactorEmbMatrix_->shape()[-1]
+ : E_->shape()[-1];
int dimWidth = (int)subBatch->batchWidth();
// factored embeddings:
@@ -96,14 +130,8 @@ std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBat
// more slowly
auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
-#if 1
+
auto batchMask = graph->constant({dimWidth, dimBatch, 1}, inits::fromVector(subBatch->mask()));
-#else // @TODO: this is dead code now, get rid of it
- // experimental: hide inline-fix source tokens from cross attention
- auto batchMask
- = graph->constant({dimWidth, dimBatch, 1},
- inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
-#endif
// give the graph inputs readable names for debugging and ONNX
batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
@@ -112,8 +140,12 @@ std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBat
Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ {
if(factoredVocab_) {
- Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
- selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
+ Expr selectedEmbs;
+ if(opt<std::string>("factorsCombine") == "concat")
+ selectedEmbs = embedWithConcat(words); // [(B*W) x E]
+ else
+ selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
+ selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
// selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), {
// selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
return selectedEmbs;
@@ -141,13 +173,15 @@ Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape&
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const {
// clang-format off
auto options = New<Options>(
- "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_],
- "dimEmb", opt<int>("dim-emb"),
- "dropout", dropoutEmbeddings_,
- "inference", inference_,
- "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb"
+ "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_],
+ "dimEmb", opt<int>("dim-emb"),
+ "dropout", dropoutEmbeddings_,
+ "inference", inference_,
+ "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb"
: prefix_ + "_Wemb",
- "fixed", embeddingFix_,
+ "fixed", embeddingFix_,
+ "dimFactorEmb", opt<int>("factors-dim-emb"), // for factored embeddings
+ "factorsCombine", opt<std::string>("factors-combine"), // for factored embeddings
"vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
// clang-format on
if(options_->hasAndNotEmpty("embedding-vectors")) {