diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-03-06 08:54:05 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-03-06 08:54:05 +0300 |
commit | ba196637847c50c76d5d0edfcfe39b9cedb0d1d0 (patch) | |
tree | a65b300fa3d7fcbdff4b84a9fceb1241ebaf7e87 /src | |
parent | 55f4216552bca148091f15b72c5c2e5b486d4c79 (diff) |
clang-format -i
Diffstat (limited to 'src')
-rw-r--r-- | src/layers/constructors.h | 70 | ||||
-rw-r--r-- | src/layers/embedding.cpp | 282 | ||||
-rw-r--r-- | src/layers/embedding.h | 108 | ||||
-rw-r--r-- | src/layers/generic.cpp | 11 | ||||
-rw-r--r-- | src/layers/generic.h | 98 | ||||
-rw-r--r-- | src/layers/logits.cpp | 424 | ||||
-rw-r--r-- | src/layers/logits.h | 110 | ||||
-rw-r--r-- | src/layers/loss.cpp | 32 | ||||
-rw-r--r-- | src/layers/loss.h | 181 | ||||
-rw-r--r-- | src/layers/output.cpp | 336 | ||||
-rw-r--r-- | src/layers/output.h | 37 | ||||
-rw-r--r-- | src/models/costs.cpp | 14 | ||||
-rw-r--r-- | src/models/costs.h | 158 | ||||
-rw-r--r-- | src/models/states.h | 70 |
14 files changed, 1068 insertions, 863 deletions
diff --git a/src/layers/constructors.h b/src/layers/constructors.h index e25449aa..9e9de207 100644 --- a/src/layers/constructors.h +++ b/src/layers/constructors.h @@ -1,8 +1,8 @@ #pragma once +#include "layers/embedding.h" #include "layers/factory.h" #include "layers/generic.h" -#include "layers/embedding.h" #include "layers/output.h" namespace marian { @@ -45,6 +45,7 @@ struct LogitLayerFactory : public Factory { // @TODO: In the long run, I hope we can get rid of the abstract factories altogether. class OutputFactory : public LogitLayerFactory { using LogitLayerFactory::LogitLayerFactory; + protected: std::string tiedTransposedName_; Ptr<data::Shortlist> shortlist_; @@ -55,9 +56,7 @@ public: return Accumulator<OutputFactory>(*this); } - void setShortlist(Ptr<data::Shortlist> shortlist) { - shortlist_ = shortlist; - } + void setShortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; } Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) override { auto output = New<Output>(graph, options_); @@ -89,8 +88,7 @@ protected: std::vector<Ptr<IUnaryLayer>> layers_; public: - MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : graph_(graph), options_(options) {} + MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {} Expr apply(const std::vector<Expr>& av) override { Expr output; @@ -106,46 +104,53 @@ public: } Logits applyAsLogits(const std::vector<Expr>& av) override { - // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different return type + // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different + // return type auto lastLayer = std::dynamic_pointer_cast<IUnaryLogitLayer>(layers_.back()); - ABORT_IF(!lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer"); - if (layers_.size() == 1) { - if (av.size() == 1) + ABORT_IF( + !lastLayer, + "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer"); + if(layers_.size() == 1) { + if(av.size() == 1) return lastLayer->applyAsLogits(av[0]); else return lastLayer->applyAsLogits(av); - } - else { + } else { Expr output; - if (av.size() == 1) + if(av.size() == 1) output = layers_[0]->apply(av[0]); else output = layers_[0]->apply(av); - for (size_t i = 1; i < layers_.size() - 1; ++i) + for(size_t i = 1; i < layers_.size() - 1; ++i) output = layers_[i]->apply(output); return lastLayer->applyAsLogits(output); } } - Expr apply(Expr e) override { return apply(std::vector<Expr>{ e }); } - Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{ e }); } + Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); } + Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); } void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); } void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); } void setShortlist(Ptr<data::Shortlist> shortlist) override final { auto p = tryAsHasShortlist(); - ABORT_IF(!p, "setShortlist() called on an MLP with an output layer that does not support short lists"); + ABORT_IF( + !p, + "setShortlist() called on an MLP with an output layer that does not support short lists"); p->setShortlist(shortlist); } void clear() override final { auto p = tryAsHasShortlist(); - if (p) + if(p) p->clear(); } + private: - Ptr<IHasShortList> tryAsHasShortlist() const { return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); } + Ptr<IHasShortList> tryAsHasShortlist() const { + return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); + } }; /** @@ -154,6 +159,7 @@ private: */ class MLPFactory : public Factory { using Factory::Factory; + private: std::vector<Ptr<LayerFactory>> layers_; @@ -177,23 +183,27 @@ public: // which will go away if we get rid of the abstract factories, and instead just construct // all layers immediately, which is my long-term goal for Marian. private: - template<class WrappedFactory> + template <class WrappedFactory> class AsLayerFactory : public LayerFactory { - WrappedFactory us; + WrappedFactory us; + public: - AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {} - Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final { - auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph)); - ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one"); - return p; - } + AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {} + Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final { + auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph)); + ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one"); + return p; + } }; - template<class WrappedFactory> - static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { return wrapped; } + template <class WrappedFactory> + static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { + return wrapped; + } + public: Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) { push_back(AsLayerFactory<OutputFactory>(lf)); - //layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf))); + // layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf))); return Accumulator<MLPFactory>(*this); } }; diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp index 488fbb8b..5a448f61 100644 --- a/src/layers/embedding.cpp +++ b/src/layers/embedding.cpp @@ -3,173 +3,205 @@ namespace marian { -Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) -: LayerBase(graph, options), inference_(opt<bool>("inference")) { -std::string name = opt<std::string>("prefix"); -int dimVoc = opt<int>("dimVocab"); -int dimEmb = opt<int>("dimEmb"); +Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) + : LayerBase(graph, options), inference_(opt<bool>("inference")) { + std::string name = opt<std::string>("prefix"); + int dimVoc = opt<int>("dimVocab"); + int dimEmb = opt<int>("dimEmb"); -bool fixed = opt<bool>("fixed", false); + bool fixed = opt<bool>("fixed", false); -factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", "")); -if (factoredVocab_) { + factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", "")); + if(factoredVocab_) { dimVoc = (int)factoredVocab_->factorVocabSize(); LOG_ONCE(info, "[embedding] Factored embeddings enabled"); -} + } -// Embedding layer initialization should depend only on embedding size, hence fanIn=false -auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length + // Embedding layer initialization should depend only on embedding size, hence fanIn=false + auto initFunc = inits::glorotUniform( + /*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length -if (options_->has("embFile")) { + if(options_->has("embFile")) { std::string file = opt<std::string>("embFile"); - if (!file.empty()) { - bool norm = opt<bool>("normalization", false); - initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); + if(!file.empty()) { + bool norm = opt<bool>("normalization", false); + initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); } -} + } -E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); + E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); } // helper to embed a sequence of words (given as indices) via factored embeddings Expr Embedding::multiRows(const Words& data, float dropProb) const { -auto graph = E_->graph(); -auto factoredData = factoredVocab_->csr_rows(data); -// multi-hot factor vectors are represented as a sparse CSR matrix -// [row index = word position index] -> set of factor indices for word at this position -ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??"); -// the CSR matrix is passed in pieces -auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights)); -auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32); -auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32); -// apply dropout -// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors. -if(!inference_) + auto graph = E_->graph(); + auto factoredData = factoredVocab_->csr_rows(data); + // multi-hot factor vectors are represented as a sparse CSR matrix + // [row index = word position index] -> set of factor indices for word at this position + ABORT_IF(factoredData.shape + != Shape({(int)factoredData.offsets.size() - 1 /*=rows of CSR*/, E_->shape()[0]}), + "shape mismatch??"); + // the CSR matrix is passed in pieces + auto weights = graph->constant({(int)factoredData.weights.size()}, + inits::fromVector(factoredData.weights)); + auto indices = graph->constant( + {(int)factoredData.indices.size()}, inits::fromVector(factoredData.indices), Type::uint32); + auto offsets = graph->constant( + {(int)factoredData.offsets.size()}, inits::fromVector(factoredData.offsets), Type::uint32); + // apply dropout + // We apply it to the weights, i.e. factors get dropped out separately, but always as entire + // vectors. + if(!inference_) weights = dropout(weights, dropProb); -// perform the product -return csr_dot(factoredData.shape, weights, indices, offsets, E_); + // perform the product + return csr_dot(factoredData.shape, weights, indices, offsets, E_); } -std::tuple<Expr/*embeddings*/, Expr/*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ { -auto graph = E_->graph(); -int dimBatch = (int)subBatch->batchSize(); -int dimEmb = E_->shape()[-1]; -int dimWidth = (int)subBatch->batchWidth(); - -// factored embeddings: -// - regular: -// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] -// - factored: -// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) -// - each row of M contains the set of factors for one word => we want a CSR matrix -// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] -// - first compute x @ M on the CPU -// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): -// - shape (U, specifically) not actually needed here -// - foreach input x[i] -// - locate row M[i,*] -// - copy through its index values (std::vector<push_back>) -// - create a matching ones vector (we can keep growing) -// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) -// - CSR matrix product with E -// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) -// - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()). -// - weighting: -// - core factors' gradients are sums over all words that use the factors; -// - core factors' embeddings move very fast -// - words will need to make up for the move; rare words cannot -// - so, we multiply each factor with 1/refCount -// - core factors get weighed down a lot -// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before -// - but forward pass weighs them down, so that all factors are in a similar numeric range -// - if it is required to be in a different range, the embeddings can still learn that, but more slowly - -auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); +std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const +/*override final*/ { + auto graph = E_->graph(); + int dimBatch = (int)subBatch->batchSize(); + int dimEmb = E_->shape()[-1]; + int dimWidth = (int)subBatch->batchWidth(); + + // factored embeddings: + // - regular: + // - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] + // - factored: + // - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) + // - each row of M contains the set of factors for one word => we want a CSR matrix + // - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] + // - first compute x @ M on the CPU + // - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): + // - shape (U, specifically) not actually needed here + // - foreach input x[i] + // - locate row M[i,*] + // - copy through its index values (std::vector<push_back>) + // - create a matching ones vector (we can keep growing) + // - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) + // - CSR matrix product with E + // - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) + // - double-check if all dimensions are specified. Probably not for transpose (which would + // be like csc_dot()). + // - weighting: + // - core factors' gradients are sums over all words that use the factors; + // - core factors' embeddings move very fast + // - words will need to make up for the move; rare words cannot + // - so, we multiply each factor with 1/refCount + // - core factors get weighed down a lot + // - no impact on gradients, as Adam makes up for it; embeddings still move fast just as + // before + // - but forward pass weighs them down, so that all factors are in a similar numeric range + // - if it is required to be in a different range, the embeddings can still learn that, but + // more slowly + + auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); #if 1 -auto batchMask = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->mask())); -#else // @TODO: this is dead code now, get rid of it -// experimental: hide inline-fix source tokens from cross attention -auto batchMask = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); + auto batchMask = graph->constant({dimWidth, dimBatch, 1}, inits::fromVector(subBatch->mask())); +#else // @TODO: this is dead code now, get rid of it + // experimental: hide inline-fix source tokens from cross attention + auto batchMask + = graph->constant({dimWidth, dimBatch, 1}, + inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); #endif -// give the graph inputs readable names for debugging and ONNX -batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); + // give the graph inputs readable names for debugging and ONNX + batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); -return std::make_tuple(batchEmbeddings, batchMask); + return std::make_tuple(batchEmbeddings, batchMask); } Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { -if (factoredVocab_) { - Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] - //selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout + if(factoredVocab_) { + Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + // selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { + // selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout return selectedEmbs; -} -else + } else return applyIndices(toWordIndexVector(words), shape); } -Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const /*override final*/ { -ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); -auto embIdxExpr = E_->graph()->indices(embIdx); -embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? -auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] -selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] -// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately) -if(!inference_) - selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); -return selectedEmbs; +Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const +/*override final*/ { + ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); + auto embIdxExpr = E_->graph()->indices(embIdx); + embIdxExpr->set_name("data_" + + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? + auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() + // (test that separately) + if(!inference_) + selectedEmbs = dropout( + selectedEmbs, options_->get<float>("dropout", 0.0f), {selectedEmbs->shape()[-3], 1, 1}); + return selectedEmbs; } // standard encoder word embeddings /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const { -auto options = New<Options>( - "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_], - "dimEmb", opt<int>("dim-emb"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb", - "fixed", embeddingFix_, - "vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings -if(options_->hasAndNotEmpty("embedding-vectors")) { + auto options = New<Options>( + "dimVocab", + opt<std::vector<int>>("dim-vocabs")[batchIndex_], + "dimEmb", + opt<int>("dim-emb"), + "dropout", + dropoutEmbeddings_, + "inference", + inference_, + "prefix", + (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" + : prefix_ + "_Wemb", + "fixed", + embeddingFix_, + "vocab", + opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings + if(options_->hasAndNotEmpty("embedding-vectors")) { auto embFiles = opt<std::vector<std::string>>("embedding-vectors"); options->set( - "embFile", embFiles[batchIndex_], - "normalization", opt<bool>("embedding-normalization")); -} -return New<Embedding>(graph_, options); + "embFile", embFiles[batchIndex_], "normalization", opt<bool>("embedding-normalization")); + } + return New<Embedding>(graph_, options); } // ULR word embeddings /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const { -return New<ULREmbedding>(graph_, New<Options>( - "dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src - "dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt - "dimUlrEmb", opt<int>("ulr-dim-emb"), - "dimEmb", opt<int>("dim-emb"), - "ulr-dropout", opt<float>("ulr-dropout"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "ulrTrainTransform", opt<bool>("ulr-trainable-transformation"), - "ulrQueryFile", opt<std::string>("ulr-query-vectors"), - "ulrKeysFile", opt<std::string>("ulr-keys-vectors"))); + return New<ULREmbedding>( + graph_, + New<Options>("dimSrcVoc", + opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src + "dimTgtVoc", + opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt + "dimUlrEmb", + opt<int>("ulr-dim-emb"), + "dimEmb", + opt<int>("dim-emb"), + "ulr-dropout", + opt<float>("ulr-dropout"), + "dropout", + dropoutEmbeddings_, + "inference", + inference_, + "ulrTrainTransform", + opt<bool>("ulr-trainable-transformation"), + "ulrQueryFile", + opt<std::string>("ulr-query-vectors"), + "ulrKeysFile", + opt<std::string>("ulr-keys-vectors"))); } // get embedding layer for this encoder or decoder // This is lazy mostly because the constructors of the consuming objects are not // guaranteed presently to have access to their graph. Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const { -if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy - if (embeddingLayers_.size() <= batchIndex_) - embeddingLayers_.resize(batchIndex_ + 1); - if (ulr) - embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR + if(embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy + if(embeddingLayers_.size() <= batchIndex_) + embeddingLayers_.resize(batchIndex_ + 1); + if(ulr) + embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR else - embeddingLayers_[batchIndex_] = createEmbeddingLayer(); -} -return embeddingLayers_[batchIndex_]; -} - + embeddingLayers_[batchIndex_] = createEmbeddingLayer(); + } + return embeddingLayers_[batchIndex_]; } +} // namespace marian diff --git a/src/layers/embedding.h b/src/layers/embedding.h index b7898c76..6edb3140 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -1,6 +1,6 @@ #pragma once -#include "marian.h" #include "generic.h" +#include "marian.h" namespace marian { @@ -19,7 +19,8 @@ class Embedding : public LayerBase, public IEmbeddingLayer { public: Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options); - std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final; + std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply( + Ptr<data::SubBatch> subBatch) const override final; Expr apply(const Words& words, const Shape& shape) const override final; @@ -27,17 +28,18 @@ public: }; class ULREmbedding : public LayerBase, public IEmbeddingLayer { - std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members + std::vector<Expr> + ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members bool inference_{false}; public: - ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : LayerBase(graph, options), inference_(opt<bool>("inference")) { - std::string name = "url_embed"; //opt<std::string>("prefix"); + ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) + : LayerBase(graph, options), inference_(opt<bool>("inference")) { + std::string name = "url_embed"; // opt<std::string>("prefix"); int dimKeys = opt<int>("dimTgtVoc"); int dimQueries = opt<int>("dimSrcVoc"); int dimEmb = opt<int>("dimEmb"); - int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size + int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size bool fixed = opt<bool>("fixed", false); // Embedding layer initialization should depend only on embedding size, hence fanIn=false @@ -46,58 +48,61 @@ public: std::string queryFile = opt<std::string>("ulrQueryFile"); std::string keyFile = opt<std::string>("ulrKeysFile"); bool trainTrans = opt<bool>("ulrTrainTransform", false); - if (!queryFile.empty() && !keyFile.empty()) { + if(!queryFile.empty() && !keyFile.empty()) { initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); name = "ulr_query"; fixed = true; - auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed); + auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(query_embed); // keys embeds initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); name = "ulr_keys"; fixed = true; - auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed); + auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(key_embed); // actual trainable embedding initFunc = inits::glorotUniform(); name = "ulr_embed"; fixed = false; - auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim + auto ulr_embed + = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim ulrEmbeddings_.push_back(ulr_embed); // init trainable src embedding name = "ulr_src_embed"; - auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed); + auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed); ulrEmbeddings_.push_back(ulr_src_embed); // ulr transformation matrix - //initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only - if (trainTrans) { + // initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall + // we make this to the fixed case only + if(trainTrans) { initFunc = inits::glorotUniform(); fixed = false; - } - else - { - initFunc = inits::eye(); // identity matrix + } else { + initFunc = inits::eye(); // identity matrix fixed = true; } name = "ulr_transform"; - auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed); + auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(ulrTransform); - initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only + initFunc = inits::fromValue( + 1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no + // universal embeddings - should be zero for top freq only fixed = true; name = "ulr_shared"; - auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed); + auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed); ulrEmbeddings_.push_back(share_embed); } } - std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final { - auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb - auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb - auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb - auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb - auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb - auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 + std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply( + Ptr<data::SubBatch> subBatch) const override final { + auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb + auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb + auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb + auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb + auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb + auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 int dimBatch = (int)subBatch->batchSize(); int dimEmb = uniEmbed->shape()[-1]; int dimWords = (int)subBatch->batchWidth(); @@ -106,34 +111,42 @@ public: // dim A = uni_embed_size*uni_embed_size // dim Q: uni_embed_size * total_merged_vocab_size // dim D = univ_tok_vocab * total_merged_vocab_size - // note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD) - // here we need to handle the mini-batch - // extract raws corresponding to Xs in this minibatch from Q + // note all above can be precombuted and serialized if A is not trainiable and during decoding + // (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this + // minibatch from Q auto embIdx = toWordIndexVector(subBatch->data()); auto queryEmbeddings = rows(queryEmbed, embIdx); - auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings - auto alpha = rows(ulrSharable, embIdx); // extract sharable flags - auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb - auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]); - qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes - auto z = dot(qt, keyEmbed, false, true); // query-key similarity + auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings + auto alpha = rows(ulrSharable, embIdx); // extract sharable flags + auto qt = dot(queryEmbeddings, + ulrTransform, + false, + false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb + auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]); + qt = qt / sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in + // magnitude with larger embeds sizes + auto z = dot(qt, keyEmbed, false, true); // query-key similarity float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout if(!inference_) z = dropout(z, dropProb); - float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature + float tau + = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature // temperature in softmax is to control randomness of predictions // high temperature Softmax outputs are more close to each other // low temperatures the softmax become more similar to "hardmax" - auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? + auto weights + = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE - auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast - auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb }); + auto chosenEmbeddings_mix + = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast + auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb}); auto graph = ulrEmbeddings_.front()->graph(); - auto batchMask = graph->constant({ dimWords, dimBatch, 1 }, - inits::fromVector(subBatch->mask())); + auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask())); if(!inference_) - batchEmbeddings = dropout(batchEmbeddings, options_->get<float>("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1}); + batchEmbeddings = dropout(batchEmbeddings, + options_->get<float>("dropout-embeddings", 0.0f), + {batchEmbeddings->shape()[-3], 1, 1}); return std::make_tuple(batchEmbeddings, batchMask); } @@ -142,9 +155,10 @@ public: } Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final { - embIdx; shape; - ABORT("not implemented"); // @TODO: implement me + embIdx; + shape; + ABORT("not implemented"); // @TODO: implement me } }; -} +} // namespace marian diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index 02e820e5..8e2ecfd7 100644 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -1,13 +1,10 @@ #include "marian.h" -#include "layers/generic.h" +#include "data/factored_vocab.h" #include "layers/constructors.h" +#include "layers/generic.h" #include "layers/loss.h" -#include "data/factored_vocab.h" -#include "models/states.h" // for EncoderState #include "layers/lsh.h" +#include "models/states.h" // for EncoderState -namespace marian { - - -} // namespace marian +namespace marian {} // namespace marian diff --git a/src/layers/generic.h b/src/layers/generic.h index eddd597e..89f5c1e9 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -5,12 +5,14 @@ #include "data/shortlist.h" #include "layers/factory.h" -namespace marian { namespace mlp { - /** - * @brief Activation functions - */ - enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish }; -}} +namespace marian { +namespace mlp { +/** + * @brief Activation functions + */ +enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish }; +} // namespace mlp +} // namespace marian namespace marian { @@ -23,8 +25,7 @@ protected: Ptr<Options> options_; public: - LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : graph_(graph), options_(options) {} + LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {} template <typename T> T opt(const std::string key) const { @@ -42,7 +43,7 @@ struct IUnaryLayer { virtual ~IUnaryLayer() {} virtual Expr apply(Expr) = 0; virtual Expr apply(const std::vector<Expr>& es) { - ABORT_IF(es.size() > 1, "Not implemented"); // simple stub + ABORT_IF(es.size() > 1, "Not implemented"); // simple stub return apply(es.front()); } }; @@ -54,7 +55,8 @@ struct IHasShortList { // Embedding from corpus sub-batch to (emb, mask) struct IEmbeddingLayer { - virtual std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const = 0; + virtual std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply( + Ptr<data::SubBatch> subBatch) const = 0; virtual Expr apply(const Words& embIdx, const Shape& shape) const = 0; @@ -63,28 +65,29 @@ struct IEmbeddingLayer { virtual ~IEmbeddingLayer() {} }; -// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream index) +// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream +// index) class EncoderDecoderLayerBase : public LayerBase { protected: const std::string prefix_; const bool embeddingFix_; - const float dropoutEmbeddings_; // this drops out full embedding vectors + const float dropoutEmbeddings_; // this drops out full embedding vectors const bool inference_; const size_t batchIndex_; - mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created) + mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created) - EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph, - Ptr<Options> options, - const std::string& prefix, + EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph, + Ptr<Options> options, + const std::string& prefix, size_t batchIndex, float dropoutEmbeddings, - bool embeddingFix) : - LayerBase(graph, options), - prefix_(options->get<std::string>("prefix", prefix)), - embeddingFix_(embeddingFix), - dropoutEmbeddings_(dropoutEmbeddings), - inference_(options->get<bool>("inference", false)), - batchIndex_(options->get<size_t>("index", batchIndex)) {} + bool embeddingFix) + : LayerBase(graph, options), + prefix_(options->get<std::string>("prefix", prefix)), + embeddingFix_(embeddingFix), + dropoutEmbeddings_(dropoutEmbeddings), + inference_(options->get<bool>("inference", false)), + batchIndex_(options->get<size_t>("index", batchIndex)) {} virtual ~EncoderDecoderLayerBase() {} @@ -101,8 +104,7 @@ namespace mlp { class Dense : public LayerBase, public IUnaryLayer { public: - Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : LayerBase(graph, options) {} + Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {} Expr apply(const std::vector<Expr>& inputs) override { ABORT_IF(inputs.empty(), "No inputs"); @@ -124,21 +126,17 @@ public: if(inputs.size() > 1) num = std::to_string(i); - Expr W = g->param( - name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform()); + Expr W = g->param(name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform()); Expr b = g->param(name + "_b" + num, {1, dim}, inits::zeros()); if(useLayerNorm) { if(useNematusNorm) { - auto ln_s = g->param( - name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f)); + auto ln_s = g->param(name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f)); auto ln_b = g->param(name + "_ln_b" + num, {1, dim}, inits::zeros()); - outputs.push_back( - layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS)); + outputs.push_back(layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS)); } else { - auto gamma = g->param( - name + "_gamma" + num, {1, dim}, inits::fromValue(1.0)); + auto gamma = g->param(name + "_gamma" + num, {1, dim}, inits::fromValue(1.0)); outputs.push_back(layerNorm(dot(in, W), gamma, b)); } @@ -165,39 +163,35 @@ public: Expr apply(Expr input) override { return apply(std::vector<Expr>({input})); } }; -} // namespace mlp - +} // namespace mlp // --- a few layers with built-in parameters created on the fly, without proper object // @TODO: change to a proper layer object // like affine() but with built-in parameters, activation, and dropout -static inline -Expr denseInline(Expr x, - std::string prefix, - std::string suffix, - int outDim, - Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(), - const std::function<Expr(Expr)>& actFn = nullptr, - float dropProb = 0.0f) -{ +static inline Expr denseInline(Expr x, + std::string prefix, + std::string suffix, + int outDim, + Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(), + const std::function<Expr(Expr)>& actFn = nullptr, + float dropProb = 0.0f) { auto graph = x->graph(); - auto W = graph->param(prefix + "_W" + suffix, { x->shape()[-1], outDim }, inits::glorotUniform()); - auto b = graph->param(prefix + "_b" + suffix, { 1, outDim }, inits::zeros()); + auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, inits::glorotUniform()); + auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros()); x = affine(x, W, b); - if (actFn) + if(actFn) x = actFn(x); - x = dropout(x, dropProb); // @TODO: check for infernce? + x = dropout(x, dropProb); // @TODO: check for infernce? return x; } -static inline -Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) { +static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) { int dimModel = x->shape()[-1]; - auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, { 1, dimModel }, inits::ones()); - auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, { 1, dimModel }, inits::zeros()); + auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, {1, dimModel}, inits::ones()); + auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, {1, dimModel}, inits::zeros()); return marian::layerNorm(x, scale, bias, 1e-6f); } diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp index cd2203e4..772c5715 100644 --- a/src/layers/logits.cpp +++ b/src/layers/logits.cpp @@ -1,212 +1,250 @@ #include "logits.h" -#include "loss.h" #include "data/factored_vocab.h" -#include "rnn/types.h" // for State::select() +#include "loss.h" +#include "rnn/types.h" // for State::select() namespace marian { - Logits::Logits(Expr logits) : Logits(New<RationalLoss>(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count) - - Ptr<ExpressionGraph> Logits::graph() const { - ABORT_IF(logits_.empty(), "Empty logits object??"); - return logits_.front()->loss()->graph(); +Logits::Logits(Expr logits) + : Logits(New<RationalLoss>(logits, nullptr)) { +} // single-output constructor from Expr only (RationalLoss has no count) + +Ptr<ExpressionGraph> Logits::graph() const { + ABORT_IF(logits_.empty(), "Empty logits object??"); + return logits_.front()->loss()->graph(); +} + +// This function assumes that the object holds one or more factor logits. +// It applies the supplied loss function to each, and then returns the aggregate loss over all +// factors. +Expr Logits::applyLossFunction( + const Words& labels, + const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const { + LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size()); + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + + auto firstLogits = logits_.front()->loss(); + ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), + "Labels not matching logits shape ({} != {}, {})??", + labels.size() * firstLogits->shape()[-1], + firstLogits->shape().elements(), + firstLogits->shape()); + + // base case (no factors) + if(!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return lossFn(firstLogits, indices(toWordIndexVector(labels))); } - // This function assumes that the object holds one or more factor logits. - // It applies the supplied loss function to each, and then returns the aggregate loss over all factors. - Expr Logits::applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/, Expr/*indices*/)>& lossFn) const { - LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size()); - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - - auto firstLogits = logits_.front()->loss(); - ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), - "Labels not matching logits shape ({} != {}, {})??", - labels.size() * firstLogits->shape()[-1], - firstLogits->shape().elements(), - firstLogits->shape()); - - // base case (no factors) - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return lossFn(firstLogits, indices(toWordIndexVector(labels))); - } - - auto numGroups = factoredVocab_->getNumGroups(); - - // split labels into individual factor labels - auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened] - - //Expr indices = this->indices(toWordIndexVector(labels)); - // accumulate all CEs for all words that have the factor - // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors. - Expr loss; - for (size_t g = 0; g < numGroups; g++) { - if (!logits_[g]) - continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] - const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) - auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply - auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor - auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) - // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next. - auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] - if(loss) - factorLoss = cast(factorLoss, loss->value_type()); - factorLoss = factorLoss * cast(reshape(factorMask, factorLoss->shape()), factorLoss->value_type()); // mask out factor for words that do not have that factor - loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1] - } - return loss; + auto numGroups = factoredVocab_->getNumGroups(); + + // split labels into individual factor labels + auto allMaskedFactoredLabels + = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened] + + // Expr indices = this->indices(toWordIndexVector(labels)); + // accumulate all CEs for all words that have the factor + // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors. + Expr loss; + for(size_t g = 0; g < numGroups; g++) { + if(!logits_[g]) + continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] + const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) + auto factorIndices = indices( + maskedFactoredLabels + .indices); // [B... flattened] factor-label indices, or 0 if factor does not apply + auto factorMask + = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with + // 0 for labels that don't have this factor + auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) + // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask + // it out next. + auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] + if(loss) + factorLoss = cast(factorLoss, loss->value_type()); + factorLoss + = factorLoss + * cast( + reshape(factorMask, factorLoss->shape()), + factorLoss->value_type()); // mask out factor for words that do not have that factor + loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1] } - - // This function assumes this object holds a single factor that represents a rational loss (with count). - //Ptr<RationalLoss> Logits::getRationalLoss() const { - // ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs"); - // ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count"); - // return logits_.front(); - //} - - // get logits for one factor group - // For groupIndex == 0, the function also requires the shortlist if there is one. - Expr Logits::getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist /*= nullptr*/, const std::vector<IndexType>& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - - auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab] - - // normalize for decoding: - // - all secondary factors: subtract their max - // - lemma: add all maxes of applicable factors - if (groupIndex > 0) { - sel = sel - max(sel, -1); - } - else { - auto numGroups = getNumFactorGroups(); - for (size_t g = 1; g < numGroups; g++) { - auto factorMaxima = max(logits_[g]->loss(), -1); // we cast since loss is likely ce-loss which has type float32 - auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>())); - sel = sel + cast(factorMaxima, sel->value_type()) * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor get multiplied with 0 - } + return loss; +} + +// This function assumes this object holds a single factor that represents a rational loss (with +// count). +// Ptr<RationalLoss> Logits::getRationalLoss() const { +// ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on +// multi-factor outputs"); ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational +// loss without count"); return logits_.front(); +//} + +// get logits for one factor group +// For groupIndex == 0, the function also requires the shortlist if there is one. +Expr Logits::getFactoredLogits(size_t groupIndex, + Ptr<data::Shortlist> shortlist /*= nullptr*/, + const std::vector<IndexType>& hypIndices /*= {}*/, + size_t beamSize /*= 0*/) const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + + auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab] + + // normalize for decoding: + // - all secondary factors: subtract their max + // - lemma: add all maxes of applicable factors + if(groupIndex > 0) { + sel = sel - max(sel, -1); + } else { + auto numGroups = getNumFactorGroups(); + for(size_t g = 1; g < numGroups; g++) { + auto factorMaxima = max(logits_[g]->loss(), + -1); // we cast since loss is likely ce-loss which has type float32 + auto factorMasks = constant( + getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>())); + sel = sel + + cast(factorMaxima, sel->value_type()) + * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor + // get multiplied with 0 } - - // if selIdx are given, then we must reshuffle accordingly - if (!hypIndices.empty()) // use the same function that shuffles decoder state - sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false); - - return sel; } - // used for breakDown() only - // Index is flattened - Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - return logits_[groupIndex]->loss()->val(); + // if selIdx are given, then we must reshuffle accordingly + if(!hypIndices.empty()) // use the same function that shuffles decoder state + sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false); + + return sel; +} + +// used for breakDown() only +// Index is flattened +Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + return logits_[groupIndex]->loss()->val(); +} + +// This function assumes that the object holds one or more factor logits, which are summed up +// into output-vocab logits according to the factored model (with correct normalization of factors). +// This is infeasible for realistic factor sets, and therefore only implemented for 1 factor. +// @TODO: remove altogether +Expr Logits::getLogits() const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + if(!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return getFactoredLogits(0); } - // This function assumes that the object holds one or more factor logits, which are summed up - // into output-vocab logits according to the factored model (with correct normalization of factors). - // This is infeasible for realistic factor sets, and therefore only implemented for 1 factor. - // @TODO: remove altogether - Expr Logits::getLogits() const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return getFactoredLogits(0); - } - #ifdef FACTOR_FULL_EXPANSION - // compute normalized factor log probs - std::vector<Expr> logProbs(logits_.size()); - for (size_t g = 0; g < logits_.size(); g++) - logProbs[g] = logsoftmax(logits_[g]->loss()); - auto y = concatenate(logProbs, /*axis=*/ -1); - - // sum up the unit logits across factors for each target word - auto graph = y->graph(); - auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] - y = dot_csr( - y, // [B x U] - factorMatrix.shape, - graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), - graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32), - graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32), - /*transB=*/ true); // -> [B x V] - - // mask out gaps - auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] - y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask)); - - return y; + // compute normalized factor log probs + std::vector<Expr> logProbs(logits_.size()); + for(size_t g = 0; g < logits_.size(); g++) + logProbs[g] = logsoftmax(logits_[g]->loss()); + auto y = concatenate(logProbs, /*axis=*/-1); + + // sum up the unit logits across factors for each target word + auto graph = y->graph(); + auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] + y = dot_csr( + y, // [B x U] + factorMatrix.shape, + graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), + graph->constant({(int)factorMatrix.indices.size()}, + inits::fromVector(factorMatrix.indices), + Type::uint32), + graph->constant({(int)factorMatrix.offsets.size()}, + inits::fromVector(factorMatrix.offsets), + Type::uint32), + /*transB=*/true); // -> [B x V] + + // mask out gaps + auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] + y = y + graph->constant({(int)gapLogMask.size()}, inits::fromVector(gapLogMask)); + + return y; #else - ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible + ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible #endif +} + +void Logits::MaskedFactorIndices::push_back(size_t factorIndex) { + bool isValid = FactoredVocab::isFactorValid(factorIndex); + indices.push_back(isValid ? (WordIndex)factorIndex : 0); + masks.push_back((float)isValid); +} + +std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words) + const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices + if(!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return {MaskedFactorIndices(words)}; } - - void Logits::MaskedFactorIndices::push_back(size_t factorIndex) { - bool isValid = FactoredVocab::isFactorValid(factorIndex); - indices.push_back(isValid ? (WordIndex)factorIndex : 0); - masks.push_back((float)isValid); + auto numGroups = factoredVocab_->getNumGroups(); + std::vector<MaskedFactorIndices> res(numGroups); + for(size_t g = 0; g < numGroups; g++) { + auto& resg = res[g]; + resg.reserve(words.size()); + for(const auto& word : words) + resg.push_back(factoredVocab_->getFactor(word, g)); } - - std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return {MaskedFactorIndices(words)}; - } - auto numGroups = factoredVocab_->getNumGroups(); - std::vector<MaskedFactorIndices> res(numGroups); - for (size_t g = 0; g < numGroups; g++) { - auto& resg = res[g]; - resg.reserve(words.size()); - for (const auto& word : words) - resg.push_back(factoredVocab_->getFactor(word, g)); - } - return res; - } - - //// use first factor of each word to determine whether it has a specific factor - //std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0 - // std::vector<float> res; - // res.reserve(words.size()); - // for (const auto& word : words) { - // auto lemma = factoredVocab_->getFactor(word, 0); - // res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); - // } - // return res; - //} - - // return a vector of 1 or 0 indicating for each lemma whether it has a specific factor - // If 'indices' is given, then return the masks for the indices; otherwise for all lemmas - std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 - size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size(); - std::vector<float> res; - res.reserve(n); - // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab - for (size_t i = 0; i < n; i++) { - auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first); - res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); - } - return res; + return res; +} + +//// use first factor of each word to determine whether it has a specific factor +// std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 +// for words that do have this factor; else 0 +// std::vector<float> res; +// res.reserve(words.size()); +// for (const auto& word : words) { +// auto lemma = factoredVocab_->getFactor(word, 0); +// res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); +// } +// return res; +//} + +// return a vector of 1 or 0 indicating for each lemma whether it has a specific factor +// If 'indices' is given, then return the masks for the indices; otherwise for all lemmas +std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) + const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 + size_t n + = indices.empty() + ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) + : indices.size(); + std::vector<float> res; + res.reserve(n); + // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this + // into FactoredVocab + for(size_t i = 0; i < n; i++) { + auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first); + res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); } - - Logits Logits::applyUnaryFunction(const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values - std::vector<Ptr<RationalLoss>> newLogits; - for (const auto& l : logits_) - newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count())); - return Logits(std::move(newLogits), factoredVocab_); - } - - Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const { - std::vector<Ptr<RationalLoss>> newLogits; - bool first = true; - for (const auto& l : logits_) { - newLogits.emplace_back(New<RationalLoss>((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others - first = false; - } - return Logits(std::move(newLogits), factoredVocab_); - } - - // @TODO: code dup with above; we can merge it into applyToRationalLoss() - Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_ - std::vector<Ptr<RationalLoss>> newLogits; - for (const auto& l : logits_) - newLogits.emplace_back(New<RationalLoss>(l->loss(), count)); - return Logits(std::move(newLogits), factoredVocab_); + return res; +} + +Logits Logits::applyUnaryFunction( + const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values + std::vector<Ptr<RationalLoss>> newLogits; + for(const auto& l : logits_) + newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count())); + return Logits(std::move(newLogits), factoredVocab_); +} + +Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1, + const std::function<Expr(Expr)>& fother) const { + std::vector<Ptr<RationalLoss>> newLogits; + bool first = true; + for(const auto& l : logits_) { + newLogits.emplace_back(New<RationalLoss>((first ? f1 : fother)(l->loss()), + l->count())); // f1 for first, fother for all others + first = false; } -}
\ No newline at end of file + return Logits(std::move(newLogits), factoredVocab_); +} + +// @TODO: code dup with above; we can merge it into applyToRationalLoss() +Logits Logits::withCounts( + const Expr& count) const { // create new Logits with 'count' implanted into all logits_ + std::vector<Ptr<RationalLoss>> newLogits; + for(const auto& l : logits_) + newLogits.emplace_back(New<RationalLoss>(l->loss(), count)); + return Logits(std::move(newLogits), factoredVocab_); +} +} // namespace marian
\ No newline at end of file diff --git a/src/layers/logits.h b/src/layers/logits.h index 4196e0d0..c61a9e74 100644 --- a/src/layers/logits.h +++ b/src/layers/logits.h @@ -1,8 +1,8 @@ #pragma once -#include "marian.h" #include "data/shortlist.h" #include "generic.h" +#include "marian.h" namespace marian { @@ -16,46 +16,77 @@ class FactoredVocab; class RationalLoss; class Logits { public: - Logits() {} - explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor - logits_.push_back(logits); - } - explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count) - Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor + Logits() {} + explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor + logits_.push_back(logits); + } + explicit Logits( + Expr logits); // single-output constructor from Expr only (RationalLoss has no count) + Logits(std::vector<Ptr<RationalLoss>>&& logits, + Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {} - Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors - Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle - //Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that - Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const; - Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values - Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values + Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors + Expr getFactoredLogits( + size_t groupIndex, + Ptr<data::Shortlist> shortlist = nullptr, + const std::vector<IndexType>& hypIndices = {}, + size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle + // Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that + Expr applyLossFunction( + const Words& labels, + const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const; + Logits applyUnaryFunction( + const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values + Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, + const std::function<Expr(Expr)>& fother) + const; // clone this but apply f1 to first and fother to to all other values - struct MaskedFactorIndices { - std::vector<WordIndex> indices; // factor index, or 0 if masked - std::vector<float> masks; - void reserve(size_t n) { indices.reserve(n); masks.reserve(n); } - void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries - MaskedFactorIndices() {} - MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case - }; - std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices - Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only - size_t getNumFactorGroups() const { return logits_.size(); } - bool empty() const { return logits_.empty(); } - Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_ + struct MaskedFactorIndices { + std::vector<WordIndex> indices; // factor index, or 0 if masked + std::vector<float> masks; + void reserve(size_t n) { + indices.reserve(n); + masks.reserve(n); + } + void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 + // for invalid entries + MaskedFactorIndices() {} + MaskedFactorIndices(const Words& words) { + indices = toWordIndexVector(words); + } // we can leave masks uninitialized for this special use case + }; + std::vector<MaskedFactorIndices> factorizeWords( + const Words& words) const; // breaks encoded Word into individual factor indices + Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only + size_t getNumFactorGroups() const { return logits_.size(); } + bool empty() const { return logits_.empty(); } + Logits withCounts( + const Expr& count) const; // create new Logits with 'count' implanted into all logits_ private: - // helper functions - Ptr<ExpressionGraph> graph() const; - Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data)); } - Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data)); } - template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector - Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type - std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const; + // helper functions + Ptr<ExpressionGraph> graph() const; + Expr constant(const Shape& shape, const std::vector<float>& data) const { + return graph()->constant(shape, inits::fromVector(data)); + } + Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { + return graph()->constant(shape, inits::fromVector(data)); + } + template <typename T> + Expr constant(const std::vector<T>& data) const { + return constant(Shape{(int)data.size()}, data); + } // same as constant() but assuming vector + Expr indices(const std::vector<uint32_t>& data) const { + return graph()->indices(data); + } // actually the same as constant(data) for this data type + std::vector<float> getFactorMasks(size_t factorGroup, + const std::vector<WordIndex>& indices) const; + private: - // members - // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr - std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group] - Ptr<FactoredVocab> factoredVocab_; + // members + // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just + // by the Expr + std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group] + Ptr<FactoredVocab> factoredVocab_; }; // Unary function that returns a Logits object @@ -65,12 +96,11 @@ private: struct IUnaryLogitLayer : public IUnaryLayer { virtual Logits applyAsLogits(Expr) = 0; virtual Logits applyAsLogits(const std::vector<Expr>& es) { - ABORT_IF(es.size() > 1, "Not implemented"); // simple stub + ABORT_IF(es.size() > 1, "Not implemented"); // simple stub return applyAsLogits(es.front()); } virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); } virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); } }; -} - +} // namespace marian diff --git a/src/layers/loss.cpp b/src/layers/loss.cpp index 67d38832..695276af 100644 --- a/src/layers/loss.cpp +++ b/src/layers/loss.cpp @@ -13,26 +13,30 @@ Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference) { bool wordScores = options->get<bool>("word-scores", false); return New<RescorerLoss>(wordScores); } else if(unlikelihood) { - ABORT_IF(!options->hasAndNotEmpty("data-weighting") - && options->get<std::string>("data-weighting-type") != "word", - "Unlikelihood loss training requires error annotation in form of per-target-label scores"); - return New<SequenceUnlikelihoodLoss>(smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on values given for data-weighting - } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. E.g. what about ce-sum? + ABORT_IF( + !options->hasAndNotEmpty("data-weighting") + && options->get<std::string>("data-weighting-type") != "word", + "Unlikelihood loss training requires error annotation in form of per-target-label scores"); + return New<SequenceUnlikelihoodLoss>( + smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on + // values given for data-weighting + } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. + // E.g. what about ce-sum? return New<CrossEntropyLoss>(smoothing, factorWeight); } } // see loss.h for detailed explanations of each class Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) { - std::string multiLossType = options->get<std::string>("multi-loss-type", "sum"); - if(multiLossType == "sum") // sum of sums - return New<SumMultiRationalLoss>(); - else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale - return New<ScaledMultiRationalLoss>(); - else if(multiLossType == "mean") // sum of means - return New<MeanMultiRationalLoss>(); - else - ABORT("Unknown multi-loss-type {}", multiLossType); + std::string multiLossType = options->get<std::string>("multi-loss-type", "sum"); + if(multiLossType == "sum") // sum of sums + return New<SumMultiRationalLoss>(); + else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale + return New<ScaledMultiRationalLoss>(); + else if(multiLossType == "mean") // sum of means + return New<MeanMultiRationalLoss>(); + else + ABORT("Unknown multi-loss-type {}", multiLossType); } } // namespace marian diff --git a/src/layers/loss.h b/src/layers/loss.h index ba93cdac..c662f991 100644 --- a/src/layers/loss.h +++ b/src/layers/loss.h @@ -1,8 +1,8 @@ #pragma once -#include "graph/expression_operators.h" -#include "layers/logits.h" // for Logits (Frank's factor hack) #include "data/types.h" +#include "graph/expression_operators.h" +#include "layers/logits.h" // for Logits (Frank's factor hack) namespace marian { @@ -22,21 +22,18 @@ namespace marian { */ class RationalLoss { protected: - Expr loss_; // numerator - Expr count_; // denominator + Expr loss_; // numerator + Expr count_; // denominator - RationalLoss() = default; // protected + RationalLoss() = default; // protected public: - RationalLoss(Expr loss, Expr count) - : loss_(loss), count_(count) {} + RationalLoss(Expr loss, Expr count) : loss_(loss), count_(count) {} RationalLoss(Expr loss, float count) - : loss_(loss), - count_(constant_like(loss, inits::fromValue(count))) {} + : loss_(loss), count_(constant_like(loss, inits::fromValue(count))) {} - RationalLoss(const RationalLoss& other) - : loss_(other.loss_), count_(other.count_) {} + RationalLoss(const RationalLoss& other) : loss_(other.loss_), count_(other.count_) {} virtual ~RationalLoss() = default; @@ -50,7 +47,7 @@ public: } template <typename T> - T loss() const { // this will fail if loss is not a single value + T loss() const { // this will fail if loss is not a single value ABORT_IF(!loss_, "Loss has not been defined"); return loss_->val()->scalar<T>(); } @@ -65,7 +62,7 @@ public: } template <typename T> - T count() const { // this will fail if loss is not a single value + T count() const { // this will fail if loss is not a single value ABORT_IF(!count_, "Labels have not been defined"); return count_->val()->scalar<T>(); } @@ -85,21 +82,21 @@ public: * RationalLoss object. */ struct StaticLoss { - float loss; // numerator - float count; // denominator + float loss; // numerator + float count; // denominator StaticLoss() : loss(0.f), count(0.f) {} StaticLoss(const RationalLoss& dynamic) - : loss(dynamic.loss<float>()), count(dynamic.count<float>()) {} + : loss(dynamic.loss<float>()), count(dynamic.count<float>()) {} - StaticLoss operator +(const StaticLoss& other) const { + StaticLoss operator+(const StaticLoss& other) const { StaticLoss res(*this); res += other; return res; } - StaticLoss& operator +=(const StaticLoss& other) { + StaticLoss& operator+=(const StaticLoss& other) { loss = loss + other.loss; count = count + other.count; return *this; @@ -139,32 +136,21 @@ protected: public: MultiRationalLoss() : RationalLoss() {} - MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { - push_back(rl); - } + MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { push_back(rl); } virtual void push_back(const RationalLoss& current) { - loss_ = accumulateLoss(current); - count_ = accumulateCount(current); + loss_ = accumulateLoss(current); + count_ = accumulateCount(current); partialLosses_.push_back(current); } - const RationalLoss& operator[](size_t i) { - return partialLosses_[i]; - } + const RationalLoss& operator[](size_t i) { return partialLosses_[i]; } - auto begin() -> decltype(partialLosses_.begin()) const { - return partialLosses_.begin(); - } + auto begin() -> decltype(partialLosses_.begin()) const { return partialLosses_.begin(); } - auto end() -> decltype(partialLosses_.end()) const { - return partialLosses_.end(); - } - - size_t size() const { - return partialLosses_.size(); - } + auto end() -> decltype(partialLosses_.end()) const { return partialLosses_.end(); } + size_t size() const { return partialLosses_.size(); } }; /** @@ -212,17 +198,19 @@ private: virtual Expr accumulateLoss(const RationalLoss& current) override { if(loss_) { const auto& first = partialLosses_.front(); - return loss_ + current.loss() * first.count() / current.count(); // scale up/down to match scale of first loss + return loss_ + + current.loss() * first.count() + / current.count(); // scale up/down to match scale of first loss } else { - return current.loss(); // first reference loss, keeps to scale with this one + return current.loss(); // first reference loss, keeps to scale with this one } } virtual Expr accumulateCount(const RationalLoss& current) override { if(count_) { - return count_; // Keep first label count // or: count_ + first.count() / current.count(); + return count_; // Keep first label count // or: count_ + first.count() / current.count(); } else { - return current.count(); // This is the first loss + return current.count(); // This is the first loss } } @@ -253,9 +241,10 @@ private: virtual Expr accumulateCount(const RationalLoss& current) override { if(count_) - return count_; // keep the existing '1' + return count_; // keep the existing '1' else - return current.count()->graph()->ones({1}, current.loss()->value_type()); // just '1' as labels are factored into loss_ + return current.count()->graph()->ones( + {1}, current.loss()->value_type()); // just '1' as labels are factored into loss_ } public: @@ -279,18 +268,21 @@ class LabelwiseLoss { protected: std::vector<int> axes_; - virtual Expr compute(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) = 0; + virtual Expr compute(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) + = 0; // label counts are available, reduce together with loss to obtain counts RationalLoss reduce(Expr loss, Expr labels) { ABORT_IF(!loss, "Loss has not been computed"); ABORT_IF(!labels, "Labels have not been computed"); - Expr lossSum = cast(loss, Type::float32); // accumulate in float32 - Expr labelsSum = cast(labels, Type::float32); // accumulate in float32 + Expr lossSum = cast(loss, Type::float32); // accumulate in float32 + Expr labelsSum = cast(labels, Type::float32); // accumulate in float32 for(int i = 0; i < axes_.size(); ++i) { - lossSum = sum(lossSum, axes_[i]); + lossSum = sum(lossSum, axes_[i]); labelsSum = sum(labelsSum, axes_[i]); } @@ -301,7 +293,7 @@ protected: RationalLoss reduce(Expr loss) { ABORT_IF(!loss, "Loss has not been computed"); - Expr lossSum = cast(loss, Type::float32); + Expr lossSum = cast(loss, Type::float32); for(int i = 0; i < axes_.size(); ++i) lossSum = sum(lossSum, axes_[i]); @@ -311,17 +303,18 @@ protected: } public: - LabelwiseLoss(const std::vector<int>& axes) - : axes_(axes) { } + LabelwiseLoss(const std::vector<int>& axes) : axes_(axes) {} - virtual RationalLoss apply(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) { + virtual RationalLoss apply(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) { Expr loss = compute(logits, labels, mask, labelWeights); if(mask) - return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting + return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting else - return reduce(loss); // we have no mask, assume all items are labels + return reduce(loss); // we have no mask, assume all items are labels } }; @@ -331,28 +324,34 @@ public: class CrossEntropyLoss : public LabelwiseLoss { public: CrossEntropyLoss(float labelSmoothing, float factorWeight) - : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1 + : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) { + } // cross-entropy already reduces over axis -1 CrossEntropyLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight) - : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1 - labelSmoothing_(labelSmoothing), factorWeight_(factorWeight) {} + : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1 + labelSmoothing_(labelSmoothing), + factorWeight_(factorWeight) {} virtual ~CrossEntropyLoss() {} -protected: - float labelSmoothing_; // interpolation factor for label smoothing, see below - float factorWeight_; // give extra weight to factors - virtual Expr compute(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) override { - // logits may be factored; in that case, the getLoss() function computes one loss for each, and sums them up +protected: + float labelSmoothing_; // interpolation factor for label smoothing, see below + float factorWeight_; // give extra weight to factors + + virtual Expr compute(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) override { + // logits may be factored; in that case, the getLoss() function computes one loss for each, and + // sums them up int inFactor = false; auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) { - logits = atleast_3d(logits); // we always assume a time and batch dimension exists. + logits = atleast_3d(logits); // we always assume a time and batch dimension exists. // for bert training or classification the time dimension is lost. // Here safeguard against 2d classifier output, adds 1 on the left, non-op. - + Expr ce = cross_entropy(logits, indices, inFactor ? 0.f : labelSmoothing_, Type::float32); - if (inFactor && factorWeight_ != 1.0f) { + if(inFactor && factorWeight_ != 1.0f) { LOG_ONCE(info, "scaling factor losses with weight {}", factorWeight_); ce = ce * factorWeight_; } @@ -365,8 +364,10 @@ protected: if(labelWeights) { // We currently do not know how to use target factors and word-level label weights together - bool wordlevel = labelWeights->shape()[-3] > 1; // Time-dimension is not trivially 1, hence we have word-level weights. - ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, "CE loss with word-level label weights is not implemented for factors"); + bool wordlevel = labelWeights->shape()[-3] + > 1; // Time-dimension is not trivially 1, hence we have word-level weights. + ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, + "CE loss with word-level label weights is not implemented for factors"); ce = ce * cast(labelWeights, Type::float32); } @@ -374,13 +375,12 @@ protected: } }; - /** * @brief Unlikelihood loss across last axis, summed up over batch and time dimensions. This is an * implementation of sequence-level unlikelihood loss from https://arxiv.org/abs/1908.04319. - * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are not - * zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it is going - * to flip over to use SUL for that sentence to penalize the selected word. + * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are + * not zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it + * is going to flip over to use SUL for that sentence to penalize the selected word. * * SUL is implemented as: * -log(gather(1 - softmax(logits), -1, indices)) @@ -390,35 +390,45 @@ protected: class SequenceUnlikelihoodLoss : public CrossEntropyLoss { public: SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight) - : CrossEntropyLoss(labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1 + : CrossEntropyLoss(labelSmoothing, factorWeight) { + } // cross-entropy already reduces over axis -1 SequenceUnlikelihoodLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight) - : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {} + : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {} protected: - virtual Expr compute(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) override { - auto ce = CrossEntropyLoss::compute(logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE + virtual Expr compute(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) override { + auto ce = CrossEntropyLoss::compute( + logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE if(!labelWeights) - return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)? + return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)? // We currently do not know how to use target factors and word-level label weights together ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors"); - ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete. - // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask again to eliminate padding (might be obsolete) + ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by + // default 1, which would make this obsolete. + // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask + // again to eliminate padding (might be obsolete) auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32); auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) { return cast(unlikelihood(logits, indices), Type::float32); }); - // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only on_ the errors with UL. This is the "mixed" training - // schedule from https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily switch between CE and UL. - auto onlyCe = eq(sum(errorMask, /*axis=*/-3), 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present - ceUl = errorMask * ceUl; // don't use for correct label or padding + // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only + // on_ the errors with UL. This is the "mixed" training schedule from + // https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily + // switch between CE and UL. + auto onlyCe = eq(sum(errorMask, /*axis=*/-3), + 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present + ceUl = errorMask * ceUl; // don't use for correct label or padding - auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never simultanously used as cost per batch entry + auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never + // simultanously used as cost per batch entry return cost; } @@ -463,7 +473,6 @@ public: } }; - /** * @brief Factory for label-wise loss functions */ diff --git a/src/layers/output.cpp b/src/layers/output.cpp index bf8fa588..1d9c7b4b 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -1,120 +1,131 @@ #include "output.h" -#include "data/factored_vocab.h" #include "common/timer.h" -#include "layers/lsh.h" +#include "data/factored_vocab.h" #include "layers/loss.h" +#include "layers/lsh.h" namespace marian { namespace mlp { /*private*/ void Output::lazyConstruct(int inputDim) { - // We must construct lazily since we won't know tying nor input dim in constructor. - if (Wt_) + // We must construct lazily since we won't know tying nor input dim in constructor. + if(Wt_) return; - // this option is only set in the decoder - if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) { - auto k = opt<std::vector<int>>("output-approx-knn")[0]; + // this option is only set in the decoder + if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) { + auto k = opt<std::vector<int>>("output-approx-knn")[0]; auto nbits = opt<std::vector<int>>("output-approx-knn")[1]; lsh_ = New<LSH>(k, nbits); - } + } - auto name = options_->get<std::string>("prefix"); - auto numOutputClasses = options_->get<int>("dim"); + auto name = options_->get<std::string>("prefix"); + auto numOutputClasses = options_->get<int>("dim"); - factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", "")); - if (factoredVocab_) { + factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", "")); + if(factoredVocab_) { numOutputClasses = (int)factoredVocab_->factorVocabSize(); LOG_ONCE(info, "[embedding] Factored outputs enabled"); - } + } - if(tiedParam_) { + if(tiedParam_) { Wt_ = tiedParam_; - } else { - if (graph_->get(name + "_W")) { // support of legacy models that did not transpose - Wt_ = graph_->param(name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false)); - isLegacyUntransposedW = true; - } - else // this is the regular case: - Wt_ = graph_->param(name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true)); - } + } else { + if(graph_->get(name + "_W")) { // support of legacy models that did not transpose + Wt_ = graph_->param( + name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false)); + isLegacyUntransposedW = true; + } else // this is the regular case: + Wt_ = graph_->param( + name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true)); + } - if(hasBias_) + if(hasBias_) b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros()); - /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); - ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary"); - if (lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix + /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); + ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary"); + if(lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix #define HARDMAX_HACK #ifdef HARDMAX_HACK - lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number + lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number #endif auto range = factoredVocab_->getGroupRange(0); auto lemmaVocabDim = (int)(range.second - range.first); - auto initFunc = inits::glorotUniform(/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length - lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, initFunc); // [L x U] L=lemmaDimEmb; transposed for speed - } + auto initFunc = inits::glorotUniform( + /*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length + lemmaEt_ = graph_->param(name + "_lemmaEt", + {lemmaDimEmb, lemmaVocabDim}, + initFunc); // [L x U] L=lemmaDimEmb; transposed for speed + } } Logits Output::applyAsLogits(Expr input) /*override final*/ { - lazyConstruct(input->shape()[-1]); + lazyConstruct(input->shape()[-1]); - auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { + auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { if(b) - return affine(x, W, b, transA, transB); + return affine(x, W, b, transA, transB); else - return dot(x, W, transA, transB); - }; + return dot(x, W, transA, transB); + }; - auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) { + auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) { if(lsh_) { - ABORT_IF( transA, "Transposed query not supported for LSH"); - ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH"); - return lsh_->apply(x, W, b); // knows how to deal with undefined bias + ABORT_IF(transA, "Transposed query not supported for LSH"); + ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH"); + return lsh_->apply(x, W, b); // knows how to deal with undefined bias } else { - return affineOrDot(x, W, b, transA, transB); + return affineOrDot(x, W, b, transA, transB); } - }; + }; - if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed - cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices()); + if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one + // batch, then clear()ed + cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices()); if(hasBias_) - cachedShortb_ = index_select(b_ , -1, shortlist_->indices()); - } + cachedShortb_ = index_select(b_, -1, shortlist_->indices()); + } - if (factoredVocab_) { + if(factoredVocab_) { auto graph = input->graph(); // project each factor separately auto numGroups = factoredVocab_->getNumGroups(); - std::vector<Ptr<RationalLoss>> allLogits(numGroups, nullptr); // (note: null entries for absent factors) - Expr input1 = input; // [B... x D] - Expr Plemma = nullptr; // used for lemmaDimEmb=-1 - Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3 - for (size_t g = 0; g < numGroups; g++) { - auto range = factoredVocab_->getGroupRange(g); - if (g > 0 && range.first == range.second) // empty entry + std::vector<Ptr<RationalLoss>> allLogits(numGroups, + nullptr); // (note: null entries for absent factors) + Expr input1 = input; // [B... x D] + Expr Plemma = nullptr; // used for lemmaDimEmb=-1 + Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3 + for(size_t g = 0; g < numGroups; g++) { + auto range = factoredVocab_->getGroupRange(g); + if(g > 0 && range.first == range.second) // empty entry continue; - ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g-1).second, "Factor groups must be consecutive (group {} vs predecessor)", g); - // slice this group's section out of W_ - Expr factorWt, factorB; - if (g == 0 && shortlist_) { + ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g - 1).second, + "Factor groups must be consecutive (group {} vs predecessor)", + g); + // slice this group's section out of W_ + Expr factorWt, factorB; + if(g == 0 && shortlist_) { factorWt = cachedShortWt_; - factorB = cachedShortb_; - } - else { - factorWt = slice(Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second)); + factorB = cachedShortb_; + } else { + factorWt = slice( + Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second)); if(hasBias_) - factorB = slice(b_, -1, Slice((int)range.first, (int)range.second)); - } - /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); - if ((lemmaDimEmb == -2 || lemmaDimEmb == -3) && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max) + factorB = slice(b_, -1, Slice((int)range.first, (int)range.second)); + } + /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); + if((lemmaDimEmb == -2 || lemmaDimEmb == -3) + && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max) LOG_ONCE(info, "[embedding] using lemma conditioning with gate"); // this mimics one transformer layer // - attention over two inputs: - // - e = current lemma. We use the original embedding vector; specifically, expectation over all lemmas. + // - e = current lemma. We use the original embedding vector; specifically, expectation + // over all lemmas. // - input = hidden state FF(h_enc+h_dec) - // - dot-prod attention to allow both sides to influence (unlike our recurrent self-attention) + // - dot-prod attention to allow both sides to influence (unlike our recurrent + // self-attention) // - multi-head to allow for multiple conditions to be modeled // - add & norm, for gradient flow and scaling // - FF layer --this is expensive; it is per-factor @@ -122,112 +133,161 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { int inputDim = input->shape()[-1]; int heads = 8; auto name = options_->get<std::string>("prefix") + "_factor" + std::to_string(g); - auto Wq = graph_->param(name + "_Wq", { inputDim, inputDim }, inits::glorotUniform()); - auto Wk = graph_->param(name + "_Wk", { inputDim, inputDim }, inits::glorotUniform()); - auto Wv = graph_->param(name + "_Wv", { inputDim, inputDim }, inits::glorotUniform()); + auto Wq = graph_->param(name + "_Wq", {inputDim, inputDim}, inits::glorotUniform()); + auto Wk = graph_->param(name + "_Wk", {inputDim, inputDim}, inits::glorotUniform()); + auto Wv = graph_->param(name + "_Wv", {inputDim, inputDim}, inits::glorotUniform()); auto toMultiHead = [&](Expr x, int heads) { - const auto& shape = x->shape(); - int inputDim = shape[-1]; - int otherDim = shape.elements() / inputDim; - ABORT_IF(inputDim / heads * heads != inputDim, "inputDim ({}) must be multiple of number of heads ({})", inputDim, heads); - return reshape(x, { otherDim, heads, 1, inputDim / heads }); + const auto& shape = x->shape(); + int inputDim = shape[-1]; + int otherDim = shape.elements() / inputDim; + ABORT_IF(inputDim / heads * heads != inputDim, + "inputDim ({}) must be multiple of number of heads ({})", + inputDim, + heads); + return reshape(x, {otherDim, heads, 1, inputDim / heads}); }; input1 = inputLemma; - auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query - auto kdm = toMultiHead(dot(input1 - input, Wk), heads); // [B... x H x D/H] the two data vectors projected as keys. Use diff and sigmoid, instead of softmax. - auto vem = toMultiHead(dot(input1, Wv), heads); // [B... x H x D/H] one of the two data vectors projected as values - auto vim = toMultiHead(dot( input, Wv), heads); // [B... x H x D/H] the other - auto zm = bdot(qm, kdm, false, true); // [B... x H x 1] - auto sm = sigmoid(zm); // [B... x H x 1] - auto rm = sm * (vem - vim) + vim; // [B... x H x D/H] - auto r = reshape(rm, input->shape()); // [B... x D] + auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query + auto kdm = toMultiHead(dot(input1 - input, Wk), + heads); // [B... x H x D/H] the two data vectors projected as keys. + // Use diff and sigmoid, instead of softmax. + auto vem = toMultiHead( + dot(input1, Wv), + heads); // [B... x H x D/H] one of the two data vectors projected as values + auto vim = toMultiHead(dot(input, Wv), heads); // [B... x H x D/H] the other + auto zm = bdot(qm, kdm, false, true); // [B... x H x 1] + auto sm = sigmoid(zm); // [B... x H x 1] + auto rm = sm * (vem - vim) + vim; // [B... x H x D/H] + auto r = reshape(rm, input->shape()); // [B... x D] // add & norm input1 = r + input1; input1 = layerNorm(input1, name + "_att"); // FF layer - auto ffnDropProb = 0.1f; // @TODO: get as a parameter - auto ffnDim = inputDim * 2; // @TODO: get as a parameter - auto f = denseInline(input1, name + "_ffn", /*suffix=*/"1", ffnDim, inits::glorotUniform(), (ActivationFunction*)relu, ffnDropProb); - f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim); + auto ffnDropProb = 0.1f; // @TODO: get as a parameter + auto ffnDim = inputDim * 2; // @TODO: get as a parameter + auto f = denseInline(input1, + name + "_ffn", + /*suffix=*/"1", + ffnDim, + inits::glorotUniform(), + (ActivationFunction*)relu, + ffnDropProb); + f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim); // add & norm input1 = f + input1; input1 = layerNorm(input1, name + "_ffn"); - } - // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix - Expr factorLogits; - if(g == 0) - factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits - else - factorLogits = affineOrDot(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits - - // optionally add lemma-dependent bias - if (Plemma) { // [B... x U0] + } + // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a + // matrix + Expr factorLogits; + if(g == 0) + factorLogits = affineOrLSH( + input1, + factorWt, + factorB, + false, + /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + else + factorLogits = affineOrDot( + input1, + factorWt, + factorB, + false, + /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + + // optionally add lemma-dependent bias + if(Plemma) { // [B... x U0] int lemmaVocabDim = Plemma->shape()[-1]; int factorVocabDim = factorLogits->shape()[-1]; auto name = options_->get<std::string>("prefix"); - Expr lemmaBt = graph_->param(name + "_lemmaBt_" + std::to_string(g), {factorVocabDim, lemmaVocabDim}, inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma - auto b = dot(Plemma, lemmaBt, false, true); // [B... x U] + Expr lemmaBt + = graph_->param(name + "_lemmaBt_" + std::to_string(g), + {factorVocabDim, lemmaVocabDim}, + inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma + auto b = dot(Plemma, lemmaBt, false, true); // [B... x U] factorLogits = factorLogits + b; - } - allLogits[g] = New<RationalLoss>(factorLogits, nullptr); - // optionally add a soft embedding of lemma back to create some lemma dependency - // @TODO: if this works, move it into lazyConstruct - if (lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure + } + allLogits[g] = New<RationalLoss>(factorLogits, nullptr); + // optionally add a soft embedding of lemma back to create some lemma dependency + // @TODO: if this works, move it into lazyConstruct + if(lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version"); // get expected lemma embedding vector - auto factorLogSoftmax = logsoftmax(factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set + auto factorLogSoftmax = logsoftmax( + factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set auto factorSoftmax = exp(factorLogSoftmax); - inputLemma = dot(factorSoftmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] - } - else if (lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max + inputLemma = dot(factorSoftmax, + factorWt, + false, + /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] + } else if(lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version"); // get max-lemma embedding vector - auto maxVal = max(factorLogits, -1); // [B... x U] note: with shortlist, this is not the full lemma set + auto maxVal = max(factorLogits, + -1); // [B... x U] note: with shortlist, this is not the full lemma set auto factorHardmax = eq(factorLogits, maxVal); - inputLemma = dot(factorHardmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] - } - else if (lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias + inputLemma = dot(factorHardmax, + factorWt, + false, + /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] + } else if(lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented"); LOG_ONCE(info, "[embedding] using lemma-dependent bias"); - auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in) - auto z = /*stopGradient*/(factorLogSoftmax); - Plemma = exp(z); // [B... x U] - } - else if (lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix + auto factorLogSoftmax + = logsoftmax(factorLogits); // (we do that again later, CSE will kick in) + auto z = /*stopGradient*/ (factorLogSoftmax); + Plemma = exp(z); // [B... x U] + } else if(lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb); - // compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE + // compute softmax. We compute logsoftmax() separately because this way, computation will be + // reused later via CSE auto factorLogSoftmax = logsoftmax(factorLogits); auto factorSoftmax = exp(factorLogSoftmax); #ifdef HARDMAX_HACK - bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation) - if (hardmax) { - lemmaDimEmb = lemmaDimEmb & 0xfffffffe; - LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb); - auto maxVal = max(factorSoftmax, -1); - factorSoftmax = eq(factorSoftmax, maxVal); + bool hardmax = (lemmaDimEmb & 1) + != 0; // odd value triggers hardmax for now (for quick experimentation) + if(hardmax) { + lemmaDimEmb = lemmaDimEmb & 0xfffffffe; + LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb); + auto maxVal = max(factorSoftmax, -1); + factorSoftmax = eq(factorSoftmax, maxVal); } #endif // re-embedding lookup, soft-indexed by softmax - if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix - cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices()); - auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L] + if(shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix + cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices()); + auto e = dot(factorSoftmax, + cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, + false, + true); // [B... x L] // project it back to regular hidden dim int inputDim = input1->shape()[-1]; auto name = options_->get<std::string>("prefix"); - // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also length 1 - Expr lemmaWt = inputDim == lemmaDimEmb ? nullptr : graph_->param(name + "_lemmaWt", { inputDim, lemmaDimEmb }, inits::glorotUniform()); // [D x L] D=hidden-vector dimension - auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D] + // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also + // length 1 + Expr lemmaWt + = inputDim == lemmaDimEmb + ? nullptr + : graph_->param(name + "_lemmaWt", + {inputDim, lemmaDimEmb}, + inits::glorotUniform()); // [D x L] D=hidden-vector dimension + auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D] // augment the original hidden vector with this additional information input1 = input1 + f; - } + } } return Logits(std::move(allLogits), factoredVocab_); - } else if (shortlist_) { - return Logits(affineOrLSH(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true)); - } else { - return Logits(affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); - } + } else if(shortlist_) { + return Logits(affineOrLSH(input, + cachedShortWt_, + cachedShortb_, + false, + /*transB=*/isLegacyUntransposedW ? false : true)); + } else { + return Logits( + affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); + } } -} -}
\ No newline at end of file +} // namespace mlp +} // namespace marian
\ No newline at end of file diff --git a/src/layers/output.h b/src/layers/output.h index 92e7eb25..2b6f4986 100644 --- a/src/layers/output.h +++ b/src/layers/output.h @@ -1,10 +1,10 @@ #pragma once -#include "marian.h" -#include "generic.h" -#include "logits.h" #include "data/shortlist.h" +#include "generic.h" #include "layers/factory.h" +#include "logits.h" +#include "marian.h" namespace marian { class LSH; @@ -14,42 +14,45 @@ namespace mlp { class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { private: // parameters held by this layer - Expr Wt_; // weight matrix is stored transposed for efficiency + Expr Wt_; // weight matrix is stored transposed for efficiency Expr b_; - Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize] - bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form + Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize] + bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form bool hasBias_{true}; Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) Expr cachedShortb_; // these match the current value of shortlist_ Expr cachedShortLemmaEt_; Ptr<FactoredVocab> factoredVocab_; - + // optional parameters set/updated after construction Expr tiedParam_; Ptr<data::Shortlist> shortlist_; Ptr<LSH> lsh_; void lazyConstruct(int inputDim); + public: Output(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : LayerBase(graph, options), - hasBias_{!options->get<bool>("output-omit-bias", false)} { + : LayerBase(graph, options), hasBias_{!options->get<bool>("output-omit-bias", false)} { clear(); } void tieTransposed(Expr tied) { - if (Wt_) - ABORT_IF(tiedParam_.get() != tied.get(), "Tied output projection cannot be changed once weights have been created"); + if(Wt_) + ABORT_IF(tiedParam_.get() != tied.get(), + "Tied output projection cannot be changed once weights have been created"); else tiedParam_ = tied; } void setShortlist(Ptr<data::Shortlist> shortlist) override final { - if (shortlist_) - ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()"); + if(shortlist_) + ABORT_IF(shortlist.get() != shortlist_.get(), + "Output shortlist cannot be changed except after clear()"); else { - ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, "No shortlist but cached parameters??"); + ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, + "No shortlist but cached parameters??"); shortlist_ = shortlist; } // cachedShortWt_ and cachedShortb_ will be created lazily inside apply() @@ -60,7 +63,7 @@ public: void clear() override final { shortlist_ = nullptr; cachedShortWt_ = nullptr; - cachedShortb_ = nullptr; + cachedShortb_ = nullptr; cachedShortLemmaEt_ = nullptr; } @@ -69,6 +72,4 @@ public: } // namespace mlp -} - - +} // namespace marian diff --git a/src/models/costs.cpp b/src/models/costs.cpp index 5105f590..c688b211 100644 --- a/src/models/costs.cpp +++ b/src/models/costs.cpp @@ -4,13 +4,11 @@ namespace marian { namespace models { Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) { -// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost) -state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax)); -// @TODO: This is becoming more and more opaque ^^. Can we simplify this? -return state; -} - - -} + // decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost) + state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax)); + // @TODO: This is becoming more and more opaque ^^. Can we simplify this? + return state; } +} // namespace models +} // namespace marian diff --git a/src/models/costs.h b/src/models/costs.h index 2d34c53a..e5463bfd 100644 --- a/src/models/costs.h +++ b/src/models/costs.h @@ -4,8 +4,8 @@ #include "layers/guided_alignment.h" #include "layers/loss.h" #include "layers/weight.h" -#include "models/encoder_decoder.h" #include "models/encoder_classifier.h" +#include "models/encoder_decoder.h" #include "models/encoder_pooler.h" namespace marian { @@ -22,10 +22,12 @@ namespace models { class ICost { public: - virtual Ptr<MultiRationalLoss> apply(Ptr<IModel> model, - Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model? - Ptr<data::Batch> batch, - bool clearGraph = true) = 0; + virtual Ptr<MultiRationalLoss> apply( + Ptr<IModel> model, + Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model? + Ptr<data::Batch> batch, + bool clearGraph = true) + = 0; virtual ~ICost() {} }; @@ -45,10 +47,9 @@ public: : options_(options), inference_(options->get<bool>("inference", false)) { loss_ = newLoss(options_, inference_); - toBeWeighted_ - = (options_->hasAndNotEmpty("data-weighting") && !inference_) - || (options_->has("dynamic-weighting") && options_->get<bool>("dynamic-weighting") - && !inference_); + toBeWeighted_ = (options_->hasAndNotEmpty("data-weighting") && !inference_) + || (options_->has("dynamic-weighting") + && options_->get<bool>("dynamic-weighting") && !inference_); if(toBeWeighted_) weighter_ = WeightingFactory(options_); } @@ -56,9 +57,9 @@ public: virtual ~EncoderDecoderCECost() {} Ptr<MultiRationalLoss> apply(Ptr<IModel> model, - Ptr<ExpressionGraph> graph, - Ptr<data::Batch> batch, - bool clearGraph = true) override { + Ptr<ExpressionGraph> graph, + Ptr<data::Batch> batch, + bool clearGraph = true) override { auto encdec = std::static_pointer_cast<EncoderDecoder>(model); auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch); @@ -72,17 +73,17 @@ public: Ptr<MultiRationalLoss> multiLoss = newMultiLoss(options_); // @TODO: adapt to multi-objective training with multiple decoders - auto partialLoss = loss_->apply(state->getLogProbs(), - state->getTargetWords(), - state->getTargetMask(), - weights); + auto partialLoss = loss_->apply( + state->getLogProbs(), state->getTargetWords(), state->getTargetMask(), weights); multiLoss->push_back(partialLoss); if(options_->get("guided-alignment", std::string("none")) != "none" && !inference_) { - auto attentionVectors = encdec->getDecoders()[0]->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1] + auto attentionVectors + = encdec->getDecoders()[0] + ->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1] ABORT_IF(attentionVectors.empty(), "Model does not seem to support alignments"); - auto attention = concatenate(attentionVectors, /*axis =*/ -1); + auto attention = concatenate(attentionVectors, /*axis =*/-1); auto alignmentLoss = guidedAlignmentCost(graph, corpusBatch, options_, attention); multiLoss->push_back(alignmentLoss); @@ -109,10 +110,9 @@ public: } Ptr<MultiRationalLoss> apply(Ptr<IModel> model, - Ptr<ExpressionGraph> graph, - Ptr<data::Batch> batch, - bool clearGraph = true) override { - + Ptr<ExpressionGraph> graph, + Ptr<data::Batch> batch, + bool clearGraph = true) override { auto enccls = std::static_pointer_cast<EncoderClassifier>(model); auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch); @@ -141,21 +141,20 @@ protected: public: EncoderPoolerRankCost(Ptr<Options> options) - : options_(options), - inference_(options->get<bool>("inference", false)) { - auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {}); - ABORT_IF(trainEmbedderRank.empty(), "EncoderPoolerRankCost expects train-embedder-rank to be set"); - - margin_ = std::stof(trainEmbedderRank[0]); - if(trainEmbedderRank.size() > 1) - normalizer_ = std::stof(trainEmbedderRank[1]); + : options_(options), inference_(options->get<bool>("inference", false)) { + auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {}); + ABORT_IF(trainEmbedderRank.empty(), + "EncoderPoolerRankCost expects train-embedder-rank to be set"); + + margin_ = std::stof(trainEmbedderRank[0]); + if(trainEmbedderRank.size() > 1) + normalizer_ = std::stof(trainEmbedderRank[1]); } Ptr<MultiRationalLoss> apply(Ptr<IModel> model, Ptr<ExpressionGraph> graph, Ptr<data::Batch> batch, bool clearGraph = true) override { - auto encpool = std::static_pointer_cast<EncoderPooler>(model); auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch); std::vector<Expr> dotProducts = encpool->apply(graph, corpusBatch, clearGraph); @@ -167,28 +166,41 @@ public: ABORT_IF(dotProducts.size() != 3, "Three dot products required for margin loss"); // multi-objective training - auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability - auto exponent = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product + auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability + auto exponent + = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product auto dp = exp(exponent); Expr dn1, dn2; - if(normalizer_ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the magnitude of the sum of negative examples in the denominator. - dn1 = normalizer_ * mean(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example - dn2 = normalizer_ * mean(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example + if(normalizer_ + != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the + // magnitude of the sum of negative examples in the denominator. + dn1 = normalizer_ + * mean(exp(dotProducts[1] - maxDot), + -1); // dot product of anchor and first negative example + dn2 = normalizer_ + * mean(exp(dotProducts[2] - maxDot), + -1); // dot product of positive examples and first negative example } else { - dn1 = sum(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example - dn2 = sum(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example + dn1 = sum(exp(dotProducts[1] - maxDot), + -1); // dot product of anchor and first negative example + dn2 = sum(exp(dotProducts[2] - maxDot), + -1); // dot product of positive examples and first negative example } // We rewrite the loss so it looks more like a log-softmax, presumably more stable? - // Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m) - auto marginLoss1 = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples - auto marginLoss2 = log(dp + dn2) - exponent; // symmetric version of the above with positive example vs negative examples - auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2); - + // Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + // + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m) + auto marginLoss1 + = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples + auto marginLoss2 + = log(dp + dn2) + - exponent; // symmetric version of the above with positive example vs negative examples + auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2); + RationalLoss loss(marginLoss, (float)dimBatch); multiLoss->push_back(loss); - + return multiLoss; } }; @@ -199,8 +211,7 @@ protected: Ptr<ICost> cost_; public: - Trainer(Ptr<IModel> model, Ptr<ICost> cost) - : model_(model), cost_(cost) {} + Trainer(Ptr<IModel> model, Ptr<ICost> cost) : model_(model), cost_(cost) {} virtual ~Trainer() {} @@ -219,8 +230,8 @@ public: } virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph, - Ptr<data::Batch> batch, - bool clearGraph = true) override { + Ptr<data::Batch> batch, + bool clearGraph = true) override { return cost_->apply(model_, graph, batch, clearGraph); }; @@ -230,24 +241,25 @@ public: class ILogProb { public: virtual Logits apply(Ptr<IModel> model, - Ptr<ExpressionGraph> graph, - Ptr<data::Batch> batch, - bool clearGraph = true) = 0; + Ptr<ExpressionGraph> graph, + Ptr<data::Batch> batch, + bool clearGraph = true) + = 0; }; -// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth? -// Beam search uses it for the former meaning, while 'marian score' and validation in the latter. -// This class is for the former use. The latter is done using Trainer. +// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for +// the ground truth? +// Beam search uses it for the former meaning, while 'marian score' and validation in the +// latter. This class is for the former use. The latter is done using Trainer. class Scorer : public IModel { protected: Ptr<IModel> model_; Ptr<ILogProb> logProb_; public: - Scorer(Ptr<IModel> model, Ptr<ILogProb> cost) - : model_(model), logProb_(cost) {} + Scorer(Ptr<IModel> model, Ptr<ILogProb> cost) : model_(model), logProb_(cost) {} - virtual ~Scorer(){} + virtual ~Scorer() {} Ptr<IModel> getModel() { return model_; } @@ -264,8 +276,8 @@ public: } virtual Logits build(Ptr<ExpressionGraph> graph, - Ptr<data::Batch> batch, - bool clearGraph = true) override { + Ptr<data::Batch> batch, + bool clearGraph = true) override { return logProb_->apply(model_, graph, batch, clearGraph); }; @@ -293,10 +305,10 @@ public: virtual ~GumbelSoftmaxStep() {} virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override { state->setLogProbs(state->getLogProbs().applyUnaryFunctions( - [](Expr logits){ // lemma gets gumbelled - return logsoftmax(logits + constant_like(logits, inits::gumbel())); - }, - logsoftmax)); // factors don't + [](Expr logits) { // lemma gets gumbelled + return logsoftmax(logits + constant_like(logits, inits::gumbel())); + }, + logsoftmax)); // factors don't return state; } }; @@ -311,8 +323,7 @@ protected: Ptr<ILogProbStep> cost_; public: - Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) - : encdec_(encdec), cost_(cost) {} + Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {} virtual void load(Ptr<ExpressionGraph> graph, const std::string& name, @@ -346,12 +357,13 @@ public: return encdec_->startState(graph, batch); } - virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph, - Ptr<DecoderState> state, - const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex] - const Words& words, // [beamIndex * activeBatchSize + batchIndex] - const std::vector<IndexType>& batchIndices, // [batchIndex] - int beamSize) override { + virtual Ptr<DecoderState> step( + Ptr<ExpressionGraph> graph, + Ptr<DecoderState> state, + const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex] + const Words& words, // [beamIndex * activeBatchSize + batchIndex] + const std::vector<IndexType>& batchIndices, // [batchIndex] + int beamSize) override { auto nextState = encdec_->step(graph, state, hypIndices, words, batchIndices, beamSize); return cost_->apply(nextState); } @@ -369,9 +381,7 @@ public: encdec_->setShortlistGenerator(shortlistGenerator); }; - virtual Ptr<data::Shortlist> getShortlist() override { - return encdec_->getShortlist(); - }; + virtual Ptr<data::Shortlist> getShortlist() override { return encdec_->getShortlist(); }; virtual data::SoftAlignment getAlignment() override { return encdec_->getAlignment(); } }; diff --git a/src/models/states.h b/src/models/states.h index cfb6fd1b..20dd59c9 100644 --- a/src/models/states.h +++ b/src/models/states.h @@ -1,7 +1,7 @@ #pragma once +#include "layers/logits.h" // @HACK: for factored embeddings only so far #include "marian.h" -#include "layers/logits.h" // @HACK: for factored embeddings only so far #include "rnn/types.h" namespace marian { @@ -9,7 +9,7 @@ namespace marian { class EncoderState { private: Expr context_; - Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask + Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask Ptr<data::CorpusBatch> batch_; public: @@ -19,31 +19,34 @@ public: EncoderState() {} virtual ~EncoderState() {} - virtual Expr getContext() const { return context_; } - virtual Expr getAttended() const { return context_; } - virtual Expr getMask() const { return mask_; } // source batch mask; may have additional positions suppressed + virtual Expr getContext() const { return context_; } + virtual Expr getAttended() const { return context_; } + virtual Expr getMask() const { + return mask_; + } // source batch mask; may have additional positions suppressed - virtual const Words& getSourceWords() { - return batch_->front()->data(); - } + virtual const Words& getSourceWords() { return batch_->front()->data(); } // Sub-select active batch entries from encoder context and context mask - Ptr<EncoderState> select(const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries - // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout - return New<EncoderState>(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_); + Ptr<EncoderState> select( + const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries + // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer + // gets transposed to the same dimension layout + return New<EncoderState>( + index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_); } }; class DecoderState { protected: - rnn::States states_; // states of individual decoder layers + rnn::States states_; // states of individual decoder layers Logits logProbs_; std::vector<Ptr<EncoderState>> encStates_; Ptr<data::CorpusBatch> batch_; - Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded + Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded Expr targetMask_; - Words targetWords_; // target labels + Words targetWords_; // target labels // Keep track of current target token position during translation size_t position_{0}; @@ -57,26 +60,30 @@ public: virtual ~DecoderState() {} // @TODO: Do we need all these to be virtual? - virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const { - return encStates_; - } + virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const { return encStates_; } virtual Logits getLogProbs() const { return logProbs_; } virtual void setLogProbs(Logits logProbs) { logProbs_ = logProbs; } - // @TODO: should this be a constructor? Then derived classes can call this without the New<> in the loop - virtual Ptr<DecoderState> select(const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex] - const std::vector<IndexType>& batchIndices, // [batchIndex] - int beamSize) const { - + // @TODO: should this be a constructor? Then derived classes can call this without the New<> in + // the loop + virtual Ptr<DecoderState> select( + const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex] + const std::vector<IndexType>& batchIndices, // [batchIndex] + int beamSize) const { std::vector<Ptr<EncoderState>> newEncStates; for(auto& es : encStates_) - // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries - newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices)); + // If the size of the batch dimension of the encoder state context changed, subselect the + // correct batch entries + newEncStates.push_back( + es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices)); // hypindices matches batchIndices in terms of batch dimension, so we only need hypIndices - auto selectedState = New<DecoderState>( - states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), logProbs_, newEncStates, batch_); + auto selectedState + = New<DecoderState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), + logProbs_, + newEncStates, + batch_); // Set positon of new state based on the target token position of current state selectedState->setPosition(getPosition()); @@ -86,7 +93,9 @@ public: virtual const rnn::States& getStates() const { return states_; } virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; }; - virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { targetHistoryEmbeddings_ = targetHistoryEmbeddings; } + virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { + targetHistoryEmbeddings_ = targetHistoryEmbeddings; + } virtual const Words& getTargetWords() const { return targetWords_; }; virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; } @@ -94,9 +103,7 @@ public: virtual Expr getTargetMask() const { return targetMask_; }; virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; } - virtual const Words& getSourceWords() const { - return getEncoderStates()[0]->getSourceWords(); - } + virtual const Words& getSourceWords() const { return getEncoderStates()[0]->getSourceWords(); } Ptr<data::CorpusBatch> getBatch() const { return batch_; } @@ -111,7 +118,8 @@ public: /** * Classifier output based on DecoderState - * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have stateful output. + * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have + * stateful output. */ class ClassifierState { private: |