diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2017-05-23 22:42:29 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-05-25 23:49:32 +0300 |
commit | c075de1c09b59ad5667d4f1ec289a9a6da3534cf (patch) | |
tree | 68128616b38d3b91e57c2120cfefbdee0cc8691b | |
parent | 92e9c08ca008bd5c2246ac028d6174f42429874d (diff) |
Add scatterAdd
-rw-r--r-- | lib/THC/THCTensorScatterGather.cu | 28 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorScatterGather.cu | 87 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorScatterGather.h | 1 |
3 files changed, 116 insertions, 0 deletions
diff --git a/lib/THC/THCTensorScatterGather.cu b/lib/THC/THCTensorScatterGather.cu index 18c9dee..b3e262d 100644 --- a/lib/THC/THCTensorScatterGather.cu +++ b/lib/THC/THCTensorScatterGather.cu @@ -1,5 +1,6 @@ #include "THCTensorMath.h" #include "THCGeneral.h" +#include "THCAtomics.cuh" #include "THCApply.cuh" // Compute the offsets into the given tensors for a linear index. For the 't2' @@ -128,6 +129,33 @@ __global__ void THCudaTensor_scatterKernel( } template <typename IndexType, typename Real, int Dims> +__global__ void THCudaTensor_scatterAddKernel( + TensorInfo<Real, IndexType> tensor, + TensorInfo<Real, IndexType> src, + TensorInfo<long, IndexType> index, + const int dim, + const IndexType totalElements) { + for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x; + linearId < totalElements; + linearId += gridDim.x * blockDim.x) { + IndexType tensorOffset = 0; + IndexType srcOffset = 0; + IndexType indexOffset = 0; + + IndexToScatterGatherOffsets<IndexType, Real, Dims>::compute(linearId, dim, + index, &indexOffset, + src, &srcOffset, + tensor, &tensorOffset); + + long indexValue = index.data[indexOffset] - TH_INDEX_BASE; + assert(indexValue >= 0 && indexValue < tensor.sizes[dim]); + tensorOffset += indexValue * tensor.strides[dim]; + + atomicAdd(&tensor.data[tensorOffset], src.data[srcOffset]); + } +} + +template <typename IndexType, typename Real, int Dims> __global__ void THCudaTensor_scatterFillKernel( TensorInfo<Real, IndexType> tensor, TensorInfo<long, IndexType> index, diff --git a/lib/THC/generic/THCTensorScatterGather.cu b/lib/THC/generic/THCTensorScatterGather.cu index c3afbbf..36e8602 100644 --- a/lib/THC/generic/THCTensorScatterGather.cu +++ b/lib/THC/generic/THCTensorScatterGather.cu @@ -184,6 +184,93 @@ void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLong #undef RUN #define RUN(TYPE, DIMS, REAL) \ + THCudaTensor_scatterAddKernel<TYPE, REAL, DIMS> \ + <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ + tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements); + +void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src) { + THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src)); + THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index)); + + THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 2, + "Index dimension is out of bounds"); + THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 3, + "Index tensor must have same dimensions as input tensor"); + THArgCheck(THCTensor_(nDimension)(state, src) == THCTensor_(nDimension)(state, tensor), 4, + "Input tensor must have same dimensions as output tensor"); + THLongStorage *indexDims = THCudaLongTensor_newSizeOf(state, index); + THArgCheck(THCTensor_(isSize)(state, src, indexDims), 3, + "Index tensor must have the same size as input tensor."); + THLongStorage_free(indexDims); + + for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) { + if (d != dim) { + THArgCheck(THCTensor_(size)(state, tensor, d) == THCTensor_(size)(state, src, d), 4, + "Input tensor must have same size as output tensor apart from the specified dimension"); + } + } + + THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS, + 1, CUTORCH_DIM_WARNING); + + const ptrdiff_t totalElements = THCudaLongTensor_nElement(state, index); + const dim3 block = getApplyBlock(); + dim3 grid; + THArgCheck(getApplyGrid(state, totalElements, grid), 1, CUTORCH_DIM_WARNING); + + THCTensor* oldTensor = NULL; + if (TensorUtils<THCTensor>::overlappingIndices(state, tensor)) { + oldTensor = tensor; + tensor = THCTensor_(newContiguous)(state, tensor); + } + + if (TensorUtils<THCTensor>::canUse32BitIndexMath(state, tensor) && + TensorUtils<THCTensor>::canUse32BitIndexMath(state, src) && + TensorUtils<THCudaLongTensor>::canUse32BitIndexMath(state, index)) { + TensorInfo<real, unsigned int> tensorInfo = + getTensorInfo<THCTensor, unsigned int>(state, tensor); + TensorInfo<real, unsigned int> srcInfo = + getTensorInfo<THCTensor, unsigned int>(state, src); + TensorInfo<long, unsigned int> indexInfo = + getTensorInfo<THCudaLongTensor, unsigned int>(state, index); + + // Specialize for a small number of dimensions. + switch (indexInfo.dims) { + case 1: + RUN(unsigned int, 1, real); + break; + case 2: + RUN(unsigned int, 2, real); + break; + case 3: + RUN(unsigned int, 3, real); + break; + default: + RUN(unsigned int, -1, real); + break; + } + } else { + TensorInfo<real, unsigned long> tensorInfo = + getTensorInfo<THCTensor, unsigned long>(state, tensor); + TensorInfo<real, unsigned long> srcInfo = + getTensorInfo<THCTensor, unsigned long>(state, src); + TensorInfo<long, unsigned long> indexInfo = + getTensorInfo<THCudaLongTensor, unsigned long>(state, index); + + RUN(unsigned long, -1, real) + } + + if (oldTensor) { + TensorUtils<THCTensor>::copyIgnoringOverlaps(state, oldTensor, tensor); + THCTensor_(free)(state, tensor); + tensor = oldTensor; + } + THCudaCheck(cudaGetLastError()); +} + +#undef RUN + +#define RUN(TYPE, DIMS, REAL) \ THCudaTensor_scatterFillKernel<TYPE, REAL, DIMS> \ <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ tensorInfo, indexInfo, value, dim, (TYPE)totalElements); diff --git a/lib/THC/generic/THCTensorScatterGather.h b/lib/THC/generic/THCTensorScatterGather.h index 2071014..e7e83b2 100644 --- a/lib/THC/generic/THCTensorScatterGather.h +++ b/lib/THC/generic/THCTensorScatterGather.h @@ -4,6 +4,7 @@ THC_API void THCTensor_(gather)(THCState* state, THCTensor *tensor, THCTensor *src, int dim, THCudaLongTensor *index); THC_API void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src); +THC_API void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src); THC_API void THCTensor_(scatterFill)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, real value); #endif |