diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-06-28 00:06:38 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-06-28 00:06:38 +0300 |
commit | 0edf3b3913c04d7a90d2af6797b9b817ac94dca9 (patch) | |
tree | 28e89a823314e8cc57470602460c1edee4657198 | |
parent | 94645a31fc93f0a93499027cdad16b7ac33ca42f (diff) |
add proper gradient summation to shift operator
-rw-r--r-- | src/graph/expression_operators.h | 2 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 6 | ||||
-rw-r--r-- | src/tensors/cpu/tensor_operators.cpp | 6 | ||||
-rw-r--r-- | src/tensors/gpu/tensor_operators.cu | 8 | ||||
-rw-r--r-- | src/tensors/tensor_operators.h | 2 |
5 files changed, 12 insertions, 12 deletions
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 57d8dad6..6e994f75 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -82,9 +82,7 @@ Expr affine(Expr a, Expr transpose(Expr a); Expr transpose(Expr a, const std::vector<int>& axes); -// check Expr concatenate(const std::vector<Expr>& concats, keywords::axis_k ax = 0); -// check Expr repeat(Expr a, size_t repeats, keywords::axis_k ax = 0); Expr reshape(Expr a, Shape shape); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 0fc17d28..2cc4fa37 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -997,11 +997,13 @@ struct ShiftNodeOp : public UnaryNodeOp { : UnaryNodeOp(a, a->shape()), shift_(shift) {} NodeOps forwardOps() { - return {NodeOp(Shift(val_, child(0)->val(), shift_, false))}; + // last parameter beta=0 says to use = (out = in + beta * out) + return {NodeOp(Shift(val_, child(0)->val(), shift_, false, 0.f))}; } NodeOps backwardOps() { - return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true))}; + // last parameter beta=1 says to use += (out = in + beta * out) + return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true, 1.0f))}; } const std::string type() { return "shift"; } diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index 9197f85f..7d8134ea 100644 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -896,7 +896,7 @@ void LayerNormalizationGrad(Tensor gradX_, } } -void Shift(Tensor out_, Tensor in_, marian::Shape shift, bool invert) { +void Shift(Tensor out_, Tensor in_, marian::Shape shift, bool invert, float beta) { int offset = 0; for(int i = 0; i < shift.size(); ++i) offset += in_->shape().stride(i) * shift[i]; @@ -911,9 +911,9 @@ void Shift(Tensor out_, Tensor in_, marian::Shape shift, bool invert) { #pragma omp parallel for for(int i = 0; i < length; ++i) { if(i - offset < 0 || i - offset >= length) { - out[i] = 0.f; + out[i] = 0.f + beta * out[i]; } else { - out[i] = in[i - offset]; + out[i] = in[i - offset] + beta * out[i]; } } } diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu index ccc99640..0e1b99c3 100644 --- a/src/tensors/gpu/tensor_operators.cu +++ b/src/tensors/gpu/tensor_operators.cu @@ -1693,19 +1693,19 @@ void LayerNormalizationGrad(Tensor gradX, eps); } -__global__ void gShift(float* out, const float* in, int length, int offset) { +__global__ void gShift(float* out, const float* in, int length, int offset, float beta) { for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) { int index = bid + blockDim.x * blockIdx.x + threadIdx.x; if(index < length) { if(index - offset < 0 || index - offset >= length) out[index] = 0; else - out[index] = in[index - offset]; + out[index] = in[index - offset] + beta * out[index]; } } } -void Shift(Tensor out, Tensor in, marian::Shape shift, bool invert) { +void Shift(Tensor out, Tensor in, marian::Shape shift, bool invert, float beta) { ABORT_IF(in->shape().size() != shift.size(), "bad dimensions"); // BUGBUG: This can only shift along the first axis. Shifting, e.g., along the last axis cannot be implemented this way. @@ -1723,7 +1723,7 @@ void Shift(Tensor out, Tensor in, marian::Shape shift, bool invert) { int threads = std::min(MAX_THREADS, length); int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0)); - gShift<<<blocks, threads>>>(out->data(), in->data(), length, offset); + gShift<<<blocks, threads>>>(out->data(), in->data(), length, offset, beta); } __global__ void gSetSparse(float* out, diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index b7422e0c..cbaff174 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -79,7 +79,7 @@ void Reduce(Functor functor, marian::Tensor out, Tensors... tensors) { DISPATCH4(CrossEntropyPickBackward, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor) DISPATCH4(TransposeND, marian::Tensor, marian::Tensor, const std::vector<int>&, float) - DISPATCH4(Shift, marian::Tensor, marian::Tensor, marian::Shape, bool) + DISPATCH5(Shift, marian::Tensor, marian::Tensor, marian::Shape, bool, float) DISPATCH3(Concatenate, marian::Tensor, const std::vector<marian::Tensor>&, int) // clang-format on |