Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorFrank Seide <fseide@microsoft.com>2019-02-13 01:09:04 +0300
committerFrank Seide <fseide@microsoft.com>2019-02-13 01:09:04 +0300
commit94791594eb647f3e72fbe81d1e15020da20dd11f (patch)
treebe2c580d5207837c0b4f4f791747c1de5d41eefe /src/graph
parent43e389d7c703b7c393ff12827121f869b1a1e472 (diff)
swapAxes() now optimizes for case where it can reshape() instead
Diffstat (limited to 'src/graph')
-rwxr-xr-xsrc/graph/expression_operators.cpp25
1 files changed, 21 insertions, 4 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 5c37beef..43e6b1a2 100755
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -213,6 +213,8 @@ Expr repeat(Expr a, size_t repeats, int ax) {
}
Expr reshape(Expr a, Shape shape) {
+ if (a->shape() == shape)
+ return a;
return Expression<ReshapeNodeOp>(a, shape);
}
@@ -256,7 +258,7 @@ Expr flatten_2d(Expr a) {
Expr stopGradient(Expr a) {
// implemented as a dummy reshape that is not trainable
- auto res = reshape(a, a->shape());
+ auto res = Expression<ReshapeNodeOp>(a, a->shape());
res->setTrainable(false);
return res;
}
@@ -530,12 +532,27 @@ Expr transpose(Expr a, const std::vector<int>& axes) {
Expr swapAxes(Expr x, int axis1, int axis2)
{
- axis1 = x->shape().axis(axis1);
- axis2 = x->shape().axis(axis2);
+ const auto& shape = x->shape();
+ axis1 = shape.axis(axis1);
+ axis2 = shape.axis(axis2);
if (axis1 == axis2)
return x;
+ if (shape[axis1] == 1 || shape[axis2] == 1) { // can we use a reshape instead?
+ if (axis1 > axis2)
+ std::swap(axis1, axis2);
+ bool canReshape = true;
+ for (int ax = axis1 + 1; ax < axis2 && canReshape; ax++)
+ canReshape &= (shape[ax] == 1);
+ if (canReshape) {
+ auto newShape = shape;
+ newShape.set(axis1, shape[axis2]);
+ newShape.set(axis2, shape[axis1]);
+ //LOG(info, "SwapAxes() did a reshape from {} to {}", shape.toString(), newShape.toString());
+ return reshape(x, newShape);
+ }
+ }
// TODO: This is code dup from transpose(x). Implement transpose(x) as swapAxes(x, 0, 1)
- std::vector<int> axes(x->shape().size());
+ std::vector<int> axes(shape.size());
for (int i = 0; i < axes.size(); ++i) // @TODO: use std::iota()
axes[i] = i;
std::swap(axes[axis1], axes[axis2]);