Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/generic/THCTensorScatterGather.cu')
-rw-r--r--lib/THC/generic/THCTensorScatterGather.cu87
1 files changed, 87 insertions, 0 deletions
diff --git a/lib/THC/generic/THCTensorScatterGather.cu b/lib/THC/generic/THCTensorScatterGather.cu
index c3afbbf..36e8602 100644
--- a/lib/THC/generic/THCTensorScatterGather.cu
+++ b/lib/THC/generic/THCTensorScatterGather.cu
@@ -184,6 +184,93 @@ void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLong
#undef RUN
#define RUN(TYPE, DIMS, REAL) \
+ THCudaTensor_scatterAddKernel<TYPE, REAL, DIMS> \
+ <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
+ tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements);
+
+void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaLongTensor *index, THCTensor *src) {
+ THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src));
+ THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index));
+
+ THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 2,
+ "Index dimension is out of bounds");
+ THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 3,
+ "Index tensor must have same dimensions as input tensor");
+ THArgCheck(THCTensor_(nDimension)(state, src) == THCTensor_(nDimension)(state, tensor), 4,
+ "Input tensor must have same dimensions as output tensor");
+ THLongStorage *indexDims = THCudaLongTensor_newSizeOf(state, index);
+ THArgCheck(THCTensor_(isSize)(state, src, indexDims), 3,
+ "Index tensor must have the same size as input tensor.");
+ THLongStorage_free(indexDims);
+
+ for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) {
+ if (d != dim) {
+ THArgCheck(THCTensor_(size)(state, tensor, d) == THCTensor_(size)(state, src, d), 4,
+ "Input tensor must have same size as output tensor apart from the specified dimension");
+ }
+ }
+
+ THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS,
+ 1, CUTORCH_DIM_WARNING);
+
+ const ptrdiff_t totalElements = THCudaLongTensor_nElement(state, index);
+ const dim3 block = getApplyBlock();
+ dim3 grid;
+ THArgCheck(getApplyGrid(state, totalElements, grid), 1, CUTORCH_DIM_WARNING);
+
+ THCTensor* oldTensor = NULL;
+ if (TensorUtils<THCTensor>::overlappingIndices(state, tensor)) {
+ oldTensor = tensor;
+ tensor = THCTensor_(newContiguous)(state, tensor);
+ }
+
+ if (TensorUtils<THCTensor>::canUse32BitIndexMath(state, tensor) &&
+ TensorUtils<THCTensor>::canUse32BitIndexMath(state, src) &&
+ TensorUtils<THCudaLongTensor>::canUse32BitIndexMath(state, index)) {
+ TensorInfo<real, unsigned int> tensorInfo =
+ getTensorInfo<THCTensor, unsigned int>(state, tensor);
+ TensorInfo<real, unsigned int> srcInfo =
+ getTensorInfo<THCTensor, unsigned int>(state, src);
+ TensorInfo<long, unsigned int> indexInfo =
+ getTensorInfo<THCudaLongTensor, unsigned int>(state, index);
+
+ // Specialize for a small number of dimensions.
+ switch (indexInfo.dims) {
+ case 1:
+ RUN(unsigned int, 1, real);
+ break;
+ case 2:
+ RUN(unsigned int, 2, real);
+ break;
+ case 3:
+ RUN(unsigned int, 3, real);
+ break;
+ default:
+ RUN(unsigned int, -1, real);
+ break;
+ }
+ } else {
+ TensorInfo<real, unsigned long> tensorInfo =
+ getTensorInfo<THCTensor, unsigned long>(state, tensor);
+ TensorInfo<real, unsigned long> srcInfo =
+ getTensorInfo<THCTensor, unsigned long>(state, src);
+ TensorInfo<long, unsigned long> indexInfo =
+ getTensorInfo<THCudaLongTensor, unsigned long>(state, index);
+
+ RUN(unsigned long, -1, real)
+ }
+
+ if (oldTensor) {
+ TensorUtils<THCTensor>::copyIgnoringOverlaps(state, oldTensor, tensor);
+ THCTensor_(free)(state, tensor);
+ tensor = oldTensor;
+ }
+ THCudaCheck(cudaGetLastError());
+}
+
+#undef RUN
+
+#define RUN(TYPE, DIMS, REAL) \
THCudaTensor_scatterFillKernel<TYPE, REAL, DIMS> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
tensorInfo, indexInfo, value, dim, (TYPE)totalElements);