diff options
Diffstat (limited to 'lib/THC/THCApply.cuh')
-rw-r--r-- | lib/THC/THCApply.cuh | 566 |
1 files changed, 301 insertions, 265 deletions
diff --git a/lib/THC/THCApply.cuh b/lib/THC/THCApply.cuh index 707e22f..dd6d32a 100644 --- a/lib/THC/THCApply.cuh +++ b/lib/THC/THCApply.cuh @@ -3,6 +3,7 @@ #include "THCTensorCopy.h" #include "THCReduceApplyUtils.cuh" +#include "THCTensorTypeUtils.cuh" // // This file contains pointwise operation functions and kernels that @@ -12,81 +13,85 @@ // // Threads per block for our apply kernel +// FIXME: use occupancy calculator instead #define THC_APPLY_THREADS_PER_BLOCK 32 * 16 -// Called when we are copying into an overlapping index `dst`, but -// we don't care which writer wins. Hacky but it works. -THC_API void THCudaTensor_copyIgnoringOverlaps(THCState* state, - THCudaTensor* dst, - THCudaTensor* src); - -template <typename Op, typename IndexType, int ADims> +template <typename Op, + typename Ta, + typename IndexType, + int ADims> #if __CUDA_ARCH__ >= 350 __launch_bounds__(32 * 16, 4) #endif __global__ void -THCudaTensor_pointwiseApply1(TensorInfo<IndexType> a, - IndexType totalElements, - Op op) { +kernelPointwiseApply1(TensorInfo<Ta, IndexType> a, + IndexType totalElements, + Op op) { for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < totalElements; linearIndex += gridDim.x * blockDim.x) { // Convert `linearIndex` into an offset of `a` const IndexType aOffset = - IndexToOffset<IndexType, ADims>::get(linearIndex, a); + IndexToOffset<Ta, IndexType, ADims>::get(linearIndex, a); op(&a.data[aOffset]); } } -template <typename Op, typename IndexType, int ADims, int BDims> +template <typename Op, + typename Ta, typename Tb, + typename IndexType, + int ADims, int BDims> #if __CUDA_ARCH__ >= 350 __launch_bounds__(32 * 16, 4) #endif __global__ void -THCudaTensor_pointwiseApply2(TensorInfo<IndexType> a, - TensorInfo<IndexType> b, - IndexType totalElements, - Op op) { +kernelPointwiseApply2(TensorInfo<Ta, IndexType> a, + TensorInfo<Tb, IndexType> b, + IndexType totalElements, + Op op) { for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < totalElements; linearIndex += gridDim.x * blockDim.x) { // Convert `linearIndex` into an offset of `a` const IndexType aOffset = - IndexToOffset<IndexType, ADims>::get(linearIndex, a); + IndexToOffset<Ta, IndexType, ADims>::get(linearIndex, a); // Convert `linearIndex` into an offset of `b` const IndexType bOffset = - IndexToOffset<IndexType, BDims>::get(linearIndex, b); + IndexToOffset<Tb, IndexType, BDims>::get(linearIndex, b); op(&a.data[aOffset], &b.data[bOffset]); } } -template <typename Op, typename IndexType, int ADims, int BDims, int CDims> +template <typename Op, + typename Ta, typename Tb, typename Tc, + typename IndexType, + int ADims, int BDims, int CDims> #if __CUDA_ARCH__ >= 350 __launch_bounds__(32 * 16, 4) #endif __global__ void -THCudaTensor_pointwiseApply3(TensorInfo<IndexType> a, - TensorInfo<IndexType> b, - TensorInfo<IndexType> c, - IndexType totalElements, - Op op) { +kernelPointwiseApply3(TensorInfo<Ta, IndexType> a, + TensorInfo<Tb, IndexType> b, + TensorInfo<Tc, IndexType> c, + IndexType totalElements, + Op op) { for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; linearIndex < totalElements; linearIndex += gridDim.x * blockDim.x) { // Convert `linearIndex` into an offset of `a` const IndexType aOffset = - IndexToOffset<IndexType, ADims>::get(linearIndex, a); + IndexToOffset<Ta, IndexType, ADims>::get(linearIndex, a); // Convert `linearIndex` into an offset of `b` const IndexType bOffset = - IndexToOffset<IndexType, BDims>::get(linearIndex, b); + IndexToOffset<Tb, IndexType, BDims>::get(linearIndex, b); // Convert `linearIndex` into an offset of `c` const IndexType cOffset = - IndexToOffset<IndexType, CDims>::get(linearIndex, c); + IndexToOffset<Tc, IndexType, CDims>::get(linearIndex, c); op(&a.data[aOffset], &b.data[bOffset], &c.data[cOffset]); } @@ -116,18 +121,17 @@ inline bool getApplyGrid(THCState* state, long totalElements, dim3& grid) { return true; } -template <typename Op> -bool THCudaTensor_pointwiseApply1(THCState* state, - THCudaTensor* a, - const Op& op, - TensorArgType aType = ReadWrite) { - long totalElements = THCudaTensor_nElement(state, a); - - if (THCudaTensor_nDimension(state, a) > MAX_CUTORCH_DIMS) { +template <typename TensorTypeA, + typename Op> +bool THC_pointwiseApply1(THCState* state, + TensorTypeA* a, + const Op& op, + TensorArgType aType = ReadWrite) { + if (TensorUtils<TensorTypeA>::getDims(state, a) > MAX_CUTORCH_DIMS) { return false; } - if (THCudaTensor_nDimension(state, a) == 0) { + if (TensorUtils<TensorTypeA>::getDims(state, a) == 0) { // Zero-dim tensor; do nothing return true; } @@ -135,6 +139,8 @@ bool THCudaTensor_pointwiseApply1(THCState* state, const dim3 block = getApplyBlock(); dim3 grid; + long totalElements = TensorUtils<TensorTypeA>::getNumElements(state, a); + if (!getApplyGrid(state, totalElements, grid)) { return false; } @@ -148,12 +154,13 @@ bool THCudaTensor_pointwiseApply1(THCState* state, // indices of a tensor with overlapping indices should probably be // an error, since it is unclear which one should win), but we will // preserve this last-writer-wins (in arbitrary copy order) behavior. - THCudaTensor* oldA = NULL; + TensorTypeA* oldA = NULL; - if (aType == ReadWrite && THC_overlappingIndices(state, a)) { + if (aType == ReadWrite && + TensorUtils<TensorTypeA>::overlappingIndices(state, a)) { // Must perform in contiguous space oldA = a; - a = THCudaTensor_newContiguous(state, a); + a = TensorUtils<TensorTypeA>::newContiguous(state, a); } // It is possible that the tensor dimensions are able to be collapsed, @@ -164,55 +171,60 @@ bool THCudaTensor_pointwiseApply1(THCState* state, // (or vice versa), the contiguous tensor can be collapsed to one // dimension, and the loop to translate the linear index to the array // index can be similarly collapsed. That is what this unrolling is for. -#define HANDLE_CASE(TYPE, A) \ - THCudaTensor_pointwiseApply1<Op, TYPE, A> \ - <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ +#define HANDLE_CASE(TYPE, A) \ + kernelPointwiseApply1<Op, \ + typename TensorUtils<TensorTypeA>::DataType, \ + TYPE, A> \ + <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ aInfo, (TYPE) totalElements, op); -#define HANDLE_A_CASE(TYPE, A) \ - { \ - if (aInfo.isContiguous()) { \ - HANDLE_CASE(TYPE, -2); \ - } else { \ - switch (A) { \ - case 1: \ - HANDLE_CASE(TYPE, 1); \ - break; \ - case 2: \ - HANDLE_CASE(TYPE, 2); \ - break; \ - case 3: \ - HANDLE_CASE(TYPE, 3); \ - break; \ - default: \ - HANDLE_CASE(TYPE, -1); \ - break; \ - } \ - } \ +#define HANDLE_A_CASE(TYPE, A) \ + { \ + if (aInfo.isContiguous()) { \ + HANDLE_CASE(TYPE, -2); \ + } else { \ + switch (A) { \ + case 1: \ + HANDLE_CASE(TYPE, 1); \ + break; \ + case 2: \ + HANDLE_CASE(TYPE, 2); \ + break; \ + default: \ + HANDLE_CASE(TYPE, -1); \ + break; \ + } \ + } \ } // Can we use 32-bit integer math in the kernel (the linear ID for the copy // and the resulting non-linear offset is all computable using 32-bit math?) // We also use unsigned index math in the kernel, as signed div/mod has // additional overhead. - if (THC_canUse32BitIndexMath(state, a)) { - TensorInfo<unsigned int> aInfo(state, a); + if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a)) { + TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo = + getTensorInfo<TensorTypeA, unsigned int>(state, a); aInfo.collapseDims(); HANDLE_A_CASE(unsigned int, aInfo.dims); } else { - TensorInfo<unsigned long> aInfo(state, a); + TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned long> aInfo = + getTensorInfo<TensorTypeA, unsigned long>(state, a); aInfo.collapseDims(); // For large tensors, we only compile the completely contiguous // version and the completely generic version, to reduce // compilation time. if (aInfo.isContiguous()) { - THCudaTensor_pointwiseApply1<Op, unsigned long, -2> + kernelPointwiseApply1<Op, + typename TensorUtils<TensorTypeA>::DataType, + unsigned long, -2> <<<grid, block, 0, THCState_getCurrentStream(state)>>>( aInfo, (unsigned long) totalElements, op); } else { - THCudaTensor_pointwiseApply1<Op, unsigned long, -1> + kernelPointwiseApply1<Op, + typename TensorUtils<TensorTypeA>::DataType, + unsigned long, -1> <<<grid, block, 0, THCState_getCurrentStream(state)>>>( aInfo, (unsigned long) totalElements, op); } @@ -221,36 +233,38 @@ bool THCudaTensor_pointwiseApply1(THCState* state, #undef HANDLE_A_CASE if (oldA) { - // Ignore overlaps when copying back; if we use THCudaTensor_copy + // Ignore overlaps when copying back; if we use THCTensor_copy // instead, it will recursively try and invoke ourselves to make // oldA contiguous. - THCudaTensor_copyIgnoringOverlaps(state, oldA, a); - THCudaTensor_free(state, a); + TensorUtils<TensorTypeA>::copyIgnoringOverlaps(state, oldA, a); + TensorUtils<TensorTypeA>::free(state, a); a = oldA; } return true; } -template <typename Op> -bool THCudaTensor_pointwiseApply2(THCState* state, - THCudaTensor* a, - THCudaTensor* b, - const Op& op, - TensorArgType aType = ReadWrite, - TensorArgType bType = ReadOnly) { - long totalElements = THCudaTensor_nElement(state, a); - - if (totalElements != THCudaTensor_nElement(state, b)) { +template <typename TensorTypeA, + typename TensorTypeB, + typename Op> +bool THC_pointwiseApply2(THCState* state, + TensorTypeA* a, + TensorTypeB* b, + const Op& op, + TensorArgType aType = ReadWrite, + TensorArgType bType = ReadOnly) { + long totalElements = TensorUtils<TensorTypeA>::getNumElements(state, a); + + if (totalElements != TensorUtils<TensorTypeB>::getNumElements(state, b)) { return false; } - if (THCudaTensor_nDimension(state, a) > MAX_CUTORCH_DIMS || - THCudaTensor_nDimension(state, b) > MAX_CUTORCH_DIMS) { + if (TensorUtils<TensorTypeA>::getDims(state, a) > MAX_CUTORCH_DIMS || + TensorUtils<TensorTypeB>::getDims(state, b) > MAX_CUTORCH_DIMS) { return false; } - if (THCudaTensor_nDimension(state, a) == 0) { + if (TensorUtils<TensorTypeA>::getDims(state, a) == 0) { // Zero-dim tensor; do nothing return true; } @@ -271,18 +285,20 @@ bool THCudaTensor_pointwiseApply2(THCState* state, // indices of a tensor with overlapping indices should probably be // an error, since it is unclear which one should win), but we will // preserve this last-writer-wins (in arbitrary copy order) behavior. - THCudaTensor* oldA = NULL; - THCudaTensor* oldB = NULL; + TensorTypeA* oldA = NULL; + TensorTypeB* oldB = NULL; - if (aType == ReadWrite && THC_overlappingIndices(state, a)) { + if (aType == ReadWrite && + TensorUtils<TensorTypeA>::overlappingIndices(state, a)) { // Must perform in contiguous space oldA = a; - a = THCudaTensor_newContiguous(state, a); + a = TensorUtils<TensorTypeA>::newContiguous(state, a); } - if (bType == ReadWrite && THC_overlappingIndices(state, b)) { + if (bType == ReadWrite && + TensorUtils<TensorTypeB>::overlappingIndices(state, b)) { // Must perform in contiguous space oldB = b; - b = THCudaTensor_newContiguous(state, b); + b = TensorUtils<TensorTypeB>::newContiguous(state, b); } // It is possible that the tensor dimensions are able to be collapsed, @@ -293,80 +309,87 @@ bool THCudaTensor_pointwiseApply2(THCState* state, // (or vice versa), the contiguous tensor can be collapsed to one // dimension, and the loop to translate the linear index to the array // index can be similarly collapsed. That is what this unrolling is for. -#define HANDLE_CASE(TYPE, A, B) \ - THCudaTensor_pointwiseApply2<Op, TYPE, A, B> \ - <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ +#define HANDLE_CASE(TYPE, A, B) \ + kernelPointwiseApply2<Op, \ + typename TensorUtils<TensorTypeA>::DataType, \ + typename TensorUtils<TensorTypeB>::DataType, \ + TYPE, A, B> \ + <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ aInfo, bInfo, (TYPE) totalElements, op); -#define HANDLE_B_CASE(TYPE, A, B) \ - { \ - if (bInfo.isContiguous()) { \ - HANDLE_CASE(TYPE, A, -2); \ - } else { \ - switch (B) { \ - case 1: \ - HANDLE_CASE(TYPE, A, 1); \ - break; \ - case 2: \ - HANDLE_CASE(TYPE, A, 2); \ - break; \ - case 3: \ - HANDLE_CASE(TYPE, A, 3); \ - break; \ - default: \ - HANDLE_CASE(TYPE, A, -1); \ - break; \ - } \ - } \ - } - -#define HANDLE_A_CASE(TYPE, A, B) \ - { \ - if (aInfo.isContiguous()) { \ - HANDLE_B_CASE(TYPE, -2, B); \ - } else { \ - switch (A) { \ - case 1: \ - HANDLE_B_CASE(TYPE, 1, B); \ - break; \ - case 2: \ - HANDLE_B_CASE(TYPE, 2, B); \ - break; \ - case 3: \ - HANDLE_B_CASE(TYPE, 3, B); \ - break; \ - default: \ - HANDLE_B_CASE(TYPE, -1, B); \ - break; \ - } \ - } \ - } - - if (THC_canUse32BitIndexMath(state, a) && - THC_canUse32BitIndexMath(state, b)) { - TensorInfo<unsigned int> aInfo(state, a); +#define HANDLE_B_CASE(TYPE, A, B) \ + { \ + if (bInfo.isContiguous()) { \ + HANDLE_CASE(TYPE, A, -2); \ + } else { \ + switch (B) { \ + case 1: \ + HANDLE_CASE(TYPE, A, 1); \ + break; \ + case 2: \ + HANDLE_CASE(TYPE, A, 2); \ + break; \ + default: \ + HANDLE_CASE(TYPE, A, -1); \ + break; \ + } \ + } \ + } + +#define HANDLE_A_CASE(TYPE, A, B) \ + { \ + if (aInfo.isContiguous()) { \ + HANDLE_B_CASE(TYPE, -2, B); \ + } else { \ + switch (A) { \ + case 1: \ + HANDLE_B_CASE(TYPE, 1, B); \ + break; \ + case 2: \ + HANDLE_B_CASE(TYPE, 2, B); \ + break; \ + default: \ + HANDLE_B_CASE(TYPE, -1, B); \ + break; \ + } \ + } \ + } + + if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a) && + TensorUtils<TensorTypeB>::canUse32BitIndexMath(state, b)) { + TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo = + getTensorInfo<TensorTypeA, unsigned int>(state, a); aInfo.collapseDims(); - TensorInfo<unsigned int> bInfo(state, b); + TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int> bInfo = + getTensorInfo<TensorTypeB, unsigned int>(state, b); bInfo.collapseDims(); HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims); } else { - TensorInfo<unsigned long> aInfo(state, a); + TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned long> aInfo = + getTensorInfo<TensorTypeA, unsigned long>(state, a); aInfo.collapseDims(); - TensorInfo<unsigned long> bInfo(state, b); + TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned long> bInfo = + getTensorInfo<TensorTypeB, unsigned long>(state, b); bInfo.collapseDims(); // For large tensors, we only compile the completely contiguous // version and the completely generic version, to reduce // compilation time. if (aInfo.isContiguous() && bInfo.isContiguous()) { - THCudaTensor_pointwiseApply2<Op, unsigned long, -2, -2> + kernelPointwiseApply2<Op, + typename TensorUtils<TensorTypeA>::DataType, + typename TensorUtils<TensorTypeB>::DataType, + unsigned long, -2, -2> <<<grid, block, 0, THCState_getCurrentStream(state)>>>( aInfo, bInfo, (unsigned long) totalElements, op); } else { - THCudaTensor_pointwiseApply2<Op, unsigned long, -1, -1> + kernelPointwiseApply2<Op, + typename TensorUtils<TensorTypeA>::DataType, + typename TensorUtils<TensorTypeB>::DataType, + unsigned long, -1, -1> <<<grid, block, 0, THCState_getCurrentStream(state)>>>( aInfo, bInfo, (unsigned long) totalElements, op); } @@ -376,49 +399,52 @@ bool THCudaTensor_pointwiseApply2(THCState* state, #undef HANDLE_A_CASE if (oldA) { - // Ignore overlaps when copying back; if we use THCudaTensor_copy + // Ignore overlaps when copying back; if we use THCTensor_copy // instead, it will recursively try and invoke ourselves to make // oldA contiguous. - THCudaTensor_copyIgnoringOverlaps(state, oldA, a); - THCudaTensor_free(state, a); + TensorUtils<TensorTypeA>::copyIgnoringOverlaps(state, oldA, a); + TensorUtils<TensorTypeA>::free(state, a); a = oldA; } if (oldB) { - // Ignore overlaps when copying back; if we use THCudaTensor_copy + // Ignore overlaps when copying back; if we use THCTensor_copy // instead, it will recursively try and invoke ourselves to make // oldB contiguous. - THCudaTensor_copyIgnoringOverlaps(state, oldB, b); - THCudaTensor_free(state, b); + TensorUtils<TensorTypeB>::copyIgnoringOverlaps(state, oldB, b); + TensorUtils<TensorTypeB>::free(state, b); b = oldB; } return true; } -template <typename Op> -bool THCudaTensor_pointwiseApply3(THCState* state, - THCudaTensor* a, - THCudaTensor* b, - THCudaTensor* c, - const Op& op, - TensorArgType aType = ReadWrite, - TensorArgType bType = ReadOnly, - TensorArgType cType = ReadOnly) { - long totalElements = THCudaTensor_nElement(state, a); - - if (totalElements != THCudaTensor_nElement(state, b) || - totalElements != THCudaTensor_nElement(state, c)) { +template <typename TensorTypeA, + typename TensorTypeB, + typename TensorTypeC, + typename Op> +bool THC_pointwiseApply3(THCState* state, + TensorTypeA* a, + TensorTypeB* b, + TensorTypeC* c, + const Op& op, + TensorArgType aType = ReadWrite, + TensorArgType bType = ReadOnly, + TensorArgType cType = ReadOnly) { + long totalElements = TensorUtils<TensorTypeA>::getNumElements(state, a); + + if (totalElements != TensorUtils<TensorTypeB>::getNumElements(state, b) || + totalElements != TensorUtils<TensorTypeC>::getNumElements(state, c)) { return false; } - if (THCudaTensor_nDimension(state, a) > MAX_CUTORCH_DIMS || - THCudaTensor_nDimension(state, b) > MAX_CUTORCH_DIMS || - THCudaTensor_nDimension(state, c) > MAX_CUTORCH_DIMS) { + if (TensorUtils<TensorTypeA>::getDims(state, a) > MAX_CUTORCH_DIMS || + TensorUtils<TensorTypeB>::getDims(state, b) > MAX_CUTORCH_DIMS || + TensorUtils<TensorTypeC>::getDims(state, c) > MAX_CUTORCH_DIMS) { return false; } - if (THCudaTensor_nDimension(state, a) == 0) { + if (TensorUtils<TensorTypeA>::getDims(state, a) == 0) { // Zero-dim tensor; do nothing return true; } @@ -439,131 +465,141 @@ bool THCudaTensor_pointwiseApply3(THCState* state, // indices of a tensor with overlapping indices should probably be // an error, since it is unclear which one should win), but we will // preserve this last-writer-wins (in arbitrary copy order) behavior. - THCudaTensor* oldA = NULL; - THCudaTensor* oldB = NULL; - THCudaTensor* oldC = NULL; + TensorTypeA* oldA = NULL; + TensorTypeB* oldB = NULL; + TensorTypeC* oldC = NULL; - if (aType == ReadWrite && THC_overlappingIndices(state, a)) { + if (aType == ReadWrite && + TensorUtils<TensorTypeA>::overlappingIndices(state, a)) { // Must perform in contiguous space oldA = a; - a = THCudaTensor_newContiguous(state, a); + a = TensorUtils<TensorTypeA>::newContiguous(state, a); } - - if (bType == ReadWrite && THC_overlappingIndices(state, b)) { + if (bType == ReadWrite && + TensorUtils<TensorTypeB>::overlappingIndices(state, b)) { // Must perform in contiguous space oldB = b; - b = THCudaTensor_newContiguous(state, b); + b = TensorUtils<TensorTypeB>::newContiguous(state, b); } - - if (cType == ReadWrite && THC_overlappingIndices(state, c)) { + if (cType == ReadWrite && + TensorUtils<TensorTypeC>::overlappingIndices(state, c)) { // Must perform in contiguous space oldC = c; - c = THCudaTensor_newContiguous(state, c); + c = TensorUtils<TensorTypeC>::newContiguous(state, c); } #define HANDLE_CASE(TYPE, A, B, C) \ - THCudaTensor_pointwiseApply3<Op, TYPE, A, B, C> \ + kernelPointwiseApply3<Op, \ + typename TensorUtils<TensorTypeA>::DataType, \ + typename TensorUtils<TensorTypeB>::DataType, \ + typename TensorUtils<TensorTypeC>::DataType, \ + TYPE, A, B, C> \ <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ aInfo, bInfo, cInfo, (TYPE) totalElements, op); -#define HANDLE_C_CASE(TYPE, A, B, C) \ - { \ - if (cInfo.isContiguous()) { \ - HANDLE_CASE(TYPE, A, B, -2); \ - } else { \ - switch (C) { \ - case 1: \ - HANDLE_CASE(TYPE, A, B, 1); \ - break; \ - case 2: \ - HANDLE_CASE(TYPE, A, B, 2); \ - break; \ - case 3: \ - HANDLE_CASE(TYPE, A, B, 3); \ - break; \ - default: \ - HANDLE_CASE(TYPE, A, B, -1); \ - break; \ - } \ - } \ - } - -#define HANDLE_B_CASE(TYPE, A, B, C) \ - { \ - if (bInfo.isContiguous()) { \ - HANDLE_C_CASE(TYPE, A, -2, C); \ - } else { \ - switch (B) { \ - case 1: \ - HANDLE_C_CASE(TYPE, A, 1, C); \ - break; \ - case 2: \ - HANDLE_C_CASE(TYPE, A, 2, C); \ - break; \ - case 3: \ - HANDLE_C_CASE(TYPE, A, 3, C); \ - break; \ - default: \ - HANDLE_C_CASE(TYPE, A, -1, C); \ - break; \ - } \ - } \ - } - -#define HANDLE_A_CASE(TYPE, A, B, C) \ - { \ - if (aInfo.isContiguous()) { \ - HANDLE_B_CASE(TYPE, -2, B, C); \ - } else { \ - switch (A) { \ - case 1: \ - HANDLE_B_CASE(TYPE, 1, B, C); \ - break; \ - case 2: \ - HANDLE_B_CASE(TYPE, 2, B, C); \ - break; \ - case 3: \ - HANDLE_B_CASE(TYPE, 3, B, C); \ - break; \ - default: \ - HANDLE_B_CASE(TYPE, -1, B, C); \ - break; \ - } \ - } \ - } - - if (THC_canUse32BitIndexMath(state, a) && - THC_canUse32BitIndexMath(state, b) && - THC_canUse32BitIndexMath(state, c)) { - TensorInfo<unsigned int> aInfo(state, a); +#define HANDLE_C_CASE(TYPE, A, B, C) \ + { \ + if (cInfo.isContiguous()) { \ + HANDLE_CASE(TYPE, A, B, -2); \ + } else { \ + switch (C) { \ + case 1: \ + HANDLE_CASE(TYPE, A, B, 1); \ + break; \ + case 2: \ + HANDLE_CASE(TYPE, A, B, 2); \ + break; \ + default: \ + HANDLE_CASE(TYPE, A, B, -1); \ + break; \ + } \ + } \ + } + +#define HANDLE_B_CASE(TYPE, A, B, C) \ + { \ + if (bInfo.isContiguous()) { \ + HANDLE_C_CASE(TYPE, A, -2, C); \ + } else { \ + switch (B) { \ + case 1: \ + HANDLE_C_CASE(TYPE, A, 1, C); \ + break; \ + case 2: \ + HANDLE_C_CASE(TYPE, A, 2, C); \ + break; \ + default: \ + HANDLE_C_CASE(TYPE, A, -1, C); \ + break; \ + } \ + } \ + } + +#define HANDLE_A_CASE(TYPE, A, B, C) \ + { \ + if (aInfo.isContiguous()) { \ + HANDLE_B_CASE(TYPE, -2, B, C); \ + } else { \ + switch (A) { \ + case 1: \ + HANDLE_B_CASE(TYPE, 1, B, C); \ + break; \ + case 2: \ + HANDLE_B_CASE(TYPE, 2, B, C); \ + break; \ + default: \ + HANDLE_B_CASE(TYPE, -1, B, C); \ + break; \ + } \ + } \ + } + + if (TensorUtils<TensorTypeA>::canUse32BitIndexMath(state, a) && + TensorUtils<TensorTypeB>::canUse32BitIndexMath(state, b) && + TensorUtils<TensorTypeC>::canUse32BitIndexMath(state, c)) { + TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned int> aInfo = + getTensorInfo<TensorTypeA, unsigned int>(state, a); aInfo.collapseDims(); - TensorInfo<unsigned int> bInfo(state, b); + TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned int> bInfo = + getTensorInfo<TensorTypeB, unsigned int>(state, b); bInfo.collapseDims(); - TensorInfo<unsigned int> cInfo(state, c); + TensorInfo<typename TensorUtils<TensorTypeC>::DataType, unsigned int> cInfo = + getTensorInfo<TensorTypeC, unsigned int>(state, c); cInfo.collapseDims(); HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims, cInfo.dims); } else { - TensorInfo<unsigned long> aInfo(state, a); + TensorInfo<typename TensorUtils<TensorTypeA>::DataType, unsigned long> aInfo = + getTensorInfo<TensorTypeA, unsigned long>(state, a); aInfo.collapseDims(); - TensorInfo<unsigned long> bInfo(state, b); + TensorInfo<typename TensorUtils<TensorTypeB>::DataType, unsigned long> bInfo = + getTensorInfo<TensorTypeB, unsigned long>(state, b); bInfo.collapseDims(); - TensorInfo<unsigned long> cInfo(state, c); + TensorInfo<typename TensorUtils<TensorTypeC>::DataType, unsigned long> cInfo = + getTensorInfo<TensorTypeC, unsigned long>(state, c); cInfo.collapseDims(); // For large tensors, we only compile the completely contiguous // version and the completely generic version, to reduce // compilation time. if (aInfo.isContiguous() && bInfo.isContiguous() && cInfo.isContiguous()) { - THCudaTensor_pointwiseApply3<Op, unsigned long, -2, -2, -2> + kernelPointwiseApply3<Op, + typename TensorUtils<TensorTypeA>::DataType, + typename TensorUtils<TensorTypeB>::DataType, + typename TensorUtils<TensorTypeC>::DataType, + unsigned long, -2, -2, -2> <<<grid, block, 0, THCState_getCurrentStream(state)>>>( aInfo, bInfo, cInfo, (unsigned long) totalElements, op); } else { - THCudaTensor_pointwiseApply3<Op, unsigned long, -1, -1, -1> + kernelPointwiseApply3<Op, + typename TensorUtils<TensorTypeA>::DataType, + typename TensorUtils<TensorTypeB>::DataType, + typename TensorUtils<TensorTypeC>::DataType, + unsigned long, -1, -1, -1> <<<grid, block, 0, THCState_getCurrentStream(state)>>>( aInfo, bInfo, cInfo, (unsigned long) totalElements, op); } @@ -574,29 +610,29 @@ bool THCudaTensor_pointwiseApply3(THCState* state, #undef HANDLE_A_CASE if (oldA) { - // Ignore overlaps when copying back; if we use THCudaTensor_copy + // Ignore overlaps when copying back; if we use THCTensor_copy // instead, it will recursively try and invoke ourselves to make // oldA contiguous. - THCudaTensor_copyIgnoringOverlaps(state, oldA, a); - THCudaTensor_free(state, a); + TensorUtils<TensorTypeA>::copyIgnoringOverlaps(state, oldA, a); + TensorUtils<TensorTypeA>::free(state, a); a = oldA; } if (oldB) { - // Ignore overlaps when copying back; if we use THCudaTensor_copy + // Ignore overlaps when copying back; if we use THCTensor_copy // instead, it will recursively try and invoke ourselves to make // oldB contiguous. - THCudaTensor_copyIgnoringOverlaps(state, oldB, b); - THCudaTensor_free(state, b); + TensorUtils<TensorTypeB>::copyIgnoringOverlaps(state, oldB, b); + TensorUtils<TensorTypeB>::free(state, b); b = oldB; } if (oldC) { - // Ignore overlaps when copying back; if we use THCudaTensor_copy + // Ignore overlaps when copying back; if we use THCTensor_copy // instead, it will recursively try and invoke ourselves to make // oldC contiguous. - THCudaTensor_copyIgnoringOverlaps(state, oldC, c); - THCudaTensor_free(state, c); + TensorUtils<TensorTypeC>::copyIgnoringOverlaps(state, oldC, c); + TensorUtils<TensorTypeC>::free(state, c); c = oldC; } |