Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/THCApply.cuh')
-rw-r--r--lib/THC/THCApply.cuh566
1 files changed, 301 insertions, 265 deletions
diff --git a/lib/THC/THCApply.cuh b/lib/THC/THCApply.cuh
index 707e22f..dd6d32a 100644
--- a/lib/THC/THCApply.cuh
+++ b/lib/THC/THCApply.cuh
@@ -3,6 +3,7 @@
#include "THCTensorCopy.h"
#include "THCReduceApplyUtils.cuh"
+#include "THCTensorTypeUtils.cuh"
//
// This file contains pointwise operation functions and kernels that
@@ -12,81 +13,85 @@
//
// Threads per block for our apply kernel
+// FIXME: use occupancy calculator instead
#define THC_APPLY_THREADS_PER_BLOCK 32 * 16
-// Called when we are copying into an overlapping index `dst`, but
-// we don't care which writer wins. Hacky but it works.
-THC_API void THCudaTensor_copyIgnoringOverlaps(THCState* state,
- THCudaTensor* dst,
- THCudaTensor* src);
-
-template <typename Op, typename IndexType, int ADims>
+template <typename Op,
+ typename Ta,
+ typename IndexType,
+ int ADims>
#if __CUDA_ARCH__ >= 350
__launch_bounds__(32 * 16, 4)
#endif
__global__ void
-THCudaTensor_pointwiseApply1(TensorInfo<IndexType> a,
- IndexType totalElements,
- Op op) {
+kernelPointwiseApply1(TensorInfo<Ta, IndexType> a,
+ IndexType totalElements,
+ Op op) {
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x) {
// Convert `linearIndex` into an offset of `a`
const IndexType aOffset =
- IndexToOffset<IndexType, ADims>::get(linearIndex, a);
+ IndexToOffset<Ta, IndexType, ADims>::get(linearIndex, a);
op(&a.data[aOffset]);
}
}
-template <typename Op, typename IndexType, int ADims, int BDims>
+template <typename Op,
+ typename Ta, typename Tb,
+ typename IndexType,
+ int ADims, int BDims>
#if __CUDA_ARCH__ >= 350
__launch_bounds__(32 * 16, 4)
#endif
__global__ void
-THCudaTensor_pointwiseApply2(TensorInfo<IndexType> a,
- TensorInfo<IndexType> b,
- IndexType totalElements,
- Op op) {
+kernelPointwiseApply2(TensorInfo<Ta, IndexType> a,
+ TensorInfo<Tb, IndexType> b,
+ IndexType totalElements,
+ Op op) {
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x) {
// Convert `linearIndex` into an offset of `a`
const IndexType aOffset =
- IndexToOffset<IndexType, ADims>::get(linearIndex, a);
+ IndexToOffset<Ta, IndexType, ADims>::get(linearIndex, a);
// Convert `linearIndex` into an offset of `b`
const IndexType bOffset =
- IndexToOffset<IndexType, BDims>::get(linearIndex, b);
+ IndexToOffset<Tb, IndexType, BDims>::get(linearIndex, b);
op(&a.data[aOffset], &b.data[bOffset]);
}
}
-template <typename Op, typename IndexType, int ADims, int BDims, int CDims>
+template <typename Op,
+ typename Ta, typename Tb, typename Tc,
+ typename IndexType,
+ int ADims, int BDims, int CDims>
#if __CUDA_ARCH__ >= 350
__launch_bounds__(32 * 16, 4)
#endif
__global__ void
-THCudaTensor_pointwiseApply3(TensorInfo<IndexType> a,
- TensorInfo<IndexType> b,
- TensorInfo<IndexType> c,
- IndexType totalElements,
- Op op) {
+kernelPointwiseApply3(TensorInfo<Ta, IndexType> a,
+ TensorInfo<Tb, IndexType> b,
+ TensorInfo<Tc, IndexType> c,
+ IndexType totalElements,
+ Op op) {
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x) {
// Convert `linearIndex` into an offset of `a`
const IndexType aOffset =
- IndexToOffset<IndexType, ADims>::get(linearIndex, a);
+ IndexToOffset<Ta, IndexType, ADims>::get(linearIndex, a);
// Convert `linearIndex` into an offset of `b`
const IndexType bOffset =
- IndexToOffset<IndexType, BDims>::get(linearIndex, b);
+ IndexToOffset<Tb, IndexType, BDims>::get(linearIndex, b);
// Convert `linearIndex` into an offset of `c`
const IndexType cOffset =
- IndexToOffset<IndexType, CDims>::get(linearIndex, c);
+ IndexToOffset<Tc, IndexType, CDims>::get(linearIndex, c);
op(&a.data[aOffset], &b.data[bOffset], &c.data[cOffset]);
}
@@ -116,18 +121,17 @@ inline bool getApplyGrid(THCState* state, long totalElements, dim3& grid) {
return true;
}
-template <typename Op>
-bool THCudaTensor_pointwiseApply1(THCState* state,
- THCudaTensor* a,
- const Op& op,
- TensorArgType aType = ReadWrite) {
- long totalElements = THCudaTensor_nElement(state, a);
-
- if (THCudaTensor_nDimension(state, a) > MAX_CUTORCH_DIMS) {
+template <typename TensorTypeA,
+ typename Op>
+bool THC_pointwiseApply1(THCState* state,
+ TensorTypeA* a,
+ const Op& op,
+ TensorArgType aType = ReadWrite) {
+ if (TensorUtils<TensorTypeA>::getDims(state, a) > MAX_CUTORCH_DIMS) {
return false;
}
- if (THCudaTensor_nDimension(state, a) == 0) {
+ if (TensorUtils<TensorTypeA>::getDims(state, a) == 0) {
// Zero-dim tensor; do nothing
return true;
}
@@ -135,6 +139,8 @@ bool THCudaTensor_pointwiseApply1(THCState* state,
const dim3 block = getApplyBlock();
dim3 grid;
+ long totalElements = TensorUtils<TensorTypeA>::getNumElements(state, a);
+
if (!getApplyGrid(state, totalElements, grid)) {
return false;
}
@@ -148,12 +154,13 @@ bool THCudaTensor_pointwiseApply1(THCState* state,
// indices of a tensor with overlapping indices should probably be
// an error, since it is unclear which one should win), but we will
// preserve this last-writer-wins (in arbitrary copy order) behavior.
- THCudaTensor* oldA = NULL;
+ TensorTypeA* oldA = NULL;
- if (aType == ReadWrite && THC_overlappingIndices(state, a)) {
+ if (aType == ReadWrite &&
+ TensorUtils<TensorTypeA>::overlappingIndices(state, a)) {
// Must perform in contiguous space
oldA = a;
- a = THCudaTensor_newContiguous(state, a);
+ a = TensorUtils<TensorTypeA>::newContiguous(state, a);
}
// It is possible that the tensor dimensions are able to be collapsed,
@@ -164,55 +171,60 @@ bool THCudaTensor_pointwiseApply1(THCState* state,
// (or vice versa), the contiguous tensor can be collapsed to one
// dimension, and the loop to translate the linear index to the array
// index can be similarly collapsed. That is what this unrolling is for.
-#define HANDLE_CASE(TYPE, A) \
- THCudaTensor_pointwiseApply1<Op, TYPE, A> \
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
+#define HANDLE_CASE(TYPE, A) \
+ kernelPointwiseApply1<Op, \
+ typename TensorUtils<TensorTypeA>::DataType, \
+ TYPE, A> \
+ <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
aInfo, (TYPE) totalElements, op);
-#define HANDLE_A_CASE(TYPE, A) \
- { \
- if (aInfo.isContiguous()) { \
- HANDLE_CASE(TYPE, -2); \
- } else { \
- switch (A) { \
- case 1: \
- HANDLE_CASE(TYPE, 1); \
- break; \
- case 2: \
- HANDLE_CASE(TYPE, 2); \
- break; \
- case 3: \
- HANDLE_CASE(TYPE, 3); \
- break; \
- default: \
- HANDLE_CASE(TYPE, -1); \
- break; \
- } \
- } \
+#define HANDLE_A_CASE(TYPE, A) \
+ { \
+ if (aInfo.isContiguous()) { \
+ HANDLE_CASE(TYPE, -2); \
+ } else { \
+ switch (A) { \
+ case 1: \
+ HANDLE_CASE(TYPE, 1); \
+ break; \
+ case 2: \
+ HANDLE_CASE(TYPE, 2); \
+ break; \
+ default: \
+ HANDLE_CASE(TYPE, -1); \
+ break; \
+ } \
+ } \
}
// Can we use 32-bit integer math in the kernel (the linear ID for the copy
// and the resulting non-linear offset is all computable using 32-bit math?)
// We also use unsigned index math in the kernel, as signed div/mod has
// additional overhead.
- if (THC_canUse32BitIndexMath(state, a)) {
- TensorInfo<unsigned int> aInfo(state, a);
+ if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a)) {
+ TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
+ getTensorInfo<TensorTypeA, unsigned int>(state, a);
aInfo.collapseDims();
HANDLE_A_CASE(unsigned int, aInfo.dims);
} else {
- TensorInfo<unsigned long> aInfo(state, a);
+ TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned long> aInfo =
+ getTensorInfo<TensorTypeA, unsigned long>(state, a);
aInfo.collapseDims();
// For large tensors, we only compile the completely contiguous
// version and the completely generic version, to reduce
// compilation time.
if (aInfo.isContiguous()) {
- THCudaTensor_pointwiseApply1<Op, unsigned long, -2>
+ kernelPointwiseApply1<Op,
+ typename TensorUtils<TensorTypeA>::DataType,
+ unsigned long, -2>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aInfo, (unsigned long) totalElements, op);
} else {
- THCudaTensor_pointwiseApply1<Op, unsigned long, -1>
+ kernelPointwiseApply1<Op,
+ typename TensorUtils<TensorTypeA>::DataType,
+ unsigned long, -1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aInfo, (unsigned long) totalElements, op);
}
@@ -221,36 +233,38 @@ bool THCudaTensor_pointwiseApply1(THCState* state,
#undef HANDLE_A_CASE
if (oldA) {
- // Ignore overlaps when copying back; if we use THCudaTensor_copy
+ // Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldA contiguous.
- THCudaTensor_copyIgnoringOverlaps(state, oldA, a);
- THCudaTensor_free(state, a);
+ TensorUtils<TensorTypeA>::copyIgnoringOverlaps(state, oldA, a);
+ TensorUtils<TensorTypeA>::free(state, a);
a = oldA;
}
return true;
}
-template <typename Op>
-bool THCudaTensor_pointwiseApply2(THCState* state,
- THCudaTensor* a,
- THCudaTensor* b,
- const Op& op,
- TensorArgType aType = ReadWrite,
- TensorArgType bType = ReadOnly) {
- long totalElements = THCudaTensor_nElement(state, a);
-
- if (totalElements != THCudaTensor_nElement(state, b)) {
+template <typename TensorTypeA,
+ typename TensorTypeB,
+ typename Op>
+bool THC_pointwiseApply2(THCState* state,
+ TensorTypeA* a,
+ TensorTypeB* b,
+ const Op& op,
+ TensorArgType aType = ReadWrite,
+ TensorArgType bType = ReadOnly) {
+ long totalElements = TensorUtils<TensorTypeA>::getNumElements(state, a);
+
+ if (totalElements != TensorUtils<TensorTypeB>::getNumElements(state, b)) {
return false;
}
- if (THCudaTensor_nDimension(state, a) > MAX_CUTORCH_DIMS ||
- THCudaTensor_nDimension(state, b) > MAX_CUTORCH_DIMS) {
+ if (TensorUtils<TensorTypeA>::getDims(state, a) > MAX_CUTORCH_DIMS ||
+ TensorUtils<TensorTypeB>::getDims(state, b) > MAX_CUTORCH_DIMS) {
return false;
}
- if (THCudaTensor_nDimension(state, a) == 0) {
+ if (TensorUtils<TensorTypeA>::getDims(state, a) == 0) {
// Zero-dim tensor; do nothing
return true;
}
@@ -271,18 +285,20 @@ bool THCudaTensor_pointwiseApply2(THCState* state,
// indices of a tensor with overlapping indices should probably be
// an error, since it is unclear which one should win), but we will
// preserve this last-writer-wins (in arbitrary copy order) behavior.
- THCudaTensor* oldA = NULL;
- THCudaTensor* oldB = NULL;
+ TensorTypeA* oldA = NULL;
+ TensorTypeB* oldB = NULL;
- if (aType == ReadWrite && THC_overlappingIndices(state, a)) {
+ if (aType == ReadWrite &&
+ TensorUtils<TensorTypeA>::overlappingIndices(state, a)) {
// Must perform in contiguous space
oldA = a;
- a = THCudaTensor_newContiguous(state, a);
+ a = TensorUtils<TensorTypeA>::newContiguous(state, a);
}
- if (bType == ReadWrite && THC_overlappingIndices(state, b)) {
+ if (bType == ReadWrite &&
+ TensorUtils<TensorTypeB>::overlappingIndices(state, b)) {
// Must perform in contiguous space
oldB = b;
- b = THCudaTensor_newContiguous(state, b);
+ b = TensorUtils<TensorTypeB>::newContiguous(state, b);
}
// It is possible that the tensor dimensions are able to be collapsed,
@@ -293,80 +309,87 @@ bool THCudaTensor_pointwiseApply2(THCState* state,
// (or vice versa), the contiguous tensor can be collapsed to one
// dimension, and the loop to translate the linear index to the array
// index can be similarly collapsed. That is what this unrolling is for.
-#define HANDLE_CASE(TYPE, A, B) \
- THCudaTensor_pointwiseApply2<Op, TYPE, A, B> \
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
+#define HANDLE_CASE(TYPE, A, B) \
+ kernelPointwiseApply2<Op, \
+ typename TensorUtils<TensorTypeA>::DataType, \
+ typename TensorUtils<TensorTypeB>::DataType, \
+ TYPE, A, B> \
+ <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
aInfo, bInfo, (TYPE) totalElements, op);
-#define HANDLE_B_CASE(TYPE, A, B) \
- { \
- if (bInfo.isContiguous()) { \
- HANDLE_CASE(TYPE, A, -2); \
- } else { \
- switch (B) { \
- case 1: \
- HANDLE_CASE(TYPE, A, 1); \
- break; \
- case 2: \
- HANDLE_CASE(TYPE, A, 2); \
- break; \
- case 3: \
- HANDLE_CASE(TYPE, A, 3); \
- break; \
- default: \
- HANDLE_CASE(TYPE, A, -1); \
- break; \
- } \
- } \
- }
-
-#define HANDLE_A_CASE(TYPE, A, B) \
- { \
- if (aInfo.isContiguous()) { \
- HANDLE_B_CASE(TYPE, -2, B); \
- } else { \
- switch (A) { \
- case 1: \
- HANDLE_B_CASE(TYPE, 1, B); \
- break; \
- case 2: \
- HANDLE_B_CASE(TYPE, 2, B); \
- break; \
- case 3: \
- HANDLE_B_CASE(TYPE, 3, B); \
- break; \
- default: \
- HANDLE_B_CASE(TYPE, -1, B); \
- break; \
- } \
- } \
- }
-
- if (THC_canUse32BitIndexMath(state, a) &&
- THC_canUse32BitIndexMath(state, b)) {
- TensorInfo<unsigned int> aInfo(state, a);
+#define HANDLE_B_CASE(TYPE, A, B) \
+ { \
+ if (bInfo.isContiguous()) { \
+ HANDLE_CASE(TYPE, A, -2); \
+ } else { \
+ switch (B) { \
+ case 1: \
+ HANDLE_CASE(TYPE, A, 1); \
+ break; \
+ case 2: \
+ HANDLE_CASE(TYPE, A, 2); \
+ break; \
+ default: \
+ HANDLE_CASE(TYPE, A, -1); \
+ break; \
+ } \
+ } \
+ }
+
+#define HANDLE_A_CASE(TYPE, A, B) \
+ { \
+ if (aInfo.isContiguous()) { \
+ HANDLE_B_CASE(TYPE, -2, B); \
+ } else { \
+ switch (A) { \
+ case 1: \
+ HANDLE_B_CASE(TYPE, 1, B); \
+ break; \
+ case 2: \
+ HANDLE_B_CASE(TYPE, 2, B); \
+ break; \
+ default: \
+ HANDLE_B_CASE(TYPE, -1, B); \
+ break; \
+ } \
+ } \
+ }
+
+ if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a) &&
+ TensorUtils<TensorTypeB>::canUse32BitIndexMath(state, b)) {
+ TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
+ getTensorInfo<TensorTypeA, unsigned int>(state, a);
aInfo.collapseDims();
- TensorInfo<unsigned int> bInfo(state, b);
+ TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int> bInfo =
+ getTensorInfo<TensorTypeB, unsigned int>(state, b);
bInfo.collapseDims();
HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
} else {
- TensorInfo<unsigned long> aInfo(state, a);
+ TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned long> aInfo =
+ getTensorInfo<TensorTypeA, unsigned long>(state, a);
aInfo.collapseDims();
- TensorInfo<unsigned long> bInfo(state, b);
+ TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned long> bInfo =
+ getTensorInfo<TensorTypeB, unsigned long>(state, b);
bInfo.collapseDims();
// For large tensors, we only compile the completely contiguous
// version and the completely generic version, to reduce
// compilation time.
if (aInfo.isContiguous() && bInfo.isContiguous()) {
- THCudaTensor_pointwiseApply2<Op, unsigned long, -2, -2>
+ kernelPointwiseApply2<Op,
+ typename TensorUtils<TensorTypeA>::DataType,
+ typename TensorUtils<TensorTypeB>::DataType,
+ unsigned long, -2, -2>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aInfo, bInfo, (unsigned long) totalElements, op);
} else {
- THCudaTensor_pointwiseApply2<Op, unsigned long, -1, -1>
+ kernelPointwiseApply2<Op,
+ typename TensorUtils<TensorTypeA>::DataType,
+ typename TensorUtils<TensorTypeB>::DataType,
+ unsigned long, -1, -1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aInfo, bInfo, (unsigned long) totalElements, op);
}
@@ -376,49 +399,52 @@ bool THCudaTensor_pointwiseApply2(THCState* state,
#undef HANDLE_A_CASE
if (oldA) {
- // Ignore overlaps when copying back; if we use THCudaTensor_copy
+ // Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldA contiguous.
- THCudaTensor_copyIgnoringOverlaps(state, oldA, a);
- THCudaTensor_free(state, a);
+ TensorUtils<TensorTypeA>::copyIgnoringOverlaps(state, oldA, a);
+ TensorUtils<TensorTypeA>::free(state, a);
a = oldA;
}
if (oldB) {
- // Ignore overlaps when copying back; if we use THCudaTensor_copy
+ // Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldB contiguous.
- THCudaTensor_copyIgnoringOverlaps(state, oldB, b);
- THCudaTensor_free(state, b);
+ TensorUtils<TensorTypeB>::copyIgnoringOverlaps(state, oldB, b);
+ TensorUtils<TensorTypeB>::free(state, b);
b = oldB;
}
return true;
}
-template <typename Op>
-bool THCudaTensor_pointwiseApply3(THCState* state,
- THCudaTensor* a,
- THCudaTensor* b,
- THCudaTensor* c,
- const Op& op,
- TensorArgType aType = ReadWrite,
- TensorArgType bType = ReadOnly,
- TensorArgType cType = ReadOnly) {
- long totalElements = THCudaTensor_nElement(state, a);
-
- if (totalElements != THCudaTensor_nElement(state, b) ||
- totalElements != THCudaTensor_nElement(state, c)) {
+template <typename TensorTypeA,
+ typename TensorTypeB,
+ typename TensorTypeC,
+ typename Op>
+bool THC_pointwiseApply3(THCState* state,
+ TensorTypeA* a,
+ TensorTypeB* b,
+ TensorTypeC* c,
+ const Op& op,
+ TensorArgType aType = ReadWrite,
+ TensorArgType bType = ReadOnly,
+ TensorArgType cType = ReadOnly) {
+ long totalElements = TensorUtils<TensorTypeA>::getNumElements(state, a);
+
+ if (totalElements != TensorUtils<TensorTypeB>::getNumElements(state, b) ||
+ totalElements != TensorUtils<TensorTypeC>::getNumElements(state, c)) {
return false;
}
- if (THCudaTensor_nDimension(state, a) > MAX_CUTORCH_DIMS ||
- THCudaTensor_nDimension(state, b) > MAX_CUTORCH_DIMS ||
- THCudaTensor_nDimension(state, c) > MAX_CUTORCH_DIMS) {
+ if (TensorUtils<TensorTypeA>::getDims(state, a) > MAX_CUTORCH_DIMS ||
+ TensorUtils<TensorTypeB>::getDims(state, b) > MAX_CUTORCH_DIMS ||
+ TensorUtils<TensorTypeC>::getDims(state, c) > MAX_CUTORCH_DIMS) {
return false;
}
- if (THCudaTensor_nDimension(state, a) == 0) {
+ if (TensorUtils<TensorTypeA>::getDims(state, a) == 0) {
// Zero-dim tensor; do nothing
return true;
}
@@ -439,131 +465,141 @@ bool THCudaTensor_pointwiseApply3(THCState* state,
// indices of a tensor with overlapping indices should probably be
// an error, since it is unclear which one should win), but we will
// preserve this last-writer-wins (in arbitrary copy order) behavior.
- THCudaTensor* oldA = NULL;
- THCudaTensor* oldB = NULL;
- THCudaTensor* oldC = NULL;
+ TensorTypeA* oldA = NULL;
+ TensorTypeB* oldB = NULL;
+ TensorTypeC* oldC = NULL;
- if (aType == ReadWrite && THC_overlappingIndices(state, a)) {
+ if (aType == ReadWrite &&
+ TensorUtils<TensorTypeA>::overlappingIndices(state, a)) {
// Must perform in contiguous space
oldA = a;
- a = THCudaTensor_newContiguous(state, a);
+ a = TensorUtils<TensorTypeA>::newContiguous(state, a);
}
-
- if (bType == ReadWrite && THC_overlappingIndices(state, b)) {
+ if (bType == ReadWrite &&
+ TensorUtils<TensorTypeB>::overlappingIndices(state, b)) {
// Must perform in contiguous space
oldB = b;
- b = THCudaTensor_newContiguous(state, b);
+ b = TensorUtils<TensorTypeB>::newContiguous(state, b);
}
-
- if (cType == ReadWrite && THC_overlappingIndices(state, c)) {
+ if (cType == ReadWrite &&
+ TensorUtils<TensorTypeC>::overlappingIndices(state, c)) {
// Must perform in contiguous space
oldC = c;
- c = THCudaTensor_newContiguous(state, c);
+ c = TensorUtils<TensorTypeC>::newContiguous(state, c);
}
#define HANDLE_CASE(TYPE, A, B, C) \
- THCudaTensor_pointwiseApply3<Op, TYPE, A, B, C> \
+ kernelPointwiseApply3<Op, \
+ typename TensorUtils<TensorTypeA>::DataType, \
+ typename TensorUtils<TensorTypeB>::DataType, \
+ typename TensorUtils<TensorTypeC>::DataType, \
+ TYPE, A, B, C> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
aInfo, bInfo, cInfo, (TYPE) totalElements, op);
-#define HANDLE_C_CASE(TYPE, A, B, C) \
- { \
- if (cInfo.isContiguous()) { \
- HANDLE_CASE(TYPE, A, B, -2); \
- } else { \
- switch (C) { \
- case 1: \
- HANDLE_CASE(TYPE, A, B, 1); \
- break; \
- case 2: \
- HANDLE_CASE(TYPE, A, B, 2); \
- break; \
- case 3: \
- HANDLE_CASE(TYPE, A, B, 3); \
- break; \
- default: \
- HANDLE_CASE(TYPE, A, B, -1); \
- break; \
- } \
- } \
- }
-
-#define HANDLE_B_CASE(TYPE, A, B, C) \
- { \
- if (bInfo.isContiguous()) { \
- HANDLE_C_CASE(TYPE, A, -2, C); \
- } else { \
- switch (B) { \
- case 1: \
- HANDLE_C_CASE(TYPE, A, 1, C); \
- break; \
- case 2: \
- HANDLE_C_CASE(TYPE, A, 2, C); \
- break; \
- case 3: \
- HANDLE_C_CASE(TYPE, A, 3, C); \
- break; \
- default: \
- HANDLE_C_CASE(TYPE, A, -1, C); \
- break; \
- } \
- } \
- }
-
-#define HANDLE_A_CASE(TYPE, A, B, C) \
- { \
- if (aInfo.isContiguous()) { \
- HANDLE_B_CASE(TYPE, -2, B, C); \
- } else { \
- switch (A) { \
- case 1: \
- HANDLE_B_CASE(TYPE, 1, B, C); \
- break; \
- case 2: \
- HANDLE_B_CASE(TYPE, 2, B, C); \
- break; \
- case 3: \
- HANDLE_B_CASE(TYPE, 3, B, C); \
- break; \
- default: \
- HANDLE_B_CASE(TYPE, -1, B, C); \
- break; \
- } \
- } \
- }
-
- if (THC_canUse32BitIndexMath(state, a) &&
- THC_canUse32BitIndexMath(state, b) &&
- THC_canUse32BitIndexMath(state, c)) {
- TensorInfo<unsigned int> aInfo(state, a);
+#define HANDLE_C_CASE(TYPE, A, B, C) \
+ { \
+ if (cInfo.isContiguous()) { \
+ HANDLE_CASE(TYPE, A, B, -2); \
+ } else { \
+ switch (C) { \
+ case 1: \
+ HANDLE_CASE(TYPE, A, B, 1); \
+ break; \
+ case 2: \
+ HANDLE_CASE(TYPE, A, B, 2); \
+ break; \
+ default: \
+ HANDLE_CASE(TYPE, A, B, -1); \
+ break; \
+ } \
+ } \
+ }
+
+#define HANDLE_B_CASE(TYPE, A, B, C) \
+ { \
+ if (bInfo.isContiguous()) { \
+ HANDLE_C_CASE(TYPE, A, -2, C); \
+ } else { \
+ switch (B) { \
+ case 1: \
+ HANDLE_C_CASE(TYPE, A, 1, C); \
+ break; \
+ case 2: \
+ HANDLE_C_CASE(TYPE, A, 2, C); \
+ break; \
+ default: \
+ HANDLE_C_CASE(TYPE, A, -1, C); \
+ break; \
+ } \
+ } \
+ }
+
+#define HANDLE_A_CASE(TYPE, A, B, C) \
+ { \
+ if (aInfo.isContiguous()) { \
+ HANDLE_B_CASE(TYPE, -2, B, C); \
+ } else { \
+ switch (A) { \
+ case 1: \
+ HANDLE_B_CASE(TYPE, 1, B, C); \
+ break; \
+ case 2: \
+ HANDLE_B_CASE(TYPE, 2, B, C); \
+ break; \
+ default: \
+ HANDLE_B_CASE(TYPE, -1, B, C); \
+ break; \
+ } \
+ } \
+ }
+
+ if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a) &&
+ TensorUtils<TensorTypeB>::canUse32BitIndexMath(state, b) &&
+ TensorUtils<TensorTypeC>::canUse32BitIndexMath(state, c)) {
+ TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo =
+ getTensorInfo<TensorTypeA, unsigned int>(state, a);
aInfo.collapseDims();
- TensorInfo<unsigned int> bInfo(state, b);
+ TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int> bInfo =
+ getTensorInfo<TensorTypeB, unsigned int>(state, b);
bInfo.collapseDims();
- TensorInfo<unsigned int> cInfo(state, c);
+ TensorInfo<typename TensorUtils<TensorTypeC>::DataType, unsigned int> cInfo =
+ getTensorInfo<TensorTypeC, unsigned int>(state, c);
cInfo.collapseDims();
HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims, cInfo.dims);
} else {
- TensorInfo<unsigned long> aInfo(state, a);
+ TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned long> aInfo =
+ getTensorInfo<TensorTypeA, unsigned long>(state, a);
aInfo.collapseDims();
- TensorInfo<unsigned long> bInfo(state, b);
+ TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned long> bInfo =
+ getTensorInfo<TensorTypeB, unsigned long>(state, b);
bInfo.collapseDims();
- TensorInfo<unsigned long> cInfo(state, c);
+ TensorInfo<typename TensorUtils<TensorTypeC>::DataType, unsigned long> cInfo =
+ getTensorInfo<TensorTypeC, unsigned long>(state, c);
cInfo.collapseDims();
// For large tensors, we only compile the completely contiguous
// version and the completely generic version, to reduce
// compilation time.
if (aInfo.isContiguous() && bInfo.isContiguous() && cInfo.isContiguous()) {
- THCudaTensor_pointwiseApply3<Op, unsigned long, -2, -2, -2>
+ kernelPointwiseApply3<Op,
+ typename TensorUtils<TensorTypeA>::DataType,
+ typename TensorUtils<TensorTypeB>::DataType,
+ typename TensorUtils<TensorTypeC>::DataType,
+ unsigned long, -2, -2, -2>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aInfo, bInfo, cInfo, (unsigned long) totalElements, op);
} else {
- THCudaTensor_pointwiseApply3<Op, unsigned long, -1, -1, -1>
+ kernelPointwiseApply3<Op,
+ typename TensorUtils<TensorTypeA>::DataType,
+ typename TensorUtils<TensorTypeB>::DataType,
+ typename TensorUtils<TensorTypeC>::DataType,
+ unsigned long, -1, -1, -1>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
aInfo, bInfo, cInfo, (unsigned long) totalElements, op);
}
@@ -574,29 +610,29 @@ bool THCudaTensor_pointwiseApply3(THCState* state,
#undef HANDLE_A_CASE
if (oldA) {
- // Ignore overlaps when copying back; if we use THCudaTensor_copy
+ // Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldA contiguous.
- THCudaTensor_copyIgnoringOverlaps(state, oldA, a);
- THCudaTensor_free(state, a);
+ TensorUtils<TensorTypeA>::copyIgnoringOverlaps(state, oldA, a);
+ TensorUtils<TensorTypeA>::free(state, a);
a = oldA;
}
if (oldB) {
- // Ignore overlaps when copying back; if we use THCudaTensor_copy
+ // Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldB contiguous.
- THCudaTensor_copyIgnoringOverlaps(state, oldB, b);
- THCudaTensor_free(state, b);
+ TensorUtils<TensorTypeB>::copyIgnoringOverlaps(state, oldB, b);
+ TensorUtils<TensorTypeB>::free(state, b);
b = oldB;
}
if (oldC) {
- // Ignore overlaps when copying back; if we use THCudaTensor_copy
+ // Ignore overlaps when copying back; if we use THCTensor_copy
// instead, it will recursively try and invoke ourselves to make
// oldC contiguous.
- THCudaTensor_copyIgnoringOverlaps(state, oldC, c);
- THCudaTensor_free(state, c);
+ TensorUtils<TensorTypeC>::copyIgnoringOverlaps(state, oldC, c);
+ TensorUtils<TensorTypeC>::free(state, c);
c = oldC;
}