diff options
author | Christian Sarofeen <csarofeen@nvidia.com> | 2017-07-03 07:39:40 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-07-03 07:39:40 +0300 |
commit | 707f94c48c6491abee4802abcff9b992188a0ec0 (patch) | |
tree | bc53e064ff45699451a3d488bd46b931675c1d65 | |
parent | 0e5ccc371397156a88201b59bf079894f49853d6 (diff) |
Add a nonContigDim reduction kernel to improve latency for small tensors. (#768)
-rw-r--r-- | lib/THC/THCReduce.cuh | 132 |
1 files changed, 126 insertions, 6 deletions
diff --git a/lib/THC/THCReduce.cuh b/lib/THC/THCReduce.cuh index 067d796..b7df49b 100644 --- a/lib/THC/THCReduce.cuh +++ b/lib/THC/THCReduce.cuh @@ -30,6 +30,99 @@ template <typename ModifyOp, __launch_bounds__(32 * 16, 4) #endif __global__ void +kernelReduceNoncontigDim_shared(TensorInfo<T, IndexType> out, + TensorInfo<T, IndexType> in, + IndexType reductionStride, + IndexType reductionSize, + IndexType totalSlices, + T init, + ModifyOp modifyOp, + ReduceOp reduceOp) { + + IndexType sliceIndex = blockIdx.x * blockDim.x + threadIdx.x; + IndexType sliceStride = gridDim.x * blockDim.x; + + __shared__ T local_reduce[THC_NONCONTIG_REDUCE_BLOCK_SIZE]; + T* shmem = &local_reduce[threadIdx.x + threadIdx.y * blockDim.x]; + T load_reg[4]; + T local_reg; + + for(;sliceIndex<totalSlices; sliceIndex+=sliceStride){ + local_reg = init; + + const IndexType outOffset = + IndexToOffset<T, IndexType, ADims>::get(sliceIndex, out); + const IndexType inOffset = + IndexToOffset<T, IndexType, BDims>::get(sliceIndex, in); + + //Unroll this loop + //for(IndexType i=threadIdx.y; i<reductionSize; i+=blockDim.y){ + // local_reg += in[inOffset + i * reductionStride]; + //} + for(IndexType i=threadIdx.y; i<reductionSize; i+=blockDim.y*4){ + if(i + blockDim.y * 3 < reductionSize){ + load_reg[0] = modifyOp(in.data[inOffset + (i + blockDim.y * 0) * reductionStride]); + load_reg[1] = modifyOp(in.data[inOffset + (i + blockDim.y * 1) * reductionStride]); + load_reg[2] = modifyOp(in.data[inOffset + (i + blockDim.y * 2) * reductionStride]); + load_reg[3] = modifyOp(in.data[inOffset + (i + blockDim.y * 3) * reductionStride]); + + local_reg = reduceOp(local_reg, + reduceOp( + reduceOp(load_reg[0], load_reg[1]), + reduceOp(load_reg[2], load_reg[3]) + ) + ); + + }else if(i + blockDim.y * 2 < reductionSize){ + load_reg[0] = modifyOp(in.data[inOffset + (i + blockDim.y * 0) * reductionStride]); + load_reg[1] = modifyOp(in.data[inOffset + (i + blockDim.y * 1) * reductionStride]); + load_reg[2] = modifyOp(in.data[inOffset + (i + blockDim.y * 2) * reductionStride]); + + local_reg = reduceOp( + reduceOp(load_reg[0], load_reg[1]), + reduceOp(load_reg[2], local_reg) + ); + + }else if( (i + blockDim.y) < reductionSize){ + load_reg[0] = modifyOp(in.data[inOffset + (i + blockDim.y * 0) * reductionStride]); + load_reg[1] = modifyOp(in.data[inOffset + (i + blockDim.y * 1) * reductionStride]); + local_reg = reduceOp( + local_reg, reduceOp(load_reg[0], load_reg[1]) + ); + + }else if(i + blockDim.y * 0 < reductionSize){ + local_reg = reduceOp(local_reg, modifyOp(in.data[inOffset + i * reductionStride])); + } + } + + *shmem = local_reg; + int dimy = blockDim.y; + while(dimy > 1){ + __syncthreads(); + if( threadIdx.y == 0 && (dimy%2 != 0) ){ + *shmem = reduceOp(*shmem, *(shmem + (dimy-1) * blockDim.x) ); + } + if(threadIdx.y < dimy/2){ + *shmem = reduceOp(*shmem, *(shmem + (dimy/2)*blockDim.x) ); + } + dimy /= 2; + } + if(threadIdx.y == 0) + out.data[outOffset] = *shmem; + } +} + + +// Kernel that handles an entire reduction of a slice of a tensor per each thread +template <typename ModifyOp, + typename ReduceOp, + typename T, + typename IndexType, + int ADims, int BDims> +#if __CUDA_ARCH__ >= 350 +__launch_bounds__(32 * 16, 4) +#endif +__global__ void kernelReduceNoncontigDim(TensorInfo<T, IndexType> out, TensorInfo<T, IndexType> in, IndexType reductionStride, @@ -206,8 +299,26 @@ bool THC_reduceDim(THCState* state, } block = getNoncontigReduceBlock(); - } + if(outElements <= 4096){ + //x dim does different columns + //y dim helps with the same reduction + //If we only have 8 loops, don't bother sharing work across ydim + unsigned long ydim = THCCeilDiv(reductionSize, 8L); + + //don't want y dim any bigger than 16, leaving min x dim to 32 + ydim = min((unsigned long) 16, ydim); + + block = dim3(THC_NONCONTIG_REDUCE_BLOCK_SIZE, 1, 1); + while(ydim > 1){ + block.x /= 2; + block.y *= 2; + ydim /= 2; + } + THC_getGridFromTiles(THCCeilDiv(outElements, (long)block.x), grid); + + } + } // Resize out to correspond to the reduced size THLongStorage* sizes = TensorUtils<TensorType>::newSizeOf(state, in); THLongStorage_set(sizes, dim, 1); @@ -231,12 +342,21 @@ bool THC_reduceDim(THCState* state, outInfo, inInfo, reductionSize, \ (TYPE) outElements, init, modifyOp, reduceOp); \ } else { \ - kernelReduceNoncontigDim<ModifyOp, ReduceOp, \ - typename TensorUtils<TensorType>::DataType, \ - TYPE, OUT, IN> \ - <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ - outInfo, inInfo, reductionStride, reductionSize, \ + if(block.y == 1){ \ + kernelReduceNoncontigDim<ModifyOp, ReduceOp, \ + typename TensorUtils<TensorType>::DataType, \ + TYPE, OUT, IN> \ + <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ + outInfo, inInfo, reductionStride, reductionSize, \ (TYPE) outElements, init, modifyOp, reduceOp); \ + }else{ \ + kernelReduceNoncontigDim_shared<ModifyOp, ReduceOp, \ + typename TensorUtils<TensorType>::DataType, \ + TYPE, OUT, IN> \ + <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ + outInfo, inInfo, reductionStride, reductionSize, \ + (TYPE) outElements, init, modifyOp, reduceOp); \ + } \ } \ #define HANDLE_IN_CASE(TYPE, OUT, IN) \ |