From 951070d355b2c1d3f285973dc27d8c0ec53c167a Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Mon, 8 May 2017 09:02:15 -0700 Subject: Make torch.cat not synchronize the host and device --- lib/THC/THCGeneral.c | 22 ++++++++++++++++- lib/THC/THCGeneral.h.in | 4 ++++ lib/THC/generic/THCTensorMath.cu | 51 ++++++++++++++++++---------------------- 3 files changed, 48 insertions(+), 29 deletions(-) diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c index 26ba750..e99487e 100644 --- a/lib/THC/THCGeneral.c +++ b/lib/THC/THCGeneral.c @@ -848,6 +848,27 @@ cudaError_t THCudaFree(THCState *state, void *ptr) return allocator->free(allocator->state, ptr); } +void* THCudaHostAlloc(THCState *state, size_t size) +{ + THCudaCheck(cudaGetLastError()); + THAllocator* allocator = state->cudaHostAllocator; + return allocator->malloc(NULL, size); +} + +void THCudaHostFree(THCState *state, void *ptr) +{ + THAllocator* allocator = state->cudaHostAllocator; + return allocator->free(NULL, ptr); +} + +void THCudaHostRecord(THCState *state, void *ptr) +{ + if (state->cudaHostAllocator == &THCCachingHostAllocator) { + THCStream* stream = THCState_getStream(state); + THCCachingHostAllocator_recordEvent(ptr, stream); + } +} + cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes) { size_t cachedBytes = 0; @@ -932,4 +953,3 @@ float THC_half2float(half h) TH_halfbits2float(&h.x, &f); return f; } - diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in index d718f7e..f33446d 100644 --- a/lib/THC/THCGeneral.h.in +++ b/lib/THC/THCGeneral.h.in @@ -204,6 +204,10 @@ THC_API void __THCusparseCheck(cusparseStatus_t status, const char *file, const THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size); THC_API cudaError_t THCudaFree(THCState *state, void *ptr); +THC_API void* THCudaHostAlloc(THCState *state, size_t size); +THC_API void THCudaHostFree(THCState *state, void *ptr); +THC_API void THCudaHostRecord(THCState *state, void *ptr); + THC_API cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes); THC_API void THCSetGCHandler(THCState *state, void (*torchGCHandlerFunction)(void *data), diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index 4c609ba..0eed5a9 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -175,23 +175,9 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, real *data = THCTensor_(data)(state, result); // Kernel Parameter - CatArrInputTensor stackInputs[CAT_ARRAY_BATCH_SIZE]; - CatArrInputTensor *d_inputs; - - // Attempt to re-use stream's scratch space for the input metadata - bool usedScratch = false; size_t tensorMetadataSize = sizeof(CatArrInputTensor) * CAT_ARRAY_BATCH_SIZE; - if (THCState_getCurrentDeviceScratchSpaceSize(state) > tensorMetadataSize) { - void* space = THCState_getCurrentDeviceScratchSpace(state); - if (space) { - d_inputs = (CatArrInputTensor *) space; - usedScratch = true; - } - } - if (!usedScratch) { - // Fallback to allocating GPU memory - THCudaCheck(THCudaMalloc(state, (void**) &d_inputs, tensorMetadataSize)); - } + CatArrInputTensor *d_inputs; + THCudaCheck(THCudaMalloc(state, (void**) &d_inputs, tensorMetadataSize)); OutputTensorSizeStride param; @@ -201,13 +187,17 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, param.outputStride[i] = THCTensor_(stride)(state, result, i); } + THCStream* stream = THCState_getStream(state); + // Template Declarations for dim = 1, 2, 3, 4 #define HANDLE_CASE(DIMS) \ - CatArrayBatchedCopy<<>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]); + CatArrayBatchedCopy<<stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]); // Now we loop offset = 0; for (i = 0; i < numInputs; i += CAT_ARRAY_BATCH_SIZE) { + // Re-allocate stackInputs every iteration to avoid read-after-write hazard + CatArrInputTensor* stackInputs = (CatArrInputTensor*) THCudaHostAlloc(state, tensorMetadataSize); cohortMax = 0; for (j = 0; j < CAT_ARRAY_BATCH_SIZE && (i+j) < numInputs; ++j) { long dimSize = cat_dimension < THCTensor_(nDimension)(state, inputs[i+j]) @@ -223,7 +213,14 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, // update offset offset += dimSize; } - THCudaCheck(cudaMemcpy(d_inputs, stackInputs, j * sizeof(CatArrInputTensor), cudaMemcpyHostToDevice)); + THCudaCheck(cudaMemcpyAsync( + d_inputs, + stackInputs, + j * sizeof(CatArrInputTensor), + cudaMemcpyHostToDevice, + stream->stream)); + THCudaHostRecord(state, stackInputs); + THCudaHostFree(state, stackInputs); // Next, let's consider how we set our kernel launch parameters. // We borrow from THCApply, which the kernel's internal indexing @@ -256,9 +253,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, } THCudaCheck(cudaGetLastError()); } - if (!usedScratch) { - THCudaCheck(THCudaFree(state, (void *)d_inputs)); - } + THCudaCheck(THCudaFree(state, d_inputs)); #undef HANDLE_CASE } else { offset = 0; @@ -399,10 +394,10 @@ void THCTensor_(linspace)(THCState *state, THCTensor *r_, real a, real b, long n if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n); if (n == 1) THCTensor_(fill)(state, r_, a); else { - THCTensor *r = THCTensor_(isContiguous)(state, r_) + THCTensor *r = THCTensor_(isContiguous)(state, r_) ? r_ // if r_ is contiguous we can direct work on it : THCTensor_(newContiguous)(state, r_); - real step = THCNumerics::div(THCNumerics::sub(b, a), + real step = THCNumerics::div(THCNumerics::sub(b, a), ScalarConvert::to(n - 1)); LinspaceOp linspace_method(a, step); thrust::device_ptr data_(THCTensor_(data)(state, r)); @@ -420,10 +415,10 @@ void THCTensor_(logspace)(THCState *state, THCTensor *r_, real a, real b, long n if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n); if (n == 1) THCTensor_(fill)(state, r_, THCNumerics::exp10(a)); else { - THCTensor *r = THCTensor_(isContiguous)(state, r_) - ? r_ + THCTensor *r = THCTensor_(isContiguous)(state, r_) + ? r_ : THCTensor_(newContiguous)(state, r_); - real step = THCNumerics::div(THCNumerics::sub(b, a), + real step = THCNumerics::div(THCNumerics::sub(b, a), ScalarConvert::to(n - 1)); LogspaceOp logspace_method(a, step); thrust::device_ptr data_(THCTensor_(data)(state, r)); @@ -444,8 +439,8 @@ void THCTensor_(range)(THCState *state, THCTensor *r_, accreal xmin, accreal xma , 2, "upper bound and larger bound incoherent with step sign"); ptrdiff_t size = (ptrdiff_t) (((xmax - xmin) / step) + 1); if (THCTensor_(nElement)(state, r_) != size) THCTensor_(resize1d)(state, r_, size); - THCTensor *r = THCTensor_(isContiguous)(state, r_) - ? r_ + THCTensor *r = THCTensor_(isContiguous)(state, r_) + ? r_ : THCTensor_(newContiguous)(state, r_); LinspaceOp linspace_method(xmin, step); thrust::device_ptr data_(THCTensor_(data)(state, r)); -- cgit v1.2.3