diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-03-05 09:12:28 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-03-05 09:12:28 +0300 |
commit | 55f4216552bca148091f15b72c5c2e5b486d4c79 (patch) | |
tree | e8c343d0334213d2e253fd311d1cbdc1cf10fe56 /src | |
parent | 7c1cb8462a2adc6540ee78f111a75ed4fbdd66ad (diff) |
add .h
Diffstat (limited to 'src')
-rw-r--r-- | src/layers/logits.h | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/src/layers/logits.h b/src/layers/logits.h new file mode 100644 index 00000000..4196e0d0 --- /dev/null +++ b/src/layers/logits.h @@ -0,0 +1,76 @@ +#pragma once + +#include "marian.h" +#include "data/shortlist.h" +#include "generic.h" + +namespace marian { + +class FactoredVocab; + +// To support factors, any output projection (that is followed by a softmax) must +// retain multiple outputs, one for each factor. Such layer returns not a single Expr, +// but a Logits object that contains multiple. +// This allows to compute softmax values in a factored manner, where we never create +// a fully expanded list of all factor combinations. +class RationalLoss; +class Logits { +public: + Logits() {} + explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor + logits_.push_back(logits); + } + explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count) + Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor + : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {} + Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors + Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle + //Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that + Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const; + Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values + Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values + + struct MaskedFactorIndices { + std::vector<WordIndex> indices; // factor index, or 0 if masked + std::vector<float> masks; + void reserve(size_t n) { indices.reserve(n); masks.reserve(n); } + void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries + MaskedFactorIndices() {} + MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case + }; + std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices + Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only + size_t getNumFactorGroups() const { return logits_.size(); } + bool empty() const { return logits_.empty(); } + Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_ +private: + // helper functions + Ptr<ExpressionGraph> graph() const; + Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data)); } + Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data)); } + template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector + Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type + std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const; +private: + // members + // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr + std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group] + Ptr<FactoredVocab> factoredVocab_; +}; + +// Unary function that returns a Logits object +// Also implements IUnaryLayer, since Logits can be cast to Expr. +// This interface is implemented by all layers that are of the form of a unary function +// that returns multiple logits, to support factors. +struct IUnaryLogitLayer : public IUnaryLayer { + virtual Logits applyAsLogits(Expr) = 0; + virtual Logits applyAsLogits(const std::vector<Expr>& es) { + ABORT_IF(es.size() > 1, "Not implemented"); // simple stub + return applyAsLogits(es.front()); + } + virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); } + virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); } +}; + +} + |