#ifndef THC_SCAN_UTILS_INC #define THC_SCAN_UTILS_INC #include "THCAsmUtils.cuh" // Collection of in-kernel scan / prefix sum utilities // Inclusive prefix sum using shared memory template __device__ void inclusivePrefixSum(T* smem, T in, T* out) { // FIXME: this is a slow, simple implementation; need up/down sweep, // prevent smem conflicts smem[threadIdx.x] = in; __syncthreads(); for (int offset = 1; offset < blockDim.x; offset *= 2) { T val = 0; if (threadIdx.x >= offset) { val = smem[threadIdx.x - offset] + smem[threadIdx.x]; } __syncthreads(); if (threadIdx.x >= offset) { smem[threadIdx.x] = val; } __syncthreads(); } *out = smem[threadIdx.x]; // Prevent write-after-read dependencies on smem usage above if necessary if (KillWARDependency) { __syncthreads(); } } // Exclusive prefix sum using shared memory template __device__ void exclusivePrefixSum(T* smem, T in, T* out, T* carry) { // FIXME: crappy implementation // We kill write-after-read dependencies separately below, hence the `false` inclusivePrefixSum(smem, in, out); *out -= in; *carry = smem[blockDim.x - 1]; // Prevent write-after-read dependencies on smem usage above if necessary if (KillWARDependency) { __syncthreads(); } } // Inclusive prefix sum for binary vars using intra-warp voting + // shared memory template __device__ void inclusiveBinaryPrefixSum(T* smem, bool in, T* out) { // Within-warp, we use warp voting. T vote = __ballot(in); T index = __popc(getLaneMaskLe() & vote); T carry = __popc(vote); int warp = threadIdx.x / 32; // Per each warp, write out a value if (getLaneId() == 0) { smem[warp] = carry; } __syncthreads(); // Sum across warps in one thread. This appears to be faster than a // warp shuffle scan for CC 3.0+ if (threadIdx.x == 0) { int current = 0; for (int i = 0; i < blockDim.x / 32; ++i) { T v = smem[i]; smem[i] += current; current += v; } } __syncthreads(); // load the carry from the preceding warp if (warp >= 1) { index += smem[warp - 1]; } *out = index; if (KillWARDependency) { __syncthreads(); } } // Exclusive prefix sum for binary vars using intra-warp voting + // shared memory template __device__ void exclusiveBinaryPrefixSum(T* smem, bool in, T* out, T* carry) { inclusiveBinaryPrefixSum(smem, in, out); // Inclusive to exclusive *out -= (T) in; // The outgoing carry for all threads is the last warp's sum *carry = smem[(blockDim.x / 32) - 1]; if (KillWARDependency) { __syncthreads(); } } #endif // THC_SCAN_UTILS_INC