diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-06-29 00:32:12 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-06-29 00:32:12 +0300 |
commit | 34b16ff585055bb18ccd85db196d65ff236a1c7e (patch) | |
tree | b09f646b92616d8663b74cae6bf4a8cca61403d1 /src/graph/node_operators_unary.h | |
parent | 352a437ab49ec00be944e11ed4bba0d52ac49931 (diff) | |
parent | 54dac41e9dc044105f9e5cc4df8345c876a22c09 (diff) |
resolve merge conflict
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index dda4dd03..fa6d25c7 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -138,12 +138,12 @@ public: } }; -struct LogitNodeOp : public UnaryNodeOp { - LogitNodeOp(Expr a) : UnaryNodeOp(a) {} +struct SigmoidNodeOp : public UnaryNodeOp { + SigmoidNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; - return {NodeOp(Element(_1 = logit(_2), val_, child(0)->val()))}; + return {NodeOp(Element(_1 = sigmoid(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { @@ -151,7 +151,7 @@ struct LogitNodeOp : public UnaryNodeOp { return {NodeOp(Add(_1 * _2 * (1.0f - _2), child(0)->grad(), adj_, val_))}; } - const std::string type() { return "logit"; } + const std::string type() { return "sigmoid"; } }; // struct Scalar2PowNodeOp : public UnaryNodeOp { @@ -350,13 +350,13 @@ struct SwishNodeOp : public UnaryNodeOp { NodeOps forwardOps() { using namespace functional; - return {NodeOp(Element(_1 = _2 * logit(_2), val_, child(0)->val()))}; + return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { using namespace functional; // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) - return {NodeOp(Add(_1 * (_3 + logit(_2) * (1.f - _3)), + return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val(), // _2 := x @@ -936,8 +936,10 @@ public: Shape outShape = a->shape(); axis_ = outShape.axis(axis); +#if 0 // this check currently fails in translation; I think should not fail for step==0 for(int i = 0; i < axis_; ++i) ABORT_IF(outShape[i] != 1, "non-consecutive slices are presently not supported by step()"); +#endif outShape.set(axis_, 1); return outShape; |