diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-03-01 02:52:40 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-03-01 02:52:40 +0300 |
commit | fd208d841b0ec4fdbfdc81df601e86986528a051 (patch) | |
tree | efe85c4e634fa1885583e2d4192a8d7252a0e945 /src/graph/node_operators_unary.h | |
parent | 71f911940c91d2fd5c337ddf4d8e88108d5ed822 (diff) |
simplified interface, get slowly rid of keywords
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 96 |
1 files changed, 41 insertions, 55 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 0a76471b..e857e790 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -14,9 +14,11 @@ namespace marian { struct UnaryNodeOp : public NaryNodeOp { - template <typename... Args> - UnaryNodeOp(Expr a, Args... args) - : NaryNodeOp({a}, keywords::shape = a->shape(), args...) {} + UnaryNodeOp(Expr a, Shape shape) + : NaryNodeOp({a}, shape) {} + + UnaryNodeOp(Expr a) + : NaryNodeOp({a}, a->shape()) {} const std::string color() { return "yellow"; } }; @@ -26,9 +28,9 @@ private: float scalar_{0}; public: - template <typename... Args> - ScalarAddNodeOp(Expr a, float scalar, Args... args) - : UnaryNodeOp(a, args...), scalar_{scalar} {} + ScalarAddNodeOp(Expr a, float scalar) + : UnaryNodeOp(a), + scalar_{scalar} {} NodeOps forwardOps() { using namespace functional; @@ -67,9 +69,8 @@ private: float scalar_{0}; public: - template <typename... Args> - ScalarMultNodeOp(Expr a, float scalar, Args... args) - : UnaryNodeOp(a, args...), scalar_{scalar} {} + ScalarMultNodeOp(Expr a, float scalar) + : UnaryNodeOp(a), scalar_{scalar} {} NodeOps forwardOps() { using namespace functional; @@ -104,8 +105,7 @@ public: }; struct LogitNodeOp : public UnaryNodeOp { - template <typename... Args> - LogitNodeOp(Args... args) : UnaryNodeOp(args...) {} + LogitNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -164,7 +164,7 @@ struct LogitNodeOp : public UnaryNodeOp { struct TanhNodeOp : public NaryNodeOp { TanhNodeOp(const std::vector<Expr>& nodes) - : NaryNodeOp(nodes, keywords::shape = newShape(nodes)) {} + : NaryNodeOp(nodes, newShape(nodes)) {} Shape newShape(const std::vector<Expr>& nodes) { return Shape::broadcast(nodes); @@ -214,8 +214,7 @@ struct TanhNodeOp : public NaryNodeOp { struct ReLUNodeOp : public UnaryNodeOp { - template <typename... Args> - ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {} + ReLUNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { // f(x) = max(0, x) @@ -265,9 +264,8 @@ struct ReLUNodeOp : public UnaryNodeOp { * \f] */ struct PReLUNodeOp : public UnaryNodeOp { - template <typename... Args> - PReLUNodeOp(float alpha, Args... args) - : UnaryNodeOp(args...), alpha_(alpha) {} + PReLUNodeOp(float alpha, Expr a) + : UnaryNodeOp(a), alpha_(alpha) {} NodeOps forwardOps() { using namespace functional; @@ -316,8 +314,7 @@ private: * */ struct SwishNodeOp : public UnaryNodeOp { - template <typename... Args> - SwishNodeOp(Args... args) : UnaryNodeOp(args...) {} + SwishNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -338,14 +335,12 @@ struct SwishNodeOp : public UnaryNodeOp { const std::string type() { return "swish"; } }; -struct SoftmaxNodeOp : public NaryNodeOp { - template <typename... Args> - SoftmaxNodeOp(Expr a, Args... args) - : NaryNodeOp(a, args...), mask_(nullptr) {} +struct SoftmaxNodeOp : public UnaryNodeOp { + SoftmaxNodeOp(Expr a) + : UnaryNodeOp(a), mask_(nullptr) {} - template <typename... Args> - SoftmaxNodeOp(Expr a, Expr mask, Args... args) - : NaryNodeOp({a}, args...), mask_(mask) {} + SoftmaxNodeOp(Expr a, Expr mask) + : UnaryNodeOp(a), mask_(mask) {} Expr mask_; @@ -396,8 +391,7 @@ struct SoftmaxNodeOp : public NaryNodeOp { }; struct LogSoftmaxNodeOp : public UnaryNodeOp { - template <typename... Args> - LogSoftmaxNodeOp(Args... args) : UnaryNodeOp(args...) {} + LogSoftmaxNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { return {NodeOp(LogSoftmax(val_, child(0)->val()))}; } @@ -416,7 +410,7 @@ struct SumNodeOp : public UnaryNodeOp { template <typename... Args> SumNodeOp(Expr a, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} + : UnaryNodeOp(a, newShape(a, args...)) {} NodeOps forwardOps() { using namespace functional; @@ -465,7 +459,7 @@ struct MeanNodeOp : public UnaryNodeOp { template <typename... Args> MeanNodeOp(Expr a, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} + : UnaryNodeOp(a, newShape(a, args...)) {} NodeOps forwardOps() { using namespace functional; @@ -516,8 +510,7 @@ struct MeanNodeOp : public UnaryNodeOp { }; struct LogNodeOp : public UnaryNodeOp { - template <typename... Args> - LogNodeOp(Args... args) : UnaryNodeOp(args...) {} + LogNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -534,8 +527,7 @@ struct LogNodeOp : public UnaryNodeOp { }; struct ExpNodeOp : public UnaryNodeOp { - template <typename... Args> - ExpNodeOp(Args... args) : UnaryNodeOp(args...) {} + ExpNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -553,9 +545,8 @@ struct ExpNodeOp : public UnaryNodeOp { struct SqrtNodeOp : public UnaryNodeOp { float epsilon_; - template <typename... Args> - SqrtNodeOp(Expr a, float epsilon, Args... args) - : UnaryNodeOp(a, args...), epsilon_(epsilon) {} + SqrtNodeOp(Expr a, float epsilon) + : UnaryNodeOp(a), epsilon_(epsilon) {} NodeOps forwardOps() { using namespace functional; @@ -591,8 +582,7 @@ struct SqrtNodeOp : public UnaryNodeOp { }; struct SquareNodeOp : public UnaryNodeOp { - template <typename... Args> - SquareNodeOp(Args... args) : UnaryNodeOp(args...) {} + SquareNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -609,8 +599,7 @@ struct SquareNodeOp : public UnaryNodeOp { }; struct NegNodeOp : public UnaryNodeOp { - template <typename... Args> - NegNodeOp(Args... args) : UnaryNodeOp(args...) {} + NegNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; @@ -626,9 +615,8 @@ struct NegNodeOp : public UnaryNodeOp { }; struct RowsNodeOp : public UnaryNodeOp { - template <typename... Args> - RowsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), + RowsNodeOp(Expr a, const std::vector<size_t>& indeces) + : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {} NodeOps forwardOps() { @@ -679,9 +667,8 @@ struct RowsNodeOp : public UnaryNodeOp { }; struct ColsNodeOp : public UnaryNodeOp { - template <typename... Args> - ColsNodeOp(Expr a, const std::vector<size_t>& indeces, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, indeces), args...), + ColsNodeOp(Expr a, const std::vector<size_t>& indeces) + : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {} NodeOps forwardOps() { @@ -731,7 +718,7 @@ struct ColsNodeOp : public UnaryNodeOp { struct SelectNodeOp : public UnaryNodeOp { SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces) - : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)), + : UnaryNodeOp(a, newShape(a, axis, indeces)), indices_(indeces) {} NodeOps forwardOps() { @@ -787,7 +774,7 @@ struct TransposeNodeOp : public UnaryNodeOp { std::vector<int> axes_; TransposeNodeOp(Expr a, const std::vector<int>& axes) - : UnaryNodeOp(a, keywords::shape = newShape(a, axes)), + : UnaryNodeOp(a, newShape(a, axes)), axes_{axes} {} NodeOps forwardOps() { @@ -844,8 +831,8 @@ private: public: template <typename... Args> - ReshapeNodeOp(Expr a, Shape shape, Args... args) - : UnaryNodeOp(a, keywords::shape = shape, args...), reshapee_(a) { + ReshapeNodeOp(Expr a, Shape shape) + : UnaryNodeOp(a, shape), reshapee_(a) { Node::destroy_ = false; } @@ -909,7 +896,7 @@ private: public: StepNodeOp(Expr a, int step, int axis) - : UnaryNodeOp(a, keywords::shape = newShape(a, axis)), + : UnaryNodeOp(a, newShape(a, axis)), stepNode_(a), step_(step) { Node::destroy_ = false; @@ -981,9 +968,8 @@ public: }; struct ShiftNodeOp : public UnaryNodeOp { - template <typename... Args> - ShiftNodeOp(Expr a, Shape shift, Args... args) - : UnaryNodeOp(a, keywords::shape = a->shape(), args...), shift_(shift) {} + ShiftNodeOp(Expr a, Shape shift) + : UnaryNodeOp(a, a->shape()), shift_(shift) {} NodeOps forwardOps() { return {NodeOp(Shift(val_, child(0)->val(), shift_, false))}; |