Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/THCTensorMath.cuh')
-rw-r--r--lib/THC/THCTensorMath.cuh52
1 files changed, 40 insertions, 12 deletions
diff --git a/lib/THC/THCTensorMath.cuh b/lib/THC/THCTensorMath.cuh
index ae8f5db..202090e 100644
--- a/lib/THC/THCTensorMath.cuh
+++ b/lib/THC/THCTensorMath.cuh
@@ -26,6 +26,24 @@ __global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t
#define CAT_ARRAY_BATCH_SIZE 1024
#define CAT_ARRAY_MAX_INPUT_DIMS 4
+inline bool getCatGrid(THCState* state, ptrdiff_t nTensors, dim3& grid) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+
+ if (curDevice == -1) {
+ return false;
+ }
+
+ // Assume a reasonable number of SMs if no state is available
+ int numSM =
+ state ? THCState_getCurrentDeviceProperties(state)->multiProcessorCount : 15;
+ //X dim of grid for cat array cooperates on a single tensor in the cat.
+ //Given half of the GPU, full utilization will always occur.
+ grid = dim3( 2LL * numSM, (long long) nTensors );
+
+ return true;
+}
+
// Similar to any other IndexToOffset calculation for copying along a given dimension.
template <typename IndexType, int Dims>
struct CatArrIndexToOffset {
@@ -77,6 +95,9 @@ struct OutputTensorSizeStride {
*
* The most important assumption made is that the input tensors are contiguous.
*/
+
+
+
template <typename T, typename IndexType, int Dims>
__global__ void CatArrayBatchedCopy(
T* output,
@@ -84,19 +105,26 @@ __global__ void CatArrayBatchedCopy(
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {
- T* data = inputs[blockIdx.y].input;
- IndexType offset = inputs[blockIdx.y].offset;
- IndexType dimSize = inputs[blockIdx.y].dimSize;
- IndexType nElements = inputs[blockIdx.y].nElements;
- IndexType dataOffset = offset * dimStride;
-
- for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
- linearIndex < nElements;
- linearIndex += gridDim.x * blockDim.x) {
+
+ IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
+ IndexType nElements = inputs[blockIdx.y].nElements;
+
+ if(tid >= nElements) return;
+
+ T* data = inputs[blockIdx.y].input;
+ IndexType offset = inputs[blockIdx.y].offset;
+ IndexType dimSize = inputs[blockIdx.y].dimSize;
+ IndexType dataOffset = offset * dimStride;
+
+ IndexType stride = gridDim.x * blockDim.x;
+
+ while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
- os.outputSize, os.outputStride, dimSize, concatDim, linearIndex);
- output[dataOffset + elementOffset] = data[linearIndex];
- }
+ os.outputSize, os.outputStride, dimSize, concatDim, tid);
+ output[dataOffset + elementOffset] = data[tid];
+
+ tid += stride;
+ }
}
#endif