Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristian Sarofeen <csarofeen@nvidia.com>2017-08-28 21:26:35 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-29 04:41:46 +0300
commit997b8b1bb8f28f0eab1483d76f1d4e9a000b5a70 (patch)
tree249a5a4356a8c93b7643a957f296a5b8d59dd247
parentd891ff361afe33e118bc9798539453f8eb0000db (diff)
Fix grid size for batch cat tensor now that getApplyGrid has been changed.
-rw-r--r--lib/THC/THCTensorMath.cuh52
-rw-r--r--lib/THC/generic/THCTensorMath.cu15
2 files changed, 46 insertions, 21 deletions
diff --git a/lib/THC/THCTensorMath.cuh b/lib/THC/THCTensorMath.cuh
index ae8f5db..202090e 100644
--- a/lib/THC/THCTensorMath.cuh
+++ b/lib/THC/THCTensorMath.cuh
@@ -26,6 +26,24 @@ __global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t
#define CAT_ARRAY_BATCH_SIZE 1024
#define CAT_ARRAY_MAX_INPUT_DIMS 4
+inline bool getCatGrid(THCState* state, ptrdiff_t nTensors, dim3& grid) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+
+ if (curDevice == -1) {
+ return false;
+ }
+
+ // Assume a reasonable number of SMs if no state is available
+ int numSM =
+ state ? THCState_getCurrentDeviceProperties(state)->multiProcessorCount : 15;
+ //X dim of grid for cat array cooperates on a single tensor in the cat.
+ //Given half of the GPU, full utilization will always occur.
+ grid = dim3( 2LL * numSM, (long long) nTensors );
+
+ return true;
+}
+
// Similar to any other IndexToOffset calculation for copying along a given dimension.
template <typename IndexType, int Dims>
struct CatArrIndexToOffset {
@@ -77,6 +95,9 @@ struct OutputTensorSizeStride {
*
* The most important assumption made is that the input tensors are contiguous.
*/
+
+
+
template <typename T, typename IndexType, int Dims>
__global__ void CatArrayBatchedCopy(
T* output,
@@ -84,19 +105,26 @@ __global__ void CatArrayBatchedCopy(
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {
- T* data = inputs[blockIdx.y].input;
- IndexType offset = inputs[blockIdx.y].offset;
- IndexType dimSize = inputs[blockIdx.y].dimSize;
- IndexType nElements = inputs[blockIdx.y].nElements;
- IndexType dataOffset = offset * dimStride;
-
- for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
- linearIndex < nElements;
- linearIndex += gridDim.x * blockDim.x) {
+
+ IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
+ IndexType nElements = inputs[blockIdx.y].nElements;
+
+ if(tid >= nElements) return;
+
+ T* data = inputs[blockIdx.y].input;
+ IndexType offset = inputs[blockIdx.y].offset;
+ IndexType dimSize = inputs[blockIdx.y].dimSize;
+ IndexType dataOffset = offset * dimStride;
+
+ IndexType stride = gridDim.x * blockDim.x;
+
+ while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
- os.outputSize, os.outputStride, dimSize, concatDim, linearIndex);
- output[dataOffset + elementOffset] = data[linearIndex];
- }
+ os.outputSize, os.outputStride, dimSize, concatDim, tid);
+ output[dataOffset + elementOffset] = data[tid];
+
+ tid += stride;
+ }
}
#endif
diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu
index 628240a..ceb6f2d 100644
--- a/lib/THC/generic/THCTensorMath.cu
+++ b/lib/THC/generic/THCTensorMath.cu
@@ -207,7 +207,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
// Template Declarations for dim = 1, 2, 3, 4
#define HANDLE_CASE(DIMS) \
- CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock, 0, stream->stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]);
+ CatArrayBatchedCopy<real, unsigned int, DIMS><<<catGrid, applyBlock, 0, stream->stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]);
// Now we loop
offset = 0;
@@ -243,15 +243,12 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
// is based on.
dim3 applyBlock = getApplyBlock();
- // We also re-use the applyGrid - but note that we use the maximum number of
- // elements for a given tensor in this grouping to determine the count
- dim3 applyGrid;
- getApplyGrid(state, cohortMax, applyGrid);
+ //Get grid where x dim fills half gpu and y dim is number of tensors.
+ //This will have cating two tensors fill the entire grid, but prevent
+ //many threads from needlessly load meta data if their sizes is small.
+ dim3 catGrid;
+ getCatGrid(state, j, catGrid);
- // Next, we set our grid's y component to be the number of tensors in
- // the batch. This will allow the kernel to determine which input
- // tensor it is responsible for copying
- applyGrid.y = j;
switch (maxDim) {
case 1: