Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-12-06 14:06:16 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-12-06 14:06:16 +0300
commit13f3ca094b6ed3a8a19c7f605f78d6679573c22f (patch)
tree2f062ad7d010a660ba0cfbcefb0d321375a79172
parent7bbe96dd0b3961e413a1eac6862bd9e7d81c1034 (diff)
bugs in gBroadcast
-rw-r--r--src/amun/gpu/mblas/matrix_functions.h35
1 files changed, 17 insertions, 18 deletions
diff --git a/src/amun/gpu/mblas/matrix_functions.h b/src/amun/gpu/mblas/matrix_functions.h
index 75b274b6..eda683fc 100644
--- a/src/amun/gpu/mblas/matrix_functions.h
+++ b/src/amun/gpu/mblas/matrix_functions.h
@@ -151,36 +151,35 @@ __global__ void gBroadcast(Functor functor,
const MatrixWrapper<float> in2Wrap,
const VectorWrapper<uint> batchMappingWrap)
{
- size_t srcSize = outWrap.dim(2);
- size_t inRows = in2Wrap.dim(0);
- size_t cols = in1Wrap.dim(1);
-
int id = threadIdx.x + blockIdx.x * blockDim.x;
if (id < outWrap.size()) {
/*
- size_t indices[SHAPE_SIZE];
+ uint indices[SHAPE_SIZE];
outWrap.id2Indices(id, indices);
- int row = id / cols; // len * batch for in1
- int srcId = row % srcSize; // source pos for in1
-
- int batchMappingIdx = row / srcSize; // batch for in1
- int batchIdx = batchMappingWrap[batchMappingIdx]; // batch id for in1
-
- outWrap[id] = functor(in1Wrap(srcId, indices[1], 0, batchIdx),
- in2Wrap(batchMappingIdx, indices[1], 0, 0) );
+ uint srcId = indices[0];
+ uint stateIdx = indices[1];
+ uint beamIdx = indices[2];
+ //assert(0 == indices[3]);
*/
- int row = id / cols;
- int stateIdx = id % cols;
+ uint cols = in1Wrap.dim(1);
+ uint srcSize = outWrap.dim(0);
+
+ uint row = id / cols;
+ uint stateIdx = id % cols;
+ uint beamIdx = row / srcSize;
+ uint srcId = row % srcSize;
- int beamIdx = row / srcSize;
- int srcId = row % srcSize;
+ uint batchIdx = batchMappingWrap[ beamIdx ];
- int batchIdx = batchMappingWrap[beamIdx];
outWrap[id] = functor(in1Wrap[(batchIdx * srcSize + srcId) * cols + stateIdx],
in2Wrap[beamIdx * cols + stateIdx]);
+ //outWrap[id] = functor(in1Wrap(indices[0], indices[1], 0, batchIdx),
+ // in2Wrap(indices[2], indices[1], 0, 0));
+ //outWrap(srcId, stateIdx, beamIdx, 0) = functor(in1Wrap(srcId, stateIdx, 0, batchIdx),
+ // in2Wrap(beamIdx, stateIdx, 0, 0));
}
}