Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-07-27 20:14:21 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-07-27 20:14:21 +0300
commitdceb7185d86ed8fd1994e86dc3e3c0e03740ec4a (patch)
tree3514f87aa2da28313043959ebd0381b3ba7de233 /src/graph/node_operators_unary.h
parent5cc8674d974bb5cae7bc8f25a51472166164a579 (diff)
parent8b0e2f951b5ce09a622fa7239b2e1e5bd8344fe4 (diff)
fix merge
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h12
1 files changed, 9 insertions, 3 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index fa6d25c7..d7ef751d 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -7,7 +7,9 @@
#include "graph/node.h"
#include "tensors/tensor_operators.h"
-//#include "tensors/gpu/cudnn_wrappers.h"
+#ifdef CUDNN
+#include "tensors/gpu/cudnn_wrappers.h"
+#endif
namespace marian {
@@ -815,7 +817,7 @@ struct TransposeNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
- return {NodeOp(TransposeND(child(0)->grad(), adj_, axes_))};
+ return {NodeOp(TransposeNDGrad(child(0)->grad(), adj_, axes_))};
}
template <class... Args>
@@ -1009,7 +1011,9 @@ struct ShiftNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
- return {NodeOp(Shift(child(0)->grad(), adj_, shift_, /*padValue=*/0.f, /*invert=*/true))};
+ // last parameter beta=1 says to use += (out = in + beta * out)
+ // @TODO: check need for padValue_
+ return {NodeOp(ShiftGrad(child(0)->grad(), adj_, shift_, true))};
}
const std::string type() { return "shift"; }
@@ -1076,6 +1080,7 @@ struct ShiftNodeOp : public UnaryNodeOp {
// Ptr<sparse::CSR> lf_;
//};
+#ifdef CUDNN
class PoolingOp : public UnaryNodeOp {
public:
PoolingOp(Expr x,
@@ -1109,6 +1114,7 @@ public:
protected:
PoolingWrapper pooling_;
};
+#endif
class PoolingWithMaskingOp : public UnaryNodeOp {
public: