Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGraeme <graemenail.work@gmail.com>2021-02-28 11:38:32 +0300
committerGitHub <noreply@github.com>2021-02-28 11:38:32 +0300
commitac71ee85181783d5d23c8dcc4a16e23ff32de355 (patch)
treef0916a1e2a8fdc1a024634e1c5d91665383b0681 /src
parent2a9c0bb3773c32c20fce74a8b0f7149478ebd8cb (diff)
Add graph operations documentation (#801)
* Doxygen structure for expression graph operators * Document arithmetic expression operations * Document comparison expression operations * Document exp/log and trig operations * Add missing implementation for cos/tan * Document expression manipulation operations * Document misc math operations * Overview of operators * Document activation functions * Document element-wise min/max * Document debugging/checkpoint operators * Document topk/argmin/argmax operations * Document index-based operations * Document reduction operations * Document lambda expression operators * Document product operations * Document softmax, cross-entropy, unlikelihood operations * Document dropout operations * Document scalar product and weighted average operations * Document layer normalization, highway and pooling operations * Document shift expression operator * Extra details on rules for adding specializations to .inc files * Add SinNodeOp example for specialization documentation * Additional details in tensor operator documentation * Remove brief command from doxygen comments * Prefer @ style doxygen functions to \ * Document n-ary function macros * Enable .cu and .inc files in documentation * Add a comment about ONNX mapping * Remove empty lines in doxygen * Update CHANGELOG Co-authored-by: Roman Grundkiewicz <rgrundkiewicz@gmail.com>
Diffstat (limited to 'src')
-rw-r--r--src/functional/predicates.h18
-rw-r--r--src/graph/expression_operators.cpp8
-rwxr-xr-xsrc/graph/expression_operators.h833
-rw-r--r--src/graph/node_operators_unary.h4
-rwxr-xr-xsrc/tensors/gpu/add.inc2
-rw-r--r--src/tensors/gpu/add_all.inc5
-rwxr-xr-xsrc/tensors/gpu/element.inc2
7 files changed, 790 insertions, 82 deletions
diff --git a/src/functional/predicates.h b/src/functional/predicates.h
index 420a88a3..e71d225b 100644
--- a/src/functional/predicates.h
+++ b/src/functional/predicates.h
@@ -39,6 +39,12 @@ struct BinaryFunctor {
}
};
+/**
+ * Macro to set up unary-functions from marian::functional::Ops.
+ * @param name name for the struct
+ * @param name2 callable typedef
+ * @param func function wrapped
+ */
#define UNARY(name, name2, func) \
namespace elem { \
struct name { \
@@ -55,6 +61,12 @@ struct BinaryFunctor {
} \
static inline name<Capture> name2(Capture x) { return name<Capture>(x); }
+/**
+ * Macro to set up binary-functions from marian::functional::Ops.
+ * @param name name for the struct
+ * @param name2 callable typedef
+ * @param func function wrapped
+ */
#define BINARY(name, name2, func) \
namespace elem { \
struct name { \
@@ -95,6 +107,12 @@ struct TernaryFunctor {
}
};
+/**
+ * Macro to set up ternary-functions from marian::functional::Ops.
+ * @param name name for the struct
+ * @param name2 callable typedef
+ * @param func function wrapped
+ */
#define TERNARY(name, name2, func) \
namespace elem { \
struct name { \
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 3d42c600..25da1e38 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -72,6 +72,14 @@ Expr sin(Expr a) {
return Expression<SinNodeOp>(a);
};
+Expr cos(Expr a) {
+ return Expression<CosNodeOp>(a);
+};
+
+Expr tan(Expr a) {
+ return Expression<TanNodeOp>(a);
+};
+
Expr swish(Expr a) {
return Expression<SwishNodeOp>(a);
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index a4a2eeee..ca0739e4 100755
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -3,145 +3,489 @@
#include "graph/node_initializers.h"
namespace marian {
+///@defgroup graph_ops Expression Graph Operators
+///@{
+/**
+ * Assigns a debug message to the expression.
+ */
Expr debug(Expr a, const std::string& message = "");
+/**
+ * Marks the expression as a gradient-checkpoint.
+ */
Expr checkpoint(Expr a);
-typedef Expr(ActivationFunction)(Expr);
-
-typedef std::function<void(Expr, const std::vector<Expr>&)> LambdaNodeFunctor;
-Expr lambda(const std::vector<Expr>&, Shape, Type, LambdaNodeFunctor);
-Expr lambda(const std::vector<Expr>&, Shape, Type, LambdaNodeFunctor, LambdaNodeFunctor);
-
-Expr plus(const std::vector<Expr>&);
-
-// TODO: should be logistic(), not sigmoid()
+typedef Expr(ActivationFunction)(Expr); ///< ActivationFunction has signature Expr(Expr)
+
+/**
+ * Convience typedef for graph @ref lambda expressions.
+ */
+typedef std::function<void(Expr out, const std::vector<Expr>& in)> LambdaNodeFunctor;
+
+/**
+ * Arbitrary node with forward operation only.
+ */
+Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd);
+
+/**
+ * Arbitrary node with forward and backward operation.
+ */
+Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd);
+
+/**
+ * @addtogroup graph_ops_activation Activation Functions
+ * Provides various activation functions for use in the expression.
+ * @ingroup graph_ops
+ * @{
+ */
+
+/**
+ * Linear Activation Function.
+ * Returns @p nodes[0]
+ */
+Expr plus(const std::vector<Expr>& nodes);
+
+/**
+ * Logistic Activation Function.
+ * Computes the <a href="https://en.wikipedia.org/wiki/Logistic_function">logistic function</a>
+ * of the given expression
+ * @todo rename sigmoid to logistic
+ */
Expr sigmoid(Expr a);
-Expr sigmoid(const std::vector<Expr>&);
+/**
+ * @copybrief sigmoid
+ * @warning not implemented
+ */
+Expr sigmoid(const std::vector<Expr>& nodes);
+
+/**
+ * Swish node.
+ * Computes the Swish activation function with @f$\beta=1 @f$
+ * @f[
+ * \operatorname{swish}(x) = x \cdot \operatorname{sigmoid}(\beta x)
+ * @f]
+ * @see SwishNodeOp
+ */
Expr swish(Expr a);
-Expr swish(const std::vector<Expr>&);
+/**
+ * @copybrief swish
+ * @warning not implemented for @p nodes of size > 1
+ * @returns swish(nodes[0])
+ */
+Expr swish(const std::vector<Expr>& nodes);
+
+/**
+ * Gaussian Error Linear Unit (GELU).
+ * Computes an _approxmiation_ to the Gaussian Error Linear Unit
+ * @f[
+ * \operatorname{gelu}(x) = x \cdot \Phi(x)
+ * = x \cdot \frac{1}{2}\left[
+ * 1 + \operatorname{erf}\left(\frac{x}{\sqrt{2}}\right)
+ * \right]
+ * \sim \operatorname{swish}(x, 1.702)
+ * @f]
+ * using @ref SwishNodeOp(a, 1.702)
+ * @see SwishNodeOp
+ */
Expr gelu(Expr a);
+
+/**
+ * @copybrief gelu
+ * @warning not implemented for @p nodes of size > 1
+ * @returns gelu(nodes[0])
+ */
Expr gelu(const std::vector<Expr>&);
-Expr tanh(const std::vector<Expr>&);
+/**
+ * Tanh.
+ * @see TanhNodeOp
+ */
+Expr tanh(const std::vector<Expr>& nodes);
+/**
+ * @copybrief tanh
+ * Convience function to put parameter pack @p Args into a Expr vector
+ */
template <typename... Args>
Expr tanh(Args... args) {
std::vector<Expr> nodes{args...};
return tanh(nodes);
}
+/**
+ * Rectified Linear Unit (ReLU).
+ * Computes the ReLU activation for the Expr
+ * @see ReLUNodeOp
+ */
Expr relu(Expr a);
-Expr relu(const std::vector<Expr>&);
+/**
+ * @copybrief relu
+ * @warning not implemented for @p nodes of size > 1
+ * @returns relu(nodes[0])
+ */
+Expr relu(const std::vector<Expr>& nodes);
+
+/**
+ * Leaky ReLU (LeakyReLU).
+ * Computes the <a href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)#LeakyReLU">
+ * LeakyReLU</a> activation for the expression
+ * Activation function:
+ * @f[
+ * \operatorname{leakyrelu}(x) =
+ * \begin{cases}
+ * 0.01x & \text{if } x \leq 0 \\
+ * x & \text{if } x > 0
+ * \end{cases}
+ * @f]
+ * @see PReLUNodeOp
+ */
Expr leakyrelu(Expr a);
-Expr leakyrelu(const std::vector<Expr>&);
+/**
+ * @copybrief leakyrelu
+ * @warning not implemented
+ */
+Expr leakyrelu(const std::vector<Expr>& nodes);
+
+/**
+ * Parametric Rectified Linear Unit (PReLU).
+ * Computes the <a href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)#Parametric_ReLU">
+ * Parametric ReLU</a> activation for the expression
+ * @f[
+ * \operatorname{leakyrelu}(x) =
+ * \begin{cases}
+ * \alpha x & \text{if } x \leq 0 \\
+ * x & \text{if } x > 0
+ * \end{cases}
+ * @f]
+ * @see PReLUNodeOp
+ * @note @p alpha is **not** trainable.
+ */
Expr prelu(Expr a, float alpha = 0.01);
-Expr prelu(const std::vector<Expr>&, float alpha = 0.01);
+/**
+ * @copybrief prelu
+ * @warning not implemented
+ */
+Expr prelu(const std::vector<Expr>&, float alpha = 0.01);
+///@}
+
+/**
+ * @addtogroup graph_ops_mathematical Mathematical
+ * Performs mathematical operations in the expression graph.
+ * @ingroup graph_ops
+ * @{
+ */
+
+///@name Exponentiation and Logarithmic functions
+///@{
+/**
+ * Natural logarithm.
+ * Computes the element-wise natural logarithm of the expression: @f$ \log(a) @f$
+ * @see LogNodeOp
+ */
Expr log(Expr a);
-Expr exp(Expr a);
+/**
+ * Natural exponentiation.
+ * Computes the element-wise natural logarithm of the expression: @f$ e^a @f$
+ * @see ExpNodeOp
+ */
+Expr exp(Expr a);
+///@}
+
+///@name Trigonometric functions
+///@{
+/**
+* Sine. Computes the element-wise sine of the expression: @f$ \sin(a) @f$.
+* @see SinNodeOp
+*/
Expr sin(Expr a);
-Expr cos(Expr a);
-Expr tan(Expr a);
-Expr clip(Expr a, float c);
+/**
+* Cosine. Computes the element-wise cosine of the expression: @f$ \cos(a) @f$.
+* @see CosNodeOp
+*/
+Expr cos(Expr a);
+/**
+* Tangent. Computes the element-wise tangent of the expression: @f$ \tan(a) @f$.
+* @see TanNodeOp
+*/
+Expr tan(Expr a);
+///@}
+///@}
+
+/**
+ * @addtogroup graph_ops_arithmetic Arithmetic
+ * Performs arithmetic in the expression graph.
+ * @ingroup graph_ops
+ * @{
+ */
+
+/**
+ * Returns @f$ -a @f$.
+ * @see NegNodeOp for implementation.
+ */
+///@{
Expr operator-(Expr a);
+///@}
/*********************************************************/
-Expr operator+(Expr a, Expr b);
-Expr operator+(float a, Expr b);
-Expr operator+(Expr a, float b);
-
-Expr operator-(Expr a, Expr b);
-Expr operator-(float a, Expr b);
-Expr operator-(Expr a, float b);
-
-Expr operator*(Expr a, Expr b);
-Expr operator*(float a, Expr b);
-Expr operator*(Expr a, float b);
+/**
+ * @name Addition
+ * Performs @f$ a + b @f$ in the expression graph.
+*/
+///@{
+Expr operator+(Expr a, Expr b); ///< @see Implementation in PlusNodeOp
+Expr operator+(float a, Expr b); ///< @see Implementation in ScalarAddNodeOp
+Expr operator+(Expr a, float b); ///< @see Implementation in ScalarAddNodeOp
+///@}
+
+/**
+ * @name Subtraction
+ * Performs @f$ a - b @f$ in the expression graph.
+ */
+///@{
+Expr operator-(Expr a, Expr b); ///< @see Implementation in MinusNodeOp
+Expr operator-(float a, Expr b); ///< @see Implementation in ScalarAddNodeOp
+Expr operator-(Expr a, float b); ///< @see Implementation in ScalarAddNodeOp
+///@}
+
+/**
+ * @name Multiplication
+ * Performs @f$ a * b @f$ in the expression graph.
+ */
+///@{
+Expr operator*(Expr a, Expr b); ///< @see Implementation in MultNodeOp
+Expr operator*(float a, Expr b); ///< @see Implementation in ScalarMultNodeOp
+Expr operator*(Expr a, float b); ///< @see Implementation in ScalarMultNodeOp
+///@}
+
+/**
+ * @name Division
+ * Performs @f$ a / b @f$ in the expression graph.
+ */
+///@{
+Expr operator/(Expr a, Expr b); ///< @see Implementation in DivNodeOp
+Expr operator/(float a, Expr b); ///< Promotes @p a to Expression<ConstantNode> and uses operator/(Expr a, Expr b).
+ ///< @todo efficient version of this without ExpressionGraph::constant
+Expr operator/(Expr a, float b); ///< Implementation via @f$ a * \frac{1}{b} @f$.
+///@}
+
+///@}
+
+/**
+ * Computes the square root of an expression.
+ * Evaluates @f$\sqrt{a + \mathrm{eps}} @f$ element-wise on the expression
+ * @param a Expression to square root
+ * @param eps Optional positive epsilon to avoid domain errors for small values in @p a
+ * @ingroup graph_ops_mathematical
+ */
+Expr sqrt(Expr a, float eps = 0.f);
-Expr operator/(Expr a, Expr b);
-Expr operator/(float a, Expr b);
-Expr operator/(Expr a, float b);
+/**
+ * Computes the square of an expression.
+ * Evaluates @f$a^2 @f$ element-wise on the expression
+ * @param a Expression to square
+ * @ingroup graph_ops_mathematical
+ */
+Expr square(Expr a);
+/**
+ * Calculate the element-wise abolute value of an expression.
+ * Returns the value of @f$ |a| @f$ element-wise for the expression @p a.
+ * @see AbsNodeOp.
+ * @ingroup graph_ops_mathematical
+ */
Expr abs(Expr a);
// Expr pow(Expr a, Expr b);
// Expr pow(float a, Expr b);
// Expr pow(Expr a, float b);
+/**
+ * Computes @f$\log(e^a + e^b)@f$.
+ */
Expr logaddexp(Expr a, Expr b);
-// Note: Following numpy, minimum() is element-wise, while min() is along an axis in both Numpy and PyTorch.
+
+///@addtogroup graph_ops_mathematical
+///@{
+/**
+ * @name Element-wise min/max
+ * Performs an element-wise min max comparison between expressions.
+ * @see min, max for axis level operations
+ * @see MinimumNodeOp, MaximumNodeOp
+ * @todo implement version without ExpressionGraph::constant.
+ */
+///@{
+
+/**
+ * Computes the element-wise maximum of its inputs.
+ */
Expr maximum(Expr a, Expr b);
+
+/**
+ * @copybrief maximum
+ * Promotes float input to a @ref ExpressionGraph::constant.
+ */
Expr maximum(float a, Expr b);
+
+/**
+ * @copybrief maximum
+ * Promotes float input to a @ref ExpressionGraph::constant.
+ */
Expr maximum(Expr a, float b);
+/**
+ * Computes the element-wise minimum its inputs.
+ */
Expr minimum(Expr a, Expr b);
+
+/**
+ * @copybrief minimum
+ * Promotes float input to a @ref ExpressionGraph::constant.
+ */
Expr minimum(float a, Expr b);
-Expr minimum(Expr a, float b);
-// Pair of expressions, currently used for topk nodes only
+/**
+ * @copybrief minimum
+ * Promotes float input to a @ref ExpressionGraph::constant.
+ */
+Expr minimum(Expr a, float b);
+///@}
+///@}
+
+/**
+ * Pair of expressions.
+ * Currently only used for topk-like nodes
+ * @see topk(), argmin(), argmax()
+ */
typedef std::tuple<Expr, Expr> Expr2;
-// Marian pseudo-operator to access elements of a tuple, just the same as std::get<N>(tuple)
+/**
+ * Pseudo-operator to access elements of a tuple.
+ * Provides the same utility as @c std::get<I>(tuple)
+ * @see Expr2
+ */
template <int I>
Expr get(Expr2 tuple) { return std::get<I>(tuple); }
-// PyTorch-like topk operator, returns a 2-tuple of nodes, first node is top-k values
-// second node is indices of these values according to given axis. Order is descending
-// by default, outputs are ordered.
+/**
+ * Returns top k elements of an expression along an axis.
+ * Return a 2-tuple (values, indices) of the @p k largest, or smallest, elements of an expression
+ * along a specified @p axis.
+ * The output is ordered according to the value of @p descending.
+ * @param a Expression to search
+ * @param k Number of elements to return
+ * @param axis Axis to along which to operate
+ * @param descending If true, consider the largest elements. Otherwise, consider the smallest elements.
+ * Default is true.
+ * @returns An ordered 2-tuple of Expressions
+ */
Expr2 topk(Expr a, int k, int axis, bool descending = true);
-// Convenience operator that maps to topk(a, k=1, axis, descending=true)
+/**
+ * Returns largest elements of an expression along an axis.
+ * Return a 2-tuple (values, indices) of largest elements of an expression
+ * along a specified @p axis.
+ * @see topk(a, k=1, axis, descending=true)
+ */
Expr2 argmax(Expr a, int axis);
-// Convenience operator that maps to topk(a, k=1, axis, descending=false)
+/**
+ * Returns smallest elements of an expression along an axis.
+ * Return a 2-tuple (values, indices) of smallest elements of an expression
+ * along a specified @p axis.
+ * @see topk(a, k=1, axis, descending=false)
+ */
Expr2 argmin(Expr a, int axis);
-// Note: We cannot overload the relational operators, as they also mean something for Expr itself.
-// Note: These names follow PyTorch convention.
-Expr lt(Expr a, Expr b);
-Expr eq(Expr a, Expr b);
-Expr gt(Expr a, Expr b);
-Expr ge(Expr a, Expr b);
-Expr ne(Expr a, Expr b);
-Expr le(Expr a, Expr b);
-
-Expr lt(float a, Expr b);
-Expr eq(float a, Expr b);
-Expr gt(float a, Expr b);
-Expr ge(float a, Expr b);
-Expr ne(float a, Expr b);
-Expr le(float a, Expr b);
-
-Expr lt(Expr a, float b);
-Expr eq(Expr a, float b);
-Expr gt(Expr a, float b);
-Expr ge(Expr a, float b);
-Expr ne(Expr a, float b);
-Expr le(Expr a, float b);
+/**
+ * @addtogroup graph_ops_cmp Comparison
+ * Performs comparision operations in the expression graph.
+ * @ingroup graph_ops
+ * Uses CmpNodeOp to perform comparison of graph expression e.g. @f$ a < b @f$.
+ * @note
+ * We cannot overload the relational operators, as they also mean something for Expr itself.
+ * @par
+ * @note
+ * These names follow <a href="https://pytorch.org/docs">PyTorch</a> convention.
+ * @{
+ */
+
+/**
+ * @name Expr-Expr comparisons
+ */
+///@{
+Expr lt(Expr a, Expr b); ///< @f$ a < b @f$
+Expr eq(Expr a, Expr b); ///< @f$ a \equiv b @f$
+Expr gt(Expr a, Expr b); ///< @f$ a > b @f$
+Expr ge(Expr a, Expr b); ///< @f$ a \geq b @f$
+Expr ne(Expr a, Expr b); ///< @f$ a \neq b @f$
+Expr le(Expr a, Expr b); ///< @f$ a \leq b @f$
+///@}
+
+/**
+ * @name Float-Expr comparisons
+ * Floats are promoted to a @ref ExpressionGraph::constant and use the Expr-Expr methods
+ */
+///@{
+Expr lt(float a, Expr b); ///< @f$ a < b @f$
+Expr eq(float a, Expr b); ///< @f$ a \equiv b @f$
+Expr gt(float a, Expr b); ///< @f$ a > b @f$
+Expr ge(float a, Expr b); ///< @f$ a \geq b @f$
+Expr ne(float a, Expr b); ///< @f$ a \neq b @f$
+Expr le(float a, Expr b); ///< @f$ a \leq b @f$
+
+Expr lt(Expr a, float b); ///< @f$ a < b @f$
+Expr eq(Expr a, float b); ///< @f$ a \equiv b @f$
+Expr gt(Expr a, float b); ///< @f$ a > b @f$
+Expr ge(Expr a, float b); ///< @f$ a \geq b @f$
+Expr ne(Expr a, float b); ///< @f$ a \neq b @f$
+Expr le(Expr a, float b); ///< @f$ a \leq b @f$
+///@}
+
+///@}
+
+/**
+ * Computes the dot product of @p a and @p b.
+ * Computes @f$ C = \alpha \operatorname{op}(A) \cdot \operatorname{op}(B) @f$,
+ * where @f$ \operatorname{op}(A) = A @f$ if @p transA is @c false, and
+ * @f$ \operatorname{op}(A) = A^\top @f$ if @c true. The @f$\alpha@f$ parameter
+ * is set by @p scalar.
+ */
Expr dot(Expr a,
Expr b,
bool transA = false,
bool transB = false,
float scalar = 1.f);
+/**
+ * Computes the batch dot product of @p a and @p b.
+ * @copydetails dot
+ */
Expr bdot(Expr a,
Expr b,
bool transA = false,
bool transB = false,
float scalar = 1.f);
+/**
+ * Performs an affine transformation.
+ * Computes
+ * @f$ C \leftarrow \alpha \operatorname{op}(A) \cdot \operatorname{op}(B) + C@f$,
+ * where @f$ \operatorname{op}(A) = A @f$ if @p transA is @c false, and
+ * @f$ \operatorname{op}(A) = A^\top @f$ if @c true. The @f$\alpha@f$ parameter
+ * is set by @p scalar.
+ */
Expr affine(Expr a,
Expr b,
Expr c,
@@ -149,47 +493,195 @@ Expr affine(Expr a,
bool transB = false,
float scalar = 1.f);
+/**
+ * Computes the dot product of CSR-tensor @p A with @p B.
+ */
Expr csr_dot(const Shape& A_shape, Expr Avalues, Expr Aindices, Expr Aoffsets, Expr B, bool transA = false);
+
+/**
+ * Computes the dot product of @p A with CSR-tensor @p B.
+ */
Expr dot_csr(Expr A, const Shape& B_shape, Expr B_values, Expr B_indices, Expr B_offsets, bool transB = false);
+/**
+ * @addtogroup graph_ops_manipulation Manipulation Operations
+ * Operators that manipulate expressions.
+ * @ingroup graph_ops
+ * @{
+ */
+
+/**
+ * Returns the transpose of an expression.
+ * Swaps the last two axes of an expression.
+ * @see TransposeNodeOp
+ */
Expr transpose(Expr a);
+
+/**
+ * Returns the transpose of an expression.
+ * Permutes the axes of an expression to resemble @p axes. Axis @c i of the returned
+ * expression corresponds to @c axes[i] of the input @p a.
+ * @param a Expression to manipulate
+ * @param axes Desired permutation of axes
+ * @see TransposeNodeOp
+ */
Expr transpose(Expr a, const std::vector<int>& axes);
+/**
+ * Swap two axes of an expression.
+ * Swaps two axes of an expression via reshaping, if possible, or transpose.
+ * @param x Expression to manipulate
+ * @param axis1 Axis to be swapped
+ * @param axis2 Axis to swap with
+ * @returns Expression with the axes @p axis1 and @p axis2 interchanged
+ * @see reshape() and transpose()
+ */
Expr swapAxes(Expr x, int axis1, int axis2);
+/**
+ * Cast an expression to a specified type.
+ * @param a Expression to cast
+ * @param type Desired type
+ * @returns Expression with data cast to @p type
+ */
Expr cast(Expr a, Type type = Type::float32);
+/**
+ * Join a list of expressions along an axis.
+ * Concatenates the elements of the expressions in @p concats along the axis @p ax.
+ * By default, @p ax operates on the first axis.
+ */
Expr concatenate(const std::vector<Expr>& concats, int ax = 0);
+
+/**
+ * Repeat elements of an expression.
+ * Repeats the elements of @p a along the @p ax axis @p repeats times.
+ * By default, @p ax operates on the first axis.
+ * @see concatenate()
+ */
Expr repeat(Expr a, size_t repeats, int ax = 0);
+/**
+ * Reshape expression to a given shape.
+ * @param a The expression to be reshaped
+ * @param shape The new shape
+ * @returns An expression with shape @p shape.
+ */
Expr reshape(Expr a, Shape shape);
-Expr clipGradient(Expr a, float clipValue);
+/**
+ * Clip the values in an expression.
+ * Clips the values of the Expr @p a to be within the interval @f$ [-c, c] @f$.
+ * @param a Expr to clip
+ * @param c Threshold to clip at
+ * @see ClipNodeOp
+ */
+Expr clip(Expr a, float c);
+/**
+ * Clip the gradient in an expression.
+ * Clips the gradient of the Expr @p a to be within the interval @f$ [-c, c] @f$
+ * @see clip for the equivalent function which clips values
+ * @see ClipGradientNodeOp
+ */
+Expr clipGradient(Expr a, float c);
+
+/**
+ * Converts input to an expression with a least one dimension.
+ * @see atleast_nd()
+ */
Expr atleast_1d(Expr a);
+
+/**
+ * Converts input to an expression with a least two dimensions.
+ * @see atleast_nd()
+ */
Expr atleast_2d(Expr a);
+
+/**
+ * Converts input to an expression with a least three dimensions.
+ * @see atleast_nd()
+ */
Expr atleast_3d(Expr a);
+
+/**
+ * Converts input to an expression with a least four dimensions.
+ * @see atleast_nd()
+ */
Expr atleast_4d(Expr a);
-Expr atleast_nd(Expr a, size_t dims);
-// create a constant of shape a->shape() and initialize with init
-// @TODO: add a && version, to avoid a ref count. NodeInitializers are typically temps.
-// @TODO: and/or make this a template on init
+/**
+ * Converts input to an expression with a least n-dimension dimensions.
+ * @param a Expression
+ * @param dims Required number of dimensions
+ * @returns An expression with at least n-dimensions
+ */
+Expr atleast_nd(Expr a, size_t dims);
+///@}
+
+/**
+ * @addtogroup graph_ops_creation Creation Operations
+ * Operators that create expressions.
+ * @ingroup graph_ops
+ * @{
+ */
+
+/**
+ * Create a constant of with the shape of @p a and initialize with @p init.
+ * @todo add a && version, to avoid a ref count. NodeInitializers are typically temps.
+ * and/or make this a template on init
+ */
static inline Expr constant_like(Expr a, const Ptr<inits::NodeInitializer>& init) {
return a->graph()->constant(a->shape(), init, a->value_type());
}
-// short-cut to init from std::vector, since we do this so often
+/**
+ * Convenience function to initialize from a vector.
+ */
template<typename ElementType>
Expr constant_like(Expr a, const std::vector<ElementType>& v) { return constant_like(a, inits::fromVector(std::move(v))); }
+
+/**
+ * Convenience function to initialize from a vector.
+ */
template<typename ElementType>
Expr constant_like(Expr a, std::vector<ElementType>&& v) { return constant_like(a, inits::fromVector(v)); }
+///@}
+
+/**
+ * @addtogroup graph_ops_manipulation
+ * @{
+ */
+
+/**
+ * Flattens an expression to one dimension.
+ * @see ReshapeNodeOp
+ */
Expr flatten(Expr a);
+
+/**
+ * Flattens an expression to two-dimensions preserving the last dimension.
+ * @see ReshapeNodeOp
+ */
Expr flatten_2d(Expr a);
+///@}
+
+/**
+ * Wraps an expression as a non-trainable expression.
+ */
Expr stopGradient(Expr a);
+/**
+ * Gathers elements along an axis.
+ * @param a The input expression
+ * @param axis The axis along which to index
+ * @param indices The indices to be gathered
+ * @returns Gathered expression with the same shape as @p indices
+ * @note @p a and @p indices must have the same rank
+ * @note The non-target axes of @p a and @p indicies must have the same size, or be broadcastable.
+ */
Expr gather(Expr a, int axis, Expr indices);
#if 0
@@ -216,70 +708,221 @@ Expr scatter(Expr a, int axis, Expr indices, Expr b);
#endif
-// Warning: Don't try to pass a scalar literal 0 as indices; it will compile but pass nullptr...
+/**
+ * Returns a new expression containing the @p indicies of expression @p a
+ * along the specified @p axis.
+ * @warning Do not pass a scalar literal 0 as @p indices;
+ * it will compile but pass a nullptr.
+ */
Expr index_select(Expr a, int axis, Expr indices);
-// convenience wrappers for index_select()
+/**
+ * @copybrief index_select
+ * Convenience wrapper that promotes a vector of @ref IndexType to an Expr
+ */
Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices);
+
+/**
+ * Performs an @ref index_select() along the first axis.
+ * @see index_select()
+ */
static inline Expr rows(Expr a, Expr indices) {
return index_select(a, 0, indices);
}
+
+/**
+ * @copybrief rows
+ * Convenience wrapper that promotes a vector of @ref IndexType to an Expr
+ */
static inline Expr rows(Expr a, const std::vector<IndexType>& indexVector) {
return index_select(a, 0, indexVector);
}
+
+/**
+ * Performs an @ref index_select() along the last axis.
+ * @see index_select()
+ */
static inline Expr cols(Expr a, Expr indices) {
return index_select(a, -1, indices);
}
+
+/**
+ * @copybrief cols
+ * Convenience wrapper that promotes a vector of @ref IndexType to an Expr
+ */
static inline Expr cols(Expr a, const std::vector<IndexType>& indexVector) {
return index_select(a, -1, indexVector);
}
+/**
+ * Returns the @p slice of the expression @p a along @p axis.
+ * @see Slice
+ */
Expr slice(Expr a, int axis, Slice slice);
-// convenience wrappers for slice()
-static inline Expr slice(Expr a, int axis, int index) { // single index @NOTE: This was formerlly called step()
+/**
+ * @copybrief slice
+ * Convenience wrapper for slice() that returns the slice along @p axis
+ * from @p index to @p index+1
+ */
+static inline Expr slice(Expr a, int axis, int index) {
return slice(a, axis, Slice(index));
}
-static inline Expr narrow(Expr a, int axis, size_t start, size_t length) { // PyTorch name
+/**
+ * @copybrief slice
+ * Convenience wrapper for slice() that returns the slice along @p axis
+ * from @p index to @p index + @p length
+ * @note this is named after an equivalent function in PyTorch
+ */
+static inline Expr narrow(Expr a, int axis, size_t start, size_t length) {
return slice(a, axis, Slice((int)start, (int)(start + length)));
}
/*********************************************************/
+///@addtogroup graph_ops_mathematical
+///@{
+///@name Aggregations
+///@{
+
+/**
+ * Compute the sum along the specified axis.
+ * @param ax Axis along which to compute the sum. Default is @c 0.
+ * @see ReduceNodeOp
+ */
Expr sum(Expr a, int ax = 0);
+
+/**
+ * Compute the arithmetic mean along the specified axis.
+ * @param ax Axis along which to compute the mean. Default is @c 0.
+ * @see ReduceNodeOp
+ */
Expr mean(Expr a, int ax = 0);
+
+/**
+ * Compute the standard deviation along the specified axis.
+ * @param ax Axis along which to compute the standard deviation
+ * @see ReduceNodeOp
+ */
Expr std(Expr a, int ax);
+
+/**
+ * Compute the variance along the specified axis.
+ * @param ax Axis along which to compute the variance
+ * @see ReduceNodeOp
+ */
Expr var(Expr a, int ax);
+
+/**
+ * Compute the maximum along the specified axis.
+ * @param ax Axis along which to find the maximum
+ * @see ReduceNodeOp
+ */
Expr max(Expr a, int ax);
+
+/**
+ * Compute the minimum along the specified axis.
+ * @param ax Axis along which to find the minimum
+ * @see ReduceNodeOp
+ */
Expr min(Expr a, int ax);
+
+/**
+ * Compute the product along the specified axis.
+ * @param ax Axis along which to compute the product
+ * @see ReduceNodeOp
+ */
Expr prod(Expr a, int ax);
+
+///@}
+///@}
+
+/**
+ * Compute the log of the sum of exponentials along the specified axis.
+ * @param ax Axis along which to perform the operation
+ * @see ReduceNodeOp
+ */
Expr logsumexp(Expr a, int ax);
+/**
+ * Computes the softmax fuction along the given axis.
+ * Applies the softmax function
+ * @f[
+ \operatorname{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
+ * @f]
+ * @see SoftmaxNodeOp
+ */
Expr softmax(Expr x, int axis = -1);
-// @TODO: maybe get rid of this entirely to not obfuscate, what's going on inside.
-// @TODO: switch to log-masking everywhere?
+/**
+ * @copybrief softmax
+ * Applies the softmax function over the unmasked values.
+ * @see SoftmaxNodeOp
+ */
Expr softmax(Expr a, Expr zeroOneMask, int axis = -1);
+/**
+ * Computes the log of the softmax function along the last axis.
+ * Applies @f$ \log(\operatorname{softmax}(x)) @f$.
+ * @see LogSoftmaxNodeOp
+ */
Expr logsoftmax(Expr a);
+/**
+ * Computes the cross-entropy loss.
+ * @param labelSmoothingAlpha The amount of label smoothing @f$\alpha \in [0,1]@f$.
+ * Default is no smoothing, @f$\alpha = 0 @f$.
+ * @see CrossEntropyNodeOp
+ */
Expr cross_entropy(Expr a, Expr b, float labelSmoothingAlpha = 0.f, Type outputType = Type::float32);
+/**
+ * Computes the unlikelihood loss.
+ * Computes the <a href="https://arxiv.org/abs/1908.04319">unlikelihood</a> loss
+ * @f$ -\log \sum (1 - \operatorname{softmax}(x)) @f$
+ */
Expr unlikelihood(Expr a, Expr b);
+/**
+ * Computes the scalar product along the specified axis.
+ * @see ScalarProductNodeOp
+ */
Expr scalar_product(Expr a, Expr b, int ax = 0);
+/**
+ * Compute the weighted arithmetic mean along the specified axis.
+ */
Expr weighted_average(Expr in, Expr weights, int ax = 0);
-Expr sqrt(Expr a, float eps = 0.f);
-Expr square(Expr a);
+/**
+ * Applies layer normalization over the last dimension.
+ * @f[
+ \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \mathrm{eps}}} \times \gamma + \beta
+ * @f]
+ * @see LayerNormalizationOp
+ */
Expr layerNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);
+/**
+ * Highway transformation.
+ * Computes the highway tranform on @p y and @p x as gated by @p t:
+ * @f$ \operatorname{sigmoid}(t) y + (1-\operatorname{sigmoid}(t)) x @f$
+ * @see HighwayNodeOp
+ */
Expr highway(Expr y, Expr x, Expr t);
+
+/** @copybrief highway
+ * Generates a highway network for @p x with a @ref relu activated layer and
+ * @ref sigmoid activated layer for gating.
+ * @see mlp::dense()
+ */
Expr highway(const std::string prefix, Expr x);
+/**
+ * Performs dropout using a given mask.
+ */
static inline Expr dropout(Expr x, Expr mask) {
if (mask)
return x * mask;
@@ -287,6 +930,9 @@ static inline Expr dropout(Expr x, Expr mask) {
return x;
}
+/**
+ * Performs dropout with a given probably and explicit shape.
+ */
static inline Expr dropout(Expr x, float dropProb, Shape shape) {
if(dropProb == 0)
return x;
@@ -295,18 +941,35 @@ static inline Expr dropout(Expr x, float dropProb, Shape shape) {
return dropout(x, mask);
}
+/**
+ * Performs dropout with a given probably.
+ */
static inline Expr dropout(Expr x, float dropProb) {
if(dropProb == 0)
return x;
return dropout(x, dropProb, x->shape());
}
-Expr shift(Expr, Shape, float padValue = 0);
+/**
+ * Shifts the elements of an expression by a per-axis offset @p shift
+ * padded with @p padValue.
+ */
+Expr shift(Expr x, Shape shift, float padValue = 0);
+/**
+ * Reindexes an expression from internal to cuDNN format.
+ */
Expr convert2cudnnFormat(Expr x);
+/**
+ * Reindexes an expression from cuDNN to internal format.
+ */
Expr convertFromcudnnFormat(Expr x);
+/**
+ * Performs average pooling.
+ * @see PoolingOp
+ */
Expr avg_pooling(Expr x,
int height,
int width,
@@ -315,6 +978,10 @@ Expr avg_pooling(Expr x,
int strideHeight = 1,
int strideWidth = 1);
+/**
+ * Performs max pooling.
+ * @see PoolingOp
+ */
Expr max_pooling(Expr x,
int height,
int width,
@@ -323,5 +990,11 @@ Expr max_pooling(Expr x,
int strideHeight = 1,
int strideWidth = 1);
+/**
+ * Pooling operation with masking.
+ * @warning not implemented
+ */
Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven = false);
+
+///@}
} // namespace marian
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index c565e035..5d3e2cc9 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -646,7 +646,7 @@ struct CosNodeOp : public UnaryNodeOp {
return {NodeOp(Add(_1 * -sin(_2), child(0)->grad(), adj_, child(0)->val()))};
}
- const std::string type() override { return "sin"; }
+ const std::string type() override { return "cos"; }
};
struct TanNodeOp : public UnaryNodeOp {
@@ -662,7 +662,7 @@ struct TanNodeOp : public UnaryNodeOp {
return {NodeOp(Add(_1 / sqr(cos(_2)), child(0)->grad(), adj_, child(0)->val()))};
}
- const std::string type() override { return "sin"; }
+ const std::string type() override { return "tan"; }
};
struct SqrtNodeOp : public UnaryNodeOp {
diff --git a/src/tensors/gpu/add.inc b/src/tensors/gpu/add.inc
index 903ee3ba..6d4c4a95 100755
--- a/src/tensors/gpu/add.inc
+++ b/src/tensors/gpu/add.inc
@@ -37,3 +37,5 @@ template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functio
template void marian::gpu::Add<marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase> >(marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::gpu::Aggregate<marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, IntrusivePtr<marian::TensorBase> >(marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase>);
+template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, marian::Tensor, marian::Tensor >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, marian::Tensor, marian::Tensor >(marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, float, marian::Tensor, marian::Tensor, marian::Tensor);
diff --git a/src/tensors/gpu/add_all.inc b/src/tensors/gpu/add_all.inc
index 29a3a5d6..a3f5c27d 100644
--- a/src/tensors/gpu/add_all.inc
+++ b/src/tensors/gpu/add_all.inc
@@ -37,6 +37,9 @@ template void marian::AggregateAll<float, float, marian::functional::BinaryFunct
template void marian::AggregateAll<float, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<float,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<float, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
+template void marian::AggregateAll<float, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void marian::AggregateAll<float, float, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
+
#if COMPILE_FP16
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
@@ -75,4 +78,6 @@ template void marian::AggregateAll<__half, float, marian::functional::BinaryFunc
template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<__half,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
+template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
#endif
diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc
index 0eb75625..ade8b489 100755
--- a/src/tensors/gpu/element.inc
+++ b/src/tensors/gpu/element.inc
@@ -68,6 +68,8 @@ template void marian::gpu::Element<marian::functional::Assign<marian::functional
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<2> > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<2> > > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Clip, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > >, marian::functional::Capture> > > >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Clip, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > >, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase>);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > >, marian::Tensor >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > >, marian::Tensor, marian::Tensor);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Tan, marian::functional::Assignee<2> > >, marian::Tensor >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Tan, marian::functional::Assignee<2> > >, marian::Tensor, marian::Tensor);
// How to add new specializations:
// When you use a new specialization, it will cause a link error of this form (example):
// .../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element<marian::functional::Assign< ... > ( ... )'