Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/THCTensorSort.cu')
-rw-r--r--lib/THC/THCTensorSort.cu39
1 files changed, 23 insertions, 16 deletions
diff --git a/lib/THC/THCTensorSort.cu b/lib/THC/THCTensorSort.cu
index c35a36a..3b9562d 100644
--- a/lib/THC/THCTensorSort.cu
+++ b/lib/THC/THCTensorSort.cu
@@ -1,6 +1,7 @@
#include "THCReduceApplyUtils.cuh"
#include "THCSortUtils.cuh"
#include "THCTensorCopy.h"
+#include "THCTensorTypeUtils.cuh"
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
@@ -28,7 +29,7 @@ unsigned long nextHighestPowerOf2(unsigned long n) {
// `sliceSize - 1`.
template <typename IndexType, int Dim>
__global__ void
-fillSliceWithIndex(TensorInfo<IndexType> out,
+fillSliceWithIndex(TensorInfo<float, IndexType> out,
IndexType totalSlices,
IndexType sliceSize,
IndexType sliceStride) {
@@ -39,7 +40,7 @@ fillSliceWithIndex(TensorInfo<IndexType> out,
}
const unsigned long offset =
- IndexToOffset<IndexType, Dim>::get(slice, out);
+ IndexToOffset<float, IndexType, Dim>::get(slice, out);
float* base = &out.data[offset];
for (long i = threadIdx.x; i < sliceSize; i += blockDim.x) {
@@ -76,9 +77,10 @@ void THCudaTensor_fillSliceWithIndex(THCState* state,
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
info, numSlices, sliceSize, info.strides[collapseDim])
- if (THC_canUse32BitIndexMath(state, t)) {
- TensorInfo<unsigned int> info(state, t, dim);
- info.sizes[dim] = 1;
+ if (TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, t)) {
+ TensorInfo<float, unsigned int> info =
+ getTensorInfo<THCudaTensor, unsigned int>(state, t);
+ info.reduceDim(dim);
int collapseDim = info.collapseDims(dim);
if (info.isContiguous()) {
@@ -93,8 +95,9 @@ void THCudaTensor_fillSliceWithIndex(THCState* state,
}
}
} else {
- TensorInfo<unsigned long> info(state, t, dim);
- info.sizes[dim] = 1;
+ TensorInfo<float, unsigned long> info =
+ getTensorInfo<THCudaTensor, unsigned long>(state, t);
+ info.reduceDim(dim);
int collapseDim = info.collapseDims(dim);
// catch-all implementation
@@ -221,13 +224,15 @@ THC_API void THCudaTensor_sortKeyValueInplace(THCState* state,
// The constructed key/value tensor info is used to select the slice
// we are sorting on a per-block basis
- if (THC_canUse32BitIndexMath(state, key)) {
- TensorInfo<unsigned int> keyInfo(state, key);
- keyInfo.sizes[dim] = 1;
+ if (TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, key)) {
+ TensorInfo<float, unsigned int> keyInfo =
+ getTensorInfo<THCudaTensor, unsigned int>(state, key);
+ keyInfo.reduceDim(dim);
int collapseKeyDim = keyInfo.collapseDims(dim);
- TensorInfo<unsigned int> valueInfo(state, value);
- valueInfo.sizes[dim] = 1;
+ TensorInfo<float, unsigned int> valueInfo =
+ getTensorInfo<THCudaTensor, unsigned int>(state, value);
+ valueInfo.reduceDim(dim);
int collapseValueDim = valueInfo.collapseDims(dim);
if (keyInfo.isContiguous()) {
@@ -246,12 +251,14 @@ THC_API void THCudaTensor_sortKeyValueInplace(THCState* state,
}
}
} else {
- TensorInfo<unsigned long> keyInfo(state, key);
- keyInfo.sizes[dim] = 1;
+ TensorInfo<float, unsigned long> keyInfo =
+ getTensorInfo<THCudaTensor, unsigned long>(state, key);
+ keyInfo.reduceDim(dim);
int collapseKeyDim = keyInfo.collapseDims(dim);
- TensorInfo<unsigned long> valueInfo(state, value);
- valueInfo.sizes[dim] = 1;
+ TensorInfo<float, unsigned long> valueInfo =
+ getTensorInfo<THCudaTensor, unsigned long>(state, value);
+ valueInfo.reduceDim(dim);
int collapseValueDim = valueInfo.collapseDims(dim);
// long case is rare, just instantiate these versions