diff options
author | Sam Gross <sgross@fb.com> | 2016-10-14 20:08:56 +0300 |
---|---|---|
committer | Sam Gross <sgross@fb.com> | 2016-10-14 20:08:56 +0300 |
commit | fa5739ea0ee26d6ecfc7b3ab1647b32f6fa91a39 (patch) | |
tree | 9091d630f884637277a2b94bb2ad01f4a4617db4 /lib/THC | |
parent | 073ba88337ff7a584abc6ee11a0f06a0696d8a69 (diff) |
Fix caching allocator when used from multiple Lua threads
Use a single, global THCCachingAllocator instance.
Previously, each Lua thread had its own THCCachingAllocator instance.
However, threads can share storages, which means a segment could be
allocated from on THCCachingAllocator and freed on another, which
breaks.
Fixes #539
Diffstat (limited to 'lib/THC')
-rw-r--r-- | lib/THC/THCCachingAllocator.cpp | 57 | ||||
-rw-r--r-- | lib/THC/THCCachingAllocator.h | 2 | ||||
-rw-r--r-- | lib/THC/THCGeneral.c | 58 | ||||
-rw-r--r-- | lib/THC/THCGeneral.h.in | 4 |
4 files changed, 69 insertions, 52 deletions
diff --git a/lib/THC/THCCachingAllocator.cpp b/lib/THC/THCCachingAllocator.cpp index 93f8327..73b81f6 100644 --- a/lib/THC/THCCachingAllocator.cpp +++ b/lib/THC/THCCachingAllocator.cpp @@ -20,6 +20,13 @@ // - Large (>1MB) and small allocation requestss are handled separately. Large // allocation requests can be filled by a cudaMalloc call of the exact size. // Small requests will allocate and split a 1MB buffer, if necessary. +// +// With this allocator, allocations and frees should logically be considered +// "usages" of the memory segment associated with streams, just like kernel +// launches. The programmer must insert the proper synchronization if memory +// segments are used from multiple streams. +// + namespace { @@ -78,6 +85,7 @@ struct THCCachingAllocator large_blocks(BlockComparator), small_blocks(BlockComparator) {} + /** allocates a block which is safe to use from the provided stream */ cudaError_t malloc(void** devPtr, size_t size, cudaStream_t stream) { std::lock_guard<std::mutex> lock(mutex); @@ -160,6 +168,22 @@ struct THCCachingAllocator return cudaSuccess; } + /** returns cached blocks to the system allocator */ + cudaError_t emptyCache() + { + std::lock_guard<std::mutex> lock(mutex); + cudaError_t err = free_blocks(large_blocks, large_blocks.begin(), large_blocks.end()); + if (err != cudaSuccess) { + return err; + } + err = free_blocks(small_blocks, small_blocks.begin(), small_blocks.end()); + if (err != cudaSuccess) { + return err; + } + return cudaSuccess; + } + + /** combine previously split blocks */ void try_merge_blocks(Block* dst, Block* src, FreeBlocks& free_blocks) { if (!src || src->allocated) { @@ -196,6 +220,8 @@ struct THCCachingAllocator cudaError_t cuda_malloc_retry(int device, void** devPtr, size_t size) { + // Try cudaMalloc. If cudaMalloc fails, frees all non-split cached blocks + // and retries. cudaError_t err = cudaMalloc(devPtr, size); if (err != cudaSuccess) { cudaGetLastError(); @@ -233,6 +259,7 @@ struct THCCachingAllocator cudaError_t free_blocks(FreeBlocks& blocks, FreeBlocks::iterator it, FreeBlocks::iterator end) { + // Frees all non-split blocks between `it` and `end` while (it != end) { Block* block = *it; if (!block->prev && !block->next) { @@ -264,27 +291,21 @@ static cudaError_t THCCachingAllocator_free(void* ctx, void* ptr) return a->free(ptr); } -static cudaError_t THCCachingAllocator_shutdown(void* ctx) +static cudaError_t THCCachingAllocator_emptyCache(void* ctx) { - cudaError_t err; THCCachingAllocator* a = (THCCachingAllocator*) ctx; - err = a->free_blocks(a->large_blocks, a->large_blocks.begin(), a->large_blocks.end()); - if (err != cudaSuccess) { - return err; - } - err = a->free_blocks(a->small_blocks, a->small_blocks.begin(), a->small_blocks.end()); - if (err != cudaSuccess) { - return err; - } - delete a; - return cudaSuccess; + return a->emptyCache(); } -THC_API void THCCachingAllocator_init(THCDeviceAllocator* alloc) +static THCCachingAllocator caching_allocator; +static THCDeviceAllocator device_allocator = { + &THCCachingAllocator_malloc, + &THCCachingAllocator_free, + &THCCachingAllocator_emptyCache, + &caching_allocator +}; + +THC_API THCDeviceAllocator* THCCachingAllocator_get() { - THCCachingAllocator* allocator = new THCCachingAllocator(); - alloc->state = allocator; - alloc->malloc = &THCCachingAllocator_malloc; - alloc->free = &THCCachingAllocator_free; - alloc->shutdown = &THCCachingAllocator_shutdown; + return &device_allocator; } diff --git a/lib/THC/THCCachingAllocator.h b/lib/THC/THCCachingAllocator.h index 60ff52c..711b1da 100644 --- a/lib/THC/THCCachingAllocator.h +++ b/lib/THC/THCCachingAllocator.h @@ -3,6 +3,6 @@ #include "THCGeneral.h" -THC_API void THCCachingAllocator_init(THCDeviceAllocator* alloc); +THC_API THCDeviceAllocator* THCCachingAllocator_get(); #endif diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c index 5bcce19..7cc7818 100644 --- a/lib/THC/THCGeneral.c +++ b/lib/THC/THCGeneral.c @@ -38,7 +38,7 @@ struct THCState { /* Allocator using cudaMallocHost. */ THAllocator* cudaHostAllocator; - THCDeviceAllocator cudaDeviceAllocator; + THCDeviceAllocator* cudaDeviceAllocator; /* Index of the current selected per-device resource. Actual CUDA resource changes based on the current device, since resources are per-device */ @@ -67,8 +67,6 @@ struct THCState { THCCudaResourcesPerDevice* THCState_getDeviceResourcePtr( THCState *state, int device); -static void THCState_initDefaultDeviceAllocator(THCDeviceAllocator* a); - THCState* THCState_alloc() { THCState* state = (THCState*) malloc(sizeof(THCState)); @@ -81,10 +79,27 @@ void THCState_free(THCState* state) free(state); } +static cudaError_t cudaMallocWrapper(void* ctx, void** devPtr, size_t size, cudaStream_t stream) +{ + return cudaMalloc(devPtr, size); +} + +static cudaError_t cudaFreeWrapper(void* ctx, void* devPtr) +{ + return cudaFree(devPtr); +} + +static THCDeviceAllocator defaultDeviceAllocator = { + &cudaMallocWrapper, + &cudaFreeWrapper, + NULL, + NULL +}; + void THCudaInit(THCState* state) { - if (!state->cudaDeviceAllocator.malloc) { - THCState_initDefaultDeviceAllocator(&state->cudaDeviceAllocator); + if (!state->cudaDeviceAllocator) { + state->cudaDeviceAllocator = &defaultDeviceAllocator; } int numDevices = 0; @@ -188,7 +203,9 @@ void THCudaShutdown(THCState* state) free(state->resourcesPerDevice[dev].devScratchSpacePerStream); } free(state->resourcesPerDevice); - state->cudaDeviceAllocator.shutdown(state->cudaDeviceAllocator.state); + if (state->cudaDeviceAllocator->emptyCache) { + state->cudaDeviceAllocator->emptyCache(state->cudaDeviceAllocator->state); + } THCThreadLocal_free(state->currentPerDeviceStream); THCThreadLocal_free(state->currentPerDeviceBlasHandle); @@ -329,12 +346,11 @@ THAllocator* THCState_getCudaHostAllocator(THCState* state) return state->cudaHostAllocator; } -THCDeviceAllocator* THCState_getDeviceAllocator(THCState* state) +void THCState_setDeviceAllocator(THCState* state, THCDeviceAllocator* allocator) { - return &state->cudaDeviceAllocator; + state->cudaDeviceAllocator = allocator; } - int THCState_getNumDevices(THCState *state) { return state->numDevices; @@ -652,31 +668,11 @@ void THCSetGCHandler(THCState *state, void (*cutorchGCFunction_)(void *data), vo state->cutorchGCData = data; } -static cudaError_t cudaMallocWrapper(void* ctx, void** devPtr, size_t size, cudaStream_t stream) -{ - return cudaMalloc(devPtr, size); -} - -static cudaError_t cudaFreeWrapper(void* ctx, void* devPtr) -{ - return cudaFree(devPtr); -} - -static cudaError_t noop(void* ctx) { return cudaSuccess; } - -static void THCState_initDefaultDeviceAllocator(THCDeviceAllocator* a) -{ - a->malloc = &cudaMallocWrapper; - a->free = &cudaFreeWrapper; - a->shutdown = &noop; - a->state = NULL; -} - cudaError_t THCudaMalloc(THCState *state, void** ptr, size_t size) { THCudaCheck(cudaGetLastError()); cudaStream_t stream = THCState_getCurrentStream(state); - THCDeviceAllocator* allocator = &state->cudaDeviceAllocator; + THCDeviceAllocator* allocator = state->cudaDeviceAllocator; cudaError_t err = allocator->malloc(allocator->state, ptr, size, stream); if (state->cutorchGCFunction != NULL && err != cudaSuccess) { cudaGetLastError(); // reset OOM error @@ -688,7 +684,7 @@ cudaError_t THCudaMalloc(THCState *state, void** ptr, size_t size) cudaError_t THCudaFree(THCState *state, void *ptr) { - THCDeviceAllocator* allocator = &state->cudaDeviceAllocator; + THCDeviceAllocator* allocator = state->cudaDeviceAllocator; return allocator->free(allocator->state, ptr); } diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in index ce3ffc5..23b00af 100644 --- a/lib/THC/THCGeneral.h.in +++ b/lib/THC/THCGeneral.h.in @@ -41,7 +41,7 @@ struct THCRNGState; /* Random number generator state. */ typedef struct _THCDeviceAllocator { cudaError_t (*malloc)(void*, void**, size_t, cudaStream_t); cudaError_t (*free)(void*, void*); - cudaError_t (*shutdown)(void*); + cudaError_t (*emptyCache)(void*); void* state; } THCDeviceAllocator; @@ -75,7 +75,7 @@ THC_API struct cudaDeviceProp* THCState_getCurrentDeviceProperties(THCState* sta THC_API struct THCRNGState* THCState_getRngState(THCState* state); THC_API THAllocator* THCState_getCudaHostAllocator(THCState* state); -THC_API THCDeviceAllocator* THCState_getDeviceAllocator(THCState* state); +THC_API void THCState_setDeviceAllocator(THCState* state, THCDeviceAllocator* allocator); THC_API void THCMagma_init(THCState *state); |