diff options
-rw-r--r-- | lib/THC/THCSortUtils.cuh | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/lib/THC/THCSortUtils.cuh b/lib/THC/THCSortUtils.cuh index ec676c0..bf9ead0 100644 --- a/lib/THC/THCSortUtils.cuh +++ b/lib/THC/THCSortUtils.cuh @@ -41,6 +41,18 @@ __device__ inline void bitonicSwap(K& kA, V& vA, bool& validA, } }; +template <typename Comparator, typename K> +__device__ inline void bitonicSwapKeys(K& kA, bool& validA, + K& kB, bool& validB, + bool dir, + const Comparator& comp) { + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(validA, validB); + } +} + template <typename Comparator, typename K, typename V, typename IndexType, int Power2SortSize> __device__ inline void bitonicSort(K keys[Power2SortSize], @@ -87,6 +99,51 @@ __device__ inline void bitonicSort(K keys[Power2SortSize], } } +template <typename Comparator, typename K, + typename IndexType, int Power2SortSize> +__device__ inline void bitonicSortKeys(K keys[Power2SortSize], + bool valid[Power2SortSize], + const Comparator& comp) { +#pragma unroll + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + +#pragma unroll + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + + // Single warp per slice is completely synchronous + if (Power2SortSize > 64) { + __syncthreads(); + } + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwapKeys<Comparator, K>( + keys[pos], valid[pos], + keys[pos + stride], valid[pos + stride], + flag, comp); + } + } + +#pragma unroll + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + // Single warp per slice is completely synchronous + if (Power2SortSize > 64) { + __syncthreads(); + } + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwapKeys<Comparator, K>( + keys[pos], valid[pos], + keys[pos + stride], valid[pos + stride], + false, comp); + } + + // Single warp per slice is completely synchronous + if (Power2SortSize > 64) { + __syncthreads(); + } +} + // Sorts (key, value) pairs (in different tensors) in-place; i.e., // modifies the input `keys` and `values` template <typename K, typename V, |