Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-03-15 21:36:31 +0300
committerGitHub <noreply@github.com>2017-03-15 21:36:31 +0300
commit4cbe9333a23cf33df4ef97b318c26151005960d9 (patch)
tree7a8443408256c2954128f6a1bdd2d90ceb336622
parentc31cc583a33f5dbf052257c72b51d3341a089bb8 (diff)
parentdae84256b778e20057946f4e889e3e12a5e28f32 (diff)
Merge pull request #727 from killeent/key-only-sort
key only block-wide bitonic sort
-rw-r--r--lib/THC/THCSortUtils.cuh57
1 files changed, 57 insertions, 0 deletions
diff --git a/lib/THC/THCSortUtils.cuh b/lib/THC/THCSortUtils.cuh
index ec676c0..bf9ead0 100644
--- a/lib/THC/THCSortUtils.cuh
+++ b/lib/THC/THCSortUtils.cuh
@@ -41,6 +41,18 @@ __device__ inline void bitonicSwap(K& kA, V& vA, bool& validA,
}
};
+template <typename Comparator, typename K>
+__device__ inline void bitonicSwapKeys(K& kA, bool& validA,
+ K& kB, bool& validB,
+ bool dir,
+ const Comparator& comp) {
+ bool swap = (comp(kA, kB) && validA) || !validB;
+ if (swap == dir) {
+ swapVars(kA, kB);
+ swapVars(validA, validB);
+ }
+}
+
template <typename Comparator, typename K, typename V,
typename IndexType, int Power2SortSize>
__device__ inline void bitonicSort(K keys[Power2SortSize],
@@ -87,6 +99,51 @@ __device__ inline void bitonicSort(K keys[Power2SortSize],
}
}
+template <typename Comparator, typename K,
+ typename IndexType, int Power2SortSize>
+__device__ inline void bitonicSortKeys(K keys[Power2SortSize],
+ bool valid[Power2SortSize],
+ const Comparator& comp) {
+#pragma unroll
+ for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
+ bool flag = ((threadIdx.x & (size / 2)) != 0);
+
+#pragma unroll
+ for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
+
+ // Single warp per slice is completely synchronous
+ if (Power2SortSize > 64) {
+ __syncthreads();
+ }
+
+ unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
+ bitonicSwapKeys<Comparator, K>(
+ keys[pos], valid[pos],
+ keys[pos + stride], valid[pos + stride],
+ flag, comp);
+ }
+ }
+
+#pragma unroll
+ for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
+ // Single warp per slice is completely synchronous
+ if (Power2SortSize > 64) {
+ __syncthreads();
+ }
+
+ unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
+ bitonicSwapKeys<Comparator, K>(
+ keys[pos], valid[pos],
+ keys[pos + stride], valid[pos + stride],
+ false, comp);
+ }
+
+ // Single warp per slice is completely synchronous
+ if (Power2SortSize > 64) {
+ __syncthreads();
+ }
+}
+
// Sorts (key, value) pairs (in different tensors) in-place; i.e.,
// modifies the input `keys` and `values`
template <typename K, typename V,