Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-28 00:06:38 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-28 00:06:38 +0300
commit0edf3b3913c04d7a90d2af6797b9b817ac94dca9 (patch)
tree28e89a823314e8cc57470602460c1edee4657198 /src/graph/node_operators_unary.h
parent94645a31fc93f0a93499027cdad16b7ac33ca42f (diff)
add proper gradient summation to shift operator
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 0fc17d28..2cc4fa37 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -997,11 +997,13 @@ struct ShiftNodeOp : public UnaryNodeOp {
: UnaryNodeOp(a, a->shape()), shift_(shift) {}
NodeOps forwardOps() {
- return {NodeOp(Shift(val_, child(0)->val(), shift_, false))};
+ // last parameter beta=0 says to use = (out = in + beta * out)
+ return {NodeOp(Shift(val_, child(0)->val(), shift_, false, 0.f))};
}
NodeOps backwardOps() {
- return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true))};
+ // last parameter beta=1 says to use += (out = in + beta * out)
+ return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true, 1.0f))};
}
const std::string type() { return "shift"; }