Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNatalia Gimelshein <ngimelshein@nvidia.com>2017-08-14 23:38:05 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-15 10:00:44 +0300
commitcec53c34e71dbbbb53574f6371c6d7e4a9f6757b (patch)
tree5dfc135c81a861ce0f4163bf092df10c9e8e7c6e
parent0252bcd1b43cc70986a359c5250c80edb6eb29c2 (diff)
accumulate in accType for reductions over dimensions
-rw-r--r--lib/THC/THCReduce.cuh103
-rw-r--r--lib/THC/generic/THCTensorMathReduce.cu26
2 files changed, 70 insertions, 59 deletions
diff --git a/lib/THC/THCReduce.cuh b/lib/THC/THCReduce.cuh
index b7df49b..cae6cf1 100644
--- a/lib/THC/THCReduce.cuh
+++ b/lib/THC/THCReduce.cuh
@@ -10,6 +10,7 @@
#include "THCTensorTypeUtils.cuh"
#include "THCReduceApplyUtils.cuh"
+#include "THCNumerics.cuh"
// Threads per thread block
#define THC_NONCONTIG_REDUCE_BLOCK_SIZE 32 * 16
@@ -23,7 +24,9 @@ __device__ __forceinline__ IndexType getReduceNoncontigDimSliceIndex() {
// Kernel that handles an entire reduction of a slice of a tensor per each thread
template <typename ModifyOp,
typename ReduceOp,
+ typename ReduceAccOp,
typename T,
+ typename AccT,
typename IndexType,
int ADims, int BDims>
#if __CUDA_ARCH__ >= 350
@@ -35,17 +38,18 @@ kernelReduceNoncontigDim_shared(TensorInfo<T, IndexType> out,
IndexType reductionStride,
IndexType reductionSize,
IndexType totalSlices,
- T init,
+ AccT init,
ModifyOp modifyOp,
- ReduceOp reduceOp) {
+ ReduceOp reduceOp,
+ ReduceAccOp reduceAccOp) {
IndexType sliceIndex = blockIdx.x * blockDim.x + threadIdx.x;
IndexType sliceStride = gridDim.x * blockDim.x;
- __shared__ T local_reduce[THC_NONCONTIG_REDUCE_BLOCK_SIZE];
- T* shmem = &local_reduce[threadIdx.x + threadIdx.y * blockDim.x];
+ __shared__ AccT local_reduce[THC_NONCONTIG_REDUCE_BLOCK_SIZE];
+ AccT* shmem = &local_reduce[threadIdx.x + threadIdx.y * blockDim.x];
T load_reg[4];
- T local_reg;
+ AccT local_reg;
for(;sliceIndex<totalSlices; sliceIndex+=sliceStride){
local_reg = init;
@@ -65,31 +69,22 @@ kernelReduceNoncontigDim_shared(TensorInfo<T, IndexType> out,
load_reg[1] = modifyOp(in.data[inOffset + (i + blockDim.y * 1) * reductionStride]);
load_reg[2] = modifyOp(in.data[inOffset + (i + blockDim.y * 2) * reductionStride]);
load_reg[3] = modifyOp(in.data[inOffset + (i + blockDim.y * 3) * reductionStride]);
-
- local_reg = reduceOp(local_reg,
- reduceOp(
- reduceOp(load_reg[0], load_reg[1]),
- reduceOp(load_reg[2], load_reg[3])
- )
- );
-
+ local_reg = reduceOp(local_reg, load_reg[0]);
+ local_reg = reduceOp(local_reg, load_reg[1]);
+ local_reg = reduceOp(local_reg, load_reg[2]);
+ local_reg = reduceOp(local_reg, load_reg[3]);
}else if(i + blockDim.y * 2 < reductionSize){
load_reg[0] = modifyOp(in.data[inOffset + (i + blockDim.y * 0) * reductionStride]);
load_reg[1] = modifyOp(in.data[inOffset + (i + blockDim.y * 1) * reductionStride]);
load_reg[2] = modifyOp(in.data[inOffset + (i + blockDim.y * 2) * reductionStride]);
-
- local_reg = reduceOp(
- reduceOp(load_reg[0], load_reg[1]),
- reduceOp(load_reg[2], local_reg)
- );
-
- }else if( (i + blockDim.y) < reductionSize){
+ local_reg = reduceOp(local_reg, load_reg[0]);
+ local_reg = reduceOp(local_reg, load_reg[1]);
+ local_reg = reduceOp(local_reg, load_reg[2]);
+ }else if( (i + blockDim.y) < reductionSize){
load_reg[0] = modifyOp(in.data[inOffset + (i + blockDim.y * 0) * reductionStride]);
load_reg[1] = modifyOp(in.data[inOffset + (i + blockDim.y * 1) * reductionStride]);
- local_reg = reduceOp(
- local_reg, reduceOp(load_reg[0], load_reg[1])
- );
-
+ local_reg = reduceOp(local_reg, load_reg[0]);
+ local_reg = reduceOp(local_reg, load_reg[1]);
}else if(i + blockDim.y * 0 < reductionSize){
local_reg = reduceOp(local_reg, modifyOp(in.data[inOffset + i * reductionStride]));
}
@@ -100,15 +95,15 @@ kernelReduceNoncontigDim_shared(TensorInfo<T, IndexType> out,
while(dimy > 1){
__syncthreads();
if( threadIdx.y == 0 && (dimy%2 != 0) ){
- *shmem = reduceOp(*shmem, *(shmem + (dimy-1) * blockDim.x) );
+ *shmem = reduceAccOp(*shmem, *(shmem + (dimy-1) * blockDim.x) );
}
if(threadIdx.y < dimy/2){
- *shmem = reduceOp(*shmem, *(shmem + (dimy/2)*blockDim.x) );
+ *shmem = reduceAccOp(*shmem, *(shmem + (dimy/2)*blockDim.x) );
}
dimy /= 2;
}
if(threadIdx.y == 0)
- out.data[outOffset] = *shmem;
+ out.data[outOffset] = ScalarConvert<AccT, T>::to(*shmem);
}
}
@@ -116,7 +111,9 @@ kernelReduceNoncontigDim_shared(TensorInfo<T, IndexType> out,
// Kernel that handles an entire reduction of a slice of a tensor per each thread
template <typename ModifyOp,
typename ReduceOp,
+ typename ReduceAccOp,
typename T,
+ typename AccT,
typename IndexType,
int ADims, int BDims>
#if __CUDA_ARCH__ >= 350
@@ -128,9 +125,10 @@ kernelReduceNoncontigDim(TensorInfo<T, IndexType> out,
IndexType reductionStride,
IndexType reductionSize,
IndexType totalSlices,
- T init,
+ AccT init,
ModifyOp modifyOp,
- ReduceOp reduceOp) {
+ ReduceOp reduceOp,
+ ReduceAccOp reduceAccOp) {
const IndexType sliceIndex = getReduceNoncontigDimSliceIndex<IndexType>();
if (sliceIndex >= totalSlices) {
@@ -146,7 +144,7 @@ kernelReduceNoncontigDim(TensorInfo<T, IndexType> out,
// For each point in reductionSize, reduce into `r`
IndexType inOffset = inBaseOffset;
- T r = init;
+ AccT r = init;
for (IndexType i = 0; i < reductionSize; ++i) {
r = reduceOp(r, modifyOp(in.data[inOffset]));
@@ -154,7 +152,7 @@ kernelReduceNoncontigDim(TensorInfo<T, IndexType> out,
}
// Write out reduced value
- out.data[outOffset] = r;
+ out.data[outOffset] = ScalarConvert<AccT, T>::to(r);
}
template <typename IndexType>
@@ -167,7 +165,9 @@ __device__ __forceinline__ IndexType getReduceContigDimSliceIndex() {
// each block
template <typename ModifyOp,
typename ReduceOp,
+ typename ReduceAccOp,
typename T,
+ typename AccT,
typename IndexType,
int ADims, int BDims>
__global__ void
@@ -175,9 +175,10 @@ kernelReduceContigDim(TensorInfo<T, IndexType> out,
TensorInfo<T, IndexType> in,
IndexType reductionSize,
IndexType totalSlices,
- T init,
+ AccT init,
ModifyOp modifyOp,
- ReduceOp reduceOp) {
+ ReduceOp reduceOp,
+ ReduceAccOp reduceAccOp) {
const IndexType sliceIndex = getReduceContigDimSliceIndex<IndexType>();
if (sliceIndex >= totalSlices) {
@@ -195,7 +196,7 @@ kernelReduceContigDim(TensorInfo<T, IndexType> out,
// Each thread in the block will reduce some subset of elements in
// the slice. The elements are guaranteed contiguous starting at
// `inBaseOffset`.
- T r = init;
+ AccT r = init;
for (IndexType i = threadIdx.x; i < reductionSize; i += blockDim.x) {
r = reduceOp(r, modifyOp(in.data[inBaseOffset + i]));
}
@@ -203,12 +204,12 @@ kernelReduceContigDim(TensorInfo<T, IndexType> out,
// Reduce within the block
// FIXME: extern name
extern __shared__ char smemChar[];
- T* smem = (T*) smemChar;
- r = reduceBlock<T, ReduceOp>(smem, blockDim.x, r, reduceOp, init);
+ AccT* smem = (AccT*) smemChar;
+ r = reduceBlock<AccT, ReduceAccOp>(smem, blockDim.x, r, reduceAccOp, init);
if (threadIdx.x == 0) {
// Write out reduced value
- out.data[outOffset] = r;
+ out.data[outOffset] = ScalarConvert<AccT, T>::to(r);
}
}
@@ -254,13 +255,18 @@ inline bool getContigReduceGrid(ptrdiff_t elements, dim3& grid) {
// Performs a reduction out[..., 0, ...] = reduce_i(modify(in[..., i, ...])) for
// all in where i and the out's 0 are indexed at dimension `dim`
-template <typename TensorType, typename ModifyOp, typename ReduceOp>
+template <typename TensorType,
+typename ModifyOp,
+typename ReduceOp,
+typename ReduceAccOp,
+typename AccT>
bool THC_reduceDim(THCState* state,
TensorType* out,
TensorType* in,
const ModifyOp& modifyOp,
const ReduceOp& reduceOp,
- typename TensorUtils<TensorType>::DataType init,
+ const ReduceAccOp& reduceAccOp,
+ AccT init,
int dim,
int keepdim) {
ptrdiff_t inElements = TensorUtils<TensorType>::getNumElements(state, in);
@@ -292,7 +298,7 @@ bool THC_reduceDim(THCState* state,
}
block = getContigReduceBlock(outElements, reductionSize);
- smemSize = sizeof(typename TensorUtils<TensorType>::DataType) * block.x;
+ smemSize = sizeof(AccT) * block.x;
} else {
if (!getNoncontigReduceGrid(outElements, grid)) {
return false;
@@ -335,27 +341,31 @@ bool THC_reduceDim(THCState* state,
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, OUT, IN) \
if (contigReduction) { \
- kernelReduceContigDim<ModifyOp, ReduceOp, \
+ kernelReduceContigDim<ModifyOp, ReduceOp, ReduceAccOp, \
typename TensorUtils<TensorType>::DataType, \
+ AccT, \
TYPE, OUT, IN> \
<<<grid, block, smemSize, THCState_getCurrentStream(state)>>>( \
outInfo, inInfo, reductionSize, \
- (TYPE) outElements, init, modifyOp, reduceOp); \
+ (TYPE) outElements, init, modifyOp, reduceOp, reduceAccOp); \
} else { \
if(block.y == 1){ \
- kernelReduceNoncontigDim<ModifyOp, ReduceOp, \
+ kernelReduceNoncontigDim<ModifyOp, ReduceOp, ReduceAccOp, \
typename TensorUtils<TensorType>::DataType, \
+ AccT, \
TYPE, OUT, IN> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
outInfo, inInfo, reductionStride, reductionSize, \
- (TYPE) outElements, init, modifyOp, reduceOp); \
+ (TYPE) outElements, init, modifyOp, reduceOp, reduceAccOp); \
}else{ \
- kernelReduceNoncontigDim_shared<ModifyOp, ReduceOp, \
+ kernelReduceNoncontigDim_shared<ModifyOp, ReduceOp,ReduceAccOp, \
typename TensorUtils<TensorType>::DataType, \
+ AccT, \
TYPE, OUT, IN> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
outInfo, inInfo, reductionStride, reductionSize, \
- (TYPE) outElements, init, modifyOp, reduceOp); \
+ (TYPE) outElements, init, modifyOp, reduceOp, \
+ reduceAccOp); \
} \
} \
@@ -409,7 +419,6 @@ bool THC_reduceDim(THCState* state,
getTensorInfo<TensorType, unsigned int>(state, in);
inInfo.reduceDim(dim);
inInfo.collapseDims();
-
HANDLE_OUT_CASE(unsigned int, outInfo.dims, inInfo.dims);
} else {
TensorInfo<typename TensorUtils<TensorType>::DataType,
diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu
index 846e7fd..a72562e 100644
--- a/lib/THC/generic/THCTensorMathReduce.cu
+++ b/lib/THC/generic/THCTensorMathReduce.cu
@@ -7,8 +7,9 @@ THCTensor_(sum)(THCState* state, THCTensor *self, THCTensor *src, long dimension
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self, src));
if (!THC_reduceDim(state, self, src,
thrust::identity<real>(),
- ReduceAdd<real, real>(),
- ScalarConvert<int, real>::to(0),
+ ReduceAdd<real, accreal>(),
+ ReduceAdd<accreal, accreal>(),
+ ScalarConvert<int, accreal>::to(0),
dimension,
keepdim)) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
@@ -22,8 +23,9 @@ THCTensor_(prod)(THCState* state, THCTensor *self, THCTensor *src, long dimensio
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self, src));
if (!THC_reduceDim(state, self, src,
thrust::identity<real>(),
- ReduceMultiply<real, real>(),
- ScalarConvert<int, real>::to(1),
+ ReduceMultiply<real, accreal>(),
+ ReduceMultiply<accreal, accreal>(),
+ ScalarConvert<int, accreal>::to(1),
dimension,
keepdim)) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
@@ -161,23 +163,23 @@ THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, l
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self, src));
if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(0.0))) {
THC_reduceDim(state, self, src,
- TensorNonZeroOp<real>(), ReduceAdd<real, real>(),
- ScalarConvert<float, real>::to(0.0), dimension, keepdim);
+ TensorNonZeroOp<real>(), ReduceAdd<real, accreal>(), ReduceAdd<accreal, accreal>(),
+ ScalarConvert<float, accreal>::to(0.0), dimension, keepdim);
} else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(1.0))) {
THC_reduceDim(state, self, src,
- TensorNormOp<real, 1>(value), ReduceAdd<real, real>(),
- ScalarConvert<float, real>::to(0.0), dimension, keepdim);
+ TensorNormOp<real, 1>(value), ReduceAdd<real, accreal>(), ReduceAdd<accreal, accreal>(),
+ ScalarConvert<float, accreal>::to(0.0), dimension, keepdim);
} else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(2.0))) {
THC_reduceDim(state, self, src,
- TensorNormOp<real, 2>(value), ReduceAdd<real, real>(),
- ScalarConvert<float, real>::to(0.0), dimension, keepdim);
+ TensorNormOp<real, 2>(value), ReduceAdd<real, accreal>(), ReduceAdd<accreal, accreal>(),
+ ScalarConvert<float, accreal>::to(0.0), dimension, keepdim);
THCTensor_(pow)(state, self, self, ScalarConvert<float, real>::to(0.5));
} else {
THC_reduceDim(state, self, src,
- TensorNormOp<real, -1>(value), ReduceAdd<real, real>(),
- ScalarConvert<float, real>::to(0.0), dimension, keepdim);
+ TensorNormOp<real, -1>(value), ReduceAdd<real, accreal>(), ReduceAdd<accreal, accreal>(),
+ ScalarConvert<float, accreal>::to(0.0), dimension, keepdim);
THCTensor_(pow)(state, self, self, THCNumerics<real>::cinv(value));
}