diff options
Diffstat (limited to 'lib/THC/generic/THCTensorMath.cu')
-rw-r--r-- | lib/THC/generic/THCTensorMath.cu | 51 |
1 files changed, 23 insertions, 28 deletions
diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index 4c609ba..0eed5a9 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -175,23 +175,9 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, real *data = THCTensor_(data)(state, result); // Kernel Parameter - CatArrInputTensor<real, unsigned int> stackInputs[CAT_ARRAY_BATCH_SIZE]; - CatArrInputTensor<real, unsigned int> *d_inputs; - - // Attempt to re-use stream's scratch space for the input metadata - bool usedScratch = false; size_t tensorMetadataSize = sizeof(CatArrInputTensor<real, unsigned int>) * CAT_ARRAY_BATCH_SIZE; - if (THCState_getCurrentDeviceScratchSpaceSize(state) > tensorMetadataSize) { - void* space = THCState_getCurrentDeviceScratchSpace(state); - if (space) { - d_inputs = (CatArrInputTensor<real, unsigned int> *) space; - usedScratch = true; - } - } - if (!usedScratch) { - // Fallback to allocating GPU memory - THCudaCheck(THCudaMalloc(state, (void**) &d_inputs, tensorMetadataSize)); - } + CatArrInputTensor<real, unsigned int> *d_inputs; + THCudaCheck(THCudaMalloc(state, (void**) &d_inputs, tensorMetadataSize)); OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param; @@ -201,13 +187,17 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, param.outputStride[i] = THCTensor_(stride)(state, result, i); } + THCStream* stream = THCState_getStream(state); + // Template Declarations for dim = 1, 2, 3, 4 #define HANDLE_CASE(DIMS) \ - CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock, 0, THCState_getCurrentStream(state)>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]); + CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock, 0, stream->stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]); // Now we loop offset = 0; for (i = 0; i < numInputs; i += CAT_ARRAY_BATCH_SIZE) { + // Re-allocate stackInputs every iteration to avoid read-after-write hazard + CatArrInputTensor<real, unsigned int>* stackInputs = (CatArrInputTensor<real, unsigned int>*) THCudaHostAlloc(state, tensorMetadataSize); cohortMax = 0; for (j = 0; j < CAT_ARRAY_BATCH_SIZE && (i+j) < numInputs; ++j) { long dimSize = cat_dimension < THCTensor_(nDimension)(state, inputs[i+j]) @@ -223,7 +213,14 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, // update offset offset += dimSize; } - THCudaCheck(cudaMemcpy(d_inputs, stackInputs, j * sizeof(CatArrInputTensor<real, unsigned int>), cudaMemcpyHostToDevice)); + THCudaCheck(cudaMemcpyAsync( + d_inputs, + stackInputs, + j * sizeof(CatArrInputTensor<real, unsigned int>), + cudaMemcpyHostToDevice, + stream->stream)); + THCudaHostRecord(state, stackInputs); + THCudaHostFree(state, stackInputs); // Next, let's consider how we set our kernel launch parameters. // We borrow from THCApply, which the kernel's internal indexing @@ -256,9 +253,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, } THCudaCheck(cudaGetLastError()); } - if (!usedScratch) { - THCudaCheck(THCudaFree(state, (void *)d_inputs)); - } + THCudaCheck(THCudaFree(state, d_inputs)); #undef HANDLE_CASE } else { offset = 0; @@ -399,10 +394,10 @@ void THCTensor_(linspace)(THCState *state, THCTensor *r_, real a, real b, long n if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n); if (n == 1) THCTensor_(fill)(state, r_, a); else { - THCTensor *r = THCTensor_(isContiguous)(state, r_) + THCTensor *r = THCTensor_(isContiguous)(state, r_) ? r_ // if r_ is contiguous we can direct work on it : THCTensor_(newContiguous)(state, r_); - real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a), + real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a), ScalarConvert<long,real>::to(n - 1)); LinspaceOp<real> linspace_method(a, step); thrust::device_ptr<real> data_(THCTensor_(data)(state, r)); @@ -420,10 +415,10 @@ void THCTensor_(logspace)(THCState *state, THCTensor *r_, real a, real b, long n if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n); if (n == 1) THCTensor_(fill)(state, r_, THCNumerics<real>::exp10(a)); else { - THCTensor *r = THCTensor_(isContiguous)(state, r_) - ? r_ + THCTensor *r = THCTensor_(isContiguous)(state, r_) + ? r_ : THCTensor_(newContiguous)(state, r_); - real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a), + real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a), ScalarConvert<long,real>::to(n - 1)); LogspaceOp<real> logspace_method(a, step); thrust::device_ptr<real> data_(THCTensor_(data)(state, r)); @@ -444,8 +439,8 @@ void THCTensor_(range)(THCState *state, THCTensor *r_, accreal xmin, accreal xma , 2, "upper bound and larger bound incoherent with step sign"); ptrdiff_t size = (ptrdiff_t) (((xmax - xmin) / step) + 1); if (THCTensor_(nElement)(state, r_) != size) THCTensor_(resize1d)(state, r_, size); - THCTensor *r = THCTensor_(isContiguous)(state, r_) - ? r_ + THCTensor *r = THCTensor_(isContiguous)(state, r_) + ? r_ : THCTensor_(newContiguous)(state, r_); LinspaceOp<real,accreal> linspace_method(xmin, step); thrust::device_ptr<real> data_(THCTensor_(data)(state, r)); |