Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-27 01:07:20 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-27 01:07:20 +0300
commitfdd9516e982dad030b2fac36acab4184e27e532b (patch)
tree2a65e091e5cc48ec42767118ce1f5d0dd6898035 /src/graph
parentc84244f76c66336d98f5ce5130f198e49b7f6b79 (diff)
fixed bug in transpose for ndarray
Diffstat (limited to 'src/graph')
-rw-r--r--src/graph/expression_operators.cu10
-rw-r--r--src/graph/node_operators_binary.h12
-rw-r--r--src/graph/node_operators_unary.h7
3 files changed, 20 insertions, 9 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu
index 2d796bee..c2096afc 100644
--- a/src/graph/expression_operators.cu
+++ b/src/graph/expression_operators.cu
@@ -165,7 +165,15 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scalar) {
}
Expr transpose(Expr a) {
- return Expression<TransposeNodeOp>(a, Shape({1, 0, 2, 3}));
+ Shape s = a->shape();
+ for(int i = 0; i < s.size(); ++i) {
+ s.set(i, i);
+ }
+ if(s.size() > 1) {
+ s.set(s.size() - 1, s.size() - 2);
+ s.set(s.size() - 2, s.size() - 1);
+ }
+ return Expression<TransposeNodeOp>(a, s);
}
Expr transpose(Expr a, Shape permute) {
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 315743e3..c07620ab 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -52,19 +52,19 @@ public:
Shape newShape(Expr a, Expr b, bool transA, bool transB) {
auto shapeA = a->shape();
if(transA) {
- shapeA.set(0, a->shape()[1]);
- shapeA.set(1, a->shape()[0]);
+ shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
+ shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
}
auto shapeB = b->shape();
if(transB) {
- shapeB.set(0, b->shape()[1]);
- shapeB.set(1, b->shape()[0]);
+ shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]);
+ shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]);
}
Shape outShape = shapeA;
- outShape.set(1, shapeB[1]);
- UTIL_THROW_IF2(shapeA[1] != shapeB[0],
+ outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
+ UTIL_THROW_IF2(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"matrix product requires dimensions to match");
return outShape;
}
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 852d6d02..9881357c 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -702,9 +702,12 @@ struct TransposeNodeOp : public UnaryNodeOp {
template <class... Args>
Shape newShape(Expr a, Shape permute) {
- Shape shape;
+ Shape shape = a->shape();
+
+ UTIL_THROW_IF2(shape.size() != permute.size(),
+ "Shape and transpose axis have different number of dimensions");
- for(int i = 0; i < 4; ++i)
+ for(int i = 0; i < shape.size(); ++i)
shape.set(i, a->shape()[permute[i]]);
return shape;