diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-03-15 21:36:31 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-15 21:36:31 +0300 |
commit | 4cbe9333a23cf33df4ef97b318c26151005960d9 (patch) | |
tree | 7a8443408256c2954128f6a1bdd2d90ceb336622 | |
parent | c31cc583a33f5dbf052257c72b51d3341a089bb8 (diff) | |
parent | dae84256b778e20057946f4e889e3e12a5e28f32 (diff) |
Merge pull request #727 from killeent/key-only-sort
key only block-wide bitonic sort
-rw-r--r-- | lib/THC/THCSortUtils.cuh | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/lib/THC/THCSortUtils.cuh b/lib/THC/THCSortUtils.cuh index ec676c0..bf9ead0 100644 --- a/lib/THC/THCSortUtils.cuh +++ b/lib/THC/THCSortUtils.cuh @@ -41,6 +41,18 @@ __device__ inline void bitonicSwap(K& kA, V& vA, bool& validA, } }; +template <typename Comparator, typename K> +__device__ inline void bitonicSwapKeys(K& kA, bool& validA, + K& kB, bool& validB, + bool dir, + const Comparator& comp) { + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(validA, validB); + } +} + template <typename Comparator, typename K, typename V, typename IndexType, int Power2SortSize> __device__ inline void bitonicSort(K keys[Power2SortSize], @@ -87,6 +99,51 @@ __device__ inline void bitonicSort(K keys[Power2SortSize], } } +template <typename Comparator, typename K, + typename IndexType, int Power2SortSize> +__device__ inline void bitonicSortKeys(K keys[Power2SortSize], + bool valid[Power2SortSize], + const Comparator& comp) { +#pragma unroll + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + +#pragma unroll + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + + // Single warp per slice is completely synchronous + if (Power2SortSize > 64) { + __syncthreads(); + } + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwapKeys<Comparator, K>( + keys[pos], valid[pos], + keys[pos + stride], valid[pos + stride], + flag, comp); + } + } + +#pragma unroll + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + // Single warp per slice is completely synchronous + if (Power2SortSize > 64) { + __syncthreads(); + } + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwapKeys<Comparator, K>( + keys[pos], valid[pos], + keys[pos + stride], valid[pos + stride], + false, comp); + } + + // Single warp per slice is completely synchronous + if (Power2SortSize > 64) { + __syncthreads(); + } +} + // Sorts (key, value) pairs (in different tensors) in-place; i.e., // modifies the input `keys` and `values` template <typename K, typename V, |