Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-10-25 17:27:40 +0300
committerGitHub <noreply@github.com>2016-10-25 17:27:40 +0300
commit17300d9cc0c462dfde81eb81f89ba0a15e095844 (patch)
tree68df56e2fa2bb38d60bfaefda562699dcf96fc9e
parente56feb8f9859990fda4c53fb59c69f08b77e85bf (diff)
parent69c9454c07a6da74ab469bb9866c17656fbed8a2 (diff)
Merge pull request #566 from torch/cachedevice
Store the device in THCStorage
-rw-r--r--FFI.lua1
-rw-r--r--lib/THC/generic/THCStorage.c12
-rw-r--r--lib/THC/generic/THCStorage.cu10
-rw-r--r--lib/THC/generic/THCStorage.h1
4 files changed, 20 insertions, 4 deletions
diff --git a/FFI.lua b/FFI.lua
index f347a89..b2777a2 100644
--- a/FFI.lua
+++ b/FFI.lua
@@ -63,6 +63,7 @@ typedef struct THCStorage
THAllocator *allocator;
void *allocatorContext;
struct THCStorage *view;
+ int device;
} THCStorage;
typedef struct THCTensor
diff --git a/lib/THC/generic/THCStorage.c b/lib/THC/generic/THCStorage.c
index e51d1ee..eb4777c 100644
--- a/lib/THC/generic/THCStorage.c
+++ b/lib/THC/generic/THCStorage.c
@@ -51,6 +51,8 @@ THCStorage* THCStorage_(newWithAllocator)(THCState *state, ptrdiff_t size,
void* allocatorContext)
{
THArgCheck(size >= 0, 2, "invalid size");
+ int device;
+ THCudaCheck(cudaGetDevice(&device));
THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage));
memset(storage, 0, sizeof(THCStorage));
@@ -59,6 +61,7 @@ THCStorage* THCStorage_(newWithAllocator)(THCState *state, ptrdiff_t size,
storage->allocator = allocator;
storage->allocatorContext = allocatorContext;
storage->size = size;
+ storage->device = device;
if(size > 0)
{
@@ -138,6 +141,15 @@ THCStorage* THCStorage_(newWithDataAndAllocator)(
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
storage->allocator = allocator;
storage->allocatorContext = allocatorContext;
+ int device;
+ if (data) {
+ struct cudaPointerAttributes attr;
+ THCudaCheck(cudaPointerGetAttributes(&attr, data));
+ device = attr.device;
+ } else {
+ THCudaCheck(cudaGetDevice(&device));
+ }
+ storage->device = device;
return storage;
}
diff --git a/lib/THC/generic/THCStorage.cu b/lib/THC/generic/THCStorage.cu
index bdef7d3..22c900a 100644
--- a/lib/THC/generic/THCStorage.cu
+++ b/lib/THC/generic/THCStorage.cu
@@ -16,6 +16,8 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size)
{
THArgCheck(size >= 0, 2, "invalid size");
THAssert(self->allocator != NULL);
+ int device;
+ THCudaCheck(cudaGetDevice(&device));
if(!(self->flag & TH_STORAGE_RESIZABLE))
THError("Trying to resize storage that is not resizable");
@@ -32,6 +34,7 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size)
THCudaCheck(err);
}
self->size = size;
+ self->device = device;
return;
}
@@ -44,6 +47,7 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size)
}
self->data = NULL;
self->size = 0;
+ self->device = device;
}
else
{
@@ -75,14 +79,12 @@ void THCStorage_(resize)(THCState *state, THCStorage *self, ptrdiff_t size)
self->data = data;
self->size = size;
+ self->device = device;
}
}
THC_API int THCStorage_(getDevice)(THCState* state, const THCStorage* storage) {
- if (!storage->data) return -1;
- cudaPointerAttributes attr;
- THCudaCheck(cudaPointerGetAttributes(&attr, storage->data));
- return attr.device;
+ return storage->device;
}
#endif
diff --git a/lib/THC/generic/THCStorage.h b/lib/THC/generic/THCStorage.h
index f621c20..e768ec6 100644
--- a/lib/THC/generic/THCStorage.h
+++ b/lib/THC/generic/THCStorage.h
@@ -15,6 +15,7 @@ typedef struct THCStorage
THCDeviceAllocator *allocator;
void *allocatorContext;
struct THCStorage *view;
+ int device;
} THCStorage;