Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2017-05-23 22:42:29 +0300
committerSoumith Chintala <soumith@gmail.com>2017-05-25 23:49:32 +0300
commitc075de1c09b59ad5667d4f1ec289a9a6da3534cf (patch)
tree68128616b38d3b91e57c2120cfefbdee0cc8691b
parent92e9c08ca008bd5c2246ac028d6174f42429874d (diff)
Add scatterAdd
-rw-r--r--lib/THC/THCTensorScatterGather.cu28
-rw-r--r--lib/THC/generic/THCTensorScatterGather.cu87
-rw-r--r--lib/THC/generic/THCTensorScatterGather.h1
3 files changed, 116 insertions, 0 deletions
diff --git a/lib/THC/THCTensorScatterGather.cu b/lib/THC/THCTensorScatterGather.cu
index 18c9dee..b3e262d 100644
--- a/lib/THC/THCTensorScatterGather.cu
+++ b/lib/THC/THCTensorScatterGather.cu
@@ -1,5 +1,6 @@
#include "THCTensorMath.h"
#include "THCGeneral.h"
+#include "THCAtomics.cuh"
#include "THCApply.cuh"
// Compute the offsets into the given tensors for a linear index. For the 't2'
@@ -128,6 +129,33 @@ __global__ void THCudaTensor_scatterKernel(
}
template <typename IndexType, typename Real, int Dims>
+__global__ void THCudaTensor_scatterAddKernel(
+ TensorInfo<Real, IndexType> tensor,
+ TensorInfo<Real, IndexType> src,
+ TensorInfo<long, IndexType> index,
+ const int dim,
+ const IndexType totalElements) {
+ for (IndexType linearId = blockIdx.x * blockDim.x + threadIdx.x;
+ linearId < totalElements;
+ linearId += gridDim.x * blockDim.x) {
+ IndexType tensorOffset = 0;
+ IndexType srcOffset = 0;
+ IndexType indexOffset = 0;
+
+ IndexToScatterGatherOffsets<IndexType, Real, Dims>::compute(linearId, dim,
+ index, &indexOffset,
+ src, &srcOffset,
+ tensor, &tensorOffset);
+
+ long indexValue = index.data[indexOffset] - TH_INDEX_BASE;
+ assert(indexValue >= 0 && indexValue < tensor.sizes[dim]);
+ tensorOffset += indexValue * tensor.strides[dim];
+
+ atomicAdd(&tensor.data[tensorOffset], src.data[srcOffset]);
+ }
+}
+
+template <typename IndexType, typename Real, int Dims>
__global__ void THCudaTensor_scatterFillKernel(
TensorInfo<Real, IndexType> tensor,
TensorInfo<long, IndexType> index,
diff --git a/lib/THC/generic/THCTensorScatterGather.cu b/lib/THC/generic/THCTensorScatterGather.cu
index c3afbbf..36e8602 100644
--- a/lib/THC/generic/THCTensorScatterGather.cu
+++ b/lib/THC/generic/THCTensorScatterGather.cu
@@ -184,6 +184,93 @@ void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLong
#undef RUN
#define RUN(TYPE, DIMS, REAL) \
+ THCudaTensor_scatterAddKernel<TYPE, REAL, DIMS> \
+ <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
+ tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements);
+
+void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src) {
+ THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src));
+ THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index));
+
+ THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 2,
+ "Index dimension is out of bounds");
+ THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 3,
+ "Index tensor must have same dimensions as input tensor");
+ THArgCheck(THCTensor_(nDimension)(state, src) == THCTensor_(nDimension)(state, tensor), 4,
+ "Input tensor must have same dimensions as output tensor");
+ THLongStorage *indexDims = THCudaLongTensor_newSizeOf(state, index);
+ THArgCheck(THCTensor_(isSize)(state, src, indexDims), 3,
+ "Index tensor must have the same size as input tensor.");
+ THLongStorage_free(indexDims);
+
+ for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) {
+ if (d != dim) {
+ THArgCheck(THCTensor_(size)(state, tensor, d) == THCTensor_(size)(state, src, d), 4,
+ "Input tensor must have same size as output tensor apart from the specified dimension");
+ }
+ }
+
+ THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS,
+ 1, CUTORCH_DIM_WARNING);
+
+ const ptrdiff_t totalElements = THCudaLongTensor_nElement(state, index);
+ const dim3 block = getApplyBlock();
+ dim3 grid;
+ THArgCheck(getApplyGrid(state, totalElements, grid), 1, CUTORCH_DIM_WARNING);
+
+ THCTensor* oldTensor = NULL;
+ if (TensorUtils<THCTensor>::overlappingIndices(state, tensor)) {
+ oldTensor = tensor;
+ tensor = THCTensor_(newContiguous)(state, tensor);
+ }
+
+ if (TensorUtils<THCTensor>::canUse32BitIndexMath(state, tensor) &&
+ TensorUtils<THCTensor>::canUse32BitIndexMath(state, src) &&
+ TensorUtils<THCudaLongTensor>::canUse32BitIndexMath(state, index)) {
+ TensorInfo<real, unsigned int> tensorInfo =
+ getTensorInfo<THCTensor, unsigned int>(state, tensor);
+ TensorInfo<real, unsigned int> srcInfo =
+ getTensorInfo<THCTensor, unsigned int>(state, src);
+ TensorInfo<long, unsigned int> indexInfo =
+ getTensorInfo<THCudaLongTensor, unsigned int>(state, index);
+
+ // Specialize for a small number of dimensions.
+ switch (indexInfo.dims) {
+ case 1:
+ RUN(unsigned int, 1, real);
+ break;
+ case 2:
+ RUN(unsigned int, 2, real);
+ break;
+ case 3:
+ RUN(unsigned int, 3, real);
+ break;
+ default:
+ RUN(unsigned int, -1, real);
+ break;
+ }
+ } else {
+ TensorInfo<real, unsigned long> tensorInfo =
+ getTensorInfo<THCTensor, unsigned long>(state, tensor);
+ TensorInfo<real, unsigned long> srcInfo =
+ getTensorInfo<THCTensor, unsigned long>(state, src);
+ TensorInfo<long, unsigned long> indexInfo =
+ getTensorInfo<THCudaLongTensor, unsigned long>(state, index);
+
+ RUN(unsigned long, -1, real)
+ }
+
+ if (oldTensor) {
+ TensorUtils<THCTensor>::copyIgnoringOverlaps(state, oldTensor, tensor);
+ THCTensor_(free)(state, tensor);
+ tensor = oldTensor;
+ }
+ THCudaCheck(cudaGetLastError());
+}
+
+#undef RUN
+
+#define RUN(TYPE, DIMS, REAL) \
THCudaTensor_scatterFillKernel<TYPE, REAL, DIMS> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
tensorInfo, indexInfo, value, dim, (TYPE)totalElements);
diff --git a/lib/THC/generic/THCTensorScatterGather.h b/lib/THC/generic/THCTensorScatterGather.h
index 2071014..e7e83b2 100644
--- a/lib/THC/generic/THCTensorScatterGather.h
+++ b/lib/THC/generic/THCTensorScatterGather.h
@@ -4,6 +4,7 @@
THC_API void THCTensor_(gather)(THCState* state, THCTensor *tensor, THCTensor *src, int dim, THCudaLongTensor *index);
THC_API void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src);
+THC_API void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src);
THC_API void THCTensor_(scatterFill)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, real value);
#endif