Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/layers/generic.h')
-rw-r--r--src/layers/generic.h62
1 files changed, 49 insertions, 13 deletions
diff --git a/src/layers/generic.h b/src/layers/generic.h
index 9af033df..b423befe 100644
--- a/src/layers/generic.h
+++ b/src/layers/generic.h
@@ -9,18 +9,19 @@
namespace marian {
namespace mlp {
-/**
- * @brief Activation functions
- */
+/** Activation functions for MLP layers. */
enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
} // namespace mlp
} // namespace marian
namespace marian {
-// Each layer consists of LayerBase and IXXXLayer which defines one or more apply()
-// functions for the respective layer type (different layers may require different signatures).
-// This base class contains configuration info for creating parameters and executing apply().
+/**
+ * Base class for a layer.
+ * Each layer consists of LayerBase and IXXXLayer which defines one or more apply()
+ * functions for the respective layer type (different layers may require different signatures).
+ * This base class contains configuration info for creating parameters and executing apply().
+ */
class LayerBase {
protected:
Ptr<ExpressionGraph> graph_;
@@ -40,22 +41,25 @@ public:
}
};
-// Simplest layer interface: Unary function
+/** Simplest layer interface: Unary function. */
struct IUnaryLayer {
virtual ~IUnaryLayer() {}
+ /** Link a node as the input for this layer. */
virtual Expr apply(Expr) = 0;
+ /** Link a list of nodes as the inputs for this layer. */
virtual Expr apply(const std::vector<Expr>& es) {
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
return apply(es.front());
}
};
+/** Shortlist interface for layers. */
struct IHasShortList {
virtual void setShortlist(Ptr<data::Shortlist> shortlist) = 0;
virtual void clear() = 0;
};
-// Embedding from corpus sub-batch to (emb, mask)
+/** Embedding from corpus sub-batch to (emb, mask). */
struct IEmbeddingLayer {
virtual std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
Ptr<data::SubBatch> subBatch) const = 0;
@@ -67,8 +71,10 @@ struct IEmbeddingLayer {
virtual ~IEmbeddingLayer() {}
};
-// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream
-// index)
+/**
+ * Base class for Encoder and Decoder classes.
+ * Have embeddings and a batch index (=stream index).
+ */
class EncoderDecoderLayerBase : public LayerBase {
protected:
const std::string prefix_;
@@ -98,16 +104,42 @@ private:
Ptr<IEmbeddingLayer> createULREmbeddingLayer() const;
public:
- // get embedding layer; lazily create on first call
+ /**
+ * Get all embedding layer(s).
+ * It lazily creates the embedding layer on first call.
+ * This is lazy mostly because the constructors of the consuming objects are not
+ * guaranteed presently to have access to their graph.
+ * @param ulr whether to use ULREmbedding layer. false by default.
+ * @return a shared pointer to the embedding layer
+ */
Ptr<IEmbeddingLayer> getEmbeddingLayer(bool ulr = false) const;
};
+/**
+ * The namespace mlp.
+ * Declare class Dense and all the available functions for creating
+ * <a href=https://en.wikipedia.org/wiki/Multilayer_perceptron>multilayer perceptron (MLP)</a>
+ * network.
+ */
namespace mlp {
+/**
+ * Base class for a fully connected layer.
+ * Implement the operations `output = activation(input * weight + bias)`.
+ */
class Dense : public LayerBase, public IUnaryLayer {
public:
+ /**
+ * Construct a dense layer in the graph.
+ * @param graph The expression graph.
+ * @param options The options used for this dense layer.
+ */
Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {}
-
+ /**
+ * Apply/Link a vector of dense layers (with the given inputs) to the expression graph.
+ * @param inputs The vector of the input expressions
+ * @return The expression holding the dense layers
+ */
Expr apply(const std::vector<Expr>& inputs) override {
ABORT_IF(inputs.empty(), "No inputs");
@@ -161,7 +193,11 @@ public:
}
// clang-format on
};
-
+ /**
+ * Apply/Link this dense layer (with the given input) to the expression graph.
+ * @param input The input expression
+ * @return The expression holding the dense layer
+ */
Expr apply(Expr input) override { return apply(std::vector<Expr>({input})); }
};