Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-03-06 08:54:05 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-03-06 08:54:05 +0300
commitba196637847c50c76d5d0edfcfe39b9cedb0d1d0 (patch)
treea65b300fa3d7fcbdff4b84a9fceb1241ebaf7e87 /src
parent55f4216552bca148091f15b72c5c2e5b486d4c79 (diff)
clang-format -i
Diffstat (limited to 'src')
-rw-r--r--src/layers/constructors.h70
-rw-r--r--src/layers/embedding.cpp282
-rw-r--r--src/layers/embedding.h108
-rw-r--r--src/layers/generic.cpp11
-rw-r--r--src/layers/generic.h98
-rw-r--r--src/layers/logits.cpp424
-rw-r--r--src/layers/logits.h110
-rw-r--r--src/layers/loss.cpp32
-rw-r--r--src/layers/loss.h181
-rw-r--r--src/layers/output.cpp336
-rw-r--r--src/layers/output.h37
-rw-r--r--src/models/costs.cpp14
-rw-r--r--src/models/costs.h158
-rw-r--r--src/models/states.h70
14 files changed, 1068 insertions, 863 deletions
diff --git a/src/layers/constructors.h b/src/layers/constructors.h
index e25449aa..9e9de207 100644
--- a/src/layers/constructors.h
+++ b/src/layers/constructors.h
@@ -1,8 +1,8 @@
#pragma once
+#include "layers/embedding.h"
#include "layers/factory.h"
#include "layers/generic.h"
-#include "layers/embedding.h"
#include "layers/output.h"
namespace marian {
@@ -45,6 +45,7 @@ struct LogitLayerFactory : public Factory {
// @TODO: In the long run, I hope we can get rid of the abstract factories altogether.
class OutputFactory : public LogitLayerFactory {
using LogitLayerFactory::LogitLayerFactory;
+
protected:
std::string tiedTransposedName_;
Ptr<data::Shortlist> shortlist_;
@@ -55,9 +56,7 @@ public:
return Accumulator<OutputFactory>(*this);
}
- void setShortlist(Ptr<data::Shortlist> shortlist) {
- shortlist_ = shortlist;
- }
+ void setShortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; }
Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) override {
auto output = New<Output>(graph, options_);
@@ -89,8 +88,7 @@ protected:
std::vector<Ptr<IUnaryLayer>> layers_;
public:
- MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : graph_(graph), options_(options) {}
+ MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
Expr apply(const std::vector<Expr>& av) override {
Expr output;
@@ -106,46 +104,53 @@ public:
}
Logits applyAsLogits(const std::vector<Expr>& av) override {
- // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different return type
+ // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different
+ // return type
auto lastLayer = std::dynamic_pointer_cast<IUnaryLogitLayer>(layers_.back());
- ABORT_IF(!lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
- if (layers_.size() == 1) {
- if (av.size() == 1)
+ ABORT_IF(
+ !lastLayer,
+ "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
+ if(layers_.size() == 1) {
+ if(av.size() == 1)
return lastLayer->applyAsLogits(av[0]);
else
return lastLayer->applyAsLogits(av);
- }
- else {
+ } else {
Expr output;
- if (av.size() == 1)
+ if(av.size() == 1)
output = layers_[0]->apply(av[0]);
else
output = layers_[0]->apply(av);
- for (size_t i = 1; i < layers_.size() - 1; ++i)
+ for(size_t i = 1; i < layers_.size() - 1; ++i)
output = layers_[i]->apply(output);
return lastLayer->applyAsLogits(output);
}
}
- Expr apply(Expr e) override { return apply(std::vector<Expr>{ e }); }
- Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{ e }); }
+ Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); }
+ Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); }
void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); }
void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); }
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
auto p = tryAsHasShortlist();
- ABORT_IF(!p, "setShortlist() called on an MLP with an output layer that does not support short lists");
+ ABORT_IF(
+ !p,
+ "setShortlist() called on an MLP with an output layer that does not support short lists");
p->setShortlist(shortlist);
}
void clear() override final {
auto p = tryAsHasShortlist();
- if (p)
+ if(p)
p->clear();
}
+
private:
- Ptr<IHasShortList> tryAsHasShortlist() const { return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); }
+ Ptr<IHasShortList> tryAsHasShortlist() const {
+ return std::dynamic_pointer_cast<IHasShortList>(layers_.back());
+ }
};
/**
@@ -154,6 +159,7 @@ private:
*/
class MLPFactory : public Factory {
using Factory::Factory;
+
private:
std::vector<Ptr<LayerFactory>> layers_;
@@ -177,23 +183,27 @@ public:
// which will go away if we get rid of the abstract factories, and instead just construct
// all layers immediately, which is my long-term goal for Marian.
private:
- template<class WrappedFactory>
+ template <class WrappedFactory>
class AsLayerFactory : public LayerFactory {
- WrappedFactory us;
+ WrappedFactory us;
+
public:
- AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
- Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
- auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
- ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
- return p;
- }
+ AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
+ Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
+ auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
+ ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
+ return p;
+ }
};
- template<class WrappedFactory>
- static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { return wrapped; }
+ template <class WrappedFactory>
+ static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) {
+ return wrapped;
+ }
+
public:
Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) {
push_back(AsLayerFactory<OutputFactory>(lf));
- //layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
+ // layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
return Accumulator<MLPFactory>(*this);
}
};
diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp
index 488fbb8b..5a448f61 100644
--- a/src/layers/embedding.cpp
+++ b/src/layers/embedding.cpp
@@ -3,173 +3,205 @@
namespace marian {
-Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
-: LayerBase(graph, options), inference_(opt<bool>("inference")) {
-std::string name = opt<std::string>("prefix");
-int dimVoc = opt<int>("dimVocab");
-int dimEmb = opt<int>("dimEmb");
+Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
+ : LayerBase(graph, options), inference_(opt<bool>("inference")) {
+ std::string name = opt<std::string>("prefix");
+ int dimVoc = opt<int>("dimVocab");
+ int dimEmb = opt<int>("dimEmb");
-bool fixed = opt<bool>("fixed", false);
+ bool fixed = opt<bool>("fixed", false);
-factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
-if (factoredVocab_) {
+ factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
+ if(factoredVocab_) {
dimVoc = (int)factoredVocab_->factorVocabSize();
LOG_ONCE(info, "[embedding] Factored embeddings enabled");
-}
+ }
-// Embedding layer initialization should depend only on embedding size, hence fanIn=false
-auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
+ // Embedding layer initialization should depend only on embedding size, hence fanIn=false
+ auto initFunc = inits::glorotUniform(
+ /*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
-if (options_->has("embFile")) {
+ if(options_->has("embFile")) {
std::string file = opt<std::string>("embFile");
- if (!file.empty()) {
- bool norm = opt<bool>("normalization", false);
- initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
+ if(!file.empty()) {
+ bool norm = opt<bool>("normalization", false);
+ initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
}
-}
+ }
-E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
+ E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
}
// helper to embed a sequence of words (given as indices) via factored embeddings
Expr Embedding::multiRows(const Words& data, float dropProb) const {
-auto graph = E_->graph();
-auto factoredData = factoredVocab_->csr_rows(data);
-// multi-hot factor vectors are represented as a sparse CSR matrix
-// [row index = word position index] -> set of factor indices for word at this position
-ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??");
-// the CSR matrix is passed in pieces
-auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights));
-auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32);
-auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32);
-// apply dropout
-// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors.
-if(!inference_)
+ auto graph = E_->graph();
+ auto factoredData = factoredVocab_->csr_rows(data);
+ // multi-hot factor vectors are represented as a sparse CSR matrix
+ // [row index = word position index] -> set of factor indices for word at this position
+ ABORT_IF(factoredData.shape
+ != Shape({(int)factoredData.offsets.size() - 1 /*=rows of CSR*/, E_->shape()[0]}),
+ "shape mismatch??");
+ // the CSR matrix is passed in pieces
+ auto weights = graph->constant({(int)factoredData.weights.size()},
+ inits::fromVector(factoredData.weights));
+ auto indices = graph->constant(
+ {(int)factoredData.indices.size()}, inits::fromVector(factoredData.indices), Type::uint32);
+ auto offsets = graph->constant(
+ {(int)factoredData.offsets.size()}, inits::fromVector(factoredData.offsets), Type::uint32);
+ // apply dropout
+ // We apply it to the weights, i.e. factors get dropped out separately, but always as entire
+ // vectors.
+ if(!inference_)
weights = dropout(weights, dropProb);
-// perform the product
-return csr_dot(factoredData.shape, weights, indices, offsets, E_);
+ // perform the product
+ return csr_dot(factoredData.shape, weights, indices, offsets, E_);
}
-std::tuple<Expr/*embeddings*/, Expr/*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ {
-auto graph = E_->graph();
-int dimBatch = (int)subBatch->batchSize();
-int dimEmb = E_->shape()[-1];
-int dimWidth = (int)subBatch->batchWidth();
-
-// factored embeddings:
-// - regular:
-// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
-// - factored:
-// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
-// - each row of M contains the set of factors for one word => we want a CSR matrix
-// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
-// - first compute x @ M on the CPU
-// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
-// - shape (U, specifically) not actually needed here
-// - foreach input x[i]
-// - locate row M[i,*]
-// - copy through its index values (std::vector<push_back>)
-// - create a matching ones vector (we can keep growing)
-// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
-// - CSR matrix product with E
-// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
-// - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()).
-// - weighting:
-// - core factors' gradients are sums over all words that use the factors;
-// - core factors' embeddings move very fast
-// - words will need to make up for the move; rare words cannot
-// - so, we multiply each factor with 1/refCount
-// - core factors get weighed down a lot
-// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before
-// - but forward pass weighs them down, so that all factors are in a similar numeric range
-// - if it is required to be in a different range, the embeddings can still learn that, but more slowly
-
-auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
+std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const
+/*override final*/ {
+ auto graph = E_->graph();
+ int dimBatch = (int)subBatch->batchSize();
+ int dimEmb = E_->shape()[-1];
+ int dimWidth = (int)subBatch->batchWidth();
+
+ // factored embeddings:
+ // - regular:
+ // - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
+ // - factored:
+ // - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
+ // - each row of M contains the set of factors for one word => we want a CSR matrix
+ // - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
+ // - first compute x @ M on the CPU
+ // - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
+ // - shape (U, specifically) not actually needed here
+ // - foreach input x[i]
+ // - locate row M[i,*]
+ // - copy through its index values (std::vector<push_back>)
+ // - create a matching ones vector (we can keep growing)
+ // - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
+ // - CSR matrix product with E
+ // - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
+ // - double-check if all dimensions are specified. Probably not for transpose (which would
+ // be like csc_dot()).
+ // - weighting:
+ // - core factors' gradients are sums over all words that use the factors;
+ // - core factors' embeddings move very fast
+ // - words will need to make up for the move; rare words cannot
+ // - so, we multiply each factor with 1/refCount
+ // - core factors get weighed down a lot
+ // - no impact on gradients, as Adam makes up for it; embeddings still move fast just as
+ // before
+ // - but forward pass weighs them down, so that all factors are in a similar numeric range
+ // - if it is required to be in a different range, the embeddings can still learn that, but
+ // more slowly
+
+ auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
#if 1
-auto batchMask = graph->constant({dimWidth, dimBatch, 1},
- inits::fromVector(subBatch->mask()));
-#else // @TODO: this is dead code now, get rid of it
-// experimental: hide inline-fix source tokens from cross attention
-auto batchMask = graph->constant({dimWidth, dimBatch, 1},
- inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
+ auto batchMask = graph->constant({dimWidth, dimBatch, 1}, inits::fromVector(subBatch->mask()));
+#else // @TODO: this is dead code now, get rid of it
+ // experimental: hide inline-fix source tokens from cross attention
+ auto batchMask
+ = graph->constant({dimWidth, dimBatch, 1},
+ inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
#endif
-// give the graph inputs readable names for debugging and ONNX
-batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
+ // give the graph inputs readable names for debugging and ONNX
+ batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
-return std::make_tuple(batchEmbeddings, batchMask);
+ return std::make_tuple(batchEmbeddings, batchMask);
}
Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ {
-if (factoredVocab_) {
- Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
- selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
- //selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
+ if(factoredVocab_) {
+ Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
+ selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
+ // selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), {
+ // selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
return selectedEmbs;
-}
-else
+ } else
return applyIndices(toWordIndexVector(words), shape);
}
-Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const /*override final*/ {
-ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
-auto embIdxExpr = E_->graph()->indices(embIdx);
-embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index?
-auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
-selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
-// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately)
-if(!inference_)
- selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
-return selectedEmbs;
+Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const
+/*override final*/ {
+ ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
+ auto embIdxExpr = E_->graph()->indices(embIdx);
+ embIdxExpr->set_name("data_"
+ + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index?
+ auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
+ selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
+ // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape()
+ // (test that separately)
+ if(!inference_)
+ selectedEmbs = dropout(
+ selectedEmbs, options_->get<float>("dropout", 0.0f), {selectedEmbs->shape()[-3], 1, 1});
+ return selectedEmbs;
}
// standard encoder word embeddings
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const {
-auto options = New<Options>(
- "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_],
- "dimEmb", opt<int>("dim-emb"),
- "dropout", dropoutEmbeddings_,
- "inference", inference_,
- "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb",
- "fixed", embeddingFix_,
- "vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
-if(options_->hasAndNotEmpty("embedding-vectors")) {
+ auto options = New<Options>(
+ "dimVocab",
+ opt<std::vector<int>>("dim-vocabs")[batchIndex_],
+ "dimEmb",
+ opt<int>("dim-emb"),
+ "dropout",
+ dropoutEmbeddings_,
+ "inference",
+ inference_,
+ "prefix",
+ (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb"
+ : prefix_ + "_Wemb",
+ "fixed",
+ embeddingFix_,
+ "vocab",
+ opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
+ if(options_->hasAndNotEmpty("embedding-vectors")) {
auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
options->set(
- "embFile", embFiles[batchIndex_],
- "normalization", opt<bool>("embedding-normalization"));
-}
-return New<Embedding>(graph_, options);
+ "embFile", embFiles[batchIndex_], "normalization", opt<bool>("embedding-normalization"));
+ }
+ return New<Embedding>(graph_, options);
}
// ULR word embeddings
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const {
-return New<ULREmbedding>(graph_, New<Options>(
- "dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
- "dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
- "dimUlrEmb", opt<int>("ulr-dim-emb"),
- "dimEmb", opt<int>("dim-emb"),
- "ulr-dropout", opt<float>("ulr-dropout"),
- "dropout", dropoutEmbeddings_,
- "inference", inference_,
- "ulrTrainTransform", opt<bool>("ulr-trainable-transformation"),
- "ulrQueryFile", opt<std::string>("ulr-query-vectors"),
- "ulrKeysFile", opt<std::string>("ulr-keys-vectors")));
+ return New<ULREmbedding>(
+ graph_,
+ New<Options>("dimSrcVoc",
+ opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
+ "dimTgtVoc",
+ opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
+ "dimUlrEmb",
+ opt<int>("ulr-dim-emb"),
+ "dimEmb",
+ opt<int>("dim-emb"),
+ "ulr-dropout",
+ opt<float>("ulr-dropout"),
+ "dropout",
+ dropoutEmbeddings_,
+ "inference",
+ inference_,
+ "ulrTrainTransform",
+ opt<bool>("ulr-trainable-transformation"),
+ "ulrQueryFile",
+ opt<std::string>("ulr-query-vectors"),
+ "ulrKeysFile",
+ opt<std::string>("ulr-keys-vectors")));
}
// get embedding layer for this encoder or decoder
// This is lazy mostly because the constructors of the consuming objects are not
// guaranteed presently to have access to their graph.
Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const {
-if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
- if (embeddingLayers_.size() <= batchIndex_)
- embeddingLayers_.resize(batchIndex_ + 1);
- if (ulr)
- embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
+ if(embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
+ if(embeddingLayers_.size() <= batchIndex_)
+ embeddingLayers_.resize(batchIndex_ + 1);
+ if(ulr)
+ embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
else
- embeddingLayers_[batchIndex_] = createEmbeddingLayer();
-}
-return embeddingLayers_[batchIndex_];
-}
-
+ embeddingLayers_[batchIndex_] = createEmbeddingLayer();
+ }
+ return embeddingLayers_[batchIndex_];
}
+} // namespace marian
diff --git a/src/layers/embedding.h b/src/layers/embedding.h
index b7898c76..6edb3140 100644
--- a/src/layers/embedding.h
+++ b/src/layers/embedding.h
@@ -1,6 +1,6 @@
#pragma once
-#include "marian.h"
#include "generic.h"
+#include "marian.h"
namespace marian {
@@ -19,7 +19,8 @@ class Embedding : public LayerBase, public IEmbeddingLayer {
public:
Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options);
- std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final;
+ std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
+ Ptr<data::SubBatch> subBatch) const override final;
Expr apply(const Words& words, const Shape& shape) const override final;
@@ -27,17 +28,18 @@ public:
};
class ULREmbedding : public LayerBase, public IEmbeddingLayer {
- std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
+ std::vector<Expr>
+ ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
bool inference_{false};
public:
- ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : LayerBase(graph, options), inference_(opt<bool>("inference")) {
- std::string name = "url_embed"; //opt<std::string>("prefix");
+ ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
+ : LayerBase(graph, options), inference_(opt<bool>("inference")) {
+ std::string name = "url_embed"; // opt<std::string>("prefix");
int dimKeys = opt<int>("dimTgtVoc");
int dimQueries = opt<int>("dimSrcVoc");
int dimEmb = opt<int>("dimEmb");
- int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
+ int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
bool fixed = opt<bool>("fixed", false);
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
@@ -46,58 +48,61 @@ public:
std::string queryFile = opt<std::string>("ulrQueryFile");
std::string keyFile = opt<std::string>("ulrKeysFile");
bool trainTrans = opt<bool>("ulrTrainTransform", false);
- if (!queryFile.empty() && !keyFile.empty()) {
+ if(!queryFile.empty() && !keyFile.empty()) {
initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false);
name = "ulr_query";
fixed = true;
- auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed);
+ auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed);
ulrEmbeddings_.push_back(query_embed);
// keys embeds
initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false);
name = "ulr_keys";
fixed = true;
- auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed);
+ auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed);
ulrEmbeddings_.push_back(key_embed);
// actual trainable embedding
initFunc = inits::glorotUniform();
name = "ulr_embed";
fixed = false;
- auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim
+ auto ulr_embed
+ = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim
ulrEmbeddings_.push_back(ulr_embed);
// init trainable src embedding
name = "ulr_src_embed";
- auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed);
+ auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed);
ulrEmbeddings_.push_back(ulr_src_embed);
// ulr transformation matrix
- //initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only
- if (trainTrans) {
+ // initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall
+ // we make this to the fixed case only
+ if(trainTrans) {
initFunc = inits::glorotUniform();
fixed = false;
- }
- else
- {
- initFunc = inits::eye(); // identity matrix
+ } else {
+ initFunc = inits::eye(); // identity matrix
fixed = true;
}
name = "ulr_transform";
- auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed);
+ auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed);
ulrEmbeddings_.push_back(ulrTransform);
- initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only
+ initFunc = inits::fromValue(
+ 1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no
+ // universal embeddings - should be zero for top freq only
fixed = true;
name = "ulr_shared";
- auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed);
+ auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed);
ulrEmbeddings_.push_back(share_embed);
}
}
- std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final {
- auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
- auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
- auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
- auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
- auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
- auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
+ std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
+ Ptr<data::SubBatch> subBatch) const override final {
+ auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
+ auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
+ auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
+ auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
+ auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
+ auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
int dimBatch = (int)subBatch->batchSize();
int dimEmb = uniEmbed->shape()[-1];
int dimWords = (int)subBatch->batchWidth();
@@ -106,34 +111,42 @@ public:
// dim A = uni_embed_size*uni_embed_size
// dim Q: uni_embed_size * total_merged_vocab_size
// dim D = univ_tok_vocab * total_merged_vocab_size
- // note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD)
- // here we need to handle the mini-batch
- // extract raws corresponding to Xs in this minibatch from Q
+ // note all above can be precombuted and serialized if A is not trainiable and during decoding
+ // (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this
+ // minibatch from Q
auto embIdx = toWordIndexVector(subBatch->data());
auto queryEmbeddings = rows(queryEmbed, embIdx);
- auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
- auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
- auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
- auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]);
- qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes
- auto z = dot(qt, keyEmbed, false, true); // query-key similarity
+ auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
+ auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
+ auto qt = dot(queryEmbeddings,
+ ulrTransform,
+ false,
+ false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
+ auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]);
+ qt = qt / sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in
+ // magnitude with larger embeds sizes
+ auto z = dot(qt, keyEmbed, false, true); // query-key similarity
float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout
if(!inference_)
z = dropout(z, dropProb);
- float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
+ float tau
+ = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
// temperature in softmax is to control randomness of predictions
// high temperature Softmax outputs are more close to each other
// low temperatures the softmax become more similar to "hardmax"
- auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
+ auto weights
+ = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE
- auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
- auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb });
+ auto chosenEmbeddings_mix
+ = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
+ auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb});
auto graph = ulrEmbeddings_.front()->graph();
- auto batchMask = graph->constant({ dimWords, dimBatch, 1 },
- inits::fromVector(subBatch->mask()));
+ auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask()));
if(!inference_)
- batchEmbeddings = dropout(batchEmbeddings, options_->get<float>("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1});
+ batchEmbeddings = dropout(batchEmbeddings,
+ options_->get<float>("dropout-embeddings", 0.0f),
+ {batchEmbeddings->shape()[-3], 1, 1});
return std::make_tuple(batchEmbeddings, batchMask);
}
@@ -142,9 +155,10 @@ public:
}
Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final {
- embIdx; shape;
- ABORT("not implemented"); // @TODO: implement me
+ embIdx;
+ shape;
+ ABORT("not implemented"); // @TODO: implement me
}
};
-}
+} // namespace marian
diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp
index 02e820e5..8e2ecfd7 100644
--- a/src/layers/generic.cpp
+++ b/src/layers/generic.cpp
@@ -1,13 +1,10 @@
#include "marian.h"
-#include "layers/generic.h"
+#include "data/factored_vocab.h"
#include "layers/constructors.h"
+#include "layers/generic.h"
#include "layers/loss.h"
-#include "data/factored_vocab.h"
-#include "models/states.h" // for EncoderState
#include "layers/lsh.h"
+#include "models/states.h" // for EncoderState
-namespace marian {
-
-
-} // namespace marian
+namespace marian {} // namespace marian
diff --git a/src/layers/generic.h b/src/layers/generic.h
index eddd597e..89f5c1e9 100644
--- a/src/layers/generic.h
+++ b/src/layers/generic.h
@@ -5,12 +5,14 @@
#include "data/shortlist.h"
#include "layers/factory.h"
-namespace marian { namespace mlp {
- /**
- * @brief Activation functions
- */
- enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
-}}
+namespace marian {
+namespace mlp {
+/**
+ * @brief Activation functions
+ */
+enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
+} // namespace mlp
+} // namespace marian
namespace marian {
@@ -23,8 +25,7 @@ protected:
Ptr<Options> options_;
public:
- LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : graph_(graph), options_(options) {}
+ LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
template <typename T>
T opt(const std::string key) const {
@@ -42,7 +43,7 @@ struct IUnaryLayer {
virtual ~IUnaryLayer() {}
virtual Expr apply(Expr) = 0;
virtual Expr apply(const std::vector<Expr>& es) {
- ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
+ ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
return apply(es.front());
}
};
@@ -54,7 +55,8 @@ struct IHasShortList {
// Embedding from corpus sub-batch to (emb, mask)
struct IEmbeddingLayer {
- virtual std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const = 0;
+ virtual std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
+ Ptr<data::SubBatch> subBatch) const = 0;
virtual Expr apply(const Words& embIdx, const Shape& shape) const = 0;
@@ -63,28 +65,29 @@ struct IEmbeddingLayer {
virtual ~IEmbeddingLayer() {}
};
-// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream index)
+// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream
+// index)
class EncoderDecoderLayerBase : public LayerBase {
protected:
const std::string prefix_;
const bool embeddingFix_;
- const float dropoutEmbeddings_; // this drops out full embedding vectors
+ const float dropoutEmbeddings_; // this drops out full embedding vectors
const bool inference_;
const size_t batchIndex_;
- mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created)
+ mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created)
- EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph,
- Ptr<Options> options,
- const std::string& prefix,
+ EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph,
+ Ptr<Options> options,
+ const std::string& prefix,
size_t batchIndex,
float dropoutEmbeddings,
- bool embeddingFix) :
- LayerBase(graph, options),
- prefix_(options->get<std::string>("prefix", prefix)),
- embeddingFix_(embeddingFix),
- dropoutEmbeddings_(dropoutEmbeddings),
- inference_(options->get<bool>("inference", false)),
- batchIndex_(options->get<size_t>("index", batchIndex)) {}
+ bool embeddingFix)
+ : LayerBase(graph, options),
+ prefix_(options->get<std::string>("prefix", prefix)),
+ embeddingFix_(embeddingFix),
+ dropoutEmbeddings_(dropoutEmbeddings),
+ inference_(options->get<bool>("inference", false)),
+ batchIndex_(options->get<size_t>("index", batchIndex)) {}
virtual ~EncoderDecoderLayerBase() {}
@@ -101,8 +104,7 @@ namespace mlp {
class Dense : public LayerBase, public IUnaryLayer {
public:
- Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : LayerBase(graph, options) {}
+ Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {}
Expr apply(const std::vector<Expr>& inputs) override {
ABORT_IF(inputs.empty(), "No inputs");
@@ -124,21 +126,17 @@ public:
if(inputs.size() > 1)
num = std::to_string(i);
- Expr W = g->param(
- name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform());
+ Expr W = g->param(name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform());
Expr b = g->param(name + "_b" + num, {1, dim}, inits::zeros());
if(useLayerNorm) {
if(useNematusNorm) {
- auto ln_s = g->param(
- name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f));
+ auto ln_s = g->param(name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f));
auto ln_b = g->param(name + "_ln_b" + num, {1, dim}, inits::zeros());
- outputs.push_back(
- layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS));
+ outputs.push_back(layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS));
} else {
- auto gamma = g->param(
- name + "_gamma" + num, {1, dim}, inits::fromValue(1.0));
+ auto gamma = g->param(name + "_gamma" + num, {1, dim}, inits::fromValue(1.0));
outputs.push_back(layerNorm(dot(in, W), gamma, b));
}
@@ -165,39 +163,35 @@ public:
Expr apply(Expr input) override { return apply(std::vector<Expr>({input})); }
};
-} // namespace mlp
-
+} // namespace mlp
// --- a few layers with built-in parameters created on the fly, without proper object
// @TODO: change to a proper layer object
// like affine() but with built-in parameters, activation, and dropout
-static inline
-Expr denseInline(Expr x,
- std::string prefix,
- std::string suffix,
- int outDim,
- Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
- const std::function<Expr(Expr)>& actFn = nullptr,
- float dropProb = 0.0f)
-{
+static inline Expr denseInline(Expr x,
+ std::string prefix,
+ std::string suffix,
+ int outDim,
+ Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
+ const std::function<Expr(Expr)>& actFn = nullptr,
+ float dropProb = 0.0f) {
auto graph = x->graph();
- auto W = graph->param(prefix + "_W" + suffix, { x->shape()[-1], outDim }, inits::glorotUniform());
- auto b = graph->param(prefix + "_b" + suffix, { 1, outDim }, inits::zeros());
+ auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, inits::glorotUniform());
+ auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros());
x = affine(x, W, b);
- if (actFn)
+ if(actFn)
x = actFn(x);
- x = dropout(x, dropProb); // @TODO: check for infernce?
+ x = dropout(x, dropProb); // @TODO: check for infernce?
return x;
}
-static inline
-Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
+static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
int dimModel = x->shape()[-1];
- auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, { 1, dimModel }, inits::ones());
- auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, { 1, dimModel }, inits::zeros());
+ auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, {1, dimModel}, inits::ones());
+ auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, {1, dimModel}, inits::zeros());
return marian::layerNorm(x, scale, bias, 1e-6f);
}
diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp
index cd2203e4..772c5715 100644
--- a/src/layers/logits.cpp
+++ b/src/layers/logits.cpp
@@ -1,212 +1,250 @@
#include "logits.h"
-#include "loss.h"
#include "data/factored_vocab.h"
-#include "rnn/types.h" // for State::select()
+#include "loss.h"
+#include "rnn/types.h" // for State::select()
namespace marian {
- Logits::Logits(Expr logits) : Logits(New<RationalLoss>(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count)
-
- Ptr<ExpressionGraph> Logits::graph() const {
- ABORT_IF(logits_.empty(), "Empty logits object??");
- return logits_.front()->loss()->graph();
+Logits::Logits(Expr logits)
+ : Logits(New<RationalLoss>(logits, nullptr)) {
+} // single-output constructor from Expr only (RationalLoss has no count)
+
+Ptr<ExpressionGraph> Logits::graph() const {
+ ABORT_IF(logits_.empty(), "Empty logits object??");
+ return logits_.front()->loss()->graph();
+}
+
+// This function assumes that the object holds one or more factor logits.
+// It applies the supplied loss function to each, and then returns the aggregate loss over all
+// factors.
+Expr Logits::applyLossFunction(
+ const Words& labels,
+ const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const {
+ LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size());
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+
+ auto firstLogits = logits_.front()->loss();
+ ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
+ "Labels not matching logits shape ({} != {}, {})??",
+ labels.size() * firstLogits->shape()[-1],
+ firstLogits->shape().elements(),
+ firstLogits->shape());
+
+ // base case (no factors)
+ if(!factoredVocab_) {
+ ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
+ return lossFn(firstLogits, indices(toWordIndexVector(labels)));
}
- // This function assumes that the object holds one or more factor logits.
- // It applies the supplied loss function to each, and then returns the aggregate loss over all factors.
- Expr Logits::applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/, Expr/*indices*/)>& lossFn) const {
- LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size());
- ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
-
- auto firstLogits = logits_.front()->loss();
- ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
- "Labels not matching logits shape ({} != {}, {})??",
- labels.size() * firstLogits->shape()[-1],
- firstLogits->shape().elements(),
- firstLogits->shape());
-
- // base case (no factors)
- if (!factoredVocab_) {
- ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
- return lossFn(firstLogits, indices(toWordIndexVector(labels)));
- }
-
- auto numGroups = factoredVocab_->getNumGroups();
-
- // split labels into individual factor labels
- auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened]
-
- //Expr indices = this->indices(toWordIndexVector(labels));
- // accumulate all CEs for all words that have the factor
- // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors.
- Expr loss;
- for (size_t g = 0; g < numGroups; g++) {
- if (!logits_[g])
- continue; // empty factor --@TODO: use an array of indices of non-empty logits_[]
- const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask)
- auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
- auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
- auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
- // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
- auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
- if(loss)
- factorLoss = cast(factorLoss, loss->value_type());
- factorLoss = factorLoss * cast(reshape(factorMask, factorLoss->shape()), factorLoss->value_type()); // mask out factor for words that do not have that factor
- loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1]
- }
- return loss;
+ auto numGroups = factoredVocab_->getNumGroups();
+
+ // split labels into individual factor labels
+ auto allMaskedFactoredLabels
+ = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened]
+
+ // Expr indices = this->indices(toWordIndexVector(labels));
+ // accumulate all CEs for all words that have the factor
+ // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors.
+ Expr loss;
+ for(size_t g = 0; g < numGroups; g++) {
+ if(!logits_[g])
+ continue; // empty factor --@TODO: use an array of indices of non-empty logits_[]
+ const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask)
+ auto factorIndices = indices(
+ maskedFactoredLabels
+ .indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
+ auto factorMask
+ = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with
+ // 0 for labels that don't have this factor
+ auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
+ // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask
+ // it out next.
+ auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
+ if(loss)
+ factorLoss = cast(factorLoss, loss->value_type());
+ factorLoss
+ = factorLoss
+ * cast(
+ reshape(factorMask, factorLoss->shape()),
+ factorLoss->value_type()); // mask out factor for words that do not have that factor
+ loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1]
}
-
- // This function assumes this object holds a single factor that represents a rational loss (with count).
- //Ptr<RationalLoss> Logits::getRationalLoss() const {
- // ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs");
- // ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count");
- // return logits_.front();
- //}
-
- // get logits for one factor group
- // For groupIndex == 0, the function also requires the shortlist if there is one.
- Expr Logits::getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist /*= nullptr*/, const std::vector<IndexType>& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const {
- ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
-
- auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
-
- // normalize for decoding:
- // - all secondary factors: subtract their max
- // - lemma: add all maxes of applicable factors
- if (groupIndex > 0) {
- sel = sel - max(sel, -1);
- }
- else {
- auto numGroups = getNumFactorGroups();
- for (size_t g = 1; g < numGroups; g++) {
- auto factorMaxima = max(logits_[g]->loss(), -1); // we cast since loss is likely ce-loss which has type float32
- auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>()));
- sel = sel + cast(factorMaxima, sel->value_type()) * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor get multiplied with 0
- }
+ return loss;
+}
+
+// This function assumes this object holds a single factor that represents a rational loss (with
+// count).
+// Ptr<RationalLoss> Logits::getRationalLoss() const {
+// ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on
+// multi-factor outputs"); ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational
+// loss without count"); return logits_.front();
+//}
+
+// get logits for one factor group
+// For groupIndex == 0, the function also requires the shortlist if there is one.
+Expr Logits::getFactoredLogits(size_t groupIndex,
+ Ptr<data::Shortlist> shortlist /*= nullptr*/,
+ const std::vector<IndexType>& hypIndices /*= {}*/,
+ size_t beamSize /*= 0*/) const {
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+
+ auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
+
+ // normalize for decoding:
+ // - all secondary factors: subtract their max
+ // - lemma: add all maxes of applicable factors
+ if(groupIndex > 0) {
+ sel = sel - max(sel, -1);
+ } else {
+ auto numGroups = getNumFactorGroups();
+ for(size_t g = 1; g < numGroups; g++) {
+ auto factorMaxima = max(logits_[g]->loss(),
+ -1); // we cast since loss is likely ce-loss which has type float32
+ auto factorMasks = constant(
+ getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>()));
+ sel = sel
+ + cast(factorMaxima, sel->value_type())
+ * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor
+ // get multiplied with 0
}
-
- // if selIdx are given, then we must reshuffle accordingly
- if (!hypIndices.empty()) // use the same function that shuffles decoder state
- sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
-
- return sel;
}
- // used for breakDown() only
- // Index is flattened
- Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const {
- ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
- return logits_[groupIndex]->loss()->val();
+ // if selIdx are given, then we must reshuffle accordingly
+ if(!hypIndices.empty()) // use the same function that shuffles decoder state
+ sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
+
+ return sel;
+}
+
+// used for breakDown() only
+// Index is flattened
+Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const {
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+ return logits_[groupIndex]->loss()->val();
+}
+
+// This function assumes that the object holds one or more factor logits, which are summed up
+// into output-vocab logits according to the factored model (with correct normalization of factors).
+// This is infeasible for realistic factor sets, and therefore only implemented for 1 factor.
+// @TODO: remove altogether
+Expr Logits::getLogits() const {
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+ if(!factoredVocab_) {
+ ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
+ return getFactoredLogits(0);
}
- // This function assumes that the object holds one or more factor logits, which are summed up
- // into output-vocab logits according to the factored model (with correct normalization of factors).
- // This is infeasible for realistic factor sets, and therefore only implemented for 1 factor.
- // @TODO: remove altogether
- Expr Logits::getLogits() const {
- ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
- if (!factoredVocab_) {
- ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
- return getFactoredLogits(0);
- }
-
#ifdef FACTOR_FULL_EXPANSION
- // compute normalized factor log probs
- std::vector<Expr> logProbs(logits_.size());
- for (size_t g = 0; g < logits_.size(); g++)
- logProbs[g] = logsoftmax(logits_[g]->loss());
- auto y = concatenate(logProbs, /*axis=*/ -1);
-
- // sum up the unit logits across factors for each target word
- auto graph = y->graph();
- auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U]
- y = dot_csr(
- y, // [B x U]
- factorMatrix.shape,
- graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)),
- graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32),
- graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32),
- /*transB=*/ true); // -> [B x V]
-
- // mask out gaps
- auto gapLogMask = factoredVocab_->getGapLogMask(); // [V]
- y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask));
-
- return y;
+ // compute normalized factor log probs
+ std::vector<Expr> logProbs(logits_.size());
+ for(size_t g = 0; g < logits_.size(); g++)
+ logProbs[g] = logsoftmax(logits_[g]->loss());
+ auto y = concatenate(logProbs, /*axis=*/-1);
+
+ // sum up the unit logits across factors for each target word
+ auto graph = y->graph();
+ auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U]
+ y = dot_csr(
+ y, // [B x U]
+ factorMatrix.shape,
+ graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)),
+ graph->constant({(int)factorMatrix.indices.size()},
+ inits::fromVector(factorMatrix.indices),
+ Type::uint32),
+ graph->constant({(int)factorMatrix.offsets.size()},
+ inits::fromVector(factorMatrix.offsets),
+ Type::uint32),
+ /*transB=*/true); // -> [B x V]
+
+ // mask out gaps
+ auto gapLogMask = factoredVocab_->getGapLogMask(); // [V]
+ y = y + graph->constant({(int)gapLogMask.size()}, inits::fromVector(gapLogMask));
+
+ return y;
#else
- ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible
+ ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible
#endif
+}
+
+void Logits::MaskedFactorIndices::push_back(size_t factorIndex) {
+ bool isValid = FactoredVocab::isFactorValid(factorIndex);
+ indices.push_back(isValid ? (WordIndex)factorIndex : 0);
+ masks.push_back((float)isValid);
+}
+
+std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words)
+ const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices
+ if(!factoredVocab_) {
+ ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
+ return {MaskedFactorIndices(words)};
}
-
- void Logits::MaskedFactorIndices::push_back(size_t factorIndex) {
- bool isValid = FactoredVocab::isFactorValid(factorIndex);
- indices.push_back(isValid ? (WordIndex)factorIndex : 0);
- masks.push_back((float)isValid);
+ auto numGroups = factoredVocab_->getNumGroups();
+ std::vector<MaskedFactorIndices> res(numGroups);
+ for(size_t g = 0; g < numGroups; g++) {
+ auto& resg = res[g];
+ resg.reserve(words.size());
+ for(const auto& word : words)
+ resg.push_back(factoredVocab_->getFactor(word, g));
}
-
- std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices
- if (!factoredVocab_) {
- ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
- return {MaskedFactorIndices(words)};
- }
- auto numGroups = factoredVocab_->getNumGroups();
- std::vector<MaskedFactorIndices> res(numGroups);
- for (size_t g = 0; g < numGroups; g++) {
- auto& resg = res[g];
- resg.reserve(words.size());
- for (const auto& word : words)
- resg.push_back(factoredVocab_->getFactor(word, g));
- }
- return res;
- }
-
- //// use first factor of each word to determine whether it has a specific factor
- //std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0
- // std::vector<float> res;
- // res.reserve(words.size());
- // for (const auto& word : words) {
- // auto lemma = factoredVocab_->getFactor(word, 0);
- // res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
- // }
- // return res;
- //}
-
- // return a vector of 1 or 0 indicating for each lemma whether it has a specific factor
- // If 'indices' is given, then return the masks for the indices; otherwise for all lemmas
- std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
- size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size();
- std::vector<float> res;
- res.reserve(n);
- // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab
- for (size_t i = 0; i < n; i++) {
- auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first);
- res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
- }
- return res;
+ return res;
+}
+
+//// use first factor of each word to determine whether it has a specific factor
+// std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0
+// for words that do have this factor; else 0
+// std::vector<float> res;
+// res.reserve(words.size());
+// for (const auto& word : words) {
+// auto lemma = factoredVocab_->getFactor(word, 0);
+// res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
+// }
+// return res;
+//}
+
+// return a vector of 1 or 0 indicating for each lemma whether it has a specific factor
+// If 'indices' is given, then return the masks for the indices; otherwise for all lemmas
+std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices)
+ const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
+ size_t n
+ = indices.empty()
+ ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first)
+ : indices.size();
+ std::vector<float> res;
+ res.reserve(n);
+ // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this
+ // into FactoredVocab
+ for(size_t i = 0; i < n; i++) {
+ auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first);
+ res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
}
-
- Logits Logits::applyUnaryFunction(const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
- std::vector<Ptr<RationalLoss>> newLogits;
- for (const auto& l : logits_)
- newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count()));
- return Logits(std::move(newLogits), factoredVocab_);
- }
-
- Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const {
- std::vector<Ptr<RationalLoss>> newLogits;
- bool first = true;
- for (const auto& l : logits_) {
- newLogits.emplace_back(New<RationalLoss>((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others
- first = false;
- }
- return Logits(std::move(newLogits), factoredVocab_);
- }
-
- // @TODO: code dup with above; we can merge it into applyToRationalLoss()
- Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_
- std::vector<Ptr<RationalLoss>> newLogits;
- for (const auto& l : logits_)
- newLogits.emplace_back(New<RationalLoss>(l->loss(), count));
- return Logits(std::move(newLogits), factoredVocab_);
+ return res;
+}
+
+Logits Logits::applyUnaryFunction(
+ const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
+ std::vector<Ptr<RationalLoss>> newLogits;
+ for(const auto& l : logits_)
+ newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count()));
+ return Logits(std::move(newLogits), factoredVocab_);
+}
+
+Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1,
+ const std::function<Expr(Expr)>& fother) const {
+ std::vector<Ptr<RationalLoss>> newLogits;
+ bool first = true;
+ for(const auto& l : logits_) {
+ newLogits.emplace_back(New<RationalLoss>((first ? f1 : fother)(l->loss()),
+ l->count())); // f1 for first, fother for all others
+ first = false;
}
-} \ No newline at end of file
+ return Logits(std::move(newLogits), factoredVocab_);
+}
+
+// @TODO: code dup with above; we can merge it into applyToRationalLoss()
+Logits Logits::withCounts(
+ const Expr& count) const { // create new Logits with 'count' implanted into all logits_
+ std::vector<Ptr<RationalLoss>> newLogits;
+ for(const auto& l : logits_)
+ newLogits.emplace_back(New<RationalLoss>(l->loss(), count));
+ return Logits(std::move(newLogits), factoredVocab_);
+}
+} // namespace marian \ No newline at end of file
diff --git a/src/layers/logits.h b/src/layers/logits.h
index 4196e0d0..c61a9e74 100644
--- a/src/layers/logits.h
+++ b/src/layers/logits.h
@@ -1,8 +1,8 @@
#pragma once
-#include "marian.h"
#include "data/shortlist.h"
#include "generic.h"
+#include "marian.h"
namespace marian {
@@ -16,46 +16,77 @@ class FactoredVocab;
class RationalLoss;
class Logits {
public:
- Logits() {}
- explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
- logits_.push_back(logits);
- }
- explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
- Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
+ Logits() {}
+ explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
+ logits_.push_back(logits);
+ }
+ explicit Logits(
+ Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
+ Logits(std::vector<Ptr<RationalLoss>>&& logits,
+ Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
: logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {}
- Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
- Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
- //Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
- Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const;
- Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
- Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values
+ Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
+ Expr getFactoredLogits(
+ size_t groupIndex,
+ Ptr<data::Shortlist> shortlist = nullptr,
+ const std::vector<IndexType>& hypIndices = {},
+ size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
+ // Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
+ Expr applyLossFunction(
+ const Words& labels,
+ const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const;
+ Logits applyUnaryFunction(
+ const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
+ Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1,
+ const std::function<Expr(Expr)>& fother)
+ const; // clone this but apply f1 to first and fother to to all other values
- struct MaskedFactorIndices {
- std::vector<WordIndex> indices; // factor index, or 0 if masked
- std::vector<float> masks;
- void reserve(size_t n) { indices.reserve(n); masks.reserve(n); }
- void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries
- MaskedFactorIndices() {}
- MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case
- };
- std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices
- Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
- size_t getNumFactorGroups() const { return logits_.size(); }
- bool empty() const { return logits_.empty(); }
- Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_
+ struct MaskedFactorIndices {
+ std::vector<WordIndex> indices; // factor index, or 0 if masked
+ std::vector<float> masks;
+ void reserve(size_t n) {
+ indices.reserve(n);
+ masks.reserve(n);
+ }
+ void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0
+ // for invalid entries
+ MaskedFactorIndices() {}
+ MaskedFactorIndices(const Words& words) {
+ indices = toWordIndexVector(words);
+ } // we can leave masks uninitialized for this special use case
+ };
+ std::vector<MaskedFactorIndices> factorizeWords(
+ const Words& words) const; // breaks encoded Word into individual factor indices
+ Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
+ size_t getNumFactorGroups() const { return logits_.size(); }
+ bool empty() const { return logits_.empty(); }
+ Logits withCounts(
+ const Expr& count) const; // create new Logits with 'count' implanted into all logits_
private:
- // helper functions
- Ptr<ExpressionGraph> graph() const;
- Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
- Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
- template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector
- Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type
- std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const;
+ // helper functions
+ Ptr<ExpressionGraph> graph() const;
+ Expr constant(const Shape& shape, const std::vector<float>& data) const {
+ return graph()->constant(shape, inits::fromVector(data));
+ }
+ Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const {
+ return graph()->constant(shape, inits::fromVector(data));
+ }
+ template <typename T>
+ Expr constant(const std::vector<T>& data) const {
+ return constant(Shape{(int)data.size()}, data);
+ } // same as constant() but assuming vector
+ Expr indices(const std::vector<uint32_t>& data) const {
+ return graph()->indices(data);
+ } // actually the same as constant(data) for this data type
+ std::vector<float> getFactorMasks(size_t factorGroup,
+ const std::vector<WordIndex>& indices) const;
+
private:
- // members
- // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr
- std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
- Ptr<FactoredVocab> factoredVocab_;
+ // members
+ // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just
+ // by the Expr
+ std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
+ Ptr<FactoredVocab> factoredVocab_;
};
// Unary function that returns a Logits object
@@ -65,12 +96,11 @@ private:
struct IUnaryLogitLayer : public IUnaryLayer {
virtual Logits applyAsLogits(Expr) = 0;
virtual Logits applyAsLogits(const std::vector<Expr>& es) {
- ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
+ ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
return applyAsLogits(es.front());
}
virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); }
virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); }
};
-}
-
+} // namespace marian
diff --git a/src/layers/loss.cpp b/src/layers/loss.cpp
index 67d38832..695276af 100644
--- a/src/layers/loss.cpp
+++ b/src/layers/loss.cpp
@@ -13,26 +13,30 @@ Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference) {
bool wordScores = options->get<bool>("word-scores", false);
return New<RescorerLoss>(wordScores);
} else if(unlikelihood) {
- ABORT_IF(!options->hasAndNotEmpty("data-weighting")
- && options->get<std::string>("data-weighting-type") != "word",
- "Unlikelihood loss training requires error annotation in form of per-target-label scores");
- return New<SequenceUnlikelihoodLoss>(smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on values given for data-weighting
- } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. E.g. what about ce-sum?
+ ABORT_IF(
+ !options->hasAndNotEmpty("data-weighting")
+ && options->get<std::string>("data-weighting-type") != "word",
+ "Unlikelihood loss training requires error annotation in form of per-target-label scores");
+ return New<SequenceUnlikelihoodLoss>(
+ smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on
+ // values given for data-weighting
+ } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones.
+ // E.g. what about ce-sum?
return New<CrossEntropyLoss>(smoothing, factorWeight);
}
}
// see loss.h for detailed explanations of each class
Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) {
- std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
- if(multiLossType == "sum") // sum of sums
- return New<SumMultiRationalLoss>();
- else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale
- return New<ScaledMultiRationalLoss>();
- else if(multiLossType == "mean") // sum of means
- return New<MeanMultiRationalLoss>();
- else
- ABORT("Unknown multi-loss-type {}", multiLossType);
+ std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
+ if(multiLossType == "sum") // sum of sums
+ return New<SumMultiRationalLoss>();
+ else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale
+ return New<ScaledMultiRationalLoss>();
+ else if(multiLossType == "mean") // sum of means
+ return New<MeanMultiRationalLoss>();
+ else
+ ABORT("Unknown multi-loss-type {}", multiLossType);
}
} // namespace marian
diff --git a/src/layers/loss.h b/src/layers/loss.h
index ba93cdac..c662f991 100644
--- a/src/layers/loss.h
+++ b/src/layers/loss.h
@@ -1,8 +1,8 @@
#pragma once
-#include "graph/expression_operators.h"
-#include "layers/logits.h" // for Logits (Frank's factor hack)
#include "data/types.h"
+#include "graph/expression_operators.h"
+#include "layers/logits.h" // for Logits (Frank's factor hack)
namespace marian {
@@ -22,21 +22,18 @@ namespace marian {
*/
class RationalLoss {
protected:
- Expr loss_; // numerator
- Expr count_; // denominator
+ Expr loss_; // numerator
+ Expr count_; // denominator
- RationalLoss() = default; // protected
+ RationalLoss() = default; // protected
public:
- RationalLoss(Expr loss, Expr count)
- : loss_(loss), count_(count) {}
+ RationalLoss(Expr loss, Expr count) : loss_(loss), count_(count) {}
RationalLoss(Expr loss, float count)
- : loss_(loss),
- count_(constant_like(loss, inits::fromValue(count))) {}
+ : loss_(loss), count_(constant_like(loss, inits::fromValue(count))) {}
- RationalLoss(const RationalLoss& other)
- : loss_(other.loss_), count_(other.count_) {}
+ RationalLoss(const RationalLoss& other) : loss_(other.loss_), count_(other.count_) {}
virtual ~RationalLoss() = default;
@@ -50,7 +47,7 @@ public:
}
template <typename T>
- T loss() const { // this will fail if loss is not a single value
+ T loss() const { // this will fail if loss is not a single value
ABORT_IF(!loss_, "Loss has not been defined");
return loss_->val()->scalar<T>();
}
@@ -65,7 +62,7 @@ public:
}
template <typename T>
- T count() const { // this will fail if loss is not a single value
+ T count() const { // this will fail if loss is not a single value
ABORT_IF(!count_, "Labels have not been defined");
return count_->val()->scalar<T>();
}
@@ -85,21 +82,21 @@ public:
* RationalLoss object.
*/
struct StaticLoss {
- float loss; // numerator
- float count; // denominator
+ float loss; // numerator
+ float count; // denominator
StaticLoss() : loss(0.f), count(0.f) {}
StaticLoss(const RationalLoss& dynamic)
- : loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}
+ : loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}
- StaticLoss operator +(const StaticLoss& other) const {
+ StaticLoss operator+(const StaticLoss& other) const {
StaticLoss res(*this);
res += other;
return res;
}
- StaticLoss& operator +=(const StaticLoss& other) {
+ StaticLoss& operator+=(const StaticLoss& other) {
loss = loss + other.loss;
count = count + other.count;
return *this;
@@ -139,32 +136,21 @@ protected:
public:
MultiRationalLoss() : RationalLoss() {}
- MultiRationalLoss(const RationalLoss& rl) : RationalLoss() {
- push_back(rl);
- }
+ MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { push_back(rl); }
virtual void push_back(const RationalLoss& current) {
- loss_ = accumulateLoss(current);
- count_ = accumulateCount(current);
+ loss_ = accumulateLoss(current);
+ count_ = accumulateCount(current);
partialLosses_.push_back(current);
}
- const RationalLoss& operator[](size_t i) {
- return partialLosses_[i];
- }
+ const RationalLoss& operator[](size_t i) { return partialLosses_[i]; }
- auto begin() -> decltype(partialLosses_.begin()) const {
- return partialLosses_.begin();
- }
+ auto begin() -> decltype(partialLosses_.begin()) const { return partialLosses_.begin(); }
- auto end() -> decltype(partialLosses_.end()) const {
- return partialLosses_.end();
- }
-
- size_t size() const {
- return partialLosses_.size();
- }
+ auto end() -> decltype(partialLosses_.end()) const { return partialLosses_.end(); }
+ size_t size() const { return partialLosses_.size(); }
};
/**
@@ -212,17 +198,19 @@ private:
virtual Expr accumulateLoss(const RationalLoss& current) override {
if(loss_) {
const auto& first = partialLosses_.front();
- return loss_ + current.loss() * first.count() / current.count(); // scale up/down to match scale of first loss
+ return loss_
+ + current.loss() * first.count()
+ / current.count(); // scale up/down to match scale of first loss
} else {
- return current.loss(); // first reference loss, keeps to scale with this one
+ return current.loss(); // first reference loss, keeps to scale with this one
}
}
virtual Expr accumulateCount(const RationalLoss& current) override {
if(count_) {
- return count_; // Keep first label count // or: count_ + first.count() / current.count();
+ return count_; // Keep first label count // or: count_ + first.count() / current.count();
} else {
- return current.count(); // This is the first loss
+ return current.count(); // This is the first loss
}
}
@@ -253,9 +241,10 @@ private:
virtual Expr accumulateCount(const RationalLoss& current) override {
if(count_)
- return count_; // keep the existing '1'
+ return count_; // keep the existing '1'
else
- return current.count()->graph()->ones({1}, current.loss()->value_type()); // just '1' as labels are factored into loss_
+ return current.count()->graph()->ones(
+ {1}, current.loss()->value_type()); // just '1' as labels are factored into loss_
}
public:
@@ -279,18 +268,21 @@ class LabelwiseLoss {
protected:
std::vector<int> axes_;
- virtual Expr compute(Logits logits, const Words& labels,
- Expr mask = nullptr, Expr labelWeights = nullptr) = 0;
+ virtual Expr compute(Logits logits,
+ const Words& labels,
+ Expr mask = nullptr,
+ Expr labelWeights = nullptr)
+ = 0;
// label counts are available, reduce together with loss to obtain counts
RationalLoss reduce(Expr loss, Expr labels) {
ABORT_IF(!loss, "Loss has not been computed");
ABORT_IF(!labels, "Labels have not been computed");
- Expr lossSum = cast(loss, Type::float32); // accumulate in float32
- Expr labelsSum = cast(labels, Type::float32); // accumulate in float32
+ Expr lossSum = cast(loss, Type::float32); // accumulate in float32
+ Expr labelsSum = cast(labels, Type::float32); // accumulate in float32
for(int i = 0; i < axes_.size(); ++i) {
- lossSum = sum(lossSum, axes_[i]);
+ lossSum = sum(lossSum, axes_[i]);
labelsSum = sum(labelsSum, axes_[i]);
}
@@ -301,7 +293,7 @@ protected:
RationalLoss reduce(Expr loss) {
ABORT_IF(!loss, "Loss has not been computed");
- Expr lossSum = cast(loss, Type::float32);
+ Expr lossSum = cast(loss, Type::float32);
for(int i = 0; i < axes_.size(); ++i)
lossSum = sum(lossSum, axes_[i]);
@@ -311,17 +303,18 @@ protected:
}
public:
- LabelwiseLoss(const std::vector<int>& axes)
- : axes_(axes) { }
+ LabelwiseLoss(const std::vector<int>& axes) : axes_(axes) {}
- virtual RationalLoss apply(Logits logits, const Words& labels,
- Expr mask = nullptr, Expr labelWeights = nullptr) {
+ virtual RationalLoss apply(Logits logits,
+ const Words& labels,
+ Expr mask = nullptr,
+ Expr labelWeights = nullptr) {
Expr loss = compute(logits, labels, mask, labelWeights);
if(mask)
- return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
+ return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
else
- return reduce(loss); // we have no mask, assume all items are labels
+ return reduce(loss); // we have no mask, assume all items are labels
}
};
@@ -331,28 +324,34 @@ public:
class CrossEntropyLoss : public LabelwiseLoss {
public:
CrossEntropyLoss(float labelSmoothing, float factorWeight)
- : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
+ : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {
+ } // cross-entropy already reduces over axis -1
CrossEntropyLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
- : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1
- labelSmoothing_(labelSmoothing), factorWeight_(factorWeight) {}
+ : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1
+ labelSmoothing_(labelSmoothing),
+ factorWeight_(factorWeight) {}
virtual ~CrossEntropyLoss() {}
-protected:
- float labelSmoothing_; // interpolation factor for label smoothing, see below
- float factorWeight_; // give extra weight to factors
- virtual Expr compute(Logits logits, const Words& labels,
- Expr mask = nullptr, Expr labelWeights = nullptr) override {
- // logits may be factored; in that case, the getLoss() function computes one loss for each, and sums them up
+protected:
+ float labelSmoothing_; // interpolation factor for label smoothing, see below
+ float factorWeight_; // give extra weight to factors
+
+ virtual Expr compute(Logits logits,
+ const Words& labels,
+ Expr mask = nullptr,
+ Expr labelWeights = nullptr) override {
+ // logits may be factored; in that case, the getLoss() function computes one loss for each, and
+ // sums them up
int inFactor = false;
auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
- logits = atleast_3d(logits); // we always assume a time and batch dimension exists.
+ logits = atleast_3d(logits); // we always assume a time and batch dimension exists.
// for bert training or classification the time dimension is lost.
// Here safeguard against 2d classifier output, adds 1 on the left, non-op.
-
+
Expr ce = cross_entropy(logits, indices, inFactor ? 0.f : labelSmoothing_, Type::float32);
- if (inFactor && factorWeight_ != 1.0f) {
+ if(inFactor && factorWeight_ != 1.0f) {
LOG_ONCE(info, "scaling factor losses with weight {}", factorWeight_);
ce = ce * factorWeight_;
}
@@ -365,8 +364,10 @@ protected:
if(labelWeights) {
// We currently do not know how to use target factors and word-level label weights together
- bool wordlevel = labelWeights->shape()[-3] > 1; // Time-dimension is not trivially 1, hence we have word-level weights.
- ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, "CE loss with word-level label weights is not implemented for factors");
+ bool wordlevel = labelWeights->shape()[-3]
+ > 1; // Time-dimension is not trivially 1, hence we have word-level weights.
+ ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1,
+ "CE loss with word-level label weights is not implemented for factors");
ce = ce * cast(labelWeights, Type::float32);
}
@@ -374,13 +375,12 @@ protected:
}
};
-
/**
* @brief Unlikelihood loss across last axis, summed up over batch and time dimensions. This is an
* implementation of sequence-level unlikelihood loss from https://arxiv.org/abs/1908.04319.
- * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are not
- * zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it is going
- * to flip over to use SUL for that sentence to penalize the selected word.
+ * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are
+ * not zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it
+ * is going to flip over to use SUL for that sentence to penalize the selected word.
*
* SUL is implemented as:
* -log(gather(1 - softmax(logits), -1, indices))
@@ -390,35 +390,45 @@ protected:
class SequenceUnlikelihoodLoss : public CrossEntropyLoss {
public:
SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight)
- : CrossEntropyLoss(labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
+ : CrossEntropyLoss(labelSmoothing, factorWeight) {
+ } // cross-entropy already reduces over axis -1
SequenceUnlikelihoodLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
- : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
+ : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
protected:
- virtual Expr compute(Logits logits, const Words& labels,
- Expr mask = nullptr, Expr labelWeights = nullptr) override {
- auto ce = CrossEntropyLoss::compute(logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
+ virtual Expr compute(Logits logits,
+ const Words& labels,
+ Expr mask = nullptr,
+ Expr labelWeights = nullptr) override {
+ auto ce = CrossEntropyLoss::compute(
+ logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
if(!labelWeights)
- return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
+ return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
// We currently do not know how to use target factors and word-level label weights together
ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");
- ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete.
- // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask again to eliminate padding (might be obsolete)
+ ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by
+ // default 1, which would make this obsolete.
+ // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask
+ // again to eliminate padding (might be obsolete)
auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);
auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
return cast(unlikelihood(logits, indices), Type::float32);
});
- // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only on_ the errors with UL. This is the "mixed" training
- // schedule from https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily switch between CE and UL.
- auto onlyCe = eq(sum(errorMask, /*axis=*/-3), 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
- ceUl = errorMask * ceUl; // don't use for correct label or padding
+ // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only
+ // on_ the errors with UL. This is the "mixed" training schedule from
+ // https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily
+ // switch between CE and UL.
+ auto onlyCe = eq(sum(errorMask, /*axis=*/-3),
+ 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
+ ceUl = errorMask * ceUl; // don't use for correct label or padding
- auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never simultanously used as cost per batch entry
+ auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never
+ // simultanously used as cost per batch entry
return cost;
}
@@ -463,7 +473,6 @@ public:
}
};
-
/**
* @brief Factory for label-wise loss functions
*/
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index bf8fa588..1d9c7b4b 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -1,120 +1,131 @@
#include "output.h"
-#include "data/factored_vocab.h"
#include "common/timer.h"
-#include "layers/lsh.h"
+#include "data/factored_vocab.h"
#include "layers/loss.h"
+#include "layers/lsh.h"
namespace marian {
namespace mlp {
/*private*/ void Output::lazyConstruct(int inputDim) {
- // We must construct lazily since we won't know tying nor input dim in constructor.
- if (Wt_)
+ // We must construct lazily since we won't know tying nor input dim in constructor.
+ if(Wt_)
return;
- // this option is only set in the decoder
- if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) {
- auto k = opt<std::vector<int>>("output-approx-knn")[0];
+ // this option is only set in the decoder
+ if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) {
+ auto k = opt<std::vector<int>>("output-approx-knn")[0];
auto nbits = opt<std::vector<int>>("output-approx-knn")[1];
lsh_ = New<LSH>(k, nbits);
- }
+ }
- auto name = options_->get<std::string>("prefix");
- auto numOutputClasses = options_->get<int>("dim");
+ auto name = options_->get<std::string>("prefix");
+ auto numOutputClasses = options_->get<int>("dim");
- factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
- if (factoredVocab_) {
+ factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
+ if(factoredVocab_) {
numOutputClasses = (int)factoredVocab_->factorVocabSize();
LOG_ONCE(info, "[embedding] Factored outputs enabled");
- }
+ }
- if(tiedParam_) {
+ if(tiedParam_) {
Wt_ = tiedParam_;
- } else {
- if (graph_->get(name + "_W")) { // support of legacy models that did not transpose
- Wt_ = graph_->param(name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false));
- isLegacyUntransposedW = true;
- }
- else // this is the regular case:
- Wt_ = graph_->param(name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true));
- }
+ } else {
+ if(graph_->get(name + "_W")) { // support of legacy models that did not transpose
+ Wt_ = graph_->param(
+ name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false));
+ isLegacyUntransposedW = true;
+ } else // this is the regular case:
+ Wt_ = graph_->param(
+ name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true));
+ }
- if(hasBias_)
+ if(hasBias_)
b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros());
- /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
- ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
- if (lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix
+ /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
+ ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
+ if(lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix
#define HARDMAX_HACK
#ifdef HARDMAX_HACK
- lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number
+ lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number
#endif
auto range = factoredVocab_->getGroupRange(0);
auto lemmaVocabDim = (int)(range.second - range.first);
- auto initFunc = inits::glorotUniform(/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length
- lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, initFunc); // [L x U] L=lemmaDimEmb; transposed for speed
- }
+ auto initFunc = inits::glorotUniform(
+ /*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length
+ lemmaEt_ = graph_->param(name + "_lemmaEt",
+ {lemmaDimEmb, lemmaVocabDim},
+ initFunc); // [L x U] L=lemmaDimEmb; transposed for speed
+ }
}
Logits Output::applyAsLogits(Expr input) /*override final*/ {
- lazyConstruct(input->shape()[-1]);
+ lazyConstruct(input->shape()[-1]);
- auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
+ auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
if(b)
- return affine(x, W, b, transA, transB);
+ return affine(x, W, b, transA, transB);
else
- return dot(x, W, transA, transB);
- };
+ return dot(x, W, transA, transB);
+ };
- auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
+ auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
if(lsh_) {
- ABORT_IF( transA, "Transposed query not supported for LSH");
- ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
- return lsh_->apply(x, W, b); // knows how to deal with undefined bias
+ ABORT_IF(transA, "Transposed query not supported for LSH");
+ ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
+ return lsh_->apply(x, W, b); // knows how to deal with undefined bias
} else {
- return affineOrDot(x, W, b, transA, transB);
+ return affineOrDot(x, W, b, transA, transB);
}
- };
+ };
- if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed
- cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
+ if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one
+ // batch, then clear()ed
+ cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
if(hasBias_)
- cachedShortb_ = index_select(b_ , -1, shortlist_->indices());
- }
+ cachedShortb_ = index_select(b_, -1, shortlist_->indices());
+ }
- if (factoredVocab_) {
+ if(factoredVocab_) {
auto graph = input->graph();
// project each factor separately
auto numGroups = factoredVocab_->getNumGroups();
- std::vector<Ptr<RationalLoss>> allLogits(numGroups, nullptr); // (note: null entries for absent factors)
- Expr input1 = input; // [B... x D]
- Expr Plemma = nullptr; // used for lemmaDimEmb=-1
- Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3
- for (size_t g = 0; g < numGroups; g++) {
- auto range = factoredVocab_->getGroupRange(g);
- if (g > 0 && range.first == range.second) // empty entry
+ std::vector<Ptr<RationalLoss>> allLogits(numGroups,
+ nullptr); // (note: null entries for absent factors)
+ Expr input1 = input; // [B... x D]
+ Expr Plemma = nullptr; // used for lemmaDimEmb=-1
+ Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3
+ for(size_t g = 0; g < numGroups; g++) {
+ auto range = factoredVocab_->getGroupRange(g);
+ if(g > 0 && range.first == range.second) // empty entry
continue;
- ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g-1).second, "Factor groups must be consecutive (group {} vs predecessor)", g);
- // slice this group's section out of W_
- Expr factorWt, factorB;
- if (g == 0 && shortlist_) {
+ ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g - 1).second,
+ "Factor groups must be consecutive (group {} vs predecessor)",
+ g);
+ // slice this group's section out of W_
+ Expr factorWt, factorB;
+ if(g == 0 && shortlist_) {
factorWt = cachedShortWt_;
- factorB = cachedShortb_;
- }
- else {
- factorWt = slice(Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
+ factorB = cachedShortb_;
+ } else {
+ factorWt = slice(
+ Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
if(hasBias_)
- factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
- }
- /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
- if ((lemmaDimEmb == -2 || lemmaDimEmb == -3) && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max)
+ factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
+ }
+ /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
+ if((lemmaDimEmb == -2 || lemmaDimEmb == -3)
+ && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max)
LOG_ONCE(info, "[embedding] using lemma conditioning with gate");
// this mimics one transformer layer
// - attention over two inputs:
- // - e = current lemma. We use the original embedding vector; specifically, expectation over all lemmas.
+ // - e = current lemma. We use the original embedding vector; specifically, expectation
+ // over all lemmas.
// - input = hidden state FF(h_enc+h_dec)
- // - dot-prod attention to allow both sides to influence (unlike our recurrent self-attention)
+ // - dot-prod attention to allow both sides to influence (unlike our recurrent
+ // self-attention)
// - multi-head to allow for multiple conditions to be modeled
// - add & norm, for gradient flow and scaling
// - FF layer --this is expensive; it is per-factor
@@ -122,112 +133,161 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
int inputDim = input->shape()[-1];
int heads = 8;
auto name = options_->get<std::string>("prefix") + "_factor" + std::to_string(g);
- auto Wq = graph_->param(name + "_Wq", { inputDim, inputDim }, inits::glorotUniform());
- auto Wk = graph_->param(name + "_Wk", { inputDim, inputDim }, inits::glorotUniform());
- auto Wv = graph_->param(name + "_Wv", { inputDim, inputDim }, inits::glorotUniform());
+ auto Wq = graph_->param(name + "_Wq", {inputDim, inputDim}, inits::glorotUniform());
+ auto Wk = graph_->param(name + "_Wk", {inputDim, inputDim}, inits::glorotUniform());
+ auto Wv = graph_->param(name + "_Wv", {inputDim, inputDim}, inits::glorotUniform());
auto toMultiHead = [&](Expr x, int heads) {
- const auto& shape = x->shape();
- int inputDim = shape[-1];
- int otherDim = shape.elements() / inputDim;
- ABORT_IF(inputDim / heads * heads != inputDim, "inputDim ({}) must be multiple of number of heads ({})", inputDim, heads);
- return reshape(x, { otherDim, heads, 1, inputDim / heads });
+ const auto& shape = x->shape();
+ int inputDim = shape[-1];
+ int otherDim = shape.elements() / inputDim;
+ ABORT_IF(inputDim / heads * heads != inputDim,
+ "inputDim ({}) must be multiple of number of heads ({})",
+ inputDim,
+ heads);
+ return reshape(x, {otherDim, heads, 1, inputDim / heads});
};
input1 = inputLemma;
- auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query
- auto kdm = toMultiHead(dot(input1 - input, Wk), heads); // [B... x H x D/H] the two data vectors projected as keys. Use diff and sigmoid, instead of softmax.
- auto vem = toMultiHead(dot(input1, Wv), heads); // [B... x H x D/H] one of the two data vectors projected as values
- auto vim = toMultiHead(dot( input, Wv), heads); // [B... x H x D/H] the other
- auto zm = bdot(qm, kdm, false, true); // [B... x H x 1]
- auto sm = sigmoid(zm); // [B... x H x 1]
- auto rm = sm * (vem - vim) + vim; // [B... x H x D/H]
- auto r = reshape(rm, input->shape()); // [B... x D]
+ auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query
+ auto kdm = toMultiHead(dot(input1 - input, Wk),
+ heads); // [B... x H x D/H] the two data vectors projected as keys.
+ // Use diff and sigmoid, instead of softmax.
+ auto vem = toMultiHead(
+ dot(input1, Wv),
+ heads); // [B... x H x D/H] one of the two data vectors projected as values
+ auto vim = toMultiHead(dot(input, Wv), heads); // [B... x H x D/H] the other
+ auto zm = bdot(qm, kdm, false, true); // [B... x H x 1]
+ auto sm = sigmoid(zm); // [B... x H x 1]
+ auto rm = sm * (vem - vim) + vim; // [B... x H x D/H]
+ auto r = reshape(rm, input->shape()); // [B... x D]
// add & norm
input1 = r + input1;
input1 = layerNorm(input1, name + "_att");
// FF layer
- auto ffnDropProb = 0.1f; // @TODO: get as a parameter
- auto ffnDim = inputDim * 2; // @TODO: get as a parameter
- auto f = denseInline(input1, name + "_ffn", /*suffix=*/"1", ffnDim, inits::glorotUniform(), (ActivationFunction*)relu, ffnDropProb);
- f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
+ auto ffnDropProb = 0.1f; // @TODO: get as a parameter
+ auto ffnDim = inputDim * 2; // @TODO: get as a parameter
+ auto f = denseInline(input1,
+ name + "_ffn",
+ /*suffix=*/"1",
+ ffnDim,
+ inits::glorotUniform(),
+ (ActivationFunction*)relu,
+ ffnDropProb);
+ f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
// add & norm
input1 = f + input1;
input1 = layerNorm(input1, name + "_ffn");
- }
- // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix
- Expr factorLogits;
- if(g == 0)
- factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
- else
- factorLogits = affineOrDot(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
-
- // optionally add lemma-dependent bias
- if (Plemma) { // [B... x U0]
+ }
+ // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a
+ // matrix
+ Expr factorLogits;
+ if(g == 0)
+ factorLogits = affineOrLSH(
+ input1,
+ factorWt,
+ factorB,
+ false,
+ /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
+ else
+ factorLogits = affineOrDot(
+ input1,
+ factorWt,
+ factorB,
+ false,
+ /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
+
+ // optionally add lemma-dependent bias
+ if(Plemma) { // [B... x U0]
int lemmaVocabDim = Plemma->shape()[-1];
int factorVocabDim = factorLogits->shape()[-1];
auto name = options_->get<std::string>("prefix");
- Expr lemmaBt = graph_->param(name + "_lemmaBt_" + std::to_string(g), {factorVocabDim, lemmaVocabDim}, inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma
- auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
+ Expr lemmaBt
+ = graph_->param(name + "_lemmaBt_" + std::to_string(g),
+ {factorVocabDim, lemmaVocabDim},
+ inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma
+ auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
factorLogits = factorLogits + b;
- }
- allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
- // optionally add a soft embedding of lemma back to create some lemma dependency
- // @TODO: if this works, move it into lazyConstruct
- if (lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure
+ }
+ allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
+ // optionally add a soft embedding of lemma back to create some lemma dependency
+ // @TODO: if this works, move it into lazyConstruct
+ if(lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure
LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version");
// get expected lemma embedding vector
- auto factorLogSoftmax = logsoftmax(factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set
+ auto factorLogSoftmax = logsoftmax(
+ factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set
auto factorSoftmax = exp(factorLogSoftmax);
- inputLemma = dot(factorSoftmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
- }
- else if (lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max
+ inputLemma = dot(factorSoftmax,
+ factorWt,
+ false,
+ /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
+ } else if(lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max
LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version");
// get max-lemma embedding vector
- auto maxVal = max(factorLogits, -1); // [B... x U] note: with shortlist, this is not the full lemma set
+ auto maxVal = max(factorLogits,
+ -1); // [B... x U] note: with shortlist, this is not the full lemma set
auto factorHardmax = eq(factorLogits, maxVal);
- inputLemma = dot(factorHardmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
- }
- else if (lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias
+ inputLemma = dot(factorHardmax,
+ factorWt,
+ false,
+ /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
+ } else if(lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias
ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented");
LOG_ONCE(info, "[embedding] using lemma-dependent bias");
- auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in)
- auto z = /*stopGradient*/(factorLogSoftmax);
- Plemma = exp(z); // [B... x U]
- }
- else if (lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix
+ auto factorLogSoftmax
+ = logsoftmax(factorLogits); // (we do that again later, CSE will kick in)
+ auto z = /*stopGradient*/ (factorLogSoftmax);
+ Plemma = exp(z); // [B... x U]
+ } else if(lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix
LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb);
- // compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE
+ // compute softmax. We compute logsoftmax() separately because this way, computation will be
+ // reused later via CSE
auto factorLogSoftmax = logsoftmax(factorLogits);
auto factorSoftmax = exp(factorLogSoftmax);
#ifdef HARDMAX_HACK
- bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation)
- if (hardmax) {
- lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
- LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
- auto maxVal = max(factorSoftmax, -1);
- factorSoftmax = eq(factorSoftmax, maxVal);
+ bool hardmax = (lemmaDimEmb & 1)
+ != 0; // odd value triggers hardmax for now (for quick experimentation)
+ if(hardmax) {
+ lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
+ LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
+ auto maxVal = max(factorSoftmax, -1);
+ factorSoftmax = eq(factorSoftmax, maxVal);
}
#endif
// re-embedding lookup, soft-indexed by softmax
- if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
- cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
- auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L]
+ if(shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
+ cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
+ auto e = dot(factorSoftmax,
+ cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_,
+ false,
+ true); // [B... x L]
// project it back to regular hidden dim
int inputDim = input1->shape()[-1];
auto name = options_->get<std::string>("prefix");
- // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also length 1
- Expr lemmaWt = inputDim == lemmaDimEmb ? nullptr : graph_->param(name + "_lemmaWt", { inputDim, lemmaDimEmb }, inits::glorotUniform()); // [D x L] D=hidden-vector dimension
- auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D]
+ // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also
+ // length 1
+ Expr lemmaWt
+ = inputDim == lemmaDimEmb
+ ? nullptr
+ : graph_->param(name + "_lemmaWt",
+ {inputDim, lemmaDimEmb},
+ inits::glorotUniform()); // [D x L] D=hidden-vector dimension
+ auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D]
// augment the original hidden vector with this additional information
input1 = input1 + f;
- }
+ }
}
return Logits(std::move(allLogits), factoredVocab_);
- } else if (shortlist_) {
- return Logits(affineOrLSH(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true));
- } else {
- return Logits(affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
- }
+ } else if(shortlist_) {
+ return Logits(affineOrLSH(input,
+ cachedShortWt_,
+ cachedShortb_,
+ false,
+ /*transB=*/isLegacyUntransposedW ? false : true));
+ } else {
+ return Logits(
+ affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
+ }
}
-}
-} \ No newline at end of file
+} // namespace mlp
+} // namespace marian \ No newline at end of file
diff --git a/src/layers/output.h b/src/layers/output.h
index 92e7eb25..2b6f4986 100644
--- a/src/layers/output.h
+++ b/src/layers/output.h
@@ -1,10 +1,10 @@
#pragma once
-#include "marian.h"
-#include "generic.h"
-#include "logits.h"
#include "data/shortlist.h"
+#include "generic.h"
#include "layers/factory.h"
+#include "logits.h"
+#include "marian.h"
namespace marian {
class LSH;
@@ -14,42 +14,45 @@ namespace mlp {
class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList {
private:
// parameters held by this layer
- Expr Wt_; // weight matrix is stored transposed for efficiency
+ Expr Wt_; // weight matrix is stored transposed for efficiency
Expr b_;
- Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize]
- bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
+ Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize]
+ bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
bool hasBias_{true};
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
Ptr<FactoredVocab> factoredVocab_;
-
+
// optional parameters set/updated after construction
Expr tiedParam_;
Ptr<data::Shortlist> shortlist_;
Ptr<LSH> lsh_;
void lazyConstruct(int inputDim);
+
public:
Output(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : LayerBase(graph, options),
- hasBias_{!options->get<bool>("output-omit-bias", false)} {
+ : LayerBase(graph, options), hasBias_{!options->get<bool>("output-omit-bias", false)} {
clear();
}
void tieTransposed(Expr tied) {
- if (Wt_)
- ABORT_IF(tiedParam_.get() != tied.get(), "Tied output projection cannot be changed once weights have been created");
+ if(Wt_)
+ ABORT_IF(tiedParam_.get() != tied.get(),
+ "Tied output projection cannot be changed once weights have been created");
else
tiedParam_ = tied;
}
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
- if (shortlist_)
- ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()");
+ if(shortlist_)
+ ABORT_IF(shortlist.get() != shortlist_.get(),
+ "Output shortlist cannot be changed except after clear()");
else {
- ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, "No shortlist but cached parameters??");
+ ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_,
+ "No shortlist but cached parameters??");
shortlist_ = shortlist;
}
// cachedShortWt_ and cachedShortb_ will be created lazily inside apply()
@@ -60,7 +63,7 @@ public:
void clear() override final {
shortlist_ = nullptr;
cachedShortWt_ = nullptr;
- cachedShortb_ = nullptr;
+ cachedShortb_ = nullptr;
cachedShortLemmaEt_ = nullptr;
}
@@ -69,6 +72,4 @@ public:
} // namespace mlp
-}
-
-
+} // namespace marian
diff --git a/src/models/costs.cpp b/src/models/costs.cpp
index 5105f590..c688b211 100644
--- a/src/models/costs.cpp
+++ b/src/models/costs.cpp
@@ -4,13 +4,11 @@ namespace marian {
namespace models {
Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) {
-// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
-state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
-// @TODO: This is becoming more and more opaque ^^. Can we simplify this?
-return state;
-}
-
-
-}
+ // decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
+ state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
+ // @TODO: This is becoming more and more opaque ^^. Can we simplify this?
+ return state;
}
+} // namespace models
+} // namespace marian
diff --git a/src/models/costs.h b/src/models/costs.h
index 2d34c53a..e5463bfd 100644
--- a/src/models/costs.h
+++ b/src/models/costs.h
@@ -4,8 +4,8 @@
#include "layers/guided_alignment.h"
#include "layers/loss.h"
#include "layers/weight.h"
-#include "models/encoder_decoder.h"
#include "models/encoder_classifier.h"
+#include "models/encoder_decoder.h"
#include "models/encoder_pooler.h"
namespace marian {
@@ -22,10 +22,12 @@ namespace models {
class ICost {
public:
- virtual Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
- Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
- Ptr<data::Batch> batch,
- bool clearGraph = true) = 0;
+ virtual Ptr<MultiRationalLoss> apply(
+ Ptr<IModel> model,
+ Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
+ Ptr<data::Batch> batch,
+ bool clearGraph = true)
+ = 0;
virtual ~ICost() {}
};
@@ -45,10 +47,9 @@ public:
: options_(options), inference_(options->get<bool>("inference", false)) {
loss_ = newLoss(options_, inference_);
- toBeWeighted_
- = (options_->hasAndNotEmpty("data-weighting") && !inference_)
- || (options_->has("dynamic-weighting") && options_->get<bool>("dynamic-weighting")
- && !inference_);
+ toBeWeighted_ = (options_->hasAndNotEmpty("data-weighting") && !inference_)
+ || (options_->has("dynamic-weighting")
+ && options_->get<bool>("dynamic-weighting") && !inference_);
if(toBeWeighted_)
weighter_ = WeightingFactory(options_);
}
@@ -56,9 +57,9 @@ public:
virtual ~EncoderDecoderCECost() {}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
- Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
+ Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
auto encdec = std::static_pointer_cast<EncoderDecoder>(model);
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
@@ -72,17 +73,17 @@ public:
Ptr<MultiRationalLoss> multiLoss = newMultiLoss(options_);
// @TODO: adapt to multi-objective training with multiple decoders
- auto partialLoss = loss_->apply(state->getLogProbs(),
- state->getTargetWords(),
- state->getTargetMask(),
- weights);
+ auto partialLoss = loss_->apply(
+ state->getLogProbs(), state->getTargetWords(), state->getTargetMask(), weights);
multiLoss->push_back(partialLoss);
if(options_->get("guided-alignment", std::string("none")) != "none" && !inference_) {
- auto attentionVectors = encdec->getDecoders()[0]->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1]
+ auto attentionVectors
+ = encdec->getDecoders()[0]
+ ->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1]
ABORT_IF(attentionVectors.empty(), "Model does not seem to support alignments");
- auto attention = concatenate(attentionVectors, /*axis =*/ -1);
+ auto attention = concatenate(attentionVectors, /*axis =*/-1);
auto alignmentLoss = guidedAlignmentCost(graph, corpusBatch, options_, attention);
multiLoss->push_back(alignmentLoss);
@@ -109,10 +110,9 @@ public:
}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
- Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
-
+ Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
auto enccls = std::static_pointer_cast<EncoderClassifier>(model);
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
@@ -141,21 +141,20 @@ protected:
public:
EncoderPoolerRankCost(Ptr<Options> options)
- : options_(options),
- inference_(options->get<bool>("inference", false)) {
- auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {});
- ABORT_IF(trainEmbedderRank.empty(), "EncoderPoolerRankCost expects train-embedder-rank to be set");
-
- margin_ = std::stof(trainEmbedderRank[0]);
- if(trainEmbedderRank.size() > 1)
- normalizer_ = std::stof(trainEmbedderRank[1]);
+ : options_(options), inference_(options->get<bool>("inference", false)) {
+ auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {});
+ ABORT_IF(trainEmbedderRank.empty(),
+ "EncoderPoolerRankCost expects train-embedder-rank to be set");
+
+ margin_ = std::stof(trainEmbedderRank[0]);
+ if(trainEmbedderRank.size() > 1)
+ normalizer_ = std::stof(trainEmbedderRank[1]);
}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
-
auto encpool = std::static_pointer_cast<EncoderPooler>(model);
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
std::vector<Expr> dotProducts = encpool->apply(graph, corpusBatch, clearGraph);
@@ -167,28 +166,41 @@ public:
ABORT_IF(dotProducts.size() != 3, "Three dot products required for margin loss");
// multi-objective training
- auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability
- auto exponent = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product
+ auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability
+ auto exponent
+ = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product
auto dp = exp(exponent);
Expr dn1, dn2;
- if(normalizer_ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the magnitude of the sum of negative examples in the denominator.
- dn1 = normalizer_ * mean(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
- dn2 = normalizer_ * mean(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
+ if(normalizer_
+ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the
+ // magnitude of the sum of negative examples in the denominator.
+ dn1 = normalizer_
+ * mean(exp(dotProducts[1] - maxDot),
+ -1); // dot product of anchor and first negative example
+ dn2 = normalizer_
+ * mean(exp(dotProducts[2] - maxDot),
+ -1); // dot product of positive examples and first negative example
} else {
- dn1 = sum(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
- dn2 = sum(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
+ dn1 = sum(exp(dotProducts[1] - maxDot),
+ -1); // dot product of anchor and first negative example
+ dn2 = sum(exp(dotProducts[2] - maxDot),
+ -1); // dot product of positive examples and first negative example
}
// We rewrite the loss so it looks more like a log-softmax, presumably more stable?
- // Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m)
- auto marginLoss1 = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples
- auto marginLoss2 = log(dp + dn2) - exponent; // symmetric version of the above with positive example vs negative examples
- auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2);
-
+ // Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp
+ // + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m)
+ auto marginLoss1
+ = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples
+ auto marginLoss2
+ = log(dp + dn2)
+ - exponent; // symmetric version of the above with positive example vs negative examples
+ auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2);
+
RationalLoss loss(marginLoss, (float)dimBatch);
multiLoss->push_back(loss);
-
+
return multiLoss;
}
};
@@ -199,8 +211,7 @@ protected:
Ptr<ICost> cost_;
public:
- Trainer(Ptr<IModel> model, Ptr<ICost> cost)
- : model_(model), cost_(cost) {}
+ Trainer(Ptr<IModel> model, Ptr<ICost> cost) : model_(model), cost_(cost) {}
virtual ~Trainer() {}
@@ -219,8 +230,8 @@ public:
}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
return cost_->apply(model_, graph, batch, clearGraph);
};
@@ -230,24 +241,25 @@ public:
class ILogProb {
public:
virtual Logits apply(Ptr<IModel> model,
- Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) = 0;
+ Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true)
+ = 0;
};
-// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth?
-// Beam search uses it for the former meaning, while 'marian score' and validation in the latter.
-// This class is for the former use. The latter is done using Trainer.
+// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for
+// the ground truth?
+// Beam search uses it for the former meaning, while 'marian score' and validation in the
+// latter. This class is for the former use. The latter is done using Trainer.
class Scorer : public IModel {
protected:
Ptr<IModel> model_;
Ptr<ILogProb> logProb_;
public:
- Scorer(Ptr<IModel> model, Ptr<ILogProb> cost)
- : model_(model), logProb_(cost) {}
+ Scorer(Ptr<IModel> model, Ptr<ILogProb> cost) : model_(model), logProb_(cost) {}
- virtual ~Scorer(){}
+ virtual ~Scorer() {}
Ptr<IModel> getModel() { return model_; }
@@ -264,8 +276,8 @@ public:
}
virtual Logits build(Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
return logProb_->apply(model_, graph, batch, clearGraph);
};
@@ -293,10 +305,10 @@ public:
virtual ~GumbelSoftmaxStep() {}
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
- [](Expr logits){ // lemma gets gumbelled
- return logsoftmax(logits + constant_like(logits, inits::gumbel()));
- },
- logsoftmax)); // factors don't
+ [](Expr logits) { // lemma gets gumbelled
+ return logsoftmax(logits + constant_like(logits, inits::gumbel()));
+ },
+ logsoftmax)); // factors don't
return state;
}
};
@@ -311,8 +323,7 @@ protected:
Ptr<ILogProbStep> cost_;
public:
- Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost)
- : encdec_(encdec), cost_(cost) {}
+ Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {}
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
@@ -346,12 +357,13 @@ public:
return encdec_->startState(graph, batch);
}
- virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
- Ptr<DecoderState> state,
- const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
- const Words& words, // [beamIndex * activeBatchSize + batchIndex]
- const std::vector<IndexType>& batchIndices, // [batchIndex]
- int beamSize) override {
+ virtual Ptr<DecoderState> step(
+ Ptr<ExpressionGraph> graph,
+ Ptr<DecoderState> state,
+ const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
+ const Words& words, // [beamIndex * activeBatchSize + batchIndex]
+ const std::vector<IndexType>& batchIndices, // [batchIndex]
+ int beamSize) override {
auto nextState = encdec_->step(graph, state, hypIndices, words, batchIndices, beamSize);
return cost_->apply(nextState);
}
@@ -369,9 +381,7 @@ public:
encdec_->setShortlistGenerator(shortlistGenerator);
};
- virtual Ptr<data::Shortlist> getShortlist() override {
- return encdec_->getShortlist();
- };
+ virtual Ptr<data::Shortlist> getShortlist() override { return encdec_->getShortlist(); };
virtual data::SoftAlignment getAlignment() override { return encdec_->getAlignment(); }
};
diff --git a/src/models/states.h b/src/models/states.h
index cfb6fd1b..20dd59c9 100644
--- a/src/models/states.h
+++ b/src/models/states.h
@@ -1,7 +1,7 @@
#pragma once
+#include "layers/logits.h" // @HACK: for factored embeddings only so far
#include "marian.h"
-#include "layers/logits.h" // @HACK: for factored embeddings only so far
#include "rnn/types.h"
namespace marian {
@@ -9,7 +9,7 @@ namespace marian {
class EncoderState {
private:
Expr context_;
- Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask
+ Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask
Ptr<data::CorpusBatch> batch_;
public:
@@ -19,31 +19,34 @@ public:
EncoderState() {}
virtual ~EncoderState() {}
- virtual Expr getContext() const { return context_; }
- virtual Expr getAttended() const { return context_; }
- virtual Expr getMask() const { return mask_; } // source batch mask; may have additional positions suppressed
+ virtual Expr getContext() const { return context_; }
+ virtual Expr getAttended() const { return context_; }
+ virtual Expr getMask() const {
+ return mask_;
+ } // source batch mask; may have additional positions suppressed
- virtual const Words& getSourceWords() {
- return batch_->front()->data();
- }
+ virtual const Words& getSourceWords() { return batch_->front()->data(); }
// Sub-select active batch entries from encoder context and context mask
- Ptr<EncoderState> select(const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
- // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout
- return New<EncoderState>(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
+ Ptr<EncoderState> select(
+ const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
+ // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer
+ // gets transposed to the same dimension layout
+ return New<EncoderState>(
+ index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
}
};
class DecoderState {
protected:
- rnn::States states_; // states of individual decoder layers
+ rnn::States states_; // states of individual decoder layers
Logits logProbs_;
std::vector<Ptr<EncoderState>> encStates_;
Ptr<data::CorpusBatch> batch_;
- Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
+ Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
Expr targetMask_;
- Words targetWords_; // target labels
+ Words targetWords_; // target labels
// Keep track of current target token position during translation
size_t position_{0};
@@ -57,26 +60,30 @@ public:
virtual ~DecoderState() {}
// @TODO: Do we need all these to be virtual?
- virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const {
- return encStates_;
- }
+ virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const { return encStates_; }
virtual Logits getLogProbs() const { return logProbs_; }
virtual void setLogProbs(Logits logProbs) { logProbs_ = logProbs; }
- // @TODO: should this be a constructor? Then derived classes can call this without the New<> in the loop
- virtual Ptr<DecoderState> select(const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
- const std::vector<IndexType>& batchIndices, // [batchIndex]
- int beamSize) const {
-
+ // @TODO: should this be a constructor? Then derived classes can call this without the New<> in
+ // the loop
+ virtual Ptr<DecoderState> select(
+ const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
+ const std::vector<IndexType>& batchIndices, // [batchIndex]
+ int beamSize) const {
std::vector<Ptr<EncoderState>> newEncStates;
for(auto& es : encStates_)
- // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
- newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
+ // If the size of the batch dimension of the encoder state context changed, subselect the
+ // correct batch entries
+ newEncStates.push_back(
+ es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
// hypindices matches batchIndices in terms of batch dimension, so we only need hypIndices
- auto selectedState = New<DecoderState>(
- states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), logProbs_, newEncStates, batch_);
+ auto selectedState
+ = New<DecoderState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/false),
+ logProbs_,
+ newEncStates,
+ batch_);
// Set positon of new state based on the target token position of current state
selectedState->setPosition(getPosition());
@@ -86,7 +93,9 @@ public:
virtual const rnn::States& getStates() const { return states_; }
virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; };
- virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { targetHistoryEmbeddings_ = targetHistoryEmbeddings; }
+ virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) {
+ targetHistoryEmbeddings_ = targetHistoryEmbeddings;
+ }
virtual const Words& getTargetWords() const { return targetWords_; };
virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }
@@ -94,9 +103,7 @@ public:
virtual Expr getTargetMask() const { return targetMask_; };
virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; }
- virtual const Words& getSourceWords() const {
- return getEncoderStates()[0]->getSourceWords();
- }
+ virtual const Words& getSourceWords() const { return getEncoderStates()[0]->getSourceWords(); }
Ptr<data::CorpusBatch> getBatch() const { return batch_; }
@@ -111,7 +118,8 @@ public:
/**
* Classifier output based on DecoderState
- * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have stateful output.
+ * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have
+ * stateful output.
*/
class ClassifierState {
private: