Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-29 08:12:07 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-29 08:12:07 +0300
commit0cdde80c9b4a6e5f4e2ffe030c6d83931cfbc150 (patch)
tree3e9841b0ec372da1a54f361ec69dc484d5705f71 /src/graph/node_operators_unary.h
parente7344470d18c182aa8cc19471e755769a0e6e47f (diff)
another attempt at fixing gradients for transpose etc
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 2cc4fa37..1e0c71e6 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -806,11 +806,11 @@ struct TransposeNodeOp : public UnaryNodeOp {
: UnaryNodeOp(a, newShape(a, axes)), axes_{axes} {}
NodeOps forwardOps() {
- return {NodeOp(TransposeND(val_, child(0)->val(), axes_, 0.f))};
+ return {NodeOp(TransposeND(val_, child(0)->val(), axes_))};
}
NodeOps backwardOps() {
- return {NodeOp(TransposeND(child(0)->grad(), adj_, axes_, 1.f))};
+ return {NodeOp(TransposeNDGrad(child(0)->grad(), adj_, axes_))};
}
template <class... Args>
@@ -998,12 +998,12 @@ struct ShiftNodeOp : public UnaryNodeOp {
NodeOps forwardOps() {
// last parameter beta=0 says to use = (out = in + beta * out)
- return {NodeOp(Shift(val_, child(0)->val(), shift_, false, 0.f))};
+ return {NodeOp(Shift(val_, child(0)->val(), shift_, false))};
}
NodeOps backwardOps() {
// last parameter beta=1 says to use += (out = in + beta * out)
- return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true, 1.0f))};
+ return {NodeOp(ShiftGrad(child(0)->grad(), adj_, shift_, true))};
}
const std::string type() { return "shift"; }