Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-10-25 08:06:00 +0300
committersoumith <soumith@fb.com>2016-10-25 17:07:52 +0300
commit94232e233cb1934ea23119d16abb7416ce9c98e9 (patch)
tree783588edf0c485bfa8fc197c2242120b32ee6f43
parentdce6cd89aa78d308dd4360484ac21da3fc80c9d0 (diff)
allocator updates
-rw-r--r--lib/THC/CMakeLists.txt2
-rw-r--r--lib/THC/THCAllocator.c10
-rw-r--r--lib/THC/THCAllocator.h2
-rw-r--r--lib/THC/THCCachingAllocator.cpp1
-rw-r--r--lib/THC/THCGeneral.c6
-rw-r--r--lib/THC/THCGeneral.h.in3
-rw-r--r--lib/THC/generic/THCStorage.c63
-rw-r--r--lib/THC/generic/THCStorage.cu32
-rw-r--r--lib/THC/generic/THCStorage.h13
9 files changed, 93 insertions, 39 deletions
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt
index edc0af0..244568f 100644
--- a/lib/THC/CMakeLists.txt
+++ b/lib/THC/CMakeLists.txt
@@ -123,10 +123,8 @@ ELSE()
ENDIF()
SET(src
- THCAllocator.c
THCCachingAllocator.cpp
THCGeneral.c
- THCStorage.c
THCStorageCopy.c
THCStream.c
THCTensor.c
diff --git a/lib/THC/THCAllocator.c b/lib/THC/THCAllocator.c
index fa55c40..5d36d4c 100644
--- a/lib/THC/THCAllocator.c
+++ b/lib/THC/THCAllocator.c
@@ -1,6 +1,6 @@
#include "THCAllocator.h"
-static void *THCudaHostAllocator_alloc(void* ctx, ptrdiff_t size) {
+static void *THCudaHostAllocator_malloc(void* ctx, ptrdiff_t size) {
void* ptr;
if (size < 0) THError("Invalid memory size: %ld", size);
@@ -18,8 +18,8 @@ static void THCudaHostAllocator_free(void* ctx, void* ptr) {
THCudaCheck(cudaFreeHost(ptr));
}
-void THCAllocator_init(THAllocator *cudaHostAllocator) {
- cudaHostAllocator->malloc = &THCudaHostAllocator_alloc;
- cudaHostAllocator->realloc = NULL;
- cudaHostAllocator->free = &THCudaHostAllocator_free;
+void THCAllocator_init(THCState *state) {
+ state->cudaHostAllocator->malloc = &THCudaHostAllocator_malloc;
+ state->cudaHostAllocator->realloc = NULL;
+ state->cudaHostAllocator->free = &THCudaHostAllocator_free;
}
diff --git a/lib/THC/THCAllocator.h b/lib/THC/THCAllocator.h
index 3481304..2f85eec 100644
--- a/lib/THC/THCAllocator.h
+++ b/lib/THC/THCAllocator.h
@@ -3,6 +3,6 @@
#include "THCGeneral.h"
-THC_API void THCAllocator_init(THAllocator *state);
+THC_API void THCAllocator_init(THCState *state);
#endif
diff --git a/lib/THC/THCCachingAllocator.cpp b/lib/THC/THCCachingAllocator.cpp
index 73b81f6..54db20d 100644
--- a/lib/THC/THCCachingAllocator.cpp
+++ b/lib/THC/THCCachingAllocator.cpp
@@ -300,6 +300,7 @@ static cudaError_t THCCachingAllocator_emptyCache(void* ctx)
static THCCachingAllocator caching_allocator;
static THCDeviceAllocator device_allocator = {
&THCCachingAllocator_malloc,
+ NULL,
&THCCachingAllocator_free,
&THCCachingAllocator_emptyCache,
&caching_allocator
diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c
index 9b4764d..0b75399 100644
--- a/lib/THC/THCGeneral.c
+++ b/lib/THC/THCGeneral.c
@@ -93,6 +93,7 @@ static cudaError_t cudaFreeWrapper(void* ctx, void* devPtr)
static THCDeviceAllocator defaultDeviceAllocator = {
&cudaMallocWrapper,
+ NULL,
&cudaFreeWrapper,
NULL,
NULL
@@ -129,7 +130,7 @@ void THCudaInit(THCState* state)
THCRandom_init(state, numDevices, device);
state->cudaHostAllocator = (THAllocator*)malloc(sizeof(THAllocator));
- THCAllocator_init(state->cudaHostAllocator);
+ THCAllocator_init(state);
/* Enable P2P access between all pairs, if possible */
THCudaEnablePeerToPeerAccess(state);
@@ -792,3 +793,6 @@ void THCHeapUpdate(THCState *state, ptrdiff_t size) {
}
#undef GLOBAL_SCRATCH_SPACE_PER_SM_STREAM
+
+#include "THCStorage.c"
+#include "THCAllocator.c"
diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in
index 9135167..8b3ac74 100644
--- a/lib/THC/THCGeneral.h.in
+++ b/lib/THC/THCGeneral.h.in
@@ -43,7 +43,8 @@ struct THCRNGState; /* Random number generator state. */
struct THCStream;
typedef struct _THCDeviceAllocator {
- cudaError_t (*malloc)(void*, void**, size_t, cudaStream_t);
+ cudaError_t (*malloc)( void*, void**, size_t, cudaStream_t);
+ cudaError_t (*realloc)(void*, void**, size_t, size_t, cudaStream_t);
cudaError_t (*free)(void*, void*);
cudaError_t (*emptyCache)(void*);
void* state;
diff --git a/lib/THC/generic/THCStorage.c b/lib/THC/generic/THCStorage.c
index ad68526..e51d1ee 100644
--- a/lib/THC/generic/THCStorage.c
+++ b/lib/THC/generic/THCStorage.c
@@ -20,53 +20,64 @@ int THCStorage_(elementSize)(THCState *state)
void THCStorage_(set)(THCState *state, THCStorage *self, ptrdiff_t index, real value)
{
THArgCheck((index >= 0) && (index < self->size), 2, "index out of bounds");
- THCudaCheck(cudaMemcpy(self->data + index, &value, sizeof(real), cudaMemcpyHostToDevice));
+ THCudaCheck(cudaMemcpy(self->data + index, &value, sizeof(real),
+ cudaMemcpyHostToDevice));
}
real THCStorage_(get)(THCState *state, const THCStorage *self, ptrdiff_t index)
{
THArgCheck((index >= 0) && (index < self->size), 2, "index out of bounds");
real value;
- THCudaCheck(cudaMemcpy(&value, self->data + index, sizeof(real), cudaMemcpyDeviceToHost));
+ THCudaCheck(cudaMemcpy(&value, self->data + index, sizeof(real),
+ cudaMemcpyDeviceToHost));
return value;
}
THCStorage* THCStorage_(new)(THCState *state)
{
- THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage));
- storage->data = NULL;
- storage->size = 0;
- storage->refcount = 1;
- storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
- return storage;
+ return THCStorage_(newWithSize)(state, 0);
}
THCStorage* THCStorage_(newWithSize)(THCState *state, ptrdiff_t size)
{
+ return THCStorage_(newWithAllocator)(
+ state, size,
+ state->cudaDeviceAllocator,
+ state->cudaDeviceAllocator->state);
+}
+
+THCStorage* THCStorage_(newWithAllocator)(THCState *state, ptrdiff_t size,
+ THCDeviceAllocator* allocator,
+ void* allocatorContext)
+{
THArgCheck(size >= 0, 2, "invalid size");
+ THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage));
+ memset(storage, 0, sizeof(THCStorage));
+ storage->refcount = 1;
+ storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
+ storage->allocator = allocator;
+ storage->allocatorContext = allocatorContext;
+ storage->size = size;
+
if(size > 0)
{
- THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage));
-
// update heap *before* attempting malloc, to free space for the malloc
THCHeapUpdate(state, size * sizeof(real));
cudaError_t err =
- THCudaMalloc(state, (void**)&(storage->data), size * sizeof(real));
+ (*allocator->malloc)(allocatorContext, (void**)&(storage->data),
+ size * sizeof(real),
+ THCState_getCurrentStream(state));
if(err != cudaSuccess){
THCHeapUpdate(state, -size * sizeof(real));
+ free(storage);
}
THCudaCheck(err);
- storage->size = size;
- storage->refcount = 1;
- storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
- return storage;
- }
- else
- {
- return THCStorage_(new)(state);
+ } else {
+ storage->data = NULL;
}
+ return storage;
}
THCStorage* THCStorage_(newWithSize1)(THCState *state, real data0)
@@ -111,11 +122,22 @@ THCStorage* THCStorage_(newWithMapping)(THCState *state, const char *fileName, p
THCStorage* THCStorage_(newWithData)(THCState *state, real *data, ptrdiff_t size)
{
+ return THCStorage_(newWithDataAndAllocator)(state, data, size,
+ state->cudaDeviceAllocator,
+ state->cudaDeviceAllocator->state);
+}
+
+THCStorage* THCStorage_(newWithDataAndAllocator)(
+ THCState *state, real *data, long size,
+ THCDeviceAllocator *allocator, void *allocatorContext) {
THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage));
+ memset(storage, 0, sizeof(THCStorage));
storage->data = data;
storage->size = size;
storage->refcount = 1;
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
+ storage->allocator = allocator;
+ storage->allocatorContext = allocatorContext;
return storage;
}
@@ -144,7 +166,8 @@ void THCStorage_(free)(THCState *state, THCStorage *self)
{
if(self->flag & TH_STORAGE_FREEMEM) {
THCHeapUpdate(state, -self->size * sizeof(real));
- THCudaCheck(THCudaFree(state, self->data));
+ THCudaCheck(
+ (*self->allocator->free)(self->allocatorContext, self->data));
}
THFree(self);
}
diff --git a/lib/THC/generic/THCStorage.cu b/lib/THC/generic/THCStorage.cu
index 63bccd7..bdef7d3 100644
--- a/lib/THC/generic/THCStorage.cu
+++ b/lib/THC/generic/THCStorage.cu
@@ -15,14 +15,31 @@ void THCStorage_(fill)(THCState *state, THCStorage *self, real value)
void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size)
{
THArgCheck(size >= 0, 2, "invalid size");
+ THAssert(self->allocator != NULL);
if(!(self->flag & TH_STORAGE_RESIZABLE))
THError("Trying to resize storage that is not resizable");
+ if (self->allocator->realloc) {
+ THCHeapUpdate(state, (size - self->size) * sizeof(real));
+ cudaError_t err = (*self->allocator->realloc)(
+ self->allocatorContext,
+ (void**)&(self->data),
+ self->size * sizeof(real),
+ size * sizeof(real), THCState_getCurrentStream(state));
+ if (err != cudaSuccess) {
+ THCHeapUpdate(state, (self->size - size) * sizeof(real));
+ THCudaCheck(err);
+ }
+ self->size = size;
+ return;
+ }
+
if(size == 0)
{
if(self->flag & TH_STORAGE_FREEMEM) {
- THCudaCheck(THCudaFree(state, self->data));
+ THCudaCheck(
+ (*self->allocator->free)(self->allocatorContext, self->data));
THCHeapUpdate(state, -self->size * sizeof(real));
}
self->data = NULL;
@@ -33,7 +50,11 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size)
real *data = NULL;
// update heap *before* attempting malloc, to free space for the malloc
THCHeapUpdate(state, size * sizeof(real));
- cudaError_t err = THCudaMalloc(state, (void**)(&data), size * sizeof(real));
+ cudaError_t err =
+ (*self->allocator->malloc)(self->allocatorContext,
+ (void**)&(data),
+ size * sizeof(real),
+ THCState_getCurrentStream(state));
if(err != cudaSuccess) {
THCHeapUpdate(state, -size * sizeof(real));
}
@@ -45,8 +66,11 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size)
THMin(self->size, size) * sizeof(real),
cudaMemcpyDeviceToDevice,
THCState_getCurrentStream(state)));
- THCudaCheck(THCudaFree(state, self->data));
- THCHeapUpdate(state, -self->size * sizeof(real));
+ if(self->flag & TH_STORAGE_FREEMEM) {
+ THCudaCheck(
+ (*self->allocator->free)(self->allocatorContext, self->data));
+ THCHeapUpdate(state, -self->size * sizeof(real));
+ }
}
self->data = data;
diff --git a/lib/THC/generic/THCStorage.h b/lib/THC/generic/THCStorage.h
index a46caad..f621c20 100644
--- a/lib/THC/generic/THCStorage.h
+++ b/lib/THC/generic/THCStorage.h
@@ -12,7 +12,7 @@ typedef struct THCStorage
ptrdiff_t size;
int refcount;
char flag;
- THAllocator *allocator;
+ THCDeviceAllocator *allocator;
void *allocatorContext;
struct THCStorage *view;
} THCStorage;
@@ -37,11 +37,14 @@ THC_API THCStorage* THCStorage_(newWithMapping)(THCState *state, const char *fil
/* takes ownership of data */
THC_API THCStorage* THCStorage_(newWithData)(THCState *state, real *data, ptrdiff_t size);
-THC_API THCStorage* THCStorage_(newWithAllocator)(THCState *state, ptrdiff_t size,
- THAllocator* allocator,
- void *allocatorContext);
+THC_API THCStorage* THCStorage_(newWithAllocator)(
+ THCState *state, ptrdiff_t size,
+ THCDeviceAllocator* allocator,
+ void *allocatorContext);
THC_API THCStorage* THCStorage_(newWithDataAndAllocator)(
- THCState *state, real* data, ptrdiff_t size, THAllocator* allocator, void *allocatorContext);
+ THCState *state, real* data, ptrdiff_t size,
+ THCDeviceAllocator* allocator,
+ void *allocatorContext);
THC_API void THCStorage_(setFlag)(THCState *state, THCStorage *storage, const char flag);
THC_API void THCStorage_(clearFlag)(THCState *state, THCStorage *storage, const char flag);