Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-29 00:32:12 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-29 00:32:12 +0300
commit34b16ff585055bb18ccd85db196d65ff236a1c7e (patch)
treeb09f646b92616d8663b74cae6bf4a8cca61403d1 /src/graph/node_operators_unary.h
parent352a437ab49ec00be944e11ed4bba0d52ac49931 (diff)
parent54dac41e9dc044105f9e5cc4df8345c876a22c09 (diff)
resolve merge conflict
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r--src/graph/node_operators_unary.h14
1 files changed, 8 insertions, 6 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index dda4dd03..fa6d25c7 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -138,12 +138,12 @@ public:
}
};
-struct LogitNodeOp : public UnaryNodeOp {
- LogitNodeOp(Expr a) : UnaryNodeOp(a) {}
+struct SigmoidNodeOp : public UnaryNodeOp {
+ SigmoidNodeOp(Expr a) : UnaryNodeOp(a) {}
NodeOps forwardOps() {
using namespace functional;
- return {NodeOp(Element(_1 = logit(_2), val_, child(0)->val()))};
+ return {NodeOp(Element(_1 = sigmoid(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
@@ -151,7 +151,7 @@ struct LogitNodeOp : public UnaryNodeOp {
return {NodeOp(Add(_1 * _2 * (1.0f - _2), child(0)->grad(), adj_, val_))};
}
- const std::string type() { return "logit"; }
+ const std::string type() { return "sigmoid"; }
};
// struct Scalar2PowNodeOp : public UnaryNodeOp {
@@ -350,13 +350,13 @@ struct SwishNodeOp : public UnaryNodeOp {
NodeOps forwardOps() {
using namespace functional;
- return {NodeOp(Element(_1 = _2 * logit(_2), val_, child(0)->val()))};
+ return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
using namespace functional;
// dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) )
- return {NodeOp(Add(_1 * (_3 + logit(_2) * (1.f - _3)),
+ return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)),
child(0)->grad(), // dJ/dx
adj_, // _1 := dJ/df
child(0)->val(), // _2 := x
@@ -936,8 +936,10 @@ public:
Shape outShape = a->shape();
axis_ = outShape.axis(axis);
+#if 0 // this check currently fails in translation; I think should not fail for step==0
for(int i = 0; i < axis_; ++i)
ABORT_IF(outShape[i] != 1, "non-consecutive slices are presently not supported by step()");
+#endif
outShape.set(axis_, 1);
return outShape;