Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rwxr-xr-xsrc/graph/node_operators_unary.h130
1 files changed, 73 insertions, 57 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 7dbaec46..6dd90faf 100755
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -412,20 +412,75 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp {
const std::string type() override { return "logsoftmax"; }
};
-struct SumNodeOp : public UnaryNodeOp {
+enum class ReduceNodeOpCode {
+ sum, mean, rms, meanSqr, min, max, prod, logSumExp
+};
+
+struct ReduceNodeOp : public UnaryNodeOp {
int axis_;
+ ReduceNodeOpCode opCode_;
+ int reducedDim_; // dimension of axis being reduced, e.g. used in mean()
- SumNodeOp(Expr a, int axis) : UnaryNodeOp(a, newShape(a, axis)) {}
+ ReduceNodeOp(Expr a, int axis, ReduceNodeOpCode opCode)
+ : UnaryNodeOp(a, newShape(a, axis)), opCode_(opCode)
+ {
+ reducedDim_ = a->shape()[axis]; // e.g. used in mean()
+ ABORT_IF(reducedDim_ != a->shape().elements() / shape().elements(), "bug in determining reducedDim");
+ }
NodeOps forwardOps() override {
using namespace functional;
- return {NodeOp(Reduce(_1, val_, child(0)->val()))};
+ switch (opCode_) {
+ case ReduceNodeOpCode::sum:
+ return {NodeOp(Reduce(_1, val_, child(0)->val()))};
+ case ReduceNodeOpCode::mean:
+ return {NodeOp(Reduce(_1, 1.0f / (float)reducedDim_, val_, child(0)->val()))};
+ case ReduceNodeOpCode::rms:
+ return {NodeOp(Reduce(_1 * _1, 1.0f / (float)reducedDim_, val_, child(0)->val());
+ Element(_1 = sqrt(_1), val_))};
+ case ReduceNodeOpCode::meanSqr:
+ return {NodeOp(Reduce(_1 * _1, 1.0f / (float)reducedDim_, val_, child(0)->val()))};
+ case ReduceNodeOpCode::min:
+ return {NodeOp(Reduce(_1, min(_1,_2), std::numeric_limits<float>::max(), val_, child(0)->val()))};
+ case ReduceNodeOpCode::max:
+ return {NodeOp(Reduce(_1, max(_1,_2), std::numeric_limits<float>::lowest(), val_, child(0)->val()))};
+ case ReduceNodeOpCode::prod:
+ return {NodeOp(Reduce(_1, _1 * _2, 1.0f, val_, child(0)->val()))};
+ case ReduceNodeOpCode::logSumExp:
+ return {NodeOp(Reduce(_1, logaddexp(_1,_2), std::numeric_limits<float>::lowest(), val_, child(0)->val()))};
+ default:
+ ABORT("Unexpected reduction op-code {}", (int)opCode_);
+ }
}
NodeOps backwardOps() override {
using namespace functional;
- return {NodeOp(Add(_1, child(0)->grad(), adj_))};
+ switch (opCode_) {
+ case ReduceNodeOpCode::sum:
+ return {NodeOp(Add(_1, child(0)->grad(), adj_))};
+ case ReduceNodeOpCode::mean:
+ return {NodeOp(Add(_1, 1.0f / (float)reducedDim_, child(0)->grad(), adj_))};
+ case ReduceNodeOpCode::rms: // WARNING: UNTESTED!!
+ // y = (sum_j x_j^2)^0.5
+ // dJ/dx_i = dJ/dy * 0.5 (sum_j x_j^2)^-0.5 * 2 x_i = dJ/dy * x_i / y --@REVIEW: is this correct?
+ // @TODO: do we need protection against div by 0? L'hospital rule?
+ return {NodeOp(Add(_1 * _2 / _3, child(0)->grad(), adj_, child(0)->val(), val_))};
+ case ReduceNodeOpCode::meanSqr: // WARNING: UNTESTED!!
+ // y = sum_j x_j^2
+ // dJ/dx_i = dJ/dy * sum_j dx_j^2/dx_i = dJ/dy * 2 dx_i --@REVIEW: is this correct?
+ return {NodeOp(Add(_1 * 2.0f * _2, child(0)->grad(), adj_, child(0)->val()))};
+ case ReduceNodeOpCode::min: // WARNING: UNTESTED!!
+ case ReduceNodeOpCode::max: // WARNING: UNTESTED!!
+ // adj_ gets routed into the min/max value --@REVIEW: is this correct?
+ return {NodeOp(Add((_1 == _2) * _3, child(0)->grad(), child(0)->val(), val_, adj_))};
+ case ReduceNodeOpCode::logSumExp:
+ // y = log(sum_j exp(x_j))
+ // dJ/dx_i = dJ/dy * 1/(sum_j exp(x_j)) exp(x_i) = dJ/dy * exp(x_i - y)) --@REVIEW: is this correct?
+ return {NodeOp(Add(_1 * exp(_2 - _3), child(0)->grad(), adj_, child(0)->val(), val_))};
+ default:
+ ABORT("Unexpected reduction op-code {}", (int)opCode_);
+ }
}
Shape newShape(Expr a, int axis) {
@@ -436,66 +491,27 @@ struct SumNodeOp : public UnaryNodeOp {
return shape;
}
- const std::string type() override { return "sum"; }
-
- const std::string color() override { return "orange"; }
-
- virtual size_t hash() override {
- if(!hash_) {
- hash_ = NaryNodeOp::hash();
- util::hash_combine(hash_, axis_);
+ const std::string type() override {
+ switch (opCode_) {
+ case ReduceNodeOpCode::sum: return "sum";
+ case ReduceNodeOpCode::mean: return "mean";
+ case ReduceNodeOpCode::rms: return "rms";
+ case ReduceNodeOpCode::meanSqr: return "meanSqr";
+ case ReduceNodeOpCode::min: return "min";
+ case ReduceNodeOpCode::max: return "max";
+ case ReduceNodeOpCode::prod: return "prod";
+ case ReduceNodeOpCode::logSumExp: return "logSumExp";
+ default: ABORT("Unexpected reduction op-code {}", (int)opCode_);
}
- return hash_;
}
- virtual bool equal(Expr node) override {
- if(!NaryNodeOp::equal(node))
- return false;
- Ptr<SumNodeOp> cnode = std::dynamic_pointer_cast<SumNodeOp>(node);
- if(!cnode)
- return false;
- if(axis_ != cnode->axis_)
- return false;
- return true;
- }
-};
-
-struct MeanNodeOp : public UnaryNodeOp {
- int axis_;
-
- MeanNodeOp(Expr a, int axis) : UnaryNodeOp(a, newShape(a, axis)) {}
-
- NodeOps forwardOps() override {
- using namespace functional;
- int left = child(0)->shape().elements() / val_->shape().elements();
- float scale = 1.f / left;
-
- return {NodeOp(Reduce(_1, scale, val_, child(0)->val()))};
- }
-
- NodeOps backwardOps() override {
- using namespace functional;
- int left = child(0)->shape().elements() / val_->shape().elements();
- float scale = 1.f / left;
-
- return {NodeOp(Add(_1, scale, child(0)->grad(), adj_))};
- }
-
- Shape newShape(Expr a, int axis) {
- Shape shape = a->shape();
- axis_ = shape.axis(axis);
- shape.set(axis_, 1);
- return shape;
- }
-
- const std::string type() override { return "mean"; }
-
const std::string color() override { return "orange"; }
virtual size_t hash() override {
if(!hash_) {
hash_ = NaryNodeOp::hash();
util::hash_combine(hash_, axis_);
+ util::hash_combine(hash_, (int)opCode_);
}
return hash_;
}
@@ -503,10 +519,10 @@ struct MeanNodeOp : public UnaryNodeOp {
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<MeanNodeOp> cnode = std::dynamic_pointer_cast<MeanNodeOp>(node);
+ Ptr<ReduceNodeOp> cnode = std::dynamic_pointer_cast<ReduceNodeOp>(node);
if(!cnode)
return false;
- if(axis_ != cnode->axis_)
+ if(axis_ != cnode->axis_ || opCode_ != cnode->opCode_)
return false;
return true;
}