diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-07-27 20:14:21 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-07-27 20:14:21 +0300 |
commit | dceb7185d86ed8fd1994e86dc3e3c0e03740ec4a (patch) | |
tree | 3514f87aa2da28313043959ebd0381b3ba7de233 /src/graph/node_operators_unary.h | |
parent | 5cc8674d974bb5cae7bc8f25a51472166164a579 (diff) | |
parent | 8b0e2f951b5ce09a622fa7239b2e1e5bd8344fe4 (diff) |
fix merge
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index fa6d25c7..d7ef751d 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -7,7 +7,9 @@ #include "graph/node.h" #include "tensors/tensor_operators.h" -//#include "tensors/gpu/cudnn_wrappers.h" +#ifdef CUDNN +#include "tensors/gpu/cudnn_wrappers.h" +#endif namespace marian { @@ -815,7 +817,7 @@ struct TransposeNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { - return {NodeOp(TransposeND(child(0)->grad(), adj_, axes_))}; + return {NodeOp(TransposeNDGrad(child(0)->grad(), adj_, axes_))}; } template <class... Args> @@ -1009,7 +1011,9 @@ struct ShiftNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { - return {NodeOp(Shift(child(0)->grad(), adj_, shift_, /*padValue=*/0.f, /*invert=*/true))}; + // last parameter beta=1 says to use += (out = in + beta * out) + // @TODO: check need for padValue_ + return {NodeOp(ShiftGrad(child(0)->grad(), adj_, shift_, true))}; } const std::string type() { return "shift"; } @@ -1076,6 +1080,7 @@ struct ShiftNodeOp : public UnaryNodeOp { // Ptr<sparse::CSR> lf_; //}; +#ifdef CUDNN class PoolingOp : public UnaryNodeOp { public: PoolingOp(Expr x, @@ -1109,6 +1114,7 @@ public: protected: PoolingWrapper pooling_; }; +#endif class PoolingWithMaskingOp : public UnaryNodeOp { public: |