From c6350c666f293c64f3745d20d9cb9796eed849c8 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Wed, 27 Jun 2018 13:04:44 -0700 Subject: fix transpose operator --- src/graph/node_operators_unary.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index bb0b66f4..0fc17d28 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -806,11 +806,11 @@ struct TransposeNodeOp : public UnaryNodeOp { : UnaryNodeOp(a, newShape(a, axes)), axes_{axes} {} NodeOps forwardOps() { - return {NodeOp(TransposeND(val_, child(0)->val(), axes_))}; + return {NodeOp(TransposeND(val_, child(0)->val(), axes_, 0.f))}; } NodeOps backwardOps() { - return {NodeOp(TransposeND(child(0)->grad(), adj_, axes_))}; + return {NodeOp(TransposeND(child(0)->grad(), adj_, axes_, 1.f))}; } template -- cgit v1.2.3 From 0edf3b3913c04d7a90d2af6797b9b817ac94dca9 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Wed, 27 Jun 2018 14:06:38 -0700 Subject: add proper gradient summation to shift operator --- src/graph/node_operators_unary.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 0fc17d28..2cc4fa37 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -997,11 +997,13 @@ struct ShiftNodeOp : public UnaryNodeOp { : UnaryNodeOp(a, a->shape()), shift_(shift) {} NodeOps forwardOps() { - return {NodeOp(Shift(val_, child(0)->val(), shift_, false))}; + // last parameter beta=0 says to use = (out = in + beta * out) + return {NodeOp(Shift(val_, child(0)->val(), shift_, false, 0.f))}; } NodeOps backwardOps() { - return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true))}; + // last parameter beta=1 says to use += (out = in + beta * out) + return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true, 1.0f))}; } const std::string type() { return "shift"; } -- cgit v1.2.3 From 0cdde80c9b4a6e5f4e2ffe030c6d83931cfbc150 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Thu, 28 Jun 2018 22:12:07 -0700 Subject: another attempt at fixing gradients for transpose etc --- src/graph/node_operators_unary.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 2cc4fa37..1e0c71e6 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -806,11 +806,11 @@ struct TransposeNodeOp : public UnaryNodeOp { : UnaryNodeOp(a, newShape(a, axes)), axes_{axes} {} NodeOps forwardOps() { - return {NodeOp(TransposeND(val_, child(0)->val(), axes_, 0.f))}; + return {NodeOp(TransposeND(val_, child(0)->val(), axes_))}; } NodeOps backwardOps() { - return {NodeOp(TransposeND(child(0)->grad(), adj_, axes_, 1.f))}; + return {NodeOp(TransposeNDGrad(child(0)->grad(), adj_, axes_))}; } template @@ -998,12 +998,12 @@ struct ShiftNodeOp : public UnaryNodeOp { NodeOps forwardOps() { // last parameter beta=0 says to use = (out = in + beta * out) - return {NodeOp(Shift(val_, child(0)->val(), shift_, false, 0.f))}; + return {NodeOp(Shift(val_, child(0)->val(), shift_, false))}; } NodeOps backwardOps() { // last parameter beta=1 says to use += (out = in + beta * out) - return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true, 1.0f))}; + return {NodeOp(ShiftGrad(child(0)->grad(), adj_, shift_, true))}; } const std::string type() { return "shift"; } -- cgit v1.2.3 From 9bcb4ee5a30de27f9ac8976f06f1c5ef952f7930 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Mon, 2 Jul 2018 17:17:14 +0100 Subject: Fix #262 PoolingWrapper dependency --- src/graph/node_operators_unary.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'src/graph/node_operators_unary.h') diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 1e0c71e6..acc84c9c 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -7,7 +7,9 @@ #include "graph/node.h" #include "tensors/tensor_operators.h" -//#include "tensors/gpu/cudnn_wrappers.h" +#ifdef CUDNN +#include "tensors/gpu/cudnn_wrappers.h" +#endif namespace marian { @@ -1068,6 +1070,7 @@ struct ShiftNodeOp : public UnaryNodeOp { // Ptr lf_; //}; +#ifdef CUDNN class PoolingOp : public UnaryNodeOp { public: PoolingOp(Expr x, @@ -1101,6 +1104,7 @@ public: protected: PoolingWrapper pooling_; }; +#endif class PoolingWithMaskingOp : public UnaryNodeOp { public: -- cgit v1.2.3