diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2017-12-04 19:33:09 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2017-12-04 19:33:09 +0300 |
commit | f2581bf7f968d7dce02f52f89a19b3afad3b5824 (patch) | |
tree | 45ad0da63cb2e3b710023a7a8ceb310e15f58ab3 | |
parent | 2b4a5550bc53e612fefc9ca5ba122eca1c276ad0 (diff) |
fastest
-rw-r--r-- | src/amun/gpu/mblas/matrix_functions.h | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/src/amun/gpu/mblas/matrix_functions.h b/src/amun/gpu/mblas/matrix_functions.h index c162910f..bd791797 100644 --- a/src/amun/gpu/mblas/matrix_functions.h +++ b/src/amun/gpu/mblas/matrix_functions.h @@ -151,16 +151,16 @@ __global__ void gBroadcast(Functor functor, const MatrixWrapper<float> in2Wrap, const VectorWrapper<uint> batchMappingWrap) { - //size_t srcSize = outWrap.dim(0); - //size_t inRows = in2Wrap.dim(0); - //size_t cols = in1Wrap.dim(1); + uint srcSize = outWrap.dim(0); + uint inRows = in2Wrap.dim(0); + uint cols = in1Wrap.dim(1); int id = threadIdx.x + blockIdx.x * blockDim.x; if (id < outWrap.size()) { + /* uint indices[SHAPE_SIZE]; outWrap.id2Indices(id, indices); - /* int row = id / cols; // len * batch for in1 int srcId = row % srcSize; // source pos for in1 @@ -171,25 +171,25 @@ __global__ void gBroadcast(Functor functor, in2Wrap(batchMappingIdx, indices[1], 0, 0) ); */ - //int row = id / cols; - //int stateIdx = id % cols; - - //int beamIdx = row / srcSize; - //int srcId = row % srcSize; + uint row = id / cols; + uint stateIdx = id % cols; + uint beamIdx = row / srcSize; + uint srcId = row % srcSize; + uint batchIdx = batchMappingWrap[ beamIdx ]; - uint batchIdx = batchMappingWrap[ indices[2] ]; + //uint batchIdx = batchMappingWrap[ indices[2] ]; //assert(srcId == indices[0]); //assert(stateIdx == indices[1]); //assert(beamIdx == indices[2]); //assert(0 == indices[3]); - //outWrap[id] = functor(in1Wrap[(batchIdx * srcSize + srcId) * cols + stateIdx], - // in2Wrap[beamIdx * cols + stateIdx]); + outWrap[id] = functor(in1Wrap[(batchIdx * srcSize + srcId) * cols + stateIdx], + in2Wrap[beamIdx * cols + stateIdx]); //outWrap[id] = functor(in1Wrap(indices[0], indices[1], 0, batchIdx), // in2Wrap(indices[2], indices[1], 0, 0)); - outWrap(indices[0], indices[1], indices[2], 0) = functor(in1Wrap(indices[0], indices[1], 0, batchIdx), - in2Wrap(indices[2], indices[1], 0, 0)); + //outWrap(srcId, stateIdx, beamIdx, 0) = functor(in1Wrap(srcId, stateIdx, 0, batchIdx), + // in2Wrap(beamIdx, stateIdx, 0, 0)); } } |