diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-03-15 21:37:21 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-15 21:37:21 +0300 |
commit | bf5ad8546b51e5f54816836fbfb42ab307d8040e (patch) | |
tree | 2ad48fbf7589525914ec866d0c5d33a371ae15fb | |
parent | 4cbe9333a23cf33df4ef97b318c26151005960d9 (diff) | |
parent | 3f2daf730b2188294df77548ac7c64d9f73f1b1a (diff) |
Merge pull request #723 from killeent/scan-primitive
add implementation of inclusive scan via upsweep-downsweep
-rw-r--r-- | lib/THC/THCScanUtils.cuh | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/lib/THC/THCScanUtils.cuh b/lib/THC/THCScanUtils.cuh index da53f1f..ee20808 100644 --- a/lib/THC/THCScanUtils.cuh +++ b/lib/THC/THCScanUtils.cuh @@ -5,6 +5,50 @@ // Collection of in-kernel scan / prefix sum utilities +// Inclusive Scan via an upsweep/downsweep mechanism. Assumes: +// +// 1. Power2ScanSize is a power of 2. This code still works for collections that +// do not exactly contain a power of 2 number of elements, simply round up to the +// nearest power of 2 and then call. +// +// 2. That there are two-elements per thread, i.e. the size of the smem storage +// is 2 * blockDim.x * sizeof(T). +// +// Consider a (+)-Scan on the following elements: +// +// Upsweep: +// +// 0 1 2 3 4 5 6 7 +// 1 5 9 13 +// 6 22 +// 28 +// +// Downsweep: +// 15 +// 3 10 21 +template <typename T, class BinaryOp, int Power2ScanSize> +__device__ void inclusivePrefixScan(T *smem, BinaryOp binop) { + // Reduce step ("upsweep") +#pragma unroll + for (int stride = 1; stride < Power2ScanSize; stride <<= 1) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if (index < Power2ScanSize) { + smem[index] = binop(smem[index], smem[index - stride]); + } + __syncthreads(); + } + + // Post-reduce step ("downsweep") +#pragma unroll + for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if ((index + stride) < Power2ScanSize) { + smem[index + stride] = binop(smem[index + stride], smem[index]); + } + __syncthreads(); + } +} + // Inclusive prefix sum using shared memory template <typename T, bool KillWARDependency, class BinaryFunction> __device__ void inclusivePrefixScan(T* smem, T in, T* out, BinaryFunction binop) { |