Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQianqian Zhu <qianqian.zhu@hotmail.com>2022-01-26 18:17:38 +0300
committerGitHub <noreply@github.com>2022-01-26 18:17:38 +0300
commit71b5454b9eb441b2d802c2a6a3be6c0be3f6a30c (patch)
treef3454498efaff2f718afa6acd39e93b43effad75
parent3b458b044e6b2695ba6ad0786320ea043d076772 (diff)
Layer documentation (#892)
* More examples for MLP layers and docs about RNN layers * Docs about embedding layer and more doxygen code docs * Add layer and factors docs into index.rst * Update layer documentation * Fix typos Co-authored-by: Roman Grundkiewicz <rgrundkiewicz@gmail.com> Co-authored-by: Graeme Nail <graemenail.work@gmail.com>
-rw-r--r--doc/factors.md2
-rw-r--r--doc/index.rst3
-rw-r--r--doc/layer.md241
-rw-r--r--doc/operators.md2
-rw-r--r--src/layers/constructors.h101
-rw-r--r--src/layers/embedding.h37
-rw-r--r--src/layers/factory.h19
-rw-r--r--src/layers/generic.h62
-rw-r--r--src/rnn/attention_constructors.h1
-rw-r--r--src/rnn/cells.h1
-rw-r--r--src/rnn/constructors.h17
11 files changed, 446 insertions, 40 deletions
diff --git a/doc/factors.md b/doc/factors.md
index 59e14b68..dbd953b9 100644
--- a/doc/factors.md
+++ b/doc/factors.md
@@ -1,4 +1,4 @@
-# Using marian with factors
+# Using Marian with factors
Following this README should allow the user to train a model with source and/or target side factors. To train with factors, the data must be formatted in a certain way. A special vocabulary file format is also required, and its extension should be `.fsv` as providing a source and/or target vocabulary file with this extension is what triggers the usage of source and/or target factors. See details below.
diff --git a/doc/index.rst b/doc/index.rst
index d0a4fefb..a790e624 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -14,7 +14,8 @@ This is developer documentation. User documentation is available at https://mari
graph
operators
-
+ layer
+ factors
api/library_index
contributing
diff --git a/doc/layer.md b/doc/layer.md
new file mode 100644
index 00000000..295a3153
--- /dev/null
+++ b/doc/layer.md
@@ -0,0 +1,241 @@
+# Layers
+
+In a typical deep neural network, highest-level blocks, which perform different kinds of
+transformations on their inputs are called layers. A layer wraps a group of nodes and performs a
+specific mathematical computation, offering a shortcut for building a more complex neural network.
+
+In Marian, for example, the `mlp::dense` layer represents a fully connected layer, which implements
+the operation `output = activation(input * weight + bias)`. A dense layer in the graph can be
+constructed with the following code:
+```cpp
+// add input node x
+auto x = graph->constant({120,5}, inits::fromVector(inputData));
+// construct a dense layer in the graph
+auto layer1 = mlp::dense()
+ ("prefix", "layer1") // prefix name is layer1
+ ("dim", 5) // output dimension is 5
+ ("activation", (int)mlp::act::tanh) // activation function is tanh
+ .construct(graph)->apply(x); // construct this layer in graph
+ // and link node x as the input
+```
+The options are passed to the layer using pairs of `(key, value)`, where `key` is a predefined
+option, and `value` is the option value. Then `construct()` is called to create a layer instance in
+the graph, and `apply()` to link the input with this layer.
+
+Alternatively, the same layer can be created defining nodes and operations directly:
+```cpp
+// construct a dense layer using nodes
+auto W1 = graph->param("W1", {120, 5}, inits::glorotUniform());
+auto b1 = graph->param("b1", {1, 5}, inits::zeros());
+auto h = tanh(affine(x, W1, b1));
+```
+There are four categories of layers implemented in Marian, described in the sections below.
+
+## Convolution layer
+
+To use a `convolution` layer, you first need to install [NVIDIA cuDNN](https://developer.nvidia.com/cudnn).
+The convolution layer supported by Marian is a 2D
+[convolution layer](https://en.wikipedia.org/wiki/Convolutional_neural_network#Convolutional_layers).
+This layer creates a convolution kernel which is used to convolved with the input. The options that
+can be passed to a `convolution` layer are the following:
+
+| Option Name | Definition | Value Type | Default Value |
+| ------------- |----------------|---------------|---------------|
+| prefix | Prefix name (used to form the parameter names) | `std::string` | `None` |
+| kernel-dims | The height and width of the kernel | `std::pair<int, int>` | `None`|
+| kernel-num | The number of kernel | `int` | `None` |
+| paddings | The height and width of paddings | `std::pair<int, int>` | `(0,0)`|
+| strides | The height and width of strides | `std::pair<int, int>` | `(1,1)` |
+
+Example:
+```cpp
+// construct a convolution layer
+auto conv_1 = convolution(graph) // pass graph pointer to the layer
+ ("prefix", "conv_1") // prefix name is conv_1
+ ("kernel-dims", std::make_pair(3,3)) // kernel is 3*3
+ ("kernel-num", 32) // kernel no. is 32
+ .apply(x); // link node x as the input
+```
+
+## MLP layers
+
+Marian offers `mlp::mlp`, which creates a
+[multilayer perceptron (MLP)](https://en.wikipedia.org/wiki/Multilayer_perceptron) network.
+It is a container which can stack multiple layers using `push_back()` function. There are two types
+of MLP layers provided by Marian: `mlp::dense` and `mlp::output`.
+
+The `mlp::dense` layer, as introduced before, is a fully connected layer, and it accepts the
+following options:
+
+| Option Name | Definition | Value Type | Default Value |
+| ------------- |----------------|---------------|---------------|
+| prefix | Prefix name (used to form the parameter names) | `std::string` | `None` |
+| dim | Output dimension | `int` | `None` |
+| layer-normalization | Whether to normalise the layer output or not | `bool` | `false` |
+| nematus-normalization | Whether to use Nematus layer normalisation or not | `bool` | `false` |
+| activation | Activation function | `int` | `mlp::act::linear` |
+
+The available activation functions for mlp are `mlp::act::linear`, `mlp::act::tanh`,
+`mlp::act::sigmoid`, `mlp::act::ReLU`, `mlp::act::LeakyReLU`, `mlp::act::PReLU`, and
+`mlp::act::swish`.
+
+Example:
+```cpp
+// construct a mlp::dense layer
+auto dense_layer = mlp::dense()
+ ("prefix", "dense_layer") // prefix name is dense_layer
+ ("dim", 3) // output dimension is 3
+ ("activation", (int)mlp::act::sigmoid) // activation function is sigmoid
+ .construct(graph)->apply(x); // construct this layer in graph and link node x as the input
+```
+
+The `mlp::output` layer is used, as the name suggests, to construct an output layer. You can tie
+embedding layers to `mlp::output` layer using `tieTransposed()`, or set shortlisted words using
+`setShortlist()`. The general options of `mlp::output` layer are listed below:
+
+| Option Name | Definition | Value Type | Default Value |
+| ------------- |----------------|---------------|---------------|
+| prefix | Prefix name (used to form the parameter names) | `std::string` | `None` |
+| dim | Output dimension | `int` | `None` |
+| vocab | File path to the factored vocabulary | `std::string` | `None` |
+| output-omit-bias | Whether this layer has a bias parameter | `bool` | `true` |
+| lemma-dim-emb | Re-embedding dimension of lemma in factors, must be used with `vocab` option | `int` | `0` |
+| output-approx-knn | Parameters for LSH-based output approximation, i.e., `k` (the first element) and `nbit` (the second element) | `std::vector<int>` | None |
+
+Example:
+```cpp
+// construct a mlp::output layer
+auto last = mlp::output()
+ ("prefix", "last") // prefix name is dense_layer
+ ("dim", 5); // output dimension is 5
+```
+Finally, an example showing how to create a `mlp::mlp` network containing multiple layers:
+```cpp
+// construct a mlp::mlp network
+auto mlp_networks = mlp::mlp() // construct a mpl container
+ .push_back(mlp::dense() // construct a dense layer
+ ("prefix", "dense") // prefix name is dense
+ ("dim", 5) // dimension is 5
+ ("activation", (int)mlp::act::tanh))// activation function is tanh
+ .push_back(mlp::output() // construct a output layer
+ ("dim", 5)) // dimension is 5
+ ("prefix", "mlp_network") // prefix name is mlp_network
+ .construct(graph); // construct this mlp layers in graph
+```
+
+## RNN layers
+Marian offers `rnn::rnn` for creating a [recurrent neural network
+(RNN)](https://en.wikipedia.org/wiki/Recurrent_neural_network) network. Just like `mlp::mlp`,
+`rnn::rnn` is a container which can stack multiple layers using `push_back()` function. Unlike mlp
+layers, Marian only provides cell-level APIs to construct RNN. RNN cells only process a single
+timestep instead of the whole batches of input sequences. There are two types of rnn layers provided
+by Marian: `rnn::cell` and `rnn::stacked_cell`.
+
+The `rnn::cell` is the base component of RNN and `rnn::stacked_cell` is a stack of `rnn::cell`. The
+few options of `rnn::cell` layer are listed below:
+
+| Option Name | Definition | Value Type | Default Value |
+| ------------- |----------------|---------------|---------------|
+| type | Type of RNN cell | `std::string` | `None` |
+
+There are nine types of RNN cells provided by Marian: `gru`, `gru-nematus`, `lstm`, `mlstm`, `mgru`,
+`tanh`, `relu`, `sru`, `ssru`. The general options for all RNN cells are the following:
+
+| Option Name | Definition | Value Type | Default Value |
+| ------------- |----------------|---------------|---------------|
+| dimInput | Input dimension | `int` | `None` |
+| dimState | Dimension of hidden state | `int` | `None` |
+| prefix | Prefix name (used to form the parameter names) | `std::string` | `None` |
+| layer-normalization | Whether to normalise the layer output or not | `bool` | `false` |
+| dropout | Dropout probability | `float` | `0` |
+| transition | Whether it is a transition layer | `bool` | `false` |
+| final | Whether it is an RNN final layer or hidden layer | `bool` | `false` |
+
+```{note}
+Not all the options listed above are available for all the cells. For example, `final` option is
+only used for `gru` and `gru-nematus` cells.
+```
+
+Example for `rnn::cell`:
+```cpp
+// construct a rnn cell
+auto rnn_cell = rnn::cell()
+ ("type", "gru") // type of rnn cell is gru
+ ("prefix", "gru_cell") // prefix name is gru_cell
+ ("final", false); // this cell is the final layer
+```
+Example for `rnn::stacked_cell`:
+```cpp
+// construct a stack of rnn cells
+auto highCell = rnn::stacked_cell();
+// for loop to add rnn cells into the stack
+for(size_t j = 1; j <= 512; j++) {
+ auto paramPrefix ="cell" + std::to_string(j);
+ highCell.push_back(rnn::cell()("prefix", paramPrefix));
+}
+```
+
+The list of available options for `rnn::rnn` layers:
+
+| Option Name | Definition | Value Type | Default Value |
+| ------------- |----------------|---------------|---------------|
+| type | Type of RNN layer | `std::string` | `gru` |
+| direction | RNN direction | `int` | `rnn::dir::forward` |
+| dimInput | Input dimension | `int` | `None` |
+| dimState | Dimension of hidden state | `int` | `None` |
+| prefix | Prefix name (used to form the parameter names) | `std::string` | `None` |
+| layer-normalization | Whether to normalise the layer output or not | `bool` | `false` |
+| nematus-normalization | Whether to use Nematus layer normalisation or not | `bool` | `false` |
+| dropout | Dropout probability | `float` | `0` |
+| skip | Whether to use skip connections | `bool` | `false` |
+| skipFirst | Whether to use skip connections for the layer(s) with `index > 0` | `bool` | `false` |
+
+Examples for `rnn::rnn()`:
+```cpp
+// construct a `rnn::rnn()` container
+auto rnn_container = rnn::rnn(
+ "type", "gru", // type of rnn cell is gru
+ "prefix", "rnn_layers", // prefix name is rnn_layers
+ "dimInput", 10, // input dimension is 10
+ "dimState", 5, // dimension of hidden state is 5
+ "dropout", 0, // dropout probability is 0
+ "layer-normalization", false) // do not normalise the layer output
+ .push_back(rnn::cell()) // add a rnn::cell in this rnn container
+ .construct(graph); // construct this rnn container in graph
+```
+Marian provides four RNN directions in `rnn::dir` enumerator: `rnn::dir::forward`,
+`rnn::dir::backward`, `rnn::dir::alternating_forward` and `rnn::dir::alternating_backward`.
+For rnn::rnn(), you can use `transduce()` to map the input state to the output state.
+
+An example for `transduce()`:
+```cpp
+auto output = rnn.construct(graph)->transduce(input);
+```
+
+## Embedding layer
+Marian provides a shortcut to construct a regular embedding layer `embedding` for words embedding.
+For `embedding` layers, there are following options available:
+
+| Option Name | Definition | Value Type | Default Value |
+| ------------- |----------------|---------------|---------------|
+| dimVocab | Size of vocabulary| `int` | `None` |
+| dimEmb | Size of embedding vector | `int` | `None` |
+| dropout | Dropout probability | `float` | `0` |
+| inference | Whether it is used for inference | `bool` | `false` |
+| prefix | Prefix name (used to form the parameter names) | `std::string` | `None` |
+| fixed | whether this layer is fixed (not trainable) | `bool` | `false` |
+| dimFactorEmb | Size of factored embedding vector | `int` | `None` |
+| factorsCombine | Which strategy is chosen to combine the factor embeddings; it can be `"concat"` | `std::string` | `None` |
+| vocab | File path to the factored vocabulary | `std::string` | `None` |
+| embFile | Paths to the factored embedding vectors | `std::string>` | `None` |
+| normalization | Whether to normalise the layer output or not | `bool` | `false` |
+
+Example to construct an embedding layer:
+```cpp
+// construct an embedding layer
+auto embedding_layer = embedding()
+ ("prefix", "embedding") // prefix name is embedding
+ ("dimVocab", 1024) // vocabulary size is 1024
+ ("dimEmb", 512) // size of embedding vector is 512
+ .construct(graph); // construct this embedding layer in graph
+```
diff --git a/doc/operators.md b/doc/operators.md
index 2cca391b..1e7bba96 100644
--- a/doc/operators.md
+++ b/doc/operators.md
@@ -1,4 +1,4 @@
-# Operations in the Expression Graph
+# Operations in the expression graph
Operations are responsible for manipulating the elements of an expression graph.
In Marian, many useful operations have already been implemented and can be found
diff --git a/src/layers/constructors.h b/src/layers/constructors.h
index 9e9de207..5597a6a4 100644
--- a/src/layers/constructors.h
+++ b/src/layers/constructors.h
@@ -12,6 +12,11 @@ namespace mlp {
* Base class for layer factories, can be used in a multi-layer network factory.
*/
struct LayerFactory : public Factory {
+ /**
+ * Construct a layer instance in a given graph.
+ * @param graph a shared pointer a graph
+ * @return a shared pointer to the layer object
+ */
virtual Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) = 0;
};
@@ -31,18 +36,24 @@ public:
}
};
-// @TODO: change naming convention
+/**
+ * A convenient typedef for constructing a MLP dense layer.
+ * @TODO: change naming convention
+ */
typedef Accumulator<DenseFactory> dense;
/**
- * Factory for output layers, can be used in a multi-layer network factory.
+ * Base factory for output layers, can be used in a multi-layer network factory.
*/
struct LogitLayerFactory : public Factory {
using Factory::Factory;
virtual Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) = 0;
};
-// @TODO: In the long run, I hope we can get rid of the abstract factories altogether.
+/**
+ * Implementation of Output layer factory, can be used in a multi-layer network factory.
+ * @TODO: In the long run, I hope we can get rid of the abstract factories altogether.
+ */
class OutputFactory : public LogitLayerFactory {
using LogitLayerFactory::LogitLayerFactory;
@@ -74,12 +85,13 @@ public:
}
};
-// @TODO: change naming convention
-typedef Accumulator<OutputFactory> output;
-
/**
- * Multi-layer network, holds and applies layers.
+ * A convenient typedef for constructing a MLP output layer.
+ * @TODO: change naming convention
*/
+typedef Accumulator<OutputFactory> output;
+
+/** Multi-layer network, holds and applies layers. */
class MLP : public IUnaryLogitLayer, public IHasShortList {
protected:
Ptr<ExpressionGraph> graph_;
@@ -88,8 +100,17 @@ protected:
std::vector<Ptr<IUnaryLayer>> layers_;
public:
+ /**
+ * Construct a MLP container in the graph.
+ * @param graph The expression graph.
+ * @param options The options used for this mlp container.
+ */
MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
-
+ /**
+ * Apply/Link a vector of mlp layers (with the given inputs) to the expression graph.
+ * @param av The vector of input expressions
+ * @return The expression holding the mlp container
+ */
Expr apply(const std::vector<Expr>& av) override {
Expr output;
if(av.size() == 1)
@@ -102,7 +123,12 @@ public:
return output;
}
-
+ /**
+ * Apply/Link a vector of mlp layers (with the given inputs) to the expression graph.
+ * @param av The vector of input expressions
+ * @return The expression holding the mlp container as a
+ * <a href=https://en.wikipedia.org/wiki/Logit>Logits</a> object
+ */
Logits applyAsLogits(const std::vector<Expr>& av) override {
// same as apply() except for the last layer, we invoke applyAsLogits(), which has a different
// return type
@@ -126,13 +152,33 @@ public:
return lastLayer->applyAsLogits(output);
}
}
-
+ /**
+ * Apply/Link a mlp layer (with the given input) to the expression graph.
+ * @param e The input expression
+ * @return The expression holding the mlp container
+ */
Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); }
+ /**
+ * Apply/Link a mlp layer (with the given input) to the expression graph.
+ * @param e The input expression
+ * @return The expression holding the mlp container as a
+ * <a href=https://en.wikipedia.org/wiki/Logit>Logits</a> object
+ */
Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); }
-
+ /**
+ * Stack a mlp layer to the mlp container.
+ * @param layer The mlp layer
+ */
void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); }
+ /**
+ * Stack a mlp layer with <a href=https://en.wikipedia.org/wiki/Logit>Logits</a> object to the mlp container.
+ * @param layer The mlp layer with <a href=https://en.wikipedia.org/wiki/Logit>Logits</a> object
+ */
void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); }
-
+ /**
+ * Set shortlisted words to the mlp container.
+ * @param shortlist The given shortlisted words
+ */
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
auto p = tryAsHasShortlist();
ABORT_IF(
@@ -140,7 +186,7 @@ public:
"setShortlist() called on an MLP with an output layer that does not support short lists");
p->setShortlist(shortlist);
}
-
+ /** Remove shortlisted words from the mlp container. */
void clear() override final {
auto p = tryAsHasShortlist();
if(p)
@@ -154,8 +200,8 @@ private:
};
/**
- * Multi-layer network factory. Can hold layer factories. Used
- * to accumulate options for later lazy construction.
+ * Multi-layer network factory. Can hold layer factories.
+ * Used to accumulate options for later lazy construction.
*/
class MLPFactory : public Factory {
using Factory::Factory;
@@ -164,6 +210,12 @@ private:
std::vector<Ptr<LayerFactory>> layers_;
public:
+ /**
+ * Create a MLP container instance in the expression graph.
+ * Used to accumulate options for later lazy construction.
+ * @param graph The expression graph
+ * @return The shared pointer to the MLP container
+ */
Ptr<MLP> construct(Ptr<ExpressionGraph> graph) {
auto mlp = New<MLP>(graph, options_);
for(auto layer : layers_) {
@@ -172,7 +224,11 @@ public:
}
return mlp;
}
-
+ /**
+ * Stack a layer to the mlp container.
+ * @param lf The layer
+ * @return The Accumulator object holding the mlp container
+ */
template <class LF>
Accumulator<MLPFactory> push_back(const LF& lf) {
layers_.push_back(New<LF>(lf));
@@ -201,6 +257,11 @@ private:
}
public:
+ /**
+ * Stack a mlp output layer to the mlp container.
+ * @param lf The mlp output layer
+ * @return The Accumulator object holding the mlp container
+ */
Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) {
push_back(AsLayerFactory<OutputFactory>(lf));
// layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
@@ -208,13 +269,19 @@ public:
}
};
-// @TODO: change naming convention.
+
+/**
+ * A convenient typedef for constructing MLP layers.
+ * @TODO: change naming convention.
+ */
typedef Accumulator<MLPFactory> mlp;
} // namespace mlp
typedef ConstructingFactory<Embedding> EmbeddingFactory;
typedef ConstructingFactory<ULREmbedding> ULREmbeddingFactory;
+/** A convenient typedef for constructing a standard embedding layers. */
typedef Accumulator<EmbeddingFactory> embedding;
+/** A convenient typedef for constructing ULR word embedding layers. */
typedef Accumulator<ULREmbeddingFactory> ulr_embedding;
} // namespace marian
diff --git a/src/layers/embedding.h b/src/layers/embedding.h
index d34c7ffb..af22b980 100644
--- a/src/layers/embedding.h
+++ b/src/layers/embedding.h
@@ -6,10 +6,12 @@ namespace marian {
class FactoredVocab;
-// A regular embedding layer.
-// Note that this also applies dropout if the option is passed (pass 0 when in inference mode).
-// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in
-// EncoderDecoderLayerBase, which knows to pass on all required parameters from options.
+/**
+ * A regular embedding layer.
+ * Note that this also applies dropout if the option is passed (pass 0 when in inference mode).
+ * It is best to not use Embedding directly, but rather via getEmbeddingLayer() in
+ * EncoderDecoderLayerBase, which knows to pass on all required parameters from options.
+ */
class Embedding : public LayerBase, public IEmbeddingLayer {
Expr E_;
Expr FactorEmbMatrix_; // Factors embedding matrix if combining lemma and factors embeddings with concatenation
@@ -19,16 +21,43 @@ class Embedding : public LayerBase, public IEmbeddingLayer {
bool inference_{false};
public:
+ /**
+ * Construct a regular embedding layer in the graph.
+ * @param graph The expression graph.
+ * @param options The options used for this embedding layer.
+ */
Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options);
+ /**
+ * Apply/Link this embedding layer (with the given batch of sentences) to the expression graph.
+ * @param subBatch The batch of sentences
+ * @return The expression tuple holding the embedding layer and the masking layer
+ */
std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
Ptr<data::SubBatch> subBatch) const override final;
+ /**
+ * Apply/Link this embedding layer (with the given words and shape) to the expression graph.
+ * @param words Sequence of vocabulary items
+ * @param shape Shape of the words
+ * @return The expression holding the embedding layer
+ */
Expr apply(const Words& words, const Shape& shape) const override final;
+ /**
+ * Apply/Link this embedding layer (with the given WordIndex vector and shape) to the expression graph.
+ * @param embIdx The vector of WordIndex objects
+ * @param shape Shape of the WordIndex vector
+ * @return The expression holding the embedding layer
+ */
Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final;
};
+/**
+ * Universal Language Representation (ULR) word embedding layer.
+ * It is under development.
+ * @todo applyIndices() is not implemented
+ */
class ULREmbedding : public LayerBase, public IEmbeddingLayer {
std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
bool inference_{false};
diff --git a/src/layers/factory.h b/src/layers/factory.h
index f9e4ddf9..df092199 100644
--- a/src/layers/factory.h
+++ b/src/layers/factory.h
@@ -3,7 +3,10 @@
#include "marian.h"
namespace marian {
-
+/**
+ * Base class for constructing models or layers.
+ * Its main attribute is options which hold the basic characteristics of the model or the layer.
+ */
class Factory : public std::enable_shared_from_this<Factory> {
protected:
Ptr<Options> options_;
@@ -68,8 +71,7 @@ public:
template <class Cast>
inline bool is() { return std::dynamic_pointer_cast<Cast>(shared_from_this()) != nullptr; }
};
-
-// simplest form of Factory that just passes on options to the constructor of a layer type
+/** Simplest form of Factory that just passes on options to the constructor of a layer. */
template<class Class>
struct ConstructingFactory : public Factory {
using Factory::Factory;
@@ -79,6 +81,17 @@ struct ConstructingFactory : public Factory {
}
};
+/**
+ * Accumulator<Factory> pattern offers a shortcut to construct models or layers.
+ * The options can be passed by a pair of parentheses. E.g., to construct a fully-connected layer:
+ * \code{.cpp}
+ * auto hidden = mlp::dense()
+ ("prefix", "hidden_layer") // layer name
+ ("dim", outDim) // output dimension
+ ("activation", (int)mlp::act::sigmoid) // activation function
+ .construct(graph); // construct this layer in graph
+ \endcode
+ */
template <class BaseFactory> // where BaseFactory : Factory
class Accumulator : public BaseFactory {
typedef BaseFactory Factory;
diff --git a/src/layers/generic.h b/src/layers/generic.h
index 9af033df..b423befe 100644
--- a/src/layers/generic.h
+++ b/src/layers/generic.h
@@ -9,18 +9,19 @@
namespace marian {
namespace mlp {
-/**
- * @brief Activation functions
- */
+/** Activation functions for MLP layers. */
enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
} // namespace mlp
} // namespace marian
namespace marian {
-// Each layer consists of LayerBase and IXXXLayer which defines one or more apply()
-// functions for the respective layer type (different layers may require different signatures).
-// This base class contains configuration info for creating parameters and executing apply().
+/**
+ * Base class for a layer.
+ * Each layer consists of LayerBase and IXXXLayer which defines one or more apply()
+ * functions for the respective layer type (different layers may require different signatures).
+ * This base class contains configuration info for creating parameters and executing apply().
+ */
class LayerBase {
protected:
Ptr<ExpressionGraph> graph_;
@@ -40,22 +41,25 @@ public:
}
};
-// Simplest layer interface: Unary function
+/** Simplest layer interface: Unary function. */
struct IUnaryLayer {
virtual ~IUnaryLayer() {}
+ /** Link a node as the input for this layer. */
virtual Expr apply(Expr) = 0;
+ /** Link a list of nodes as the inputs for this layer. */
virtual Expr apply(const std::vector<Expr>& es) {
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
return apply(es.front());
}
};
+/** Shortlist interface for layers. */
struct IHasShortList {
virtual void setShortlist(Ptr<data::Shortlist> shortlist) = 0;
virtual void clear() = 0;
};
-// Embedding from corpus sub-batch to (emb, mask)
+/** Embedding from corpus sub-batch to (emb, mask). */
struct IEmbeddingLayer {
virtual std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
Ptr<data::SubBatch> subBatch) const = 0;
@@ -67,8 +71,10 @@ struct IEmbeddingLayer {
virtual ~IEmbeddingLayer() {}
};
-// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream
-// index)
+/**
+ * Base class for Encoder and Decoder classes.
+ * Have embeddings and a batch index (=stream index).
+ */
class EncoderDecoderLayerBase : public LayerBase {
protected:
const std::string prefix_;
@@ -98,16 +104,42 @@ private:
Ptr<IEmbeddingLayer> createULREmbeddingLayer() const;
public:
- // get embedding layer; lazily create on first call
+ /**
+ * Get all embedding layer(s).
+ * It lazily creates the embedding layer on first call.
+ * This is lazy mostly because the constructors of the consuming objects are not
+ * guaranteed presently to have access to their graph.
+ * @param ulr whether to use ULREmbedding layer. false by default.
+ * @return a shared pointer to the embedding layer
+ */
Ptr<IEmbeddingLayer> getEmbeddingLayer(bool ulr = false) const;
};
+/**
+ * The namespace mlp.
+ * Declare class Dense and all the available functions for creating
+ * <a href=https://en.wikipedia.org/wiki/Multilayer_perceptron>multilayer perceptron (MLP)</a>
+ * network.
+ */
namespace mlp {
+/**
+ * Base class for a fully connected layer.
+ * Implement the operations `output = activation(input * weight + bias)`.
+ */
class Dense : public LayerBase, public IUnaryLayer {
public:
+ /**
+ * Construct a dense layer in the graph.
+ * @param graph The expression graph.
+ * @param options The options used for this dense layer.
+ */
Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {}
-
+ /**
+ * Apply/Link a vector of dense layers (with the given inputs) to the expression graph.
+ * @param inputs The vector of the input expressions
+ * @return The expression holding the dense layers
+ */
Expr apply(const std::vector<Expr>& inputs) override {
ABORT_IF(inputs.empty(), "No inputs");
@@ -161,7 +193,11 @@ public:
}
// clang-format on
};
-
+ /**
+ * Apply/Link this dense layer (with the given input) to the expression graph.
+ * @param input The input expression
+ * @return The expression holding the dense layer
+ */
Expr apply(Expr input) override { return apply(std::vector<Expr>({input})); }
};
diff --git a/src/rnn/attention_constructors.h b/src/rnn/attention_constructors.h
index a878f57f..4ad1975e 100644
--- a/src/rnn/attention_constructors.h
+++ b/src/rnn/attention_constructors.h
@@ -33,6 +33,7 @@ public:
}
};
+/** A convenient typedef for constructing RNN attention layers. */
typedef Accumulator<AttentionFactory> attention;
} // namespace rnn
} // namespace marian
diff --git a/src/rnn/cells.h b/src/rnn/cells.h
index cddfd26e..18ac4d1d 100644
--- a/src/rnn/cells.h
+++ b/src/rnn/cells.h
@@ -197,6 +197,7 @@ public:
Expr gruOps(const std::vector<Expr>& nodes, bool final = false);
+/** Base class for a gated recurrent unit (GRU) cell. */
class GRU : public Cell {
protected:
std::string prefix_;
diff --git a/src/rnn/constructors.h b/src/rnn/constructors.h
index beb1fce1..22acfe9e 100644
--- a/src/rnn/constructors.h
+++ b/src/rnn/constructors.h
@@ -5,6 +5,12 @@
#include "rnn/rnn.h"
namespace marian {
+/**
+ * The namespace rnn.
+ * Declare class Dense and all the available functions for creating
+ * <a href=https://en.wikipedia.org/wiki/Recurrent_neural_network>recurrent neural network (RNN)</a>
+ * network.
+ */
namespace rnn {
typedef Factory StackableFactory;
@@ -28,6 +34,12 @@ struct InputFactory : public StackableFactory {
virtual Ptr<CellInput> construct(Ptr<ExpressionGraph> graph) = 0;
};
+/**
+ * Base class for constructing RNN cells.
+ * RNN cells only process a single timestep instead of the whole batches of input sequences.
+ * There are nine types of RNN cells provided by Marian, i.e., `gru`, `gru-nematus`, `lstm`,
+ * `mlstm`, `mgru`, `tanh`, `relu`, `sru`, `ssru`.
+ */
class CellFactory : public StackableFactory {
protected:
std::vector<std::function<Expr(Ptr<rnn::RNN>)>> inputs_;
@@ -92,8 +104,10 @@ public:
}
};
+/** A convenience typedef for constructing RNN cells. */
typedef Accumulator<CellFactory> cell;
+/** Base class for constructing a stack of RNN cells (`rnn::cell`). */
class StackedCellFactory : public CellFactory {
protected:
std::vector<Ptr<StackableFactory>> stackableFactories_;
@@ -137,8 +151,10 @@ public:
}
};
+/** A convenience typedef for constructing a stack of RNN cells. */
typedef Accumulator<StackedCellFactory> stacked_cell;
+/** Base class for constructing RNN layers. */
class RNNFactory : public Factory {
using Factory::Factory;
protected:
@@ -195,6 +211,7 @@ public:
}
};
+/** A convenience typedef for constructing RNN containers/layers. */
typedef Accumulator<RNNFactory> rnn;
} // namespace rnn
} // namespace marian