Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/layers/constructors.h')
-rw-r--r--src/layers/constructors.h101
1 files changed, 84 insertions, 17 deletions
diff --git a/src/layers/constructors.h b/src/layers/constructors.h
index 9e9de207..5597a6a4 100644
--- a/src/layers/constructors.h
+++ b/src/layers/constructors.h
@@ -12,6 +12,11 @@ namespace mlp {
* Base class for layer factories, can be used in a multi-layer network factory.
*/
struct LayerFactory : public Factory {
+ /**
+ * Construct a layer instance in a given graph.
+ * @param graph a shared pointer a graph
+ * @return a shared pointer to the layer object
+ */
virtual Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) = 0;
};
@@ -31,18 +36,24 @@ public:
}
};
-// @TODO: change naming convention
+/**
+ * A convenient typedef for constructing a MLP dense layer.
+ * @TODO: change naming convention
+ */
typedef Accumulator<DenseFactory> dense;
/**
- * Factory for output layers, can be used in a multi-layer network factory.
+ * Base factory for output layers, can be used in a multi-layer network factory.
*/
struct LogitLayerFactory : public Factory {
using Factory::Factory;
virtual Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) = 0;
};
-// @TODO: In the long run, I hope we can get rid of the abstract factories altogether.
+/**
+ * Implementation of Output layer factory, can be used in a multi-layer network factory.
+ * @TODO: In the long run, I hope we can get rid of the abstract factories altogether.
+ */
class OutputFactory : public LogitLayerFactory {
using LogitLayerFactory::LogitLayerFactory;
@@ -74,12 +85,13 @@ public:
}
};
-// @TODO: change naming convention
-typedef Accumulator<OutputFactory> output;
-
/**
- * Multi-layer network, holds and applies layers.
+ * A convenient typedef for constructing a MLP output layer.
+ * @TODO: change naming convention
*/
+typedef Accumulator<OutputFactory> output;
+
+/** Multi-layer network, holds and applies layers. */
class MLP : public IUnaryLogitLayer, public IHasShortList {
protected:
Ptr<ExpressionGraph> graph_;
@@ -88,8 +100,17 @@ protected:
std::vector<Ptr<IUnaryLayer>> layers_;
public:
+ /**
+ * Construct a MLP container in the graph.
+ * @param graph The expression graph.
+ * @param options The options used for this mlp container.
+ */
MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
-
+ /**
+ * Apply/Link a vector of mlp layers (with the given inputs) to the expression graph.
+ * @param av The vector of input expressions
+ * @return The expression holding the mlp container
+ */
Expr apply(const std::vector<Expr>& av) override {
Expr output;
if(av.size() == 1)
@@ -102,7 +123,12 @@ public:
return output;
}
-
+ /**
+ * Apply/Link a vector of mlp layers (with the given inputs) to the expression graph.
+ * @param av The vector of input expressions
+ * @return The expression holding the mlp container as a
+ * <a href=https://en.wikipedia.org/wiki/Logit>Logits</a> object
+ */
Logits applyAsLogits(const std::vector<Expr>& av) override {
// same as apply() except for the last layer, we invoke applyAsLogits(), which has a different
// return type
@@ -126,13 +152,33 @@ public:
return lastLayer->applyAsLogits(output);
}
}
-
+ /**
+ * Apply/Link a mlp layer (with the given input) to the expression graph.
+ * @param e The input expression
+ * @return The expression holding the mlp container
+ */
Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); }
+ /**
+ * Apply/Link a mlp layer (with the given input) to the expression graph.
+ * @param e The input expression
+ * @return The expression holding the mlp container as a
+ * <a href=https://en.wikipedia.org/wiki/Logit>Logits</a> object
+ */
Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); }
-
+ /**
+ * Stack a mlp layer to the mlp container.
+ * @param layer The mlp layer
+ */
void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); }
+ /**
+ * Stack a mlp layer with <a href=https://en.wikipedia.org/wiki/Logit>Logits</a> object to the mlp container.
+ * @param layer The mlp layer with <a href=https://en.wikipedia.org/wiki/Logit>Logits</a> object
+ */
void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); }
-
+ /**
+ * Set shortlisted words to the mlp container.
+ * @param shortlist The given shortlisted words
+ */
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
auto p = tryAsHasShortlist();
ABORT_IF(
@@ -140,7 +186,7 @@ public:
"setShortlist() called on an MLP with an output layer that does not support short lists");
p->setShortlist(shortlist);
}
-
+ /** Remove shortlisted words from the mlp container. */
void clear() override final {
auto p = tryAsHasShortlist();
if(p)
@@ -154,8 +200,8 @@ private:
};
/**
- * Multi-layer network factory. Can hold layer factories. Used
- * to accumulate options for later lazy construction.
+ * Multi-layer network factory. Can hold layer factories.
+ * Used to accumulate options for later lazy construction.
*/
class MLPFactory : public Factory {
using Factory::Factory;
@@ -164,6 +210,12 @@ private:
std::vector<Ptr<LayerFactory>> layers_;
public:
+ /**
+ * Create a MLP container instance in the expression graph.
+ * Used to accumulate options for later lazy construction.
+ * @param graph The expression graph
+ * @return The shared pointer to the MLP container
+ */
Ptr<MLP> construct(Ptr<ExpressionGraph> graph) {
auto mlp = New<MLP>(graph, options_);
for(auto layer : layers_) {
@@ -172,7 +224,11 @@ public:
}
return mlp;
}
-
+ /**
+ * Stack a layer to the mlp container.
+ * @param lf The layer
+ * @return The Accumulator object holding the mlp container
+ */
template <class LF>
Accumulator<MLPFactory> push_back(const LF& lf) {
layers_.push_back(New<LF>(lf));
@@ -201,6 +257,11 @@ private:
}
public:
+ /**
+ * Stack a mlp output layer to the mlp container.
+ * @param lf The mlp output layer
+ * @return The Accumulator object holding the mlp container
+ */
Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) {
push_back(AsLayerFactory<OutputFactory>(lf));
// layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
@@ -208,13 +269,19 @@ public:
}
};
-// @TODO: change naming convention.
+
+/**
+ * A convenient typedef for constructing MLP layers.
+ * @TODO: change naming convention.
+ */
typedef Accumulator<MLPFactory> mlp;
} // namespace mlp
typedef ConstructingFactory<Embedding> EmbeddingFactory;
typedef ConstructingFactory<ULREmbedding> ULREmbeddingFactory;
+/** A convenient typedef for constructing a standard embedding layers. */
typedef Accumulator<EmbeddingFactory> embedding;
+/** A convenient typedef for constructing ULR word embedding layers. */
typedef Accumulator<ULREmbeddingFactory> ulr_embedding;
} // namespace marian