Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2017-03-08 18:34:14 +0300
committerTrevor Killeen <killeentm@gmail.com>2017-03-08 18:34:14 +0300
commit3f2daf730b2188294df77548ac7c64d9f73f1b1a (patch)
tree4d0b29bb60e6efa62a81f00f1fc05ee50f2197d7
parentd4c2b1d34631b78717289e0d219751f448d32b9c (diff)
add implementation of inclusive scan via upsweep-downsweep
-rw-r--r--lib/THC/THCScanUtils.cuh44
1 files changed, 44 insertions, 0 deletions
diff --git a/lib/THC/THCScanUtils.cuh b/lib/THC/THCScanUtils.cuh
index da53f1f..ee20808 100644
--- a/lib/THC/THCScanUtils.cuh
+++ b/lib/THC/THCScanUtils.cuh
@@ -5,6 +5,50 @@
// Collection of in-kernel scan / prefix sum utilities
+// Inclusive Scan via an upsweep/downsweep mechanism. Assumes:
+//
+// 1. Power2ScanSize is a power of 2. This code still works for collections that
+// do not exactly contain a power of 2 number of elements, simply round up to the
+// nearest power of 2 and then call.
+//
+// 2. That there are two-elements per thread, i.e. the size of the smem storage
+// is 2 * blockDim.x * sizeof(T).
+//
+// Consider a (+)-Scan on the following elements:
+//
+// Upsweep:
+//
+// 0 1 2 3 4 5 6 7
+// 1 5 9 13
+// 6 22
+// 28
+//
+// Downsweep:
+// 15
+// 3 10 21
+template <typename T, class BinaryOp, int Power2ScanSize>
+__device__ void inclusivePrefixScan(T *smem, BinaryOp binop) {
+ // Reduce step ("upsweep")
+#pragma unroll
+ for (int stride = 1; stride < Power2ScanSize; stride <<= 1) {
+ int index = (threadIdx.x + 1) * stride * 2 - 1;
+ if (index < Power2ScanSize) {
+ smem[index] = binop(smem[index], smem[index - stride]);
+ }
+ __syncthreads();
+ }
+
+ // Post-reduce step ("downsweep")
+#pragma unroll
+ for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) {
+ int index = (threadIdx.x + 1) * stride * 2 - 1;
+ if ((index + stride) < Power2ScanSize) {
+ smem[index + stride] = binop(smem[index + stride], smem[index]);
+ }
+ __syncthreads();
+ }
+}
+
// Inclusive prefix sum using shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void inclusivePrefixScan(T* smem, T in, T* out, BinaryFunction binop) {