Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2018-03-12 23:34:10 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2018-03-12 23:34:10 +0300
commit6d0c75cf48bab913e2c9c52f1c4c6cd0d656005d (patch)
tree717342edade369af33a771f00a7dd05354ea8afb /src/graph/node_operators_unary.h
parent5f2eedc6e505eecf5bdef474be3e4f7066702fa7 (diff)
Autoformat files
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h126
1 files changed, 48 insertions, 78 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 0ca2c2a2..8d81a63a 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -12,11 +12,9 @@
namespace marian {
struct UnaryNodeOp : public NaryNodeOp {
- UnaryNodeOp(Expr a, Shape shape)
- : NaryNodeOp({a}, shape) {}
+ UnaryNodeOp(Expr a, Shape shape) : NaryNodeOp({a}, shape) {}
- UnaryNodeOp(Expr a)
- : NaryNodeOp({a}, a->shape()) {}
+ UnaryNodeOp(Expr a) : NaryNodeOp({a}, a->shape()) {}
const std::string color() { return "yellow"; }
};
@@ -26,9 +24,7 @@ private:
float scalar_{0};
public:
- ScalarAddNodeOp(Expr a, float scalar)
- : UnaryNodeOp(a),
- scalar_{scalar} {}
+ ScalarAddNodeOp(Expr a, float scalar) : UnaryNodeOp(a), scalar_{scalar} {}
NodeOps forwardOps() {
using namespace functional;
@@ -67,8 +63,7 @@ private:
float scalar_{0};
public:
- ScalarMultNodeOp(Expr a, float scalar)
- : UnaryNodeOp(a), scalar_{scalar} {}
+ ScalarMultNodeOp(Expr a, float scalar) : UnaryNodeOp(a), scalar_{scalar} {}
NodeOps forwardOps() {
using namespace functional;
@@ -210,7 +205,6 @@ struct TanhNodeOp : public NaryNodeOp {
const std::string type() { return "tanh"; }
};
-
struct ReLUNodeOp : public UnaryNodeOp {
ReLUNodeOp(Expr a) : UnaryNodeOp(a) {}
@@ -262,8 +256,7 @@ struct ReLUNodeOp : public UnaryNodeOp {
* \f]
*/
struct PReLUNodeOp : public UnaryNodeOp {
- PReLUNodeOp(float alpha, Expr a)
- : UnaryNodeOp(a), alpha_(alpha) {}
+ PReLUNodeOp(float alpha, Expr a) : UnaryNodeOp(a), alpha_(alpha) {}
NodeOps forwardOps() {
using namespace functional;
@@ -334,11 +327,9 @@ struct SwishNodeOp : public UnaryNodeOp {
};
struct SoftmaxNodeOp : public UnaryNodeOp {
- SoftmaxNodeOp(Expr a)
- : UnaryNodeOp(a), mask_(nullptr) {}
+ SoftmaxNodeOp(Expr a) : UnaryNodeOp(a), mask_(nullptr) {}
- SoftmaxNodeOp(Expr a, Expr mask)
- : UnaryNodeOp(a), mask_(mask) {}
+ SoftmaxNodeOp(Expr a, Expr mask) : UnaryNodeOp(a), mask_(mask) {}
Expr mask_;
@@ -407,17 +398,18 @@ struct SumNodeOp : public UnaryNodeOp {
int ax_;
template <typename... Args>
- SumNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, newShape(a, args...)) {}
+ SumNodeOp(Expr a, Args... args) : UnaryNodeOp(a, newShape(a, args...)) {}
NodeOps forwardOps() {
using namespace functional;
- return {NodeOp(Reduce(_1, val_, child(0)->val()))}; }
+ return {NodeOp(Reduce(_1, val_, child(0)->val()))};
+ }
NodeOps backwardOps() {
using namespace functional;
- return {NodeOp(Add(_1, child(0)->grad(), adj_))}; }
+ return {NodeOp(Add(_1, child(0)->grad(), adj_))};
+ }
template <class... Args>
Shape newShape(Expr a, Args... args) {
@@ -456,8 +448,7 @@ struct MeanNodeOp : public UnaryNodeOp {
int ax_;
template <typename... Args>
- MeanNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, newShape(a, args...)) {}
+ MeanNodeOp(Expr a, Args... args) : UnaryNodeOp(a, newShape(a, args...)) {}
NodeOps forwardOps() {
using namespace functional;
@@ -543,8 +534,7 @@ struct ExpNodeOp : public UnaryNodeOp {
struct SqrtNodeOp : public UnaryNodeOp {
float epsilon_;
- SqrtNodeOp(Expr a, float epsilon)
- : UnaryNodeOp(a), epsilon_(epsilon) {}
+ SqrtNodeOp(Expr a, float epsilon) : UnaryNodeOp(a), epsilon_(epsilon) {}
NodeOps forwardOps() {
using namespace functional;
@@ -614,8 +604,7 @@ struct NegNodeOp : public UnaryNodeOp {
struct RowsNodeOp : public UnaryNodeOp {
RowsNodeOp(Expr a, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, indeces)),
- indices_(indeces) {}
+ : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
@@ -666,8 +655,7 @@ struct RowsNodeOp : public UnaryNodeOp {
struct ColsNodeOp : public UnaryNodeOp {
ColsNodeOp(Expr a, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, indeces)),
- indices_(indeces) {}
+ : UnaryNodeOp(a, newShape(a, indeces)), indices_(indeces) {}
NodeOps forwardOps() {
// @TODO: solve this with a tensor!
@@ -716,8 +704,7 @@ struct ColsNodeOp : public UnaryNodeOp {
struct SelectNodeOp : public UnaryNodeOp {
SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces)
- : UnaryNodeOp(a, newShape(a, axis, indeces)),
- indices_(indeces) {}
+ : UnaryNodeOp(a, newShape(a, axis, indeces)), indices_(indeces) {}
NodeOps forwardOps() {
return {NodeOp(
@@ -772,8 +759,7 @@ struct TransposeNodeOp : public UnaryNodeOp {
std::vector<int> axes_;
TransposeNodeOp(Expr a, const std::vector<int>& axes)
- : UnaryNodeOp(a, newShape(a, axes)),
- axes_{axes} {}
+ : UnaryNodeOp(a, newShape(a, axes)), axes_{axes} {}
NodeOps forwardOps() {
return {NodeOp(TransposeND(val_, child(0)->val(), axes_))};
@@ -788,7 +774,7 @@ struct TransposeNodeOp : public UnaryNodeOp {
Shape shape = a->shape();
ABORT_IF(shape.size() != axes.size(),
- "Shape and transpose axes have different number of dimensions");
+ "Shape and transpose axes have different number of dimensions");
for(int i = 0; i < shape.size(); ++i)
shape.set(i, a->shape()[axes[i]]);
@@ -829,8 +815,7 @@ private:
public:
template <typename... Args>
- ReshapeNodeOp(Expr a, Shape shape)
- : UnaryNodeOp(a, shape), reshapee_(a) {
+ ReshapeNodeOp(Expr a, Shape shape) : UnaryNodeOp(a, shape), reshapee_(a) {
Node::destroy_ = false;
}
@@ -894,9 +879,7 @@ private:
public:
StepNodeOp(Expr a, int step, int axis)
- : UnaryNodeOp(a, newShape(a, axis)),
- stepNode_(a),
- step_(step) {
+ : UnaryNodeOp(a, newShape(a, axis)), stepNode_(a), step_(step) {
Node::destroy_ = false;
}
@@ -1056,67 +1039,54 @@ public:
padWidth,
strideHeight,
strideWidth,
- mode) {
- }
+ mode) {}
NodeOps forwardOps() {
return {NodeOp(pooling_.forward(child(0)->val(), val_))};
}
NodeOps backwardOps() {
- return {NodeOp(pooling_.backward(
- child(0)->val(),
- child(0)->grad(),
- val_,
- adj_))};
+ return {NodeOp(
+ pooling_.backward(child(0)->val(), child(0)->grad(), val_, adj_))};
}
const std::string type() { return "layer_pooling"; }
-
protected:
PoolingWrapper pooling_;
};
class PoolingWithMaskingOp : public UnaryNodeOp {
- public:
- PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false)
- : UnaryNodeOp(x),
- mask_(mask),
- width_(width),
- isEven_(isEven)
- {
- auto xShape = x->shape();
- int dimBatch = xShape[0];
- int dimWord = xShape[1];
- int cols = (isEven_) ? xShape[2] - 1 : xShape[2];
- int dimSentence = (cols / width_) + (cols % width_ != 0);
- shape_ = {dimBatch, dimWord, dimSentence};
- }
+public:
+ PoolingWithMaskingOp(Expr x, Expr mask, int width, bool isEven = false)
+ : UnaryNodeOp(x), mask_(mask), width_(width), isEven_(isEven) {
+ auto xShape = x->shape();
+ int dimBatch = xShape[0];
+ int dimWord = xShape[1];
+ int cols = (isEven_) ? xShape[2] - 1 : xShape[2];
+ int dimSentence = (cols / width_) + (cols % width_ != 0);
+ shape_ = {dimBatch, dimWord, dimSentence};
+ }
- NodeOps forwardOps() {
- return {NodeOp(PoolingWithMaskingForward(val_,
+ NodeOps forwardOps() {
+ return {NodeOp(PoolingWithMaskingForward(
+ val_, child(0)->val(), mask_->val(), width_, isEven_))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(PoolingWithMaskingBackward(adj_,
+ child(0)->grad(),
child(0)->val(),
mask_->val(),
width_,
isEven_))};
- }
-
- NodeOps backwardOps() {
- return {NodeOp(PoolingWithMaskingBackward(adj_,
- child(0)->grad(),
- child(0)->val(),
- mask_->val(),
- width_,
- isEven_))};
- }
+ }
- const std::string type() {return "layer_pooling";}
+ const std::string type() { return "layer_pooling"; }
- protected:
- Expr mask_;
- int width_;
- bool isEven_;
+protected:
+ Expr mask_;
+ int width_;
+ bool isEven_;
};
-
}