diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-03-04 07:18:19 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-03-04 07:18:19 +0300 |
commit | f7266886f0d478d802a88f2ce82b71f27c37bf07 (patch) | |
tree | 7440cbd865bf2afd596bdaa46ab3862f9dd8b594 /src | |
parent | ca47eabca5cb9bb11a3e4fe45afa77501128b4b9 (diff) |
move logits to its own file
Diffstat (limited to 'src')
-rw-r--r-- | src/layers/embedding.h | 2 | ||||
-rw-r--r-- | src/layers/generic.h | 66 | ||||
-rw-r--r-- | src/layers/loss.h | 2 | ||||
-rw-r--r-- | src/layers/output.h | 1 | ||||
-rw-r--r-- | src/models/states.h | 2 |
5 files changed, 5 insertions, 68 deletions
diff --git a/src/layers/embedding.h b/src/layers/embedding.h index 91fd0b9d..b7898c76 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -4,6 +4,8 @@ namespace marian { +class FactoredVocab; + // A regular embedding layer. // Note that this also applies dropout if the option is passed (pass 0 when in inference mode). // It is best to not use Embedding directly, but rather via getEmbeddingLayer() in diff --git a/src/layers/generic.h b/src/layers/generic.h index 6d953fd8..eddd597e 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -97,72 +97,6 @@ public: Ptr<IEmbeddingLayer> getEmbeddingLayer(bool ulr = false) const; }; -class FactoredVocab; - -// To support factors, any output projection (that is followed by a softmax) must -// retain multiple outputs, one for each factor. Such layer returns not a single Expr, -// but a Logits object that contains multiple. -// This allows to compute softmax values in a factored manner, where we never create -// a fully expanded list of all factor combinations. -class RationalLoss; -class Logits { -public: - Logits() {} - explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor - logits_.push_back(logits); - } - explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count) - Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor - : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {} - Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors - Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle - //Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that - Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const; - Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values - Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values - - struct MaskedFactorIndices { - std::vector<WordIndex> indices; // factor index, or 0 if masked - std::vector<float> masks; - void reserve(size_t n) { indices.reserve(n); masks.reserve(n); } - void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries - MaskedFactorIndices() {} - MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case - }; - std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices - Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only - size_t getNumFactorGroups() const { return logits_.size(); } - bool empty() const { return logits_.empty(); } - Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_ -private: - // helper functions - Ptr<ExpressionGraph> graph() const; - Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data)); } - Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data)); } - template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector - Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type - std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const; -private: - // members - // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr - std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group] - Ptr<FactoredVocab> factoredVocab_; -}; - -// Unary function that returns a Logits object -// Also implements IUnaryLayer, since Logits can be cast to Expr. -// This interface is implemented by all layers that are of the form of a unary function -// that returns multiple logits, to support factors. -struct IUnaryLogitLayer : public IUnaryLayer { - virtual Logits applyAsLogits(Expr) = 0; - virtual Logits applyAsLogits(const std::vector<Expr>& es) { - ABORT_IF(es.size() > 1, "Not implemented"); // simple stub - return applyAsLogits(es.front()); - } - virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); } - virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); } -}; - namespace mlp { class Dense : public LayerBase, public IUnaryLayer { diff --git a/src/layers/loss.h b/src/layers/loss.h index d7bc19e4..ba93cdac 100644 --- a/src/layers/loss.h +++ b/src/layers/loss.h @@ -1,7 +1,7 @@ #pragma once #include "graph/expression_operators.h" -#include "layers/generic.h" // for Logits (Frank's factor hack) +#include "layers/logits.h" // for Logits (Frank's factor hack) #include "data/types.h" namespace marian { diff --git a/src/layers/output.h b/src/layers/output.h index d091556a..92e7eb25 100644 --- a/src/layers/output.h +++ b/src/layers/output.h @@ -2,6 +2,7 @@ #include "marian.h" #include "generic.h" +#include "logits.h" #include "data/shortlist.h" #include "layers/factory.h" diff --git a/src/models/states.h b/src/models/states.h index c2f9ee05..cfb6fd1b 100644 --- a/src/models/states.h +++ b/src/models/states.h @@ -1,7 +1,7 @@ #pragma once #include "marian.h" -#include "layers/generic.h" // @HACK: for factored embeddings only so far +#include "layers/logits.h" // @HACK: for factored embeddings only so far #include "rnn/types.h" namespace marian { |