diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-27 01:07:20 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-27 01:07:20 +0300 |
commit | fdd9516e982dad030b2fac36acab4184e27e532b (patch) | |
tree | 2a65e091e5cc48ec42767118ce1f5d0dd6898035 /src/graph | |
parent | c84244f76c66336d98f5ce5130f198e49b7f6b79 (diff) |
fixed bug in transpose for ndarray
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/expression_operators.cu | 10 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 12 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 7 |
3 files changed, 20 insertions, 9 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index 2d796bee..c2096afc 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -165,7 +165,15 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scalar) { } Expr transpose(Expr a) { - return Expression<TransposeNodeOp>(a, Shape({1, 0, 2, 3})); + Shape s = a->shape(); + for(int i = 0; i < s.size(); ++i) { + s.set(i, i); + } + if(s.size() > 1) { + s.set(s.size() - 1, s.size() - 2); + s.set(s.size() - 2, s.size() - 1); + } + return Expression<TransposeNodeOp>(a, s); } Expr transpose(Expr a, Shape permute) { diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 315743e3..c07620ab 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -52,19 +52,19 @@ public: Shape newShape(Expr a, Expr b, bool transA, bool transB) { auto shapeA = a->shape(); if(transA) { - shapeA.set(0, a->shape()[1]); - shapeA.set(1, a->shape()[0]); + shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]); + shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]); } auto shapeB = b->shape(); if(transB) { - shapeB.set(0, b->shape()[1]); - shapeB.set(1, b->shape()[0]); + shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]); + shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]); } Shape outShape = shapeA; - outShape.set(1, shapeB[1]); - UTIL_THROW_IF2(shapeA[1] != shapeB[0], + outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]); + UTIL_THROW_IF2(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2], "matrix product requires dimensions to match"); return outShape; } diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 852d6d02..9881357c 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -702,9 +702,12 @@ struct TransposeNodeOp : public UnaryNodeOp { template <class... Args> Shape newShape(Expr a, Shape permute) { - Shape shape; + Shape shape = a->shape(); + + UTIL_THROW_IF2(shape.size() != permute.size(), + "Shape and transpose axis have different number of dimensions"); - for(int i = 0; i < 4; ++i) + for(int i = 0; i < shape.size(); ++i) shape.set(i, a->shape()[permute[i]]); return shape; |