diff options
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index dda4dd03..fa6d25c7 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -138,12 +138,12 @@ public: } }; -struct LogitNodeOp : public UnaryNodeOp { - LogitNodeOp(Expr a) : UnaryNodeOp(a) {} +struct SigmoidNodeOp : public UnaryNodeOp { + SigmoidNodeOp(Expr a) : UnaryNodeOp(a) {} NodeOps forwardOps() { using namespace functional; - return {NodeOp(Element(_1 = logit(_2), val_, child(0)->val()))}; + return {NodeOp(Element(_1 = sigmoid(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { @@ -151,7 +151,7 @@ struct LogitNodeOp : public UnaryNodeOp { return {NodeOp(Add(_1 * _2 * (1.0f - _2), child(0)->grad(), adj_, val_))}; } - const std::string type() { return "logit"; } + const std::string type() { return "sigmoid"; } }; // struct Scalar2PowNodeOp : public UnaryNodeOp { @@ -350,13 +350,13 @@ struct SwishNodeOp : public UnaryNodeOp { NodeOps forwardOps() { using namespace functional; - return {NodeOp(Element(_1 = _2 * logit(_2), val_, child(0)->val()))}; + return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { using namespace functional; // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) - return {NodeOp(Add(_1 * (_3 + logit(_2) * (1.f - _3)), + return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val(), // _2 := x @@ -936,8 +936,10 @@ public: Shape outShape = a->shape(); axis_ = outShape.axis(axis); +#if 0 // this check currently fails in translation; I think should not fail for step==0 for(int i = 0; i < axis_; ++i) ABORT_IF(outShape[i] != 1, "non-consecutive slices are presently not supported by step()"); +#endif outShape.set(axis_, 1); return outShape; |