From fd208d841b0ec4fdbfdc81df601e86986528a051 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Wed, 28 Feb 2018 15:52:40 -0800 Subject: simplified interface, get slowly rid of keywords --- src/graph/node_operators_unary.h | 96 +++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 55 deletions(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 0a76471b..e857e790 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -14,9 +14,11 @@ namespace marian { struct UnaryNodeOp : public NaryNodeOp { - template - UnaryNodeOp(Expr a, Args... args) - : NaryNodeOp({a}, keywords::shape = a->shape(), args...) {} + UnaryNodeOp(Expr a, Shape shape) + : NaryNodeOp({a}, shape) {} + + UnaryNodeOp(Expr a) + : NaryNodeOp({a}, a->shape()) {} const std::string color() { return "yellow"; } }; @@ -26,9 +28,9 @@ private: float scalar_{0}; public: - template - ScalarAddNodeOp(Expr a, float scalar, Args... args) - : UnaryNodeOp(a, args...), scalar_{scalar} {} + ScalarAddNodeOp(Expr a, float scalar) + : UnaryNodeOp(a), + scalar_{scalar} {} NodeOps forwardOps() { using namespace functional; @@ -67,9 +69,8 @@ private: float scalar_{0}; public: - template - ScalarMultNodeOp(Expr a, float scalar, Args... args) - : UnaryNodeOp(a, args...), scalar_{scalar} {} + ScalarMultNodeOp(Expr a, float scalar) + : UnaryNodeOp(a), scalar_{scalar} {} NodeOps forwardOps() { using namespace functional; @@ -104,8 +105,7 @@ public: }; struct LogitNodeOp : public UnaryNodeOp { - template - LogitNodeOp(Args... args) : UnaryNodeOp(args...) {} + LogitNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -164,7 +164,7 @@ struct LogitNodeOp : public UnaryNodeOp { struct TanhNodeOp : public NaryNodeOp { TanhNodeOp(const std::vector& nodes) - : NaryNodeOp(nodes, keywords::shape = newShape(nodes)) {} + : NaryNodeOp(nodes, newShape(nodes)) {} Shape newShape(const std::vector& nodes) { return Shape::broadcast(nodes); @@ -214,8 +214,7 @@ struct TanhNodeOp : public NaryNodeOp { struct ReLUNodeOp : public UnaryNodeOp { - template - ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {} + ReLUNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { // f(x) = max(0, x) @@ -265,9 +264,8 @@ struct ReLUNodeOp : public UnaryNodeOp { * \f] */ struct PReLUNodeOp : public UnaryNodeOp { - template - PReLUNodeOp(float alpha, Args... args) - : UnaryNodeOp(args...), alpha_(alpha) {} + PReLUNodeOp(float alpha, Expr a) + : UnaryNodeOp(a), alpha_(alpha) {} NodeOps forwardOps() { using namespace functional; @@ -316,8 +314,7 @@ private: * */ struct SwishNodeOp : public UnaryNodeOp { - template - SwishNodeOp(Args... args) : UnaryNodeOp(args...) {} + SwishNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -338,14 +335,12 @@ struct SwishNodeOp : public UnaryNodeOp { const std::string type() { return "swish"; } }; -struct SoftmaxNodeOp : public NaryNodeOp { - template - SoftmaxNodeOp(Expr a, Args... args) - : NaryNodeOp(a, args...), mask_(nullptr) {} +struct SoftmaxNodeOp : public UnaryNodeOp { + SoftmaxNodeOp(Expr a) + : UnaryNodeOp(a), mask_(nullptr) {} - template - SoftmaxNodeOp(Expr a, Expr mask, Args... args) - : NaryNodeOp({a}, args...), mask_(mask) {} + SoftmaxNodeOp(Expr a, Expr mask) + : UnaryNodeOp(a), mask_(mask) {} Expr mask_; @@ -396,8 +391,7 @@ struct SoftmaxNodeOp : public NaryNodeOp { }; struct LogSoftmaxNodeOp : public UnaryNodeOp { - template - LogSoftmaxNodeOp(Args... args) : UnaryNodeOp(args...) {} + LogSoftmaxNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { return {NodeOp(LogSoftmax(val_, child(0)->val()))}; } @@ -416,7 +410,7 @@ struct SumNodeOp : public UnaryNodeOp { template SumNodeOp(Expr a, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} + : UnaryNodeOp(a, newShape(a, args...)) {} NodeOps forwardOps() { using namespace functional; @@ -465,7 +459,7 @@ struct MeanNodeOp : public UnaryNodeOp { template MeanNodeOp(Expr a, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} + : UnaryNodeOp(a, newShape(a, args...)) {} NodeOps forwardOps() { using namespace functional; @@ -516,8 +510,7 @@ struct MeanNodeOp : public UnaryNodeOp { }; struct LogNodeOp : public UnaryNodeOp { - template - LogNodeOp(Args... args) : UnaryNodeOp(args...) {} + LogNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -534,8 +527,7 @@ struct LogNodeOp : public UnaryNodeOp { }; struct ExpNodeOp : public UnaryNodeOp { - template - ExpNodeOp(Args... args) : UnaryNodeOp(args...) {} + ExpNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -553,9 +545,8 @@ struct ExpNodeOp : public UnaryNodeOp { struct SqrtNodeOp : public UnaryNodeOp { float epsilon_; - template - SqrtNodeOp(Expr a, float epsilon, Args... args) - : UnaryNodeOp(a, args...), epsilon_(epsilon) {} + SqrtNodeOp(Expr a, float epsilon) + : UnaryNodeOp(a), epsilon_(epsilon) {} NodeOps forwardOps() { using namespace functional; @@ -591,8 +582,7 @@ struct SqrtNodeOp : public UnaryNodeOp { }; struct SquareNodeOp : public UnaryNodeOp { - template - SquareNodeOp(Args... args) : UnaryNodeOp(args...) {} + SquareNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -609,8 +599,7 @@ struct SquareNodeOp : public UnaryNodeOp { }; struct NegNodeOp : public UnaryNodeOp { - template - NegNodeOp(Args... args) : UnaryNodeOp(args...) {} + NegNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -626,9 +615,8 @@ struct NegNodeOp : public UnaryNodeOp { }; struct RowsNodeOp : public UnaryNodeOp { - template - RowsNodeOp(Expr a, const std::vector& indeces, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), + RowsNodeOp(Expr a, const std::vector& indeces) + : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {} NodeOps forwardOps() { @@ -679,9 +667,8 @@ struct RowsNodeOp : public UnaryNodeOp { }; struct ColsNodeOp : public UnaryNodeOp { - template - ColsNodeOp(Expr a, const std::vector& indeces, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), + ColsNodeOp(Expr a, const std::vector& indeces) + : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {} NodeOps forwardOps() { @@ -731,7 +718,7 @@ struct ColsNodeOp : public UnaryNodeOp { struct SelectNodeOp : public UnaryNodeOp { SelectNodeOp(Expr a, int axis, const std::vector& indeces) - : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)), + : UnaryNodeOp(a, newShape(a, axis, indeces)), indices_(indeces) {} NodeOps forwardOps() { @@ -787,7 +774,7 @@ struct TransposeNodeOp : public UnaryNodeOp { std::vector axes_; TransposeNodeOp(Expr a, const std::vector& axes) - : UnaryNodeOp(a, keywords::shape = newShape(a, axes)), + : UnaryNodeOp(a, newShape(a, axes)), axes_{axes} {} NodeOps forwardOps() { @@ -844,8 +831,8 @@ private: public: template - ReshapeNodeOp(Expr a, Shape shape, Args... args) - : UnaryNodeOp(a, keywords::shape = shape, args...), reshapee_(a) { + ReshapeNodeOp(Expr a, Shape shape) + : UnaryNodeOp(a, shape), reshapee_(a) { Node::destroy_ = false; } @@ -909,7 +896,7 @@ private: public: StepNodeOp(Expr a, int step, int axis) - : UnaryNodeOp(a, keywords::shape = newShape(a, axis)), + : UnaryNodeOp(a, newShape(a, axis)), stepNode_(a), step_(step) { Node::destroy_ = false; @@ -981,9 +968,8 @@ public: }; struct ShiftNodeOp : public UnaryNodeOp { - template - ShiftNodeOp(Expr a, Shape shift, Args... args) - : UnaryNodeOp(a, keywords::shape = a->shape(), args...), shift_(shift) {} + ShiftNodeOp(Expr a, Shape shift) + : UnaryNodeOp(a, a->shape()), shift_(shift) {} NodeOps forwardOps() { return {NodeOp(Shift(val_, child(0)->val(), shift_, false))}; -- cgit v1.2.3