diff options
author | soumith <soumith@fb.com> | 2016-10-25 08:06:00 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-10-25 17:07:52 +0300 |
commit | 94232e233cb1934ea23119d16abb7416ce9c98e9 (patch) | |
tree | 783588edf0c485bfa8fc197c2242120b32ee6f43 | |
parent | dce6cd89aa78d308dd4360484ac21da3fc80c9d0 (diff) |
allocator updates
-rw-r--r-- | lib/THC/CMakeLists.txt | 2 | ||||
-rw-r--r-- | lib/THC/THCAllocator.c | 10 | ||||
-rw-r--r-- | lib/THC/THCAllocator.h | 2 | ||||
-rw-r--r-- | lib/THC/THCCachingAllocator.cpp | 1 | ||||
-rw-r--r-- | lib/THC/THCGeneral.c | 6 | ||||
-rw-r--r-- | lib/THC/THCGeneral.h.in | 3 | ||||
-rw-r--r-- | lib/THC/generic/THCStorage.c | 63 | ||||
-rw-r--r-- | lib/THC/generic/THCStorage.cu | 32 | ||||
-rw-r--r-- | lib/THC/generic/THCStorage.h | 13 |
9 files changed, 93 insertions, 39 deletions
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt index edc0af0..244568f 100644 --- a/lib/THC/CMakeLists.txt +++ b/lib/THC/CMakeLists.txt @@ -123,10 +123,8 @@ ELSE() ENDIF() SET(src - THCAllocator.c THCCachingAllocator.cpp THCGeneral.c - THCStorage.c THCStorageCopy.c THCStream.c THCTensor.c diff --git a/lib/THC/THCAllocator.c b/lib/THC/THCAllocator.c index fa55c40..5d36d4c 100644 --- a/lib/THC/THCAllocator.c +++ b/lib/THC/THCAllocator.c @@ -1,6 +1,6 @@ #include "THCAllocator.h" -static void *THCudaHostAllocator_alloc(void* ctx, ptrdiff_t size) { +static void *THCudaHostAllocator_malloc(void* ctx, ptrdiff_t size) { void* ptr; if (size < 0) THError("Invalid memory size: %ld", size); @@ -18,8 +18,8 @@ static void THCudaHostAllocator_free(void* ctx, void* ptr) { THCudaCheck(cudaFreeHost(ptr)); } -void THCAllocator_init(THAllocator *cudaHostAllocator) { - cudaHostAllocator->malloc = &THCudaHostAllocator_alloc; - cudaHostAllocator->realloc = NULL; - cudaHostAllocator->free = &THCudaHostAllocator_free; +void THCAllocator_init(THCState *state) { + state->cudaHostAllocator->malloc = &THCudaHostAllocator_malloc; + state->cudaHostAllocator->realloc = NULL; + state->cudaHostAllocator->free = &THCudaHostAllocator_free; } diff --git a/lib/THC/THCAllocator.h b/lib/THC/THCAllocator.h index 3481304..2f85eec 100644 --- a/lib/THC/THCAllocator.h +++ b/lib/THC/THCAllocator.h @@ -3,6 +3,6 @@ #include "THCGeneral.h" -THC_API void THCAllocator_init(THAllocator *state); +THC_API void THCAllocator_init(THCState *state); #endif diff --git a/lib/THC/THCCachingAllocator.cpp b/lib/THC/THCCachingAllocator.cpp index 73b81f6..54db20d 100644 --- a/lib/THC/THCCachingAllocator.cpp +++ b/lib/THC/THCCachingAllocator.cpp @@ -300,6 +300,7 @@ static cudaError_t THCCachingAllocator_emptyCache(void* ctx) static THCCachingAllocator caching_allocator; static THCDeviceAllocator device_allocator = { &THCCachingAllocator_malloc, + NULL, &THCCachingAllocator_free, &THCCachingAllocator_emptyCache, &caching_allocator diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c index 9b4764d..0b75399 100644 --- a/lib/THC/THCGeneral.c +++ b/lib/THC/THCGeneral.c @@ -93,6 +93,7 @@ static cudaError_t cudaFreeWrapper(void* ctx, void* devPtr) static THCDeviceAllocator defaultDeviceAllocator = { &cudaMallocWrapper, + NULL, &cudaFreeWrapper, NULL, NULL @@ -129,7 +130,7 @@ void THCudaInit(THCState* state) THCRandom_init(state, numDevices, device); state->cudaHostAllocator = (THAllocator*)malloc(sizeof(THAllocator)); - THCAllocator_init(state->cudaHostAllocator); + THCAllocator_init(state); /* Enable P2P access between all pairs, if possible */ THCudaEnablePeerToPeerAccess(state); @@ -792,3 +793,6 @@ void THCHeapUpdate(THCState *state, ptrdiff_t size) { } #undef GLOBAL_SCRATCH_SPACE_PER_SM_STREAM + +#include "THCStorage.c" +#include "THCAllocator.c" diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in index 9135167..8b3ac74 100644 --- a/lib/THC/THCGeneral.h.in +++ b/lib/THC/THCGeneral.h.in @@ -43,7 +43,8 @@ struct THCRNGState; /* Random number generator state. */ struct THCStream; typedef struct _THCDeviceAllocator { - cudaError_t (*malloc)(void*, void**, size_t, cudaStream_t); + cudaError_t (*malloc)( void*, void**, size_t, cudaStream_t); + cudaError_t (*realloc)(void*, void**, size_t, size_t, cudaStream_t); cudaError_t (*free)(void*, void*); cudaError_t (*emptyCache)(void*); void* state; diff --git a/lib/THC/generic/THCStorage.c b/lib/THC/generic/THCStorage.c index ad68526..e51d1ee 100644 --- a/lib/THC/generic/THCStorage.c +++ b/lib/THC/generic/THCStorage.c @@ -20,53 +20,64 @@ int THCStorage_(elementSize)(THCState *state) void THCStorage_(set)(THCState *state, THCStorage *self, ptrdiff_t index, real value) { THArgCheck((index >= 0) && (index < self->size), 2, "index out of bounds"); - THCudaCheck(cudaMemcpy(self->data + index, &value, sizeof(real), cudaMemcpyHostToDevice)); + THCudaCheck(cudaMemcpy(self->data + index, &value, sizeof(real), + cudaMemcpyHostToDevice)); } real THCStorage_(get)(THCState *state, const THCStorage *self, ptrdiff_t index) { THArgCheck((index >= 0) && (index < self->size), 2, "index out of bounds"); real value; - THCudaCheck(cudaMemcpy(&value, self->data + index, sizeof(real), cudaMemcpyDeviceToHost)); + THCudaCheck(cudaMemcpy(&value, self->data + index, sizeof(real), + cudaMemcpyDeviceToHost)); return value; } THCStorage* THCStorage_(new)(THCState *state) { - THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage)); - storage->data = NULL; - storage->size = 0; - storage->refcount = 1; - storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; - return storage; + return THCStorage_(newWithSize)(state, 0); } THCStorage* THCStorage_(newWithSize)(THCState *state, ptrdiff_t size) { + return THCStorage_(newWithAllocator)( + state, size, + state->cudaDeviceAllocator, + state->cudaDeviceAllocator->state); +} + +THCStorage* THCStorage_(newWithAllocator)(THCState *state, ptrdiff_t size, + THCDeviceAllocator* allocator, + void* allocatorContext) +{ THArgCheck(size >= 0, 2, "invalid size"); + THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage)); + memset(storage, 0, sizeof(THCStorage)); + storage->refcount = 1; + storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; + storage->allocator = allocator; + storage->allocatorContext = allocatorContext; + storage->size = size; + if(size > 0) { - THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage)); - // update heap *before* attempting malloc, to free space for the malloc THCHeapUpdate(state, size * sizeof(real)); cudaError_t err = - THCudaMalloc(state, (void**)&(storage->data), size * sizeof(real)); + (*allocator->malloc)(allocatorContext, (void**)&(storage->data), + size * sizeof(real), + THCState_getCurrentStream(state)); if(err != cudaSuccess){ THCHeapUpdate(state, -size * sizeof(real)); + free(storage); } THCudaCheck(err); - storage->size = size; - storage->refcount = 1; - storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; - return storage; - } - else - { - return THCStorage_(new)(state); + } else { + storage->data = NULL; } + return storage; } THCStorage* THCStorage_(newWithSize1)(THCState *state, real data0) @@ -111,11 +122,22 @@ THCStorage* THCStorage_(newWithMapping)(THCState *state, const char *fileName, p THCStorage* THCStorage_(newWithData)(THCState *state, real *data, ptrdiff_t size) { + return THCStorage_(newWithDataAndAllocator)(state, data, size, + state->cudaDeviceAllocator, + state->cudaDeviceAllocator->state); +} + +THCStorage* THCStorage_(newWithDataAndAllocator)( + THCState *state, real *data, long size, + THCDeviceAllocator *allocator, void *allocatorContext) { THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage)); + memset(storage, 0, sizeof(THCStorage)); storage->data = data; storage->size = size; storage->refcount = 1; storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; + storage->allocator = allocator; + storage->allocatorContext = allocatorContext; return storage; } @@ -144,7 +166,8 @@ void THCStorage_(free)(THCState *state, THCStorage *self) { if(self->flag & TH_STORAGE_FREEMEM) { THCHeapUpdate(state, -self->size * sizeof(real)); - THCudaCheck(THCudaFree(state, self->data)); + THCudaCheck( + (*self->allocator->free)(self->allocatorContext, self->data)); } THFree(self); } diff --git a/lib/THC/generic/THCStorage.cu b/lib/THC/generic/THCStorage.cu index 63bccd7..bdef7d3 100644 --- a/lib/THC/generic/THCStorage.cu +++ b/lib/THC/generic/THCStorage.cu @@ -15,14 +15,31 @@ void THCStorage_(fill)(THCState *state, THCStorage *self, real value) void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size) { THArgCheck(size >= 0, 2, "invalid size"); + THAssert(self->allocator != NULL); if(!(self->flag & TH_STORAGE_RESIZABLE)) THError("Trying to resize storage that is not resizable"); + if (self->allocator->realloc) { + THCHeapUpdate(state, (size - self->size) * sizeof(real)); + cudaError_t err = (*self->allocator->realloc)( + self->allocatorContext, + (void**)&(self->data), + self->size * sizeof(real), + size * sizeof(real), THCState_getCurrentStream(state)); + if (err != cudaSuccess) { + THCHeapUpdate(state, (self->size - size) * sizeof(real)); + THCudaCheck(err); + } + self->size = size; + return; + } + if(size == 0) { if(self->flag & TH_STORAGE_FREEMEM) { - THCudaCheck(THCudaFree(state, self->data)); + THCudaCheck( + (*self->allocator->free)(self->allocatorContext, self->data)); THCHeapUpdate(state, -self->size * sizeof(real)); } self->data = NULL; @@ -33,7 +50,11 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size) real *data = NULL; // update heap *before* attempting malloc, to free space for the malloc THCHeapUpdate(state, size * sizeof(real)); - cudaError_t err = THCudaMalloc(state, (void**)(&data), size * sizeof(real)); + cudaError_t err = + (*self->allocator->malloc)(self->allocatorContext, + (void**)&(data), + size * sizeof(real), + THCState_getCurrentStream(state)); if(err != cudaSuccess) { THCHeapUpdate(state, -size * sizeof(real)); } @@ -45,8 +66,11 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size) THMin(self->size, size) * sizeof(real), cudaMemcpyDeviceToDevice, THCState_getCurrentStream(state))); - THCudaCheck(THCudaFree(state, self->data)); - THCHeapUpdate(state, -self->size * sizeof(real)); + if(self->flag & TH_STORAGE_FREEMEM) { + THCudaCheck( + (*self->allocator->free)(self->allocatorContext, self->data)); + THCHeapUpdate(state, -self->size * sizeof(real)); + } } self->data = data; diff --git a/lib/THC/generic/THCStorage.h b/lib/THC/generic/THCStorage.h index a46caad..f621c20 100644 --- a/lib/THC/generic/THCStorage.h +++ b/lib/THC/generic/THCStorage.h @@ -12,7 +12,7 @@ typedef struct THCStorage ptrdiff_t size; int refcount; char flag; - THAllocator *allocator; + THCDeviceAllocator *allocator; void *allocatorContext; struct THCStorage *view; } THCStorage; @@ -37,11 +37,14 @@ THC_API THCStorage* THCStorage_(newWithMapping)(THCState *state, const char *fil /* takes ownership of data */ THC_API THCStorage* THCStorage_(newWithData)(THCState *state, real *data, ptrdiff_t size); -THC_API THCStorage* THCStorage_(newWithAllocator)(THCState *state, ptrdiff_t size, - THAllocator* allocator, - void *allocatorContext); +THC_API THCStorage* THCStorage_(newWithAllocator)( + THCState *state, ptrdiff_t size, + THCDeviceAllocator* allocator, + void *allocatorContext); THC_API THCStorage* THCStorage_(newWithDataAndAllocator)( - THCState *state, real* data, ptrdiff_t size, THAllocator* allocator, void *allocatorContext); + THCState *state, real* data, ptrdiff_t size, + THCDeviceAllocator* allocator, + void *allocatorContext); THC_API void THCStorage_(setFlag)(THCState *state, THCStorage *storage, const char flag); THC_API void THCStorage_(clearFlag)(THCState *state, THCStorage *storage, const char flag); |