diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-10-17 02:08:03 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-17 02:08:03 +0300 |
commit | bde012eb73552a60aa3b567d0d562793b234dc97 (patch) | |
tree | 049c683cee6c4f8788621c6c9b29ce2b7f6688c9 | |
parent | f479a676bba3979d7606cf517bb3d855ad38859f (diff) | |
parent | 157cb3f91df758cc147c7760b25c56f44834c776 (diff) |
Merge pull request #550 from colesbury/streams
Add stream API that is not based on indices
-rw-r--r-- | FFI.lua | 9 | ||||
-rw-r--r-- | init.c | 3 | ||||
-rw-r--r-- | lib/THC/CMakeLists.txt | 2 | ||||
-rw-r--r-- | lib/THC/THC.h | 1 | ||||
-rw-r--r-- | lib/THC/THCGeneral.c | 159 | ||||
-rw-r--r-- | lib/THC/THCGeneral.h.in | 8 | ||||
-rw-r--r-- | lib/THC/THCReduceAll.cuh | 26 | ||||
-rw-r--r-- | lib/THC/THCStream.c | 30 | ||||
-rw-r--r-- | lib/THC/THCStream.h | 19 | ||||
-rw-r--r-- | lib/THC/THCTensorCopy.cu | 10 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorCopy.c | 6 |
11 files changed, 207 insertions, 66 deletions
@@ -8,8 +8,15 @@ struct cublasContext; typedef struct cublasContext *cublasHandle_t; typedef struct CUhandle_st *cublasHandle_t; +typedef struct _THCStream { + cudaStream_t stream; + int device; + int refcount; +} THCStream; + + typedef struct _THCCudaResourcesPerDevice { - cudaStream_t* streams; + THCStream** streams; cublasHandle_t* blasHandles; size_t scratchSpacePerStream; void** devScratchSpacePerStream; @@ -408,8 +408,7 @@ static int cutorch_getBlasHandle(lua_State *L) static int cutorch_setDefaultStream(lua_State *L) { THCState *state = cutorch_getstate(L); - THCState_setCurrentStreamIndex(state, 0); - + THCState_setStream(state, NULL); return 0; } diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt index b9ddfbe..cbd9cba 100644 --- a/lib/THC/CMakeLists.txt +++ b/lib/THC/CMakeLists.txt @@ -128,6 +128,7 @@ SET(src THCGeneral.c THCStorage.c THCStorageCopy.c + THCStream.c THCTensor.c THCTensorCopy.c THCThreadLocal.c @@ -200,6 +201,7 @@ INSTALL(FILES THCBlas.h THCStorage.h THCStorageCopy.h + THCStream.h THCTensor.h THCTensorCopy.h THCTensorRandom.h diff --git a/lib/THC/THC.h b/lib/THC/THC.h index 14cc342..b13eb39 100644 --- a/lib/THC/THC.h +++ b/lib/THC/THC.h @@ -6,6 +6,7 @@ #include "THCBlas.h" #include "THCStorage.h" #include "THCStorageCopy.h" +#include "THCStream.h" #include "THCTensor.h" #include "THCTensorCopy.h" #include "THCTensorRandom.h" diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c index 7cc7818..1d4622b 100644 --- a/lib/THC/THCGeneral.c +++ b/lib/THC/THCGeneral.c @@ -4,6 +4,7 @@ #include "THCBlas.h" #include "THCAllocator.h" #include "THCThreadLocal.h" +#include "THCStream.h" #include <stdlib.h> #include <stdint.h> @@ -12,7 +13,7 @@ typedef struct _THCCudaResourcesPerDevice { - cudaStream_t* streams; + THCStream** streams; cublasHandle_t* blasHandles; /* Size of scratch space per each stream on this device available */ size_t scratchSpacePerStream; @@ -40,10 +41,11 @@ struct THCState { THAllocator* cudaHostAllocator; THCDeviceAllocator* cudaDeviceAllocator; - /* Index of the current selected per-device resource. Actual CUDA resource - changes based on the current device, since resources are per-device */ - THCThreadLocal/*<int>*/ currentPerDeviceStream; + /* Index of the current selected BLAS handle. The actual BLAS handle used + depends on the current device. */ THCThreadLocal/*<int>*/ currentPerDeviceBlasHandle; + /* Array of thread locals containing the current stream for each device */ + THCThreadLocal* currentStreams; /* Table of enabled peer-to-peer access between directed pairs of GPUs. If i accessing allocs on j is enabled, p2pAccess[i][j] is 1; 0 otherwise. */ @@ -110,7 +112,10 @@ void THCudaInit(THCState* state) THCudaCheck(cudaGetDevice(&device)); /* Start in the default stream on the current device */ - state->currentPerDeviceStream = THCThreadLocal_alloc(); + state->currentStreams = (THCThreadLocal*) malloc(numDevices * sizeof(THCThreadLocal)); + for (int i = 0; i < numDevices; ++i) { + state->currentStreams[i] = THCThreadLocal_alloc(); + } state->currentPerDeviceBlasHandle = THCThreadLocal_alloc(); state->resourcesPerDevice = (THCCudaResourcesPerDevice*) @@ -134,6 +139,10 @@ void THCudaInit(THCState* state) THCudaCheck(cudaSetDevice(i)); THCudaCheck(cudaGetDeviceProperties(&state->deviceProperties[i], i)); + // Allocate space for the NULL stream + res->streams = (THCStream**) malloc(sizeof(THCStream*)); + res->streams[0] = NULL; + /* The scratch space that we want to have available per each device is based on the number of SMs available per device */ int numSM = state->deviceProperties[i].multiProcessorCount; @@ -182,10 +191,10 @@ void THCudaShutdown(THCState* state) /* cleanup per-device state */ for (int dev = 0; dev < deviceCount; ++dev) { THCudaCheck(cudaSetDevice(dev)); - /* Free Torch-defined streams (0 is the default stream) */ - for (int stream = 1; stream <= state->numUserStreams; ++stream) { - THCudaCheck(cudaStreamDestroy( - THCState_getDeviceStream(state, dev, stream))); + THCCudaResourcesPerDevice* res = &(state->resourcesPerDevice[dev]); + /* Free user reserved streams (0 is the default stream) */ + for (int i = 1; i <= state->numUserStreams; ++i) { + THCStream_free(res->streams[i]); } /* Free Torch-defined handles (0 is NULL for consistency with streams API) */ for (int handle = 1; handle <= state->numUserBlasHandles; ++handle) { @@ -198,15 +207,17 @@ void THCudaShutdown(THCState* state) THCudaCheck(THCudaFree(state, THCState_getDeviceScratchSpace(state, dev, stream))); } - free(state->resourcesPerDevice[dev].streams); - free(state->resourcesPerDevice[dev].blasHandles); - free(state->resourcesPerDevice[dev].devScratchSpacePerStream); + free(res->streams); + free(res->blasHandles); + free(res->devScratchSpacePerStream); + THCStream_free((THCStream*)THCThreadLocal_get(state->currentStreams[dev])); + THCThreadLocal_free(state->currentStreams[dev]); } free(state->resourcesPerDevice); if (state->cudaDeviceAllocator->emptyCache) { state->cudaDeviceAllocator->emptyCache(state->cudaDeviceAllocator->state); } - THCThreadLocal_free(state->currentPerDeviceStream); + free(state->currentStreams); THCThreadLocal_free(state->currentPerDeviceBlasHandle); THCudaCheck(cudaSetDevice(prevDev)); @@ -369,22 +380,14 @@ void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking) /* Otherwise, we have to allocate a new set of streams and stream data */ for (int dev = 0; dev < state->numDevices; ++dev) { THCudaCheck(cudaSetDevice(dev)); + THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev); /* +1 for the default stream as well */ - cudaStream_t* newStreams = - (cudaStream_t*) malloc((numStreams + 1) * sizeof(cudaStream_t)); + THCStream** newStreams = realloc(res->streams, (numStreams + 1) * sizeof(THCStream*)); + THAssert(newStreams); - void** newScratchSpace = - (void**) malloc((numStreams + 1) * sizeof(void*)); - - /* Copy over old stream data - (0 is default stream, 1 ... numUserStreams are rest) */ - for (int stream = 0; stream <= state->numUserStreams; ++stream) { - newStreams[stream] = - THCState_getDeviceStream(state, dev, stream); - newScratchSpace[stream] = - THCState_getDeviceScratchSpace(state, dev, stream); - } + void** newScratchSpace = realloc(res->devScratchSpacePerStream, (numStreams + 1) * sizeof(void*)); + THAssert(newScratchSpace); /* Allocate new stream resources */ size_t scratchSpaceSize = THCState_getDeviceScratchSpaceSize(state, dev); @@ -392,16 +395,12 @@ void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking) nonBlocking ? cudaStreamNonBlocking : cudaStreamDefault; for (int stream = state->numUserStreams + 1; stream <= numStreams; ++stream) { - newStreams[stream] = NULL; - THCudaCheck(cudaStreamCreateWithFlags(newStreams + stream, flags)); + newStreams[stream] = THCStream_new(flags); newScratchSpace[stream] = NULL; THCudaCheck(THCudaMalloc(state, &newScratchSpace[stream], scratchSpaceSize)); } - THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev); - free(res->streams); res->streams = newStreams; - free(res->devScratchSpacePerStream); res->devScratchSpacePerStream = newScratchSpace; } @@ -473,14 +472,15 @@ THCCudaResourcesPerDevice* THCState_getDeviceResourcePtr( return &(state->resourcesPerDevice[device]); } -cudaStream_t THCState_getDeviceStream(THCState *state, int device, int stream) +cudaStream_t THCState_getDeviceStream(THCState *state, int device, int streamIndex) { - if (stream > state->numUserStreams || stream < 0) + if (streamIndex > state->numUserStreams || streamIndex < 0) { - THError("%d is not a stream", stream); + THError("%d is not a stream", streamIndex); } - return (THCState_getDeviceResourcePtr(state, device)->streams == NULL) ? 0 - : THCState_getDeviceResourcePtr(state, device)->streams[stream]; + THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device); + THCStream* stream = res->streams[streamIndex]; + return stream ? stream->stream : NULL; } cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int handle) @@ -493,6 +493,31 @@ cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int han return THCState_getDeviceResourcePtr(state, device)->blasHandles[handle]; } +static THCStream* THCState_getStreamOnDevice(THCState* state, int device) +{ + return (THCStream*) THCThreadLocal_get(state->currentStreams[device]); +} + +static void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream) +{ + if (stream) { + if (stream->device != device) { + THError("invalid stream; expected stream for device %d, but was on %d", + device, stream->device); + } + THCStream_retain(stream); + } + THCThreadLocal local = state->currentStreams[device]; + THCStream_free((THCStream*)THCThreadLocal_get(local)); + THCThreadLocal_set(local, stream); +} + +cudaStream_t THCState_getCurrentStreamOnDevice(THCState *state, int device) +{ + THCStream* stream = THCState_getStreamOnDevice(state, device); + return stream ? stream->stream : NULL; +} + cudaStream_t THCState_getCurrentStream(THCState *state) { /* This is called at the point of kernel execution. @@ -501,13 +526,7 @@ cudaStream_t THCState_getCurrentStream(THCState *state) if (state) { int device; THCudaCheck(cudaGetDevice(&device)); - - int streamIndex = THCState_getCurrentStreamIndex(state); - if (streamIndex == 0) { - return NULL; - } - - return THCState_getDeviceResourcePtr(state, device)->streams[streamIndex]; + return THCState_getCurrentStreamOnDevice(state, device); } else { /* assume default stream */ return NULL; @@ -532,8 +551,21 @@ cublasHandle_t THCState_getCurrentBlasHandle(THCState *state) int THCState_getCurrentStreamIndex(THCState *state) { - void* value = THCThreadLocal_get(state->currentPerDeviceStream); - return (int) (intptr_t) value; + THCStream* stream = THCState_getStream(state); + if (!stream) { + return 0; + } + + int device; + THCudaCheck(cudaGetDevice(&device)); + THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device); + for (int i = 0; i < state->numUserStreams; ++i) { + if (res->streams[i] == stream) { + return i; + } + } + + return -1; } int THCState_getCurrentBlasHandleIndex(THCState *state) @@ -545,9 +577,37 @@ int THCState_getCurrentBlasHandleIndex(THCState *state) return (int) (intptr_t) value; } -void THCState_setCurrentStreamIndex(THCState *state, int stream) +THCStream* THCState_getStream(THCState *state) +{ + int device; + THCudaCheck(cudaGetDevice(&device)); + return THCState_getStreamOnDevice(state, device); +} + +void THCState_setStream(THCState *state, THCStream *stream) { - THCThreadLocal_set(state->currentPerDeviceStream, (void*)(intptr_t)stream); + int device; + THCudaCheck(cudaGetDevice(&device)); + THCState_setStreamOnDevice(state, device, stream); +} + +void THCState_setCurrentStreamIndex(THCState *state, int streamIndex) +{ + if (streamIndex < 0 || streamIndex > state->numUserStreams) { + THError("%d is not a valid stream, valid range is: (0, %d)", streamIndex, + state->numUserStreams); + } + + int device; + for (device = 0; device < state->numDevices; ++device) { + THCStream* stream = NULL; + if (streamIndex != 0) { + THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device); + stream = res->streams[streamIndex]; + } + + THCState_setStreamOnDevice(state, device, stream); + } } void THCState_setCurrentBlasHandleIndex(THCState *state, int handle) @@ -565,7 +625,10 @@ void* THCState_getCurrentDeviceScratchSpace(THCState* state) int device = -1; THCudaCheck(cudaGetDevice(&device)); int stream = THCState_getCurrentStreamIndex(state); - + if (stream < 0) { + // new stream API + return NULL; + } return THCState_getDeviceScratchSpace(state, device, stream); } diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in index d2660e1..9135167 100644 --- a/lib/THC/THCGeneral.h.in +++ b/lib/THC/THCGeneral.h.in @@ -40,6 +40,7 @@ #endif struct THCRNGState; /* Random number generator state. */ +struct THCStream; typedef struct _THCDeviceAllocator { cudaError_t (*malloc)(void*, void**, size_t, cudaStream_t); @@ -87,8 +88,13 @@ THC_API int THCState_getNumDevices(THCState* state); THC_API void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking); THC_API int THCState_getNumStreams(THCState* state); -THC_API cudaStream_t THCState_getDeviceStream(THCState *state, int device, int stream); +/* Stream API */ +THC_API cudaStream_t THCState_getCurrentStreamOnDevice(THCState *state, int device); THC_API cudaStream_t THCState_getCurrentStream(THCState *state); +THC_API struct THCStream* THCState_getStream(THCState *state); +THC_API void THCState_setStream(THCState *state, struct THCStream* stream); +/* deprecated stream API */ +THC_API cudaStream_t THCState_getDeviceStream(THCState *state, int device, int stream); THC_API int THCState_getCurrentStreamIndex(THCState *state); THC_API void THCState_setCurrentStreamIndex(THCState *state, int stream); diff --git a/lib/THC/THCReduceAll.cuh b/lib/THC/THCReduceAll.cuh index a9cea84..6b00498 100644 --- a/lib/THC/THCReduceAll.cuh +++ b/lib/THC/THCReduceAll.cuh @@ -186,13 +186,21 @@ void callReduceAll(THCState* state, dim3 block; if (isTwoPassReductionSize(totalElements)) { + bool freeScratchSpace = false; + void* scratchSpace = THCState_getCurrentDeviceScratchSpace(state); + if (!scratchSpace) { + THCudaCheck(THCudaMalloc(state, &scratchSpace, + THCState_getCurrentDeviceScratchSpaceSize(state))); + freeScratchSpace = true; + } + getPass1ReduceBlockGrid<InT, AccT>(state, totalElements, grid, block); size_t smemSize = block.x * sizeof(AccT); kernelReduceAllPass1<ModifyOp, ReduceOp, ReduceAccOp, InT, AccT, IndexType, ADims> <<<grid, block, smemSize, THCState_getCurrentStream(state)>>>( in, (IndexType) totalElements, init, modifyOp, reduceOp, reduceAccOp, - (AccT*) THCState_getCurrentDeviceScratchSpace(state)); + (AccT*) scratchSpace); int numPass1Blocks = grid.x; getPass2ReduceBlockGrid<InT, AccT>(state, totalElements, grid, block); @@ -201,9 +209,11 @@ void callReduceAll(THCState* state, kernelReduceAllPass2<ReduceAccOp, AccT, IndexType> <<<grid, block, smemSize, THCState_getCurrentStream(state)>>>( numPass1Blocks, init, reduceAccOp, - (AccT*) THCState_getCurrentDeviceScratchSpace(state), - devOut); + (AccT*) scratchSpace, devOut); + if (freeScratchSpace) { + THCudaCheck(THCudaFree(state, scratchSpace)); + } } else { getSinglePassReduceBlockGrid<InT, AccT>(totalElements, grid, block); size_t smemSize = block.x * sizeof(AccT); @@ -241,11 +251,17 @@ bool THC_reduceAll(THCState* state, return true; } + bool freeDevOut = false; AccT* devOut = out; if (!outOnDevice) { // Use the stream-specific scratch space for the reduction kernel // to write out its value devOut = (AccT*) THCState_getCurrentDeviceScratchSpace(state); + if (!devOut) { + THCudaCheck(THCudaMalloc(state, (void**)&devOut, + THCState_getCurrentDeviceScratchSpaceSize(state))); + freeDevOut = true; + } } // It is possible that the tensor dimensions are able to be collapsed, @@ -313,6 +329,10 @@ bool THC_reduceAll(THCState* state, cudaMemcpy(out, devOut, sizeof(AccT), cudaMemcpyDeviceToHost); } + if (freeDevOut) { + THCudaCheck(THCudaFree(state, devOut)); + } + return true; } diff --git a/lib/THC/THCStream.c b/lib/THC/THCStream.c new file mode 100644 index 0000000..e261a51 --- /dev/null +++ b/lib/THC/THCStream.c @@ -0,0 +1,30 @@ +#include "THCStream.h" + +#include <cuda_runtime_api.h> +#include "THAtomic.h" + + +THCStream* THCStream_new(int flags) +{ + THCStream* self = (THCStream*) malloc(sizeof(THCStream)); + self->refcount = 1; + THCudaCheck(cudaGetDevice(&self->device)); + THCudaCheck(cudaStreamCreateWithFlags(&self->stream, flags)); + return self; +} + +void THCStream_free(THCStream* self) +{ + if (!self) { + return; + } + if (THAtomicDecrementRef(&self->refcount)) { + THCudaCheck(cudaStreamDestroy(self->stream)); + free(self); + } +} + +void THCStream_retain(THCStream* self) +{ + THAtomicIncrementRef(&self->refcount); +} diff --git a/lib/THC/THCStream.h b/lib/THC/THCStream.h new file mode 100644 index 0000000..7e4bb49 --- /dev/null +++ b/lib/THC/THCStream.h @@ -0,0 +1,19 @@ +#ifndef THC_STREAM_INC +#define THC_STREAM_INC + +#include <cuda_runtime_api.h> +#include "THCGeneral.h" + +typedef struct THCStream +{ + cudaStream_t stream; + int device; + int refcount; +} THCStream; + + +THC_API THCStream* THCStream_new(int flags); +THC_API void THCStream_free(THCStream* self); +THC_API void THCStream_retain(THCStream* self); + +#endif // THC_STREAM_INC diff --git a/lib/THC/THCTensorCopy.cu b/lib/THC/THCTensorCopy.cu index b0001c5..8889939 100644 --- a/lib/THC/THCTensorCopy.cu +++ b/lib/THC/THCTensorCopy.cu @@ -65,12 +65,8 @@ THC_copyTensor(THCState* state, TensorTypeDst* dst, TensorTypeSrc* src) { // user to add needed synchronization on the dst device, since the // stream on the dst device that wishes to synchronize may not be // the same index as the one on the src device. - int copyStreamIndex = - THCState_getCurrentStreamIndex(state); - cudaStream_t copyStream = - THCState_getDeviceStream(state, srcDev, copyStreamIndex); - - if (srcDev != dstDev && copyStreamIndex == 0) { + cudaStream_t copyStream = THCState_getCurrentStreamOnDevice(state, srcDev); + if (srcDev != dstDev && copyStream == NULL) { // This is a cross-device copy on the default stream. We perform a // two-way barrier between both devices' default streams before // the copy. This ensures that any write-after-write and @@ -182,7 +178,7 @@ THC_copyTensor(THCState* state, TensorTypeDst* dst, TensorTypeSrc* src) { } } - if (srcDev != dstDev && copyStreamIndex == 0) { + if (srcDev != dstDev && copyStream == NULL) { // dst waits on src barrier (dst already waits on dst). We cannot // operate on dst's copy until the copy is complete. diff --git a/lib/THC/generic/THCTensorCopy.c b/lib/THC/generic/THCTensorCopy.c index e0bcadd..64f8364 100644 --- a/lib/THC/generic/THCTensorCopy.c +++ b/lib/THC/generic/THCTensorCopy.c @@ -153,8 +153,7 @@ void THCTensor_(copyAsyncCPU)(THCState *state, THCTensor *self, struct THTensor THTensor_(data)(src), THTensor_(nElement)(src) * sizeof(real), cudaMemcpyHostToDevice, - THCState_getDeviceStream(state, tensorDevice, - THCState_getCurrentStreamIndex(state)))); + THCState_getCurrentStream(state))); if (currentDevice != tensorDevice) { THCudaCheck(cudaSetDevice(currentDevice)); @@ -182,8 +181,7 @@ void THTensor_(copyAsyncCuda)(THCState *state, THTensor *self, struct THCTensor THCTensor_(data)(state, src), THCTensor_(nElement)(state, src) * sizeof(real), cudaMemcpyDeviceToHost, - THCState_getDeviceStream(state, tensorDevice, - THCState_getCurrentStreamIndex(state)))); + THCState_getCurrentStream(state))); if (currentDevice != tensorDevice) { THCudaCheck(cudaSetDevice(currentDevice)); |