diff options
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rwxr-xr-x | src/graph/node_operators_unary.h | 130 |
1 files changed, 73 insertions, 57 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 7dbaec46..6dd90faf 100755 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -412,20 +412,75 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp { const std::string type() override { return "logsoftmax"; } }; -struct SumNodeOp : public UnaryNodeOp { +enum class ReduceNodeOpCode { + sum, mean, rms, meanSqr, min, max, prod, logSumExp +}; + +struct ReduceNodeOp : public UnaryNodeOp { int axis_; + ReduceNodeOpCode opCode_; + int reducedDim_; // dimension of axis being reduced, e.g. used in mean() - SumNodeOp(Expr a, int axis) : UnaryNodeOp(a, newShape(a, axis)) {} + ReduceNodeOp(Expr a, int axis, ReduceNodeOpCode opCode) + : UnaryNodeOp(a, newShape(a, axis)), opCode_(opCode) + { + reducedDim_ = a->shape()[axis]; // e.g. used in mean() + ABORT_IF(reducedDim_ != a->shape().elements() / shape().elements(), "bug in determining reducedDim"); + } NodeOps forwardOps() override { using namespace functional; - return {NodeOp(Reduce(_1, val_, child(0)->val()))}; + switch (opCode_) { + case ReduceNodeOpCode::sum: + return {NodeOp(Reduce(_1, val_, child(0)->val()))}; + case ReduceNodeOpCode::mean: + return {NodeOp(Reduce(_1, 1.0f / (float)reducedDim_, val_, child(0)->val()))}; + case ReduceNodeOpCode::rms: + return {NodeOp(Reduce(_1 * _1, 1.0f / (float)reducedDim_, val_, child(0)->val()); + Element(_1 = sqrt(_1), val_))}; + case ReduceNodeOpCode::meanSqr: + return {NodeOp(Reduce(_1 * _1, 1.0f / (float)reducedDim_, val_, child(0)->val()))}; + case ReduceNodeOpCode::min: + return {NodeOp(Reduce(_1, min(_1,_2), std::numeric_limits<float>::max(), val_, child(0)->val()))}; + case ReduceNodeOpCode::max: + return {NodeOp(Reduce(_1, max(_1,_2), std::numeric_limits<float>::lowest(), val_, child(0)->val()))}; + case ReduceNodeOpCode::prod: + return {NodeOp(Reduce(_1, _1 * _2, 1.0f, val_, child(0)->val()))}; + case ReduceNodeOpCode::logSumExp: + return {NodeOp(Reduce(_1, logaddexp(_1,_2), std::numeric_limits<float>::lowest(), val_, child(0)->val()))}; + default: + ABORT("Unexpected reduction op-code {}", (int)opCode_); + } } NodeOps backwardOps() override { using namespace functional; - return {NodeOp(Add(_1, child(0)->grad(), adj_))}; + switch (opCode_) { + case ReduceNodeOpCode::sum: + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; + case ReduceNodeOpCode::mean: + return {NodeOp(Add(_1, 1.0f / (float)reducedDim_, child(0)->grad(), adj_))}; + case ReduceNodeOpCode::rms: // WARNING: UNTESTED!! + // y = (sum_j x_j^2)^0.5 + // dJ/dx_i = dJ/dy * 0.5 (sum_j x_j^2)^-0.5 * 2 x_i = dJ/dy * x_i / y --@REVIEW: is this correct? + // @TODO: do we need protection against div by 0? L'hospital rule? + return {NodeOp(Add(_1 * _2 / _3, child(0)->grad(), adj_, child(0)->val(), val_))}; + case ReduceNodeOpCode::meanSqr: // WARNING: UNTESTED!! + // y = sum_j x_j^2 + // dJ/dx_i = dJ/dy * sum_j dx_j^2/dx_i = dJ/dy * 2 dx_i --@REVIEW: is this correct? + return {NodeOp(Add(_1 * 2.0f * _2, child(0)->grad(), adj_, child(0)->val()))}; + case ReduceNodeOpCode::min: // WARNING: UNTESTED!! + case ReduceNodeOpCode::max: // WARNING: UNTESTED!! + // adj_ gets routed into the min/max value --@REVIEW: is this correct? + return {NodeOp(Add((_1 == _2) * _3, child(0)->grad(), child(0)->val(), val_, adj_))}; + case ReduceNodeOpCode::logSumExp: + // y = log(sum_j exp(x_j)) + // dJ/dx_i = dJ/dy * 1/(sum_j exp(x_j)) exp(x_i) = dJ/dy * exp(x_i - y)) --@REVIEW: is this correct? + return {NodeOp(Add(_1 * exp(_2 - _3), child(0)->grad(), adj_, child(0)->val(), val_))}; + default: + ABORT("Unexpected reduction op-code {}", (int)opCode_); + } } Shape newShape(Expr a, int axis) { @@ -436,66 +491,27 @@ struct SumNodeOp : public UnaryNodeOp { return shape; } - const std::string type() override { return "sum"; } - - const std::string color() override { return "orange"; } - - virtual size_t hash() override { - if(!hash_) { - hash_ = NaryNodeOp::hash(); - util::hash_combine(hash_, axis_); + const std::string type() override { + switch (opCode_) { + case ReduceNodeOpCode::sum: return "sum"; + case ReduceNodeOpCode::mean: return "mean"; + case ReduceNodeOpCode::rms: return "rms"; + case ReduceNodeOpCode::meanSqr: return "meanSqr"; + case ReduceNodeOpCode::min: return "min"; + case ReduceNodeOpCode::max: return "max"; + case ReduceNodeOpCode::prod: return "prod"; + case ReduceNodeOpCode::logSumExp: return "logSumExp"; + default: ABORT("Unexpected reduction op-code {}", (int)opCode_); } - return hash_; } - virtual bool equal(Expr node) override { - if(!NaryNodeOp::equal(node)) - return false; - Ptr<SumNodeOp> cnode = std::dynamic_pointer_cast<SumNodeOp>(node); - if(!cnode) - return false; - if(axis_ != cnode->axis_) - return false; - return true; - } -}; - -struct MeanNodeOp : public UnaryNodeOp { - int axis_; - - MeanNodeOp(Expr a, int axis) : UnaryNodeOp(a, newShape(a, axis)) {} - - NodeOps forwardOps() override { - using namespace functional; - int left = child(0)->shape().elements() / val_->shape().elements(); - float scale = 1.f / left; - - return {NodeOp(Reduce(_1, scale, val_, child(0)->val()))}; - } - - NodeOps backwardOps() override { - using namespace functional; - int left = child(0)->shape().elements() / val_->shape().elements(); - float scale = 1.f / left; - - return {NodeOp(Add(_1, scale, child(0)->grad(), adj_))}; - } - - Shape newShape(Expr a, int axis) { - Shape shape = a->shape(); - axis_ = shape.axis(axis); - shape.set(axis_, 1); - return shape; - } - - const std::string type() override { return "mean"; } - const std::string color() override { return "orange"; } virtual size_t hash() override { if(!hash_) { hash_ = NaryNodeOp::hash(); util::hash_combine(hash_, axis_); + util::hash_combine(hash_, (int)opCode_); } return hash_; } @@ -503,10 +519,10 @@ struct MeanNodeOp : public UnaryNodeOp { virtual bool equal(Expr node) override { if(!NaryNodeOp::equal(node)) return false; - Ptr<MeanNodeOp> cnode = std::dynamic_pointer_cast<MeanNodeOp>(node); + Ptr<ReduceNodeOp> cnode = std::dynamic_pointer_cast<ReduceNodeOp>(node); if(!cnode) return false; - if(axis_ != cnode->axis_) + if(axis_ != cnode->axis_ || opCode_ != cnode->opCode_) return false; return true; } |