diff options
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 9881357c..a3f60366 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -146,8 +146,8 @@ struct TanhNodeOp : public NaryNodeOp { for(int n = 1; n < nodes.size(); ++n) { Shape shapen = nodes[n]->shape(); for(int i = 0; i < shapen.size(); ++i) { - UTIL_THROW_IF2(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1, - "Shapes cannot be broadcasted"); + ABORT_IF(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1, + "Shapes cannot be broadcasted"); shape.set(i, std::max(shape[i], shapen[i])); } } @@ -237,8 +237,11 @@ struct SwishNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { - return {NodeOp( - Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), child(0)->grad(), adj_, child(0)->val(), val_))}; + return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), + child(0)->grad(), + adj_, + child(0)->val(), + val_))}; } const std::string type() { return "swish"; } |