Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-03-01 02:52:40 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-03-01 02:52:40 +0300
commitfd208d841b0ec4fdbfdc81df601e86986528a051 (patch)
treeefe85c4e634fa1885583e2d4192a8d7252a0e945 /src/graph/node_operators_unary.h
parent71f911940c91d2fd5c337ddf4d8e88108d5ed822 (diff)
simplified interface, get slowly rid of keywords
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h96
1 files changed, 41 insertions, 55 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 0a76471b..e857e790 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -14,9 +14,11 @@
namespace marian {
struct UnaryNodeOp : public NaryNodeOp {
- template <typename... Args>
- UnaryNodeOp(Expr a, Args... args)
- : NaryNodeOp({a}, keywords::shape = a->shape(), args...) {}
+ UnaryNodeOp(Expr a, Shape shape)
+ : NaryNodeOp({a}, shape) {}
+
+ UnaryNodeOp(Expr a)
+ : NaryNodeOp({a}, a->shape()) {}
const std::string color() { return "yellow"; }
};
@@ -26,9 +28,9 @@ private:
float scalar_{0};
public:
- template <typename... Args>
- ScalarAddNodeOp(Expr a, float scalar, Args... args)
- : UnaryNodeOp(a, args...), scalar_{scalar} {}
+ ScalarAddNodeOp(Expr a, float scalar)
+ : UnaryNodeOp(a),
+ scalar_{scalar} {}
NodeOps forwardOps() {
using namespace functional;
@@ -67,9 +69,8 @@ private:
float scalar_{0};
public:
- template <typename... Args>
- ScalarMultNodeOp(Expr a, float scalar, Args... args)
- : UnaryNodeOp(a, args...), scalar_{scalar} {}
+ ScalarMultNodeOp(Expr a, float scalar)
+ : UnaryNodeOp(a), scalar_{scalar} {}
NodeOps forwardOps() {
using namespace functional;
@@ -104,8 +105,7 @@ public:
};
struct LogitNodeOp : public UnaryNodeOp {
- template <typename... Args>
- LogitNodeOp(Args... args) : UnaryNodeOp(args...) {}
+ LogitNodeOp(Expr a) : UnaryNodeOp(a) {}
NodeOps forwardOps() {
using namespace functional;
@@ -164,7 +164,7 @@ struct LogitNodeOp : public UnaryNodeOp {
struct TanhNodeOp : public NaryNodeOp {
TanhNodeOp(const std::vector<Expr>& nodes)
- : NaryNodeOp(nodes, keywords::shape = newShape(nodes)) {}
+ : NaryNodeOp(nodes, newShape(nodes)) {}
Shape newShape(const std::vector<Expr>& nodes) {
return Shape::broadcast(nodes);
@@ -214,8 +214,7 @@ struct TanhNodeOp : public NaryNodeOp {
struct ReLUNodeOp : public UnaryNodeOp {
- template <typename... Args>
- ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {}
+ ReLUNodeOp(Expr a) : UnaryNodeOp(a) {}
NodeOps forwardOps() {
// f(x) = max(0, x)
@@ -265,9 +264,8 @@ struct ReLUNodeOp : public UnaryNodeOp {
* \f]
*/
struct PReLUNodeOp : public UnaryNodeOp {
- template <typename... Args>
- PReLUNodeOp(float alpha, Args... args)
- : UnaryNodeOp(args...), alpha_(alpha) {}
+ PReLUNodeOp(float alpha, Expr a)
+ : UnaryNodeOp(a), alpha_(alpha) {}
NodeOps forwardOps() {
using namespace functional;
@@ -316,8 +314,7 @@ private:
*
*/
struct SwishNodeOp : public UnaryNodeOp {
- template <typename... Args>
- SwishNodeOp(Args... args) : UnaryNodeOp(args...) {}
+ SwishNodeOp(Expr a) : UnaryNodeOp(a) {}
NodeOps forwardOps() {
using namespace functional;
@@ -338,14 +335,12 @@ struct SwishNodeOp : public UnaryNodeOp {
const std::string type() { return "swish"; }
};
-struct SoftmaxNodeOp : public NaryNodeOp {
- template <typename... Args>
- SoftmaxNodeOp(Expr a, Args... args)
- : NaryNodeOp(a, args...), mask_(nullptr) {}
+struct SoftmaxNodeOp : public UnaryNodeOp {
+ SoftmaxNodeOp(Expr a)
+ : UnaryNodeOp(a), mask_(nullptr) {}
- template <typename... Args>
- SoftmaxNodeOp(Expr a, Expr mask, Args... args)
- : NaryNodeOp({a}, args...), mask_(mask) {}
+ SoftmaxNodeOp(Expr a, Expr mask)
+ : UnaryNodeOp(a), mask_(mask) {}
Expr mask_;
@@ -396,8 +391,7 @@ struct SoftmaxNodeOp : public NaryNodeOp {
};
struct LogSoftmaxNodeOp : public UnaryNodeOp {
- template <typename... Args>
- LogSoftmaxNodeOp(Args... args) : UnaryNodeOp(args...) {}
+ LogSoftmaxNodeOp(Expr a) : UnaryNodeOp(a) {}
NodeOps forwardOps() { return {NodeOp(LogSoftmax(val_, child(0)->val()))}; }
@@ -416,7 +410,7 @@ struct SumNodeOp : public UnaryNodeOp {
template <typename... Args>
SumNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {}
+ : UnaryNodeOp(a, newShape(a, args...)) {}
NodeOps forwardOps() {
using namespace functional;
@@ -465,7 +459,7 @@ struct MeanNodeOp : public UnaryNodeOp {
template <typename... Args>
MeanNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {}
+ : UnaryNodeOp(a, newShape(a, args...)) {}
NodeOps forwardOps() {
using namespace functional;
@@ -516,8 +510,7 @@ struct MeanNodeOp : public UnaryNodeOp {
};
struct LogNodeOp : public UnaryNodeOp {
- template <typename... Args>
- LogNodeOp(Args... args) : UnaryNodeOp(args...) {}
+ LogNodeOp(Expr a) : UnaryNodeOp(a) {}
NodeOps forwardOps() {
using namespace functional;
@@ -534,8 +527,7 @@ struct LogNodeOp : public UnaryNodeOp {
};
struct ExpNodeOp : public UnaryNodeOp {
- template <typename... Args>
- ExpNodeOp(Args... args) : UnaryNodeOp(args...) {}
+ ExpNodeOp(Expr a) : UnaryNodeOp(a) {}
NodeOps forwardOps() {
using namespace functional;
@@ -553,9 +545,8 @@ struct ExpNodeOp : public UnaryNodeOp {
struct SqrtNodeOp : public UnaryNodeOp {
float epsilon_;
- template <typename... Args>
- SqrtNodeOp(Expr a, float epsilon, Args... args)
- : UnaryNodeOp(a, args...), epsilon_(epsilon) {}
+ SqrtNodeOp(Expr a, float epsilon)
+ : UnaryNodeOp(a), epsilon_(epsilon) {}
NodeOps forwardOps() {
using namespace functional;
@@ -591,8 +582,7 @@ struct SqrtNodeOp : public UnaryNodeOp {
};
struct SquareNodeOp : public UnaryNodeOp {
- template <typename... Args>
- SquareNodeOp(Args... args) : UnaryNodeOp(args...) {}
+ SquareNodeOp(Expr a) : UnaryNodeOp(a) {}
NodeOps forwardOps() {
using namespace functional;
@@ -609,8 +599,7 @@ struct SquareNodeOp : public UnaryNodeOp {
};
struct NegNodeOp : public UnaryNodeOp {
- template <typename... Args>
- NegNodeOp(Args... args) : UnaryNodeOp(args...) {}
+ NegNodeOp(Expr a) : UnaryNodeOp(a) {}
NodeOps forwardOps() {
using namespace functional;
@@ -626,9 +615,8 @@ struct NegNodeOp : public UnaryNodeOp {
};
struct RowsNodeOp : public UnaryNodeOp {
- template <typename... Args>
- RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args)
- : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...),
+ RowsNodeOp(Expr a, const std::vector<size_t>& indeces)
+ : UnaryNodeOp(a, newShape(a, indeces)),
indices_(indeces) {}
NodeOps forwardOps() {
@@ -679,9 +667,8 @@ struct RowsNodeOp : public UnaryNodeOp {
};
struct ColsNodeOp : public UnaryNodeOp {
- template <typename... Args>
- ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args)
- : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...),
+ ColsNodeOp(Expr a, const std::vector<size_t>& indeces)
+ : UnaryNodeOp(a, newShape(a, indeces)),
indices_(indeces) {}
NodeOps forwardOps() {
@@ -731,7 +718,7 @@ struct ColsNodeOp : public UnaryNodeOp {
struct SelectNodeOp : public UnaryNodeOp {
SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)),
+ : UnaryNodeOp(a, newShape(a, axis, indeces)),
indices_(indeces) {}
NodeOps forwardOps() {
@@ -787,7 +774,7 @@ struct TransposeNodeOp : public UnaryNodeOp {
std::vector<int> axes_;
TransposeNodeOp(Expr a, const std::vector<int>& axes)
- : UnaryNodeOp(a, keywords::shape = newShape(a, axes)),
+ : UnaryNodeOp(a, newShape(a, axes)),
axes_{axes} {}
NodeOps forwardOps() {
@@ -844,8 +831,8 @@ private:
public:
template <typename... Args>
- ReshapeNodeOp(Expr a, Shape shape, Args... args)
- : UnaryNodeOp(a, keywords::shape = shape, args...), reshapee_(a) {
+ ReshapeNodeOp(Expr a, Shape shape)
+ : UnaryNodeOp(a, shape), reshapee_(a) {
Node::destroy_ = false;
}
@@ -909,7 +896,7 @@ private:
public:
StepNodeOp(Expr a, int step, int axis)
- : UnaryNodeOp(a, keywords::shape = newShape(a, axis)),
+ : UnaryNodeOp(a, newShape(a, axis)),
stepNode_(a),
step_(step) {
Node::destroy_ = false;
@@ -981,9 +968,8 @@ public:
};
struct ShiftNodeOp : public UnaryNodeOp {
- template <typename... Args>
- ShiftNodeOp(Expr a, Shape shift, Args... args)
- : UnaryNodeOp(a, keywords::shape = a->shape(), args...), shift_(shift) {}
+ ShiftNodeOp(Expr a, Shape shift)
+ : UnaryNodeOp(a, a->shape()), shift_(shift) {}
NodeOps forwardOps() {
return {NodeOp(Shift(val_, child(0)->val(), shift_, false))};