diff options
author | Trevor Killeen <killeentm@gmail.com> | 2017-03-08 18:34:14 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2017-03-08 18:34:14 +0300 |
commit | 3f2daf730b2188294df77548ac7c64d9f73f1b1a (patch) | |
tree | 4d0b29bb60e6efa62a81f00f1fc05ee50f2197d7 | |
parent | d4c2b1d34631b78717289e0d219751f448d32b9c (diff) |
add implementation of inclusive scan via upsweep-downsweep
-rw-r--r-- | lib/THC/THCScanUtils.cuh | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/lib/THC/THCScanUtils.cuh b/lib/THC/THCScanUtils.cuh index da53f1f..ee20808 100644 --- a/lib/THC/THCScanUtils.cuh +++ b/lib/THC/THCScanUtils.cuh @@ -5,6 +5,50 @@ // Collection of in-kernel scan / prefix sum utilities +// Inclusive Scan via an upsweep/downsweep mechanism. Assumes: +// +// 1. Power2ScanSize is a power of 2. This code still works for collections that +// do not exactly contain a power of 2 number of elements, simply round up to the +// nearest power of 2 and then call. +// +// 2. That there are two-elements per thread, i.e. the size of the smem storage +// is 2 * blockDim.x * sizeof(T). +// +// Consider a (+)-Scan on the following elements: +// +// Upsweep: +// +// 0 1 2 3 4 5 6 7 +// 1 5 9 13 +// 6 22 +// 28 +// +// Downsweep: +// 15 +// 3 10 21 +template <typename T, class BinaryOp, int Power2ScanSize> +__device__ void inclusivePrefixScan(T *smem, BinaryOp binop) { + // Reduce step ("upsweep") +#pragma unroll + for (int stride = 1; stride < Power2ScanSize; stride <<= 1) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if (index < Power2ScanSize) { + smem[index] = binop(smem[index], smem[index - stride]); + } + __syncthreads(); + } + + // Post-reduce step ("downsweep") +#pragma unroll + for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) { + int index = (threadIdx.x + 1) * stride * 2 - 1; + if ((index + stride) < Power2ScanSize) { + smem[index + stride] = binop(smem[index + stride], smem[index]); + } + __syncthreads(); + } +} + // Inclusive prefix sum using shared memory template <typename T, bool KillWARDependency, class BinaryFunction> __device__ void inclusivePrefixScan(T* smem, T in, T* out, BinaryFunction binop) { |