diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-27 01:07:20 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-10-27 01:07:20 +0300 |
commit | fdd9516e982dad030b2fac36acab4184e27e532b (patch) | |
tree | 2a65e091e5cc48ec42767118ce1f5d0dd6898035 /src | |
parent | c84244f76c66336d98f5ce5130f198e49b7f6b79 (diff) |
fixed bug in transpose for ndarray
Diffstat (limited to 'src')
-rw-r--r-- | src/common/shape.h | 5 | ||||
-rw-r--r-- | src/graph/expression_operators.cu | 10 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 12 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 7 | ||||
-rw-r--r-- | src/kernels/tensor_operators.cu | 14 | ||||
-rw-r--r-- | src/tests/operator_tests.cpp | 10 |
6 files changed, 42 insertions, 16 deletions
diff --git a/src/common/shape.h b/src/common/shape.h index 0fa77b29..b0b88120 100644 --- a/src/common/shape.h +++ b/src/common/shape.h @@ -29,6 +29,11 @@ struct Shape { updateStrides(); } + void resize(size_t n) { + shape_.resize(n, 1); + updateStrides(); + } + void updateStrides() { stride_.resize(shape_.size()); bstride_.resize(shape_.size()); diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index 2d796bee..c2096afc 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -165,7 +165,15 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scalar) { } Expr transpose(Expr a) { - return Expression<TransposeNodeOp>(a, Shape({1, 0, 2, 3})); + Shape s = a->shape(); + for(int i = 0; i < s.size(); ++i) { + s.set(i, i); + } + if(s.size() > 1) { + s.set(s.size() - 1, s.size() - 2); + s.set(s.size() - 2, s.size() - 1); + } + return Expression<TransposeNodeOp>(a, s); } Expr transpose(Expr a, Shape permute) { diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 315743e3..c07620ab 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -52,19 +52,19 @@ public: Shape newShape(Expr a, Expr b, bool transA, bool transB) { auto shapeA = a->shape(); if(transA) { - shapeA.set(0, a->shape()[1]); - shapeA.set(1, a->shape()[0]); + shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]); + shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]); } auto shapeB = b->shape(); if(transB) { - shapeB.set(0, b->shape()[1]); - shapeB.set(1, b->shape()[0]); + shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]); + shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]); } Shape outShape = shapeA; - outShape.set(1, shapeB[1]); - UTIL_THROW_IF2(shapeA[1] != shapeB[0], + outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]); + UTIL_THROW_IF2(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2], "matrix product requires dimensions to match"); return outShape; } diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 852d6d02..9881357c 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -702,9 +702,12 @@ struct TransposeNodeOp : public UnaryNodeOp { template <class... Args> Shape newShape(Expr a, Shape permute) { - Shape shape; + Shape shape = a->shape(); + + UTIL_THROW_IF2(shape.size() != permute.size(), + "Shape and transpose axis have different number of dimensions"); - for(int i = 0; i < 4; ++i) + for(int i = 0; i < shape.size(); ++i) shape.set(i, a->shape()[permute[i]]); return shape; diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu index 645e6af8..1c0b7c35 100644 --- a/src/kernels/tensor_operators.cu +++ b/src/kernels/tensor_operators.cu @@ -190,7 +190,7 @@ __global__ void gTranspose4D(float* out, outShape.dims(index, dims1); for(int i = 0; i < num; ++i) - dims2[i] = dims1[permute[i]]; + dims2[permute[i]] = dims1[i]; int inIndex = inShape.index(dims2); @@ -207,8 +207,18 @@ void Transpose4D(Tensor out, Tensor in, Shape permute) { int threads = std::min(MAX_THREADS, length); int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0)); + Shape permuteGPU; + permuteGPU.resize(ShapeGPU::size()); + + int diff = ShapeGPU::size() - permute.size(); + for(int i = 0; i < permuteGPU.size(); ++i) + if(i < diff) + permuteGPU.set(i, i); + else + permuteGPU.set(i, permute[i - diff] + diff); + gTranspose4D<<<blocks, threads>>>( - out->data(), out->shape(), in->data(), in->shape(), permute); + out->data(), out->shape(), in->data(), in->shape(), permuteGPU); } __global__ void gSoftmax(float* out, diff --git a/src/tests/operator_tests.cpp b/src/tests/operator_tests.cpp index 67cc0265..1feddb18 100644 --- a/src/tests/operator_tests.cpp +++ b/src/tests/operator_tests.cpp @@ -21,7 +21,7 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") { values.clear(); std::vector<float> vC({22, 28, 49, 64, 76, 100, 103, 136}); - auto A = graph->param("A", {2, 3, 2}, keywords::init = inits::from_vector(vA)); + auto A = graph->param("A", {2, 2, 3}, keywords::init = inits::from_vector(vA)); auto B = graph->param("B", {3, 2}, keywords::init = inits::from_vector(vB)); auto C = dot(A, B); graph->forward(); @@ -127,15 +127,15 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") { std::vector<float> vT1({1, 5, 2, 6, 3, 7, 4, 8}); std::vector<float> vT3({1, 2, 5, 6, 3, 4, 7, 8}); std::vector<float> vT4({1, 5, 3, 7, 2, 6, 4, 8}); - std::vector<float> vT5({1, 3, 2, 4, 5, 7, 6, 8}); + std::vector<float> vT5({1, 2, 5, 6, 3, 4, 7, 8}); auto a = graph->constant({2, 4}, keywords::init = inits::from_vector(vA)); auto t1 = transpose(a); auto t2 = transpose(t1); auto t3 = transpose(reshape(t1, {2, 2, 2})); - auto t4 = transpose(reshape(a, {2, 2, 1, 2}), {2, 3, 0, 1}); - auto t5 = transpose(reshape(a, {2, 2, 1, 2}), {1, 2, 3, 0}); + auto t4 = transpose(reshape(a, {2, 1, 2, 2}), {1, 3, 2, 0}); + auto t5 = transpose(reshape(a, {2, 1, 2, 2}), {2, 0, 1, 3}); graph->forward(); @@ -143,7 +143,7 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") { CHECK(t2->shape() == Shape({2, 4})); CHECK(t3->shape() == Shape({2, 2, 2})); CHECK(t4->shape() == Shape({1, 2, 2, 2})); - CHECK(t5->shape() == Shape({2, 1, 2, 2})); + CHECK(t5->shape() == Shape({2, 2, 1, 2})); t1->val()->get(values); CHECK( values == vT1 ); |