Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-03-04 07:18:19 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-03-04 07:18:19 +0300
commitf7266886f0d478d802a88f2ce82b71f27c37bf07 (patch)
tree7440cbd865bf2afd596bdaa46ab3862f9dd8b594 /src
parentca47eabca5cb9bb11a3e4fe45afa77501128b4b9 (diff)
move logits to its own file
Diffstat (limited to 'src')
-rw-r--r--src/layers/embedding.h2
-rw-r--r--src/layers/generic.h66
-rw-r--r--src/layers/loss.h2
-rw-r--r--src/layers/output.h1
-rw-r--r--src/models/states.h2
5 files changed, 5 insertions, 68 deletions
diff --git a/src/layers/embedding.h b/src/layers/embedding.h
index 91fd0b9d..b7898c76 100644
--- a/src/layers/embedding.h
+++ b/src/layers/embedding.h
@@ -4,6 +4,8 @@
namespace marian {
+class FactoredVocab;
+
// A regular embedding layer.
// Note that this also applies dropout if the option is passed (pass 0 when in inference mode).
// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in
diff --git a/src/layers/generic.h b/src/layers/generic.h
index 6d953fd8..eddd597e 100644
--- a/src/layers/generic.h
+++ b/src/layers/generic.h
@@ -97,72 +97,6 @@ public:
Ptr<IEmbeddingLayer> getEmbeddingLayer(bool ulr = false) const;
};
-class FactoredVocab;
-
-// To support factors, any output projection (that is followed by a softmax) must
-// retain multiple outputs, one for each factor. Such layer returns not a single Expr,
-// but a Logits object that contains multiple.
-// This allows to compute softmax values in a factored manner, where we never create
-// a fully expanded list of all factor combinations.
-class RationalLoss;
-class Logits {
-public:
- Logits() {}
- explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
- logits_.push_back(logits);
- }
- explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
- Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
- : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {}
- Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
- Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
- //Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
- Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const;
- Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
- Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values
-
- struct MaskedFactorIndices {
- std::vector<WordIndex> indices; // factor index, or 0 if masked
- std::vector<float> masks;
- void reserve(size_t n) { indices.reserve(n); masks.reserve(n); }
- void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries
- MaskedFactorIndices() {}
- MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case
- };
- std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices
- Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
- size_t getNumFactorGroups() const { return logits_.size(); }
- bool empty() const { return logits_.empty(); }
- Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_
-private:
- // helper functions
- Ptr<ExpressionGraph> graph() const;
- Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
- Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
- template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector
- Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type
- std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const;
-private:
- // members
- // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr
- std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
- Ptr<FactoredVocab> factoredVocab_;
-};
-
-// Unary function that returns a Logits object
-// Also implements IUnaryLayer, since Logits can be cast to Expr.
-// This interface is implemented by all layers that are of the form of a unary function
-// that returns multiple logits, to support factors.
-struct IUnaryLogitLayer : public IUnaryLayer {
- virtual Logits applyAsLogits(Expr) = 0;
- virtual Logits applyAsLogits(const std::vector<Expr>& es) {
- ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
- return applyAsLogits(es.front());
- }
- virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); }
- virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); }
-};
-
namespace mlp {
class Dense : public LayerBase, public IUnaryLayer {
diff --git a/src/layers/loss.h b/src/layers/loss.h
index d7bc19e4..ba93cdac 100644
--- a/src/layers/loss.h
+++ b/src/layers/loss.h
@@ -1,7 +1,7 @@
#pragma once
#include "graph/expression_operators.h"
-#include "layers/generic.h" // for Logits (Frank's factor hack)
+#include "layers/logits.h" // for Logits (Frank's factor hack)
#include "data/types.h"
namespace marian {
diff --git a/src/layers/output.h b/src/layers/output.h
index d091556a..92e7eb25 100644
--- a/src/layers/output.h
+++ b/src/layers/output.h
@@ -2,6 +2,7 @@
#include "marian.h"
#include "generic.h"
+#include "logits.h"
#include "data/shortlist.h"
#include "layers/factory.h"
diff --git a/src/models/states.h b/src/models/states.h
index c2f9ee05..cfb6fd1b 100644
--- a/src/models/states.h
+++ b/src/models/states.h
@@ -1,7 +1,7 @@
#pragma once
#include "marian.h"
-#include "layers/generic.h" // @HACK: for factored embeddings only so far
+#include "layers/logits.h" // @HACK: for factored embeddings only so far
#include "rnn/types.h"
namespace marian {