diff options
author | Roman Grundkiewicz <rogrundk@microsoft.com> | 2021-03-08 14:09:03 +0300 |
---|---|---|
committer | Roman Grundkiewicz <rogrundk@microsoft.com> | 2021-03-08 14:09:03 +0300 |
commit | cd018e8d0404687c0bd13f64962bd22617b80331 (patch) | |
tree | 555ce956bf9ffc13fe96d767d6f3380b45ab1d86 /src | |
parent | ba196637847c50c76d5d0edfcfe39b9cedb0d1d0 (diff) |
Update formatting
Diffstat (limited to 'src')
-rw-r--r-- | src/layers/embedding.cpp | 77 | ||||
-rw-r--r-- | src/layers/embedding.h | 87 | ||||
-rw-r--r-- | src/layers/logits.cpp | 31 |
3 files changed, 85 insertions, 110 deletions
diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp index 5a448f61..92c4ad6d 100644 --- a/src/layers/embedding.cpp +++ b/src/layers/embedding.cpp @@ -6,8 +6,8 @@ namespace marian { Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options), inference_(opt<bool>("inference")) { std::string name = opt<std::string>("prefix"); - int dimVoc = opt<int>("dimVocab"); - int dimEmb = opt<int>("dimEmb"); + int dimVoc = opt<int>("dimVocab"); + int dimEmb = opt<int>("dimEmb"); bool fixed = opt<bool>("fixed", false); @@ -25,7 +25,7 @@ Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) std::string file = opt<std::string>("embFile"); if(!file.empty()) { bool norm = opt<bool>("normalization", false); - initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); + initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); } } @@ -34,7 +34,7 @@ Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) // helper to embed a sequence of words (given as indices) via factored embeddings Expr Embedding::multiRows(const Words& data, float dropProb) const { - auto graph = E_->graph(); + auto graph = E_->graph(); auto factoredData = factoredVocab_->csr_rows(data); // multi-hot factor vectors are represented as a sparse CSR matrix // [row index = word position index] -> set of factor indices for word at this position @@ -59,9 +59,9 @@ Expr Embedding::multiRows(const Words& data, float dropProb) const { std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ { - auto graph = E_->graph(); + auto graph = E_->graph(); int dimBatch = (int)subBatch->batchSize(); - int dimEmb = E_->shape()[-1]; + int dimEmb = E_->shape()[-1]; int dimWidth = (int)subBatch->batchWidth(); // factored embeddings: @@ -113,7 +113,7 @@ std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBat Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { if(factoredVocab_) { Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] // selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { // selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout return selectedEmbs; @@ -128,7 +128,7 @@ Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() // (test that separately) if(!inference_) @@ -139,22 +139,17 @@ Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& // standard encoder word embeddings /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const { + // clang-format off auto options = New<Options>( - "dimVocab", - opt<std::vector<int>>("dim-vocabs")[batchIndex_], - "dimEmb", - opt<int>("dim-emb"), - "dropout", - dropoutEmbeddings_, - "inference", - inference_, - "prefix", - (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" - : prefix_ + "_Wemb", - "fixed", - embeddingFix_, - "vocab", - opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings + "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_], + "dimEmb", opt<int>("dim-emb"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" + : prefix_ + "_Wemb", + "fixed", embeddingFix_, + "vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings + // clang-format on if(options_->hasAndNotEmpty("embedding-vectors")) { auto embFiles = opt<std::vector<std::string>>("embedding-vectors"); options->set( @@ -165,28 +160,20 @@ Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& // ULR word embeddings /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const { - return New<ULREmbedding>( - graph_, - New<Options>("dimSrcVoc", - opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src - "dimTgtVoc", - opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt - "dimUlrEmb", - opt<int>("ulr-dim-emb"), - "dimEmb", - opt<int>("dim-emb"), - "ulr-dropout", - opt<float>("ulr-dropout"), - "dropout", - dropoutEmbeddings_, - "inference", - inference_, - "ulrTrainTransform", - opt<bool>("ulr-trainable-transformation"), - "ulrQueryFile", - opt<std::string>("ulr-query-vectors"), - "ulrKeysFile", - opt<std::string>("ulr-keys-vectors"))); + // clang-format off + return New<ULREmbedding>(graph_, New<Options>( + "dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src + "dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt + "dimUlrEmb", opt<int>("ulr-dim-emb"), + "dimEmb", opt<int>("dim-emb"), + "ulr-dropout", opt<float>("ulr-dropout"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "ulrTrainTransform", opt<bool>("ulr-trainable-transformation"), + "ulrQueryFile", opt<std::string>("ulr-query-vectors"), + "ulrKeysFile", opt<std::string>("ulr-keys-vectors") + )); + // clang-format on } // get embedding layer for this encoder or decoder diff --git a/src/layers/embedding.h b/src/layers/embedding.h index 6edb3140..2fa7b78d 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -28,47 +28,45 @@ public: }; class ULREmbedding : public LayerBase, public IEmbeddingLayer { - std::vector<Expr> - ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members + std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members bool inference_{false}; public: ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options), inference_(opt<bool>("inference")) { std::string name = "url_embed"; // opt<std::string>("prefix"); - int dimKeys = opt<int>("dimTgtVoc"); - int dimQueries = opt<int>("dimSrcVoc"); - int dimEmb = opt<int>("dimEmb"); - int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size - bool fixed = opt<bool>("fixed", false); + int dimKeys = opt<int>("dimTgtVoc"); + int dimQueries = opt<int>("dimSrcVoc"); + int dimEmb = opt<int>("dimEmb"); + int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size + bool fixed = opt<bool>("fixed", false); // Embedding layer initialization should depend only on embedding size, hence fanIn=false auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); std::string queryFile = opt<std::string>("ulrQueryFile"); - std::string keyFile = opt<std::string>("ulrKeysFile"); - bool trainTrans = opt<bool>("ulrTrainTransform", false); + std::string keyFile = opt<std::string>("ulrKeysFile"); + bool trainTrans = opt<bool>("ulrTrainTransform", false); if(!queryFile.empty() && !keyFile.empty()) { - initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); - name = "ulr_query"; - fixed = true; + initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); + name = "ulr_query"; + fixed = true; auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(query_embed); // keys embeds - initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); - name = "ulr_keys"; - fixed = true; + initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); + name = "ulr_keys"; + fixed = true; auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(key_embed); // actual trainable embedding initFunc = inits::glorotUniform(); - name = "ulr_embed"; - fixed = false; - auto ulr_embed - = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim + name = "ulr_embed"; + fixed = false; + auto ulr_embed = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim ulrEmbeddings_.push_back(ulr_embed); // init trainable src embedding - name = "ulr_src_embed"; + name = "ulr_src_embed"; auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed); ulrEmbeddings_.push_back(ulr_src_embed); // ulr transformation matrix @@ -76,20 +74,20 @@ public: // we make this to the fixed case only if(trainTrans) { initFunc = inits::glorotUniform(); - fixed = false; + fixed = false; } else { initFunc = inits::eye(); // identity matrix - fixed = true; + fixed = true; } - name = "ulr_transform"; + name = "ulr_transform"; auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(ulrTransform); initFunc = inits::fromValue( 1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no // universal embeddings - should be zero for top freq only - fixed = true; - name = "ulr_shared"; + fixed = true; + name = "ulr_shared"; auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed); ulrEmbeddings_.push_back(share_embed); } @@ -97,15 +95,15 @@ public: std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply( Ptr<data::SubBatch> subBatch) const override final { - auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb - auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb - auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb - auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb + auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb + auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb + auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb + auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb - auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 - int dimBatch = (int)subBatch->batchSize(); - int dimEmb = uniEmbed->shape()[-1]; - int dimWords = (int)subBatch->batchWidth(); + auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 + int dimBatch = (int)subBatch->batchSize(); + int dimEmb = uniEmbed->shape()[-1]; + int dimWords = (int)subBatch->batchWidth(); // D = K.A.QT // dimm(K) = univ_tok_vocab*uni_embed_size // dim A = uni_embed_size*uni_embed_size @@ -114,18 +112,15 @@ public: // note all above can be precombuted and serialized if A is not trainiable and during decoding // (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this // minibatch from Q - auto embIdx = toWordIndexVector(subBatch->data()); + auto embIdx = toWordIndexVector(subBatch->data()); auto queryEmbeddings = rows(queryEmbed, embIdx); - auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings - auto alpha = rows(ulrSharable, embIdx); // extract sharable flags - auto qt = dot(queryEmbeddings, - ulrTransform, - false, - false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb - auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]); + auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings + auto alpha = rows(ulrSharable, embIdx); // extract sharable flags + auto qt = dot(queryEmbeddings, ulrTransform, false, false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb + auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]); qt = qt / sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in // magnitude with larger embeds sizes - auto z = dot(qt, keyEmbed, false, true); // query-key similarity + auto z = dot(qt, keyEmbed, false, true); // query-key similarity float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout if(!inference_) z = dropout(z, dropProb); @@ -135,13 +130,11 @@ public: // temperature in softmax is to control randomness of predictions // high temperature Softmax outputs are more close to each other // low temperatures the softmax become more similar to "hardmax" - auto weights - = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? + auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE - auto chosenEmbeddings_mix - = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast + auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb}); - auto graph = ulrEmbeddings_.front()->graph(); + auto graph = ulrEmbeddings_.front()->graph(); auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask())); if(!inference_) batchEmbeddings = dropout(batchEmbeddings, diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp index 772c5715..8c4d69bd 100644 --- a/src/layers/logits.cpp +++ b/src/layers/logits.cpp @@ -48,17 +48,14 @@ Expr Logits::applyLossFunction( for(size_t g = 0; g < numGroups; g++) { if(!logits_[g]) continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] - const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) - auto factorIndices = indices( - maskedFactoredLabels - .indices); // [B... flattened] factor-label indices, or 0 if factor does not apply - auto factorMask - = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with - // 0 for labels that don't have this factor - auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) - // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask - // it out next. - auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] + // clang-format off + const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) + auto factorIndices = indices(maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply + auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor + auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) + // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next. + auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] + // clang-format on if(loss) factorLoss = cast(factorLoss, loss->value_type()); factorLoss @@ -140,6 +137,7 @@ Expr Logits::getLogits() const { logProbs[g] = logsoftmax(logits_[g]->loss()); auto y = concatenate(logProbs, /*axis=*/-1); + // clang-format off // sum up the unit logits across factors for each target word auto graph = y->graph(); auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] @@ -147,13 +145,10 @@ Expr Logits::getLogits() const { y, // [B x U] factorMatrix.shape, graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), - graph->constant({(int)factorMatrix.indices.size()}, - inits::fromVector(factorMatrix.indices), - Type::uint32), - graph->constant({(int)factorMatrix.offsets.size()}, - inits::fromVector(factorMatrix.offsets), - Type::uint32), + graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32), + graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32), /*transB=*/true); // -> [B x V] + // clang-format on // mask out gaps auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] @@ -247,4 +242,4 @@ Logits Logits::withCounts( newLogits.emplace_back(New<RationalLoss>(l->loss(), count)); return Logits(std::move(newLogits), factoredVocab_); } -} // namespace marian
\ No newline at end of file +} // namespace marian |