diff options
author | Frank Seide <fseide@microsoft.com> | 2019-02-13 01:09:04 +0300 |
---|---|---|
committer | Frank Seide <fseide@microsoft.com> | 2019-02-13 01:09:04 +0300 |
commit | 94791594eb647f3e72fbe81d1e15020da20dd11f (patch) | |
tree | be2c580d5207837c0b4f4f791747c1de5d41eefe /src/graph | |
parent | 43e389d7c703b7c393ff12827121f869b1a1e472 (diff) |
swapAxes() now optimizes for case where it can reshape() instead
Diffstat (limited to 'src/graph')
-rwxr-xr-x | src/graph/expression_operators.cpp | 25 |
1 files changed, 21 insertions, 4 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 5c37beef..43e6b1a2 100755 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -213,6 +213,8 @@ Expr repeat(Expr a, size_t repeats, int ax) { } Expr reshape(Expr a, Shape shape) { + if (a->shape() == shape) + return a; return Expression<ReshapeNodeOp>(a, shape); } @@ -256,7 +258,7 @@ Expr flatten_2d(Expr a) { Expr stopGradient(Expr a) { // implemented as a dummy reshape that is not trainable - auto res = reshape(a, a->shape()); + auto res = Expression<ReshapeNodeOp>(a, a->shape()); res->setTrainable(false); return res; } @@ -530,12 +532,27 @@ Expr transpose(Expr a, const std::vector<int>& axes) { Expr swapAxes(Expr x, int axis1, int axis2) { - axis1 = x->shape().axis(axis1); - axis2 = x->shape().axis(axis2); + const auto& shape = x->shape(); + axis1 = shape.axis(axis1); + axis2 = shape.axis(axis2); if (axis1 == axis2) return x; + if (shape[axis1] == 1 || shape[axis2] == 1) { // can we use a reshape instead? + if (axis1 > axis2) + std::swap(axis1, axis2); + bool canReshape = true; + for (int ax = axis1 + 1; ax < axis2 && canReshape; ax++) + canReshape &= (shape[ax] == 1); + if (canReshape) { + auto newShape = shape; + newShape.set(axis1, shape[axis2]); + newShape.set(axis2, shape[axis1]); + //LOG(info, "SwapAxes() did a reshape from {} to {}", shape.toString(), newShape.toString()); + return reshape(x, newShape); + } + } // TODO: This is code dup from transpose(x). Implement transpose(x) as swapAxes(x, 0, 1) - std::vector<int> axes(x->shape().size()); + std::vector<int> axes(shape.size()); for (int i = 0; i < axes.size(); ++i) // @TODO: use std::iota() axes[i] = i; std::swap(axes[axis1], axes[axis2]); |