From 997b8b1bb8f28f0eab1483d76f1d4e9a000b5a70 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 28 Aug 2017 18:26:35 +0000 Subject: Fix grid size for batch cat tensor now that getApplyGrid has been changed. --- lib/THC/THCTensorMath.cuh | 52 ++++++++++++++++++++++++++++++---------- lib/THC/generic/THCTensorMath.cu | 15 +++++------- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/lib/THC/THCTensorMath.cuh b/lib/THC/THCTensorMath.cuh index ae8f5db..202090e 100644 --- a/lib/THC/THCTensorMath.cuh +++ b/lib/THC/THCTensorMath.cuh @@ -26,6 +26,24 @@ __global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t #define CAT_ARRAY_BATCH_SIZE 1024 #define CAT_ARRAY_MAX_INPUT_DIMS 4 +inline bool getCatGrid(THCState* state, ptrdiff_t nTensors, dim3& grid) { + int curDevice = -1; + cudaGetDevice(&curDevice); + + if (curDevice == -1) { + return false; + } + + // Assume a reasonable number of SMs if no state is available + int numSM = + state ? THCState_getCurrentDeviceProperties(state)->multiProcessorCount : 15; + //X dim of grid for cat array cooperates on a single tensor in the cat. + //Given half of the GPU, full utilization will always occur. + grid = dim3( 2LL * numSM, (long long) nTensors ); + + return true; +} + // Similar to any other IndexToOffset calculation for copying along a given dimension. template struct CatArrIndexToOffset { @@ -77,6 +95,9 @@ struct OutputTensorSizeStride { * * The most important assumption made is that the input tensors are contiguous. */ + + + template __global__ void CatArrayBatchedCopy( T* output, @@ -84,19 +105,26 @@ __global__ void CatArrayBatchedCopy( OutputTensorSizeStride os, const int concatDim, IndexType dimStride) { - T* data = inputs[blockIdx.y].input; - IndexType offset = inputs[blockIdx.y].offset; - IndexType dimSize = inputs[blockIdx.y].dimSize; - IndexType nElements = inputs[blockIdx.y].nElements; - IndexType dataOffset = offset * dimStride; - - for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; - linearIndex < nElements; - linearIndex += gridDim.x * blockDim.x) { + + IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; + IndexType nElements = inputs[blockIdx.y].nElements; + + if(tid >= nElements) return; + + T* data = inputs[blockIdx.y].input; + IndexType offset = inputs[blockIdx.y].offset; + IndexType dimSize = inputs[blockIdx.y].dimSize; + IndexType dataOffset = offset * dimStride; + + IndexType stride = gridDim.x * blockDim.x; + + while( tid < nElements){ IndexType elementOffset = CatArrIndexToOffset::compute( - os.outputSize, os.outputStride, dimSize, concatDim, linearIndex); - output[dataOffset + elementOffset] = data[linearIndex]; - } + os.outputSize, os.outputStride, dimSize, concatDim, tid); + output[dataOffset + elementOffset] = data[tid]; + + tid += stride; + } } #endif diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index 628240a..ceb6f2d 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -207,7 +207,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, // Template Declarations for dim = 1, 2, 3, 4 #define HANDLE_CASE(DIMS) \ - CatArrayBatchedCopy<<stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]); + CatArrayBatchedCopy<<stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]); // Now we loop offset = 0; @@ -243,15 +243,12 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, // is based on. dim3 applyBlock = getApplyBlock(); - // We also re-use the applyGrid - but note that we use the maximum number of - // elements for a given tensor in this grouping to determine the count - dim3 applyGrid; - getApplyGrid(state, cohortMax, applyGrid); + //Get grid where x dim fills half gpu and y dim is number of tensors. + //This will have cating two tensors fill the entire grid, but prevent + //many threads from needlessly load meta data if their sizes is small. + dim3 catGrid; + getCatGrid(state, j, catGrid); - // Next, we set our grid's y component to be the number of tensors in - // the batch. This will allow the kernel to determine which input - // tensor it is responsible for copying - applyGrid.y = j; switch (maxDim) { case 1: -- cgit v1.2.3