Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h12
1 files changed, 9 insertions, 3 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index fa6d25c7..d7ef751d 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -7,7 +7,9 @@
#include "graph/node.h"
#include "tensors/tensor_operators.h"
-//#include "tensors/gpu/cudnn_wrappers.h"
+#ifdef CUDNN
+#include "tensors/gpu/cudnn_wrappers.h"
+#endif
namespace marian {
@@ -815,7 +817,7 @@ struct TransposeNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
- return {NodeOp(TransposeND(child(0)->grad(), adj_, axes_))};
+ return {NodeOp(TransposeNDGrad(child(0)->grad(), adj_, axes_))};
}
template <class... Args>
@@ -1009,7 +1011,9 @@ struct ShiftNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
- return {NodeOp(Shift(child(0)->grad(), adj_, shift_, /*padValue=*/0.f, /*invert=*/true))};
+ // last parameter beta=1 says to use += (out = in + beta * out)
+ // @TODO: check need for padValue_
+ return {NodeOp(ShiftGrad(child(0)->grad(), adj_, shift_, true))};
}
const std::string type() { return "shift"; }
@@ -1076,6 +1080,7 @@ struct ShiftNodeOp : public UnaryNodeOp {
// Ptr<sparse::CSR> lf_;
//};
+#ifdef CUDNN
class PoolingOp : public UnaryNodeOp {
public:
PoolingOp(Expr x,
@@ -1109,6 +1114,7 @@ public:
protected:
PoolingWrapper pooling_;
};
+#endif
class PoolingWithMaskingOp : public UnaryNodeOp {
public: