diff options
Diffstat (limited to 'lib/THC/THCTensorSort.cu')
-rw-r--r-- | lib/THC/THCTensorSort.cu | 39 |
1 files changed, 23 insertions, 16 deletions
diff --git a/lib/THC/THCTensorSort.cu b/lib/THC/THCTensorSort.cu index c35a36a..3b9562d 100644 --- a/lib/THC/THCTensorSort.cu +++ b/lib/THC/THCTensorSort.cu @@ -1,6 +1,7 @@ #include "THCReduceApplyUtils.cuh" #include "THCSortUtils.cuh" #include "THCTensorCopy.h" +#include "THCTensorTypeUtils.cuh" #include <thrust/device_ptr.h> #include <thrust/sort.h> @@ -28,7 +29,7 @@ unsigned long nextHighestPowerOf2(unsigned long n) { // `sliceSize - 1`. template <typename IndexType, int Dim> __global__ void -fillSliceWithIndex(TensorInfo<IndexType> out, +fillSliceWithIndex(TensorInfo<float, IndexType> out, IndexType totalSlices, IndexType sliceSize, IndexType sliceStride) { @@ -39,7 +40,7 @@ fillSliceWithIndex(TensorInfo<IndexType> out, } const unsigned long offset = - IndexToOffset<IndexType, Dim>::get(slice, out); + IndexToOffset<float, IndexType, Dim>::get(slice, out); float* base = &out.data[offset]; for (long i = threadIdx.x; i < sliceSize; i += blockDim.x) { @@ -76,9 +77,10 @@ void THCudaTensor_fillSliceWithIndex(THCState* state, <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ info, numSlices, sliceSize, info.strides[collapseDim]) - if (THC_canUse32BitIndexMath(state, t)) { - TensorInfo<unsigned int> info(state, t, dim); - info.sizes[dim] = 1; + if (TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, t)) { + TensorInfo<float, unsigned int> info = + getTensorInfo<THCudaTensor, unsigned int>(state, t); + info.reduceDim(dim); int collapseDim = info.collapseDims(dim); if (info.isContiguous()) { @@ -93,8 +95,9 @@ void THCudaTensor_fillSliceWithIndex(THCState* state, } } } else { - TensorInfo<unsigned long> info(state, t, dim); - info.sizes[dim] = 1; + TensorInfo<float, unsigned long> info = + getTensorInfo<THCudaTensor, unsigned long>(state, t); + info.reduceDim(dim); int collapseDim = info.collapseDims(dim); // catch-all implementation @@ -221,13 +224,15 @@ THC_API void THCudaTensor_sortKeyValueInplace(THCState* state, // The constructed key/value tensor info is used to select the slice // we are sorting on a per-block basis - if (THC_canUse32BitIndexMath(state, key)) { - TensorInfo<unsigned int> keyInfo(state, key); - keyInfo.sizes[dim] = 1; + if (TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, key)) { + TensorInfo<float, unsigned int> keyInfo = + getTensorInfo<THCudaTensor, unsigned int>(state, key); + keyInfo.reduceDim(dim); int collapseKeyDim = keyInfo.collapseDims(dim); - TensorInfo<unsigned int> valueInfo(state, value); - valueInfo.sizes[dim] = 1; + TensorInfo<float, unsigned int> valueInfo = + getTensorInfo<THCudaTensor, unsigned int>(state, value); + valueInfo.reduceDim(dim); int collapseValueDim = valueInfo.collapseDims(dim); if (keyInfo.isContiguous()) { @@ -246,12 +251,14 @@ THC_API void THCudaTensor_sortKeyValueInplace(THCState* state, } } } else { - TensorInfo<unsigned long> keyInfo(state, key); - keyInfo.sizes[dim] = 1; + TensorInfo<float, unsigned long> keyInfo = + getTensorInfo<THCudaTensor, unsigned long>(state, key); + keyInfo.reduceDim(dim); int collapseKeyDim = keyInfo.collapseDims(dim); - TensorInfo<unsigned long> valueInfo(state, value); - valueInfo.sizes[dim] = 1; + TensorInfo<float, unsigned long> valueInfo = + getTensorInfo<THCudaTensor, unsigned long>(state, value); + valueInfo.reduceDim(dim); int collapseValueDim = valueInfo.collapseDims(dim); // long case is rare, just instantiate these versions |