diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-10-25 17:27:40 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-25 17:27:40 +0300 |
commit | 17300d9cc0c462dfde81eb81f89ba0a15e095844 (patch) | |
tree | 68df56e2fa2bb38d60bfaefda562699dcf96fc9e | |
parent | e56feb8f9859990fda4c53fb59c69f08b77e85bf (diff) | |
parent | 69c9454c07a6da74ab469bb9866c17656fbed8a2 (diff) |
Merge pull request #566 from torch/cachedevice
Store the device in THCStorage
-rw-r--r-- | FFI.lua | 1 | ||||
-rw-r--r-- | lib/THC/generic/THCStorage.c | 12 | ||||
-rw-r--r-- | lib/THC/generic/THCStorage.cu | 10 | ||||
-rw-r--r-- | lib/THC/generic/THCStorage.h | 1 |
4 files changed, 20 insertions, 4 deletions
@@ -63,6 +63,7 @@ typedef struct THCStorage THAllocator *allocator; void *allocatorContext; struct THCStorage *view; + int device; } THCStorage; typedef struct THCTensor diff --git a/lib/THC/generic/THCStorage.c b/lib/THC/generic/THCStorage.c index e51d1ee..eb4777c 100644 --- a/lib/THC/generic/THCStorage.c +++ b/lib/THC/generic/THCStorage.c @@ -51,6 +51,8 @@ THCStorage* THCStorage_(newWithAllocator)(THCState *state, ptrdiff_t size, void* allocatorContext) { THArgCheck(size >= 0, 2, "invalid size"); + int device; + THCudaCheck(cudaGetDevice(&device)); THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage)); memset(storage, 0, sizeof(THCStorage)); @@ -59,6 +61,7 @@ THCStorage* THCStorage_(newWithAllocator)(THCState *state, ptrdiff_t size, storage->allocator = allocator; storage->allocatorContext = allocatorContext; storage->size = size; + storage->device = device; if(size > 0) { @@ -138,6 +141,15 @@ THCStorage* THCStorage_(newWithDataAndAllocator)( storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; storage->allocator = allocator; storage->allocatorContext = allocatorContext; + int device; + if (data) { + struct cudaPointerAttributes attr; + THCudaCheck(cudaPointerGetAttributes(&attr, data)); + device = attr.device; + } else { + THCudaCheck(cudaGetDevice(&device)); + } + storage->device = device; return storage; } diff --git a/lib/THC/generic/THCStorage.cu b/lib/THC/generic/THCStorage.cu index bdef7d3..22c900a 100644 --- a/lib/THC/generic/THCStorage.cu +++ b/lib/THC/generic/THCStorage.cu @@ -16,6 +16,8 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size) { THArgCheck(size >= 0, 2, "invalid size"); THAssert(self->allocator != NULL); + int device; + THCudaCheck(cudaGetDevice(&device)); if(!(self->flag & TH_STORAGE_RESIZABLE)) THError("Trying to resize storage that is not resizable"); @@ -32,6 +34,7 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size) THCudaCheck(err); } self->size = size; + self->device = device; return; } @@ -44,6 +47,7 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size) } self->data = NULL; self->size = 0; + self->device = device; } else { @@ -75,14 +79,12 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size) self->data = data; self->size = size; + self->device = device; } } THC_API int THCStorage_(getDevice)(THCState* state, const THCStorage* storage) { - if (!storage->data) return -1; - cudaPointerAttributes attr; - THCudaCheck(cudaPointerGetAttributes(&attr, storage->data)); - return attr.device; + return storage->device; } #endif diff --git a/lib/THC/generic/THCStorage.h b/lib/THC/generic/THCStorage.h index f621c20..e768ec6 100644 --- a/lib/THC/generic/THCStorage.h +++ b/lib/THC/generic/THCStorage.h @@ -15,6 +15,7 @@ typedef struct THCStorage THCDeviceAllocator *allocator; void *allocatorContext; struct THCStorage *view; + int device; } THCStorage; |