Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2016-09-22 21:10:47 +0300
committerSam Gross <sgross@fb.com>2016-09-25 22:57:50 +0300
commit691b7b2e0b55b994e70609b7605a37fcf49f9fec (patch)
tree5a618842732ca3747faa740eebc81d64e07213e6 /lib
parent3624cb9a5be6646d172fa686ce423ff2c9c3b3e6 (diff)
Add CUDA caching allocator
The allocator can be enabled by setting the environment variable THC_CACHING_ALLOCATOR=1
Diffstat (limited to 'lib')
-rw-r--r--lib/THC/CMakeLists.txt29
-rw-r--r--lib/THC/THCCachingAllocator.cpp291
-rw-r--r--lib/THC/THCCachingAllocator.h8
-rw-r--r--lib/THC/THCGeneral.c36
-rw-r--r--lib/THC/THCGeneral.h.in8
5 files changed, 358 insertions, 14 deletions
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt
index 365d88a..f2eab04 100644
--- a/lib/THC/CMakeLists.txt
+++ b/lib/THC/CMakeLists.txt
@@ -103,9 +103,17 @@ ENDIF()
INCLUDE_DIRECTORIES("${CMAKE_CURRENT_BINARY_DIR}")
CONFIGURE_FILE(THCGeneral.h.in "${CMAKE_CURRENT_BINARY_DIR}/THCGeneral.h")
-SET(CMAKE_C_FLAGS "-std=c99 ${CMAKE_C_FLAGS}")
+SET(CMAKE_C_FLAGS "-std=c99 ${CMAKE_C_FLAGS} -g -O0")
+SET(CMAKE_CXX_STANDARD 11)
SET(src
- THCGeneral.c THCAllocator.c THCStorage.c THCStorageCopy.c THCTensor.c THCTensorCopy.c)
+ THCAllocator.c
+ THCCachingAllocator.cpp
+ THCGeneral.c
+ THCStorage.c
+ THCStorageCopy.c
+ THCTensor.c
+ THCTensorCopy.c
+ )
SET(src-cuda
THCReduceApplyUtils.cu
@@ -189,6 +197,7 @@ INSTALL(FILES
THCScanUtils.cuh
THCSortUtils.cuh
THCAllocator.h
+ THCCachingAllocator.h
THCDeviceUtils.cuh
THCDeviceTensor.cuh
THCDeviceTensor-inl.cuh
@@ -227,8 +236,8 @@ INSTALL(FILES
generic/THCTensorMasked.cu
generic/THCTensorMath.h
generic/THCTensorMath.cu
- generic/THCTensorMathBlas.cu
- generic/THCTensorMathBlas.h
+ generic/THCTensorMathBlas.cu
+ generic/THCTensorMathBlas.h
generic/THCTensorMathCompare.h
generic/THCTensorMathCompare.cu
generic/THCTensorMathCompareT.h
@@ -239,10 +248,10 @@ INSTALL(FILES
generic/THCTensorMathPointwise.cu
generic/THCTensorMathReduce.h
generic/THCTensorMathReduce.cu
- generic/THCTensorScatterGather.h
- generic/THCTensorScatterGather.cu
- generic/THCTensorIndex.h
- generic/THCTensorIndex.cu
- generic/THCTensorSort.h
- generic/THCTensorSort.cu
+ generic/THCTensorScatterGather.h
+ generic/THCTensorScatterGather.cu
+ generic/THCTensorIndex.h
+ generic/THCTensorIndex.cu
+ generic/THCTensorSort.h
+ generic/THCTensorSort.cu
DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC/generic")
diff --git a/lib/THC/THCCachingAllocator.cpp b/lib/THC/THCCachingAllocator.cpp
new file mode 100644
index 0000000..0688197
--- /dev/null
+++ b/lib/THC/THCCachingAllocator.cpp
@@ -0,0 +1,291 @@
+#include "THCCachingAllocator.h"
+
+#include <cuda_runtime_api.h>
+#include <map>
+#include <memory>
+#include <mutex>
+#include <set>
+#include <unordered_map>
+
+//
+// Yet another caching allocator for CUDA device allocations.
+//
+// - Allocations are associated with a stream. Once freed, blocks can be
+// re-allocated on the same stream, but not on any other stream.
+// - The allocator attempts to find the smallest cached block that will fit the
+// requested size. If the block is larger than the requested size, it may be
+// split. If no block is found, the allocator will delegate to cudaMalloc.
+// - If the cudaMalloc fails, the allocator will free all cached blocks that
+// are not split and retry the allocation.
+// - Large (>1MB) and small allocation requestss are handled separately. Large
+// allocation requests can be filled by a cudaMalloc call of the exact size.
+// Small requests will allocate and split a 1MB buffer, if necessary.
+
+namespace {
+
+const size_t kRoundSmall = 512; // round up small allocs to 512 bytes
+const size_t kRoundLarge = 131072; // round up large allocs to 128 KiB
+const size_t kSmallAlloc = 1048576; // largest "small" allocation is 1 MiB
+
+struct Block {
+ int device; // gpu
+ cudaStream_t stream; // allocation stream
+ size_t size; // block size in bytes
+ char* ptr; // memory address
+ bool allocated; // in-use flag
+ Block* prev; // prev block if split from a larger allocation
+ Block* next; // next block if split from a larger allocation
+
+ Block(int device, cudaStream_t stream, size_t size, char* ptr=NULL) :
+ device(device), stream(stream), size(size), ptr(ptr), allocated(0),
+ prev(NULL), next(NULL) { }
+};
+
+static bool BlockComparator(const Block* a, const Block* b)
+{
+ if (a->device != b->device) {
+ return a->device < b->device;
+ }
+ if (a->stream != b->stream) {
+ return (uintptr_t)a->stream < (uintptr_t)b->stream;
+ }
+ if (a->size != b->size) {
+ return a->size < b->size;
+ }
+ return (uintptr_t)a->ptr < (uintptr_t)b->ptr;
+}
+
+} // namespace
+
+struct THCCachingAllocator
+{
+ typedef bool (*Comparison)(const Block*, const Block*);
+ typedef std::set<Block*, Comparison> FreeBlocks;
+
+ // lock around malloc and free
+ std::mutex mutex;
+
+ // cached blocks larger than 1 MB
+ FreeBlocks large_blocks;
+
+ // cached blocks 1 MB or smaller
+ FreeBlocks small_blocks;
+
+ // allocated blocks by device pointer
+ std::unordered_map<char*, Block*> allocated_blocks;
+
+ THCCachingAllocator() :
+ large_blocks(BlockComparator),
+ small_blocks(BlockComparator) {}
+
+ cudaError_t malloc(void** devPtr, size_t size, cudaStream_t stream)
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+
+ int device;
+ cudaError_t err = cudaGetDevice(&device);
+ if (err != cudaSuccess) {
+ return err;
+ }
+
+ size = round_size(size);
+ bool small = size <= kSmallAlloc;
+
+ Block search_key(device, stream, size);
+ auto& free_blocks = small ? large_blocks : small_blocks;
+
+ Block* block = NULL;
+ Block* remaining = NULL;
+
+ auto it = free_blocks.lower_bound(&search_key);
+ if (it != free_blocks.end() && (*it)->device == device && (*it)->stream == stream) {
+ block = *it;
+ free_blocks.erase(it);
+ } else {
+ void* ptr;
+ size_t alloc_size = small ? kSmallAlloc : size;
+ cudaError_t err = cuda_malloc_retry(device, &ptr, alloc_size);
+ if (err != cudaSuccess) {
+ return err;
+ }
+ block = new Block(device, stream, alloc_size, (char*)ptr);
+ }
+
+ if (block->size - size >= (small ? kRoundSmall : kSmallAlloc + 1)) {
+ remaining = block;
+
+ block = new Block(device, stream, size, block->ptr);
+ block->prev = remaining->prev;
+ if (block->prev) {
+ block->prev->next = block;
+ }
+ block->next = remaining;
+
+ remaining->prev = block;
+ remaining->ptr += size;
+ remaining->size -= size;
+ free_blocks.insert(remaining);
+ }
+
+ block->allocated = true;
+ allocated_blocks[block->ptr] = block;
+
+ *devPtr = (void*)block->ptr;
+ return cudaSuccess;
+ }
+
+ cudaError_t free(void* ptr)
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+ if (!ptr) {
+ return cudaSuccess;
+ }
+
+ auto it = allocated_blocks.find((char*)ptr);
+ if (it == allocated_blocks.end()) {
+ return cudaErrorInvalidDevicePointer;
+ }
+
+ Block* block = it->second;
+ int device = block->device;
+ allocated_blocks.erase(it);
+
+ bool small = block->size <= kSmallAlloc;
+ auto& free_blocks = small ? large_blocks : small_blocks;
+ try_merge_blocks(block, block->prev, free_blocks);
+ try_merge_blocks(block, block->next, free_blocks);
+
+ block->allocated = false;
+ free_blocks.insert(block);
+
+ return cudaSuccess;
+ }
+
+ void try_merge_blocks(Block* dst, Block* src, FreeBlocks& free_blocks)
+ {
+ if (!src || src->allocated) {
+ return;
+ }
+ if (dst->prev == src) {
+ dst->ptr = src->ptr;
+ dst->prev = src->prev;
+ if (dst->prev) {
+ dst->prev->next = dst;
+ }
+ } else {
+ dst->next = src->next;
+ if (dst->next) {
+ dst->next->prev = dst;
+ }
+ }
+ dst->size += src->size;
+ free_blocks.erase(src);
+ delete src;
+ }
+
+ size_t round_size(size_t size)
+ {
+ if (size < kRoundSmall) {
+ size = kRoundSmall;
+ } else if (size < kSmallAlloc) {
+ size += kRoundSmall - 1 - (size - 1) % kRoundSmall;
+ } else {
+ size += kRoundLarge - 1 - (size - 1) % kRoundLarge;
+ }
+ return size;
+ }
+
+ cudaError_t cuda_malloc_retry(int device, void** devPtr, size_t size)
+ {
+ cudaError_t err = cudaMalloc(devPtr, size);
+ if (err != cudaSuccess) {
+ cudaGetLastError();
+ err = free_cached_blocks(device);
+ if (err != cudaSuccess) {
+ return err;
+ }
+ err = cudaMalloc(devPtr, size);
+ if (err != cudaSuccess) {
+ return err;
+ }
+ }
+ return cudaSuccess;
+ }
+
+ cudaError_t free_cached_blocks(int device)
+ {
+ // Free all non-split cached blocks on device
+ Block lower_bound(device, NULL, 0);
+ Block upper_bound(device + 1, NULL, 0);
+
+ cudaError_t err = free_blocks(
+ large_blocks,
+ large_blocks.lower_bound(&lower_bound),
+ large_blocks.lower_bound(&upper_bound));
+ if (err != cudaSuccess) {
+ return err;
+ }
+ err = free_blocks(
+ small_blocks,
+ small_blocks.lower_bound(&lower_bound),
+ small_blocks.lower_bound(&upper_bound));
+ return err;
+ }
+
+ cudaError_t free_blocks(FreeBlocks& blocks, FreeBlocks::iterator it, FreeBlocks::iterator end)
+ {
+ while (it != end) {
+ Block* block = *it;
+ if (!block->prev && !block->next) {
+ cudaError_t err = cudaFree((void*)block->ptr);
+ if (err != cudaSuccess) {
+ return err;
+ }
+ auto cur = it;
+ ++it;
+ blocks.erase(cur);
+ delete block;
+ } else {
+ ++it;
+ }
+ }
+ return cudaSuccess;
+ }
+};
+
+static cudaError_t THCCachingAllocator_malloc(void* ctx, void** ptr, size_t size, cudaStream_t stream)
+{
+ THCCachingAllocator* a = (THCCachingAllocator*) ctx;
+ return a->malloc(ptr, size, stream);
+}
+
+static cudaError_t THCCachingAllocator_free(void* ctx, void* ptr)
+{
+ THCCachingAllocator* a = (THCCachingAllocator*) ctx;
+ return a->free(ptr);
+}
+
+static cudaError_t THCCachingAllocator_shutdown(void* ctx)
+{
+ cudaError_t err;
+ THCCachingAllocator* a = (THCCachingAllocator*) ctx;
+ err = a->free_blocks(a->large_blocks, a->large_blocks.begin(), a->large_blocks.end());
+ if (err != cudaSuccess) {
+ return err;
+ }
+ err = a->free_blocks(a->small_blocks, a->small_blocks.begin(), a->small_blocks.end());
+ if (err != cudaSuccess) {
+ return err;
+ }
+ delete a;
+ return cudaSuccess;
+}
+
+THC_API void THCCachingAllocator_init(THCDeviceAllocator* alloc)
+{
+ THCCachingAllocator* allocator = new THCCachingAllocator();
+ alloc->state = allocator;
+ alloc->malloc = &THCCachingAllocator_malloc;
+ alloc->free = &THCCachingAllocator_free;
+ alloc->shutdown = &THCCachingAllocator_shutdown;
+}
diff --git a/lib/THC/THCCachingAllocator.h b/lib/THC/THCCachingAllocator.h
new file mode 100644
index 0000000..60ff52c
--- /dev/null
+++ b/lib/THC/THCCachingAllocator.h
@@ -0,0 +1,8 @@
+#ifndef THC_DEVICE_ALLOCATOR_INC
+#define THC_DEVICE_ALLOCATOR_INC
+
+#include "THCGeneral.h"
+
+THC_API void THCCachingAllocator_init(THCDeviceAllocator* alloc);
+
+#endif
diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c
index feb7223..dbe0be9 100644
--- a/lib/THC/THCGeneral.c
+++ b/lib/THC/THCGeneral.c
@@ -11,10 +11,15 @@
THCCudaResourcesPerDevice* THCState_getDeviceResourcePtr(
THCState *state, int device);
+static void THCState_initDefaultDeviceAllocator(THCDeviceAllocator* a);
+
void THCudaInit(THCState* state)
{
state->cutorchGCFunction = NULL;
state->cutorchGCData = NULL;
+ if (!state->cudaDeviceAllocator.malloc) {
+ THCState_initDefaultDeviceAllocator(&state->cudaDeviceAllocator);
+ }
int count = 0;
THCudaCheck(cudaGetDeviceCount(&count));
@@ -123,6 +128,7 @@ void THCudaShutdown(THCState* state)
free(state->resourcesPerDevice[dev].devScratchSpacePerStream);
}
free(state->resourcesPerDevice);
+ state->cudaDeviceAllocator.shutdown(state->cudaDeviceAllocator.state);
THCudaCheck(cudaSetDevice(prevDev));
}
@@ -603,22 +609,44 @@ void THCSetGCHandler(THCState *state, void (*cutorchGCFunction_)(void *data), vo
state->cutorchGCData = data;
}
+static cudaError_t cudaMallocWrapper(void* ctx, void** devPtr, size_t size, cudaStream_t stream)
+{
+ return cudaMalloc(devPtr, size);
+}
+
+static cudaError_t cudaFreeWrapper(void* ctx, void* devPtr)
+{
+ return cudaFree(devPtr);
+}
+
+static cudaError_t noop(void* ctx) { return cudaSuccess; }
+
+static void THCState_initDefaultDeviceAllocator(THCDeviceAllocator* a)
+{
+ a->malloc = &cudaMallocWrapper;
+ a->free = &cudaFreeWrapper;
+ a->shutdown = &noop;
+ a->state = NULL;
+}
+
cudaError_t THCudaMalloc(THCState *state, void** ptr, size_t size)
{
THCudaCheck(cudaGetLastError());
- cudaError_t err = cudaMalloc(ptr, size);
+ cudaStream_t stream = THCState_getCurrentStream(state);
+ THCDeviceAllocator* allocator = &state->cudaDeviceAllocator;
+ cudaError_t err = allocator->malloc(allocator->state, ptr, size, stream);
if (state->cutorchGCFunction != NULL && err != cudaSuccess) {
cudaGetLastError(); // reset OOM error
(state->cutorchGCFunction)(state->cutorchGCData);
- err = cudaMalloc(ptr, size);
+ err = allocator->malloc(allocator->state, ptr, size, stream);
}
return err;
}
cudaError_t THCudaFree(THCState *state, void *ptr)
{
- cudaError_t err = cudaFree(ptr);
- return err;
+ THCDeviceAllocator* allocator = &state->cudaDeviceAllocator;
+ return allocator->free(allocator->state, ptr);
}
static long applyHeapDelta(THCState *state) {
diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in
index 784bee1..b3df3a4 100644
--- a/lib/THC/THCGeneral.h.in
+++ b/lib/THC/THCGeneral.h.in
@@ -48,6 +48,13 @@ typedef struct _THCCudaResourcesPerDevice {
void** devScratchSpacePerStream;
} THCCudaResourcesPerDevice;
+typedef struct _THCDeviceAllocator {
+ cudaError_t (*malloc)(void*, void**, size_t, cudaStream_t);
+ cudaError_t (*free)(void*, void*);
+ cudaError_t (*shutdown)(void*);
+ void* state;
+} THCDeviceAllocator;
+
/* Global state to be held in the cutorch table. */
typedef struct THCState
@@ -74,6 +81,7 @@ typedef struct THCState
int currentPerDeviceBlasHandle;
/* Allocator using cudaMallocHost. */
THAllocator* cudaHostAllocator;
+ THCDeviceAllocator cudaDeviceAllocator;
/* Table of enabled peer-to-peer access between directed pairs of GPUs.
If i accessing allocs on j is enabled, p2pAccess[i][j] is 1; 0 otherwise. */