Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-27 01:07:20 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-27 01:07:20 +0300
commitfdd9516e982dad030b2fac36acab4184e27e532b (patch)
tree2a65e091e5cc48ec42767118ce1f5d0dd6898035 /src
parentc84244f76c66336d98f5ce5130f198e49b7f6b79 (diff)
fixed bug in transpose for ndarray
Diffstat (limited to 'src')
-rw-r--r--src/common/shape.h5
-rw-r--r--src/graph/expression_operators.cu10
-rw-r--r--src/graph/node_operators_binary.h12
-rw-r--r--src/graph/node_operators_unary.h7
-rw-r--r--src/kernels/tensor_operators.cu14
-rw-r--r--src/tests/operator_tests.cpp10
6 files changed, 42 insertions, 16 deletions
diff --git a/src/common/shape.h b/src/common/shape.h
index 0fa77b29..b0b88120 100644
--- a/src/common/shape.h
+++ b/src/common/shape.h
@@ -29,6 +29,11 @@ struct Shape {
updateStrides();
}
+ void resize(size_t n) {
+ shape_.resize(n, 1);
+ updateStrides();
+ }
+
void updateStrides() {
stride_.resize(shape_.size());
bstride_.resize(shape_.size());
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu
index 2d796bee..c2096afc 100644
--- a/src/graph/expression_operators.cu
+++ b/src/graph/expression_operators.cu
@@ -165,7 +165,15 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scalar) {
}
Expr transpose(Expr a) {
- return Expression<TransposeNodeOp>(a, Shape({1, 0, 2, 3}));
+ Shape s = a->shape();
+ for(int i = 0; i < s.size(); ++i) {
+ s.set(i, i);
+ }
+ if(s.size() > 1) {
+ s.set(s.size() - 1, s.size() - 2);
+ s.set(s.size() - 2, s.size() - 1);
+ }
+ return Expression<TransposeNodeOp>(a, s);
}
Expr transpose(Expr a, Shape permute) {
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 315743e3..c07620ab 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -52,19 +52,19 @@ public:
Shape newShape(Expr a, Expr b, bool transA, bool transB) {
auto shapeA = a->shape();
if(transA) {
- shapeA.set(0, a->shape()[1]);
- shapeA.set(1, a->shape()[0]);
+ shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
+ shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
}
auto shapeB = b->shape();
if(transB) {
- shapeB.set(0, b->shape()[1]);
- shapeB.set(1, b->shape()[0]);
+ shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]);
+ shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]);
}
Shape outShape = shapeA;
- outShape.set(1, shapeB[1]);
- UTIL_THROW_IF2(shapeA[1] != shapeB[0],
+ outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
+ UTIL_THROW_IF2(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"matrix product requires dimensions to match");
return outShape;
}
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 852d6d02..9881357c 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -702,9 +702,12 @@ struct TransposeNodeOp : public UnaryNodeOp {
template <class... Args>
Shape newShape(Expr a, Shape permute) {
- Shape shape;
+ Shape shape = a->shape();
+
+ UTIL_THROW_IF2(shape.size() != permute.size(),
+ "Shape and transpose axis have different number of dimensions");
- for(int i = 0; i < 4; ++i)
+ for(int i = 0; i < shape.size(); ++i)
shape.set(i, a->shape()[permute[i]]);
return shape;
diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu
index 645e6af8..1c0b7c35 100644
--- a/src/kernels/tensor_operators.cu
+++ b/src/kernels/tensor_operators.cu
@@ -190,7 +190,7 @@ __global__ void gTranspose4D(float* out,
outShape.dims(index, dims1);
for(int i = 0; i < num; ++i)
- dims2[i] = dims1[permute[i]];
+ dims2[permute[i]] = dims1[i];
int inIndex = inShape.index(dims2);
@@ -207,8 +207,18 @@ void Transpose4D(Tensor out, Tensor in, Shape permute) {
int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
+ Shape permuteGPU;
+ permuteGPU.resize(ShapeGPU::size());
+
+ int diff = ShapeGPU::size() - permute.size();
+ for(int i = 0; i < permuteGPU.size(); ++i)
+ if(i < diff)
+ permuteGPU.set(i, i);
+ else
+ permuteGPU.set(i, permute[i - diff] + diff);
+
gTranspose4D<<<blocks, threads>>>(
- out->data(), out->shape(), in->data(), in->shape(), permute);
+ out->data(), out->shape(), in->data(), in->shape(), permuteGPU);
}
__global__ void gSoftmax(float* out,
diff --git a/src/tests/operator_tests.cpp b/src/tests/operator_tests.cpp
index 67cc0265..1feddb18 100644
--- a/src/tests/operator_tests.cpp
+++ b/src/tests/operator_tests.cpp
@@ -21,7 +21,7 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") {
values.clear();
std::vector<float> vC({22, 28, 49, 64, 76, 100, 103, 136});
- auto A = graph->param("A", {2, 3, 2}, keywords::init = inits::from_vector(vA));
+ auto A = graph->param("A", {2, 2, 3}, keywords::init = inits::from_vector(vA));
auto B = graph->param("B", {3, 2}, keywords::init = inits::from_vector(vB));
auto C = dot(A, B);
graph->forward();
@@ -127,15 +127,15 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") {
std::vector<float> vT1({1, 5, 2, 6, 3, 7, 4, 8});
std::vector<float> vT3({1, 2, 5, 6, 3, 4, 7, 8});
std::vector<float> vT4({1, 5, 3, 7, 2, 6, 4, 8});
- std::vector<float> vT5({1, 3, 2, 4, 5, 7, 6, 8});
+ std::vector<float> vT5({1, 2, 5, 6, 3, 4, 7, 8});
auto a = graph->constant({2, 4}, keywords::init = inits::from_vector(vA));
auto t1 = transpose(a);
auto t2 = transpose(t1);
auto t3 = transpose(reshape(t1, {2, 2, 2}));
- auto t4 = transpose(reshape(a, {2, 2, 1, 2}), {2, 3, 0, 1});
- auto t5 = transpose(reshape(a, {2, 2, 1, 2}), {1, 2, 3, 0});
+ auto t4 = transpose(reshape(a, {2, 1, 2, 2}), {1, 3, 2, 0});
+ auto t5 = transpose(reshape(a, {2, 1, 2, 2}), {2, 0, 1, 3});
graph->forward();
@@ -143,7 +143,7 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") {
CHECK(t2->shape() == Shape({2, 4}));
CHECK(t3->shape() == Shape({2, 2, 2}));
CHECK(t4->shape() == Shape({1, 2, 2, 2}));
- CHECK(t5->shape() == Shape({2, 1, 2, 2}));
+ CHECK(t5->shape() == Shape({2, 2, 1, 2}));
t1->val()->get(values);
CHECK( values == vT1 );