diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2017-12-06 14:06:16 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2017-12-06 14:06:16 +0300 |
commit | 13f3ca094b6ed3a8a19c7f605f78d6679573c22f (patch) | |
tree | 2f062ad7d010a660ba0cfbcefb0d321375a79172 | |
parent | 7bbe96dd0b3961e413a1eac6862bd9e7d81c1034 (diff) |
bugs in gBroadcast
-rw-r--r-- | src/amun/gpu/mblas/matrix_functions.h | 35 |
1 files changed, 17 insertions, 18 deletions
diff --git a/src/amun/gpu/mblas/matrix_functions.h b/src/amun/gpu/mblas/matrix_functions.h index 75b274b6..eda683fc 100644 --- a/src/amun/gpu/mblas/matrix_functions.h +++ b/src/amun/gpu/mblas/matrix_functions.h @@ -151,36 +151,35 @@ __global__ void gBroadcast(Functor functor, const MatrixWrapper<float> in2Wrap, const VectorWrapper<uint> batchMappingWrap) { - size_t srcSize = outWrap.dim(2); - size_t inRows = in2Wrap.dim(0); - size_t cols = in1Wrap.dim(1); - int id = threadIdx.x + blockIdx.x * blockDim.x; if (id < outWrap.size()) { /* - size_t indices[SHAPE_SIZE]; + uint indices[SHAPE_SIZE]; outWrap.id2Indices(id, indices); - int row = id / cols; // len * batch for in1 - int srcId = row % srcSize; // source pos for in1 - - int batchMappingIdx = row / srcSize; // batch for in1 - int batchIdx = batchMappingWrap[batchMappingIdx]; // batch id for in1 - - outWrap[id] = functor(in1Wrap(srcId, indices[1], 0, batchIdx), - in2Wrap(batchMappingIdx, indices[1], 0, 0) ); + uint srcId = indices[0]; + uint stateIdx = indices[1]; + uint beamIdx = indices[2]; + //assert(0 == indices[3]); */ - int row = id / cols; - int stateIdx = id % cols; + uint cols = in1Wrap.dim(1); + uint srcSize = outWrap.dim(0); + + uint row = id / cols; + uint stateIdx = id % cols; + uint beamIdx = row / srcSize; + uint srcId = row % srcSize; - int beamIdx = row / srcSize; - int srcId = row % srcSize; + uint batchIdx = batchMappingWrap[ beamIdx ]; - int batchIdx = batchMappingWrap[beamIdx]; outWrap[id] = functor(in1Wrap[(batchIdx * srcSize + srcId) * cols + stateIdx], in2Wrap[beamIdx * cols + stateIdx]); + //outWrap[id] = functor(in1Wrap(indices[0], indices[1], 0, batchIdx), + // in2Wrap(indices[2], indices[1], 0, 0)); + //outWrap(srcId, stateIdx, beamIdx, 0) = functor(in1Wrap(srcId, stateIdx, 0, batchIdx), + // in2Wrap(beamIdx, stateIdx, 0, 0)); } } |