Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuillaume Klein <guillaume.klein@systrangroup.com>2017-03-08 19:29:50 +0300
committerGuillaume Klein <guillaume.klein@systrangroup.com>2017-03-08 19:29:50 +0300
commitf7c67993e277e01081d5efafbfba4cbaf51ec1ae (patch)
tree85ebe7af33efc69c597f5be3197a409dc91e7e75
parentd4c2b1d34631b78717289e0d219751f448d32b9c (diff)
Add CUDA caching allocator accessor
-rw-r--r--README.md1
-rw-r--r--init.c9
-rw-r--r--lib/THC/THCGeneral.c4
-rw-r--r--lib/THC/THCGeneral.h.in1
4 files changed, 15 insertions, 0 deletions
diff --git a/README.md b/README.md
index 3b4a174..263a131 100644
--- a/README.md
+++ b/README.md
@@ -51,6 +51,7 @@ With the caching memory allocator, device allocations and frees should logically
- `cutorch.getState()` - Returns the global state of the cutorch package. This state is not for users, it stores the raw RNG states, cublas handles and other thread and device-specific stuff.
- `cutorch.withDevice(devID, f)` - This is a convenience for multi-GPU code, that takes in a device ID as well as a function f. It switches cutorch to the new device, executes the function f, and switches back cutorch to the original device.
- `cutorch.createCudaHostTensor([...])` - Allocates a `torch.FloatTensor` of [host-pinned memory](https://devblogs.nvidia.com/parallelforall/how-optimize-data-transfers-cuda-cc/), where dimensions can be given as an argument list of sizes or a `torch.LongStorage`.
+- `cutorch.isCachingAllocatorEnabled()` - Returns whether the caching CUDA memory allocator is enabled or not.
#### Low-level streams functions (dont use this as a user, easy to shoot yourself in the foot):
- `cutorch.reserveStreams(n [, nonblocking])`: creates n user streams for use on every device. NOTE: stream index `s` on device 1 is a different cudaStream_t than stream `s` on device 2. Takes an optional non-blocking flag; by default, this is assumed to be false. If true, then the stream is created with cudaStreamNonBlocking.
diff --git a/init.c b/init.c
index 26b6935..8b32a1a 100644
--- a/init.c
+++ b/init.c
@@ -699,6 +699,14 @@ static int cutorch_setKernelPeerToPeerAccess(lua_State *L)
return 0;
}
+static int cutorch_isCachingAllocatorEnabled(lua_State *L)
+{
+ THCState *state = cutorch_getstate(L);
+ lua_pushboolean(L, THCState_isCachingAllocatorEnabled(state));
+
+ return 1;
+}
+
static int cutorch_getMemoryUsage(lua_State *L) {
size_t freeBytes = 0;
size_t totalBytes = 0;
@@ -993,6 +1001,7 @@ static const struct luaL_Reg cutorch_stuff__ [] = {
{"setPeerToPeerAccess", cutorch_setPeerToPeerAccess},
{"setKernelPeerToPeerAccess", cutorch_setKernelPeerToPeerAccess},
{"getKernelPeerToPeerAccess", cutorch_getKernelPeerToPeerAccess},
+ {"isCachingAllocatorEnabled", cutorch_isCachingAllocatorEnabled},
{"getDeviceProperties", cutorch_getDeviceProperties},
{"getRuntimeVersion", cutorch_getRuntimeVersion},
{"getDriverVersion", cutorch_getDriverVersion},
diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c
index f78518c..09bb43f 100644
--- a/lib/THC/THCGeneral.c
+++ b/lib/THC/THCGeneral.c
@@ -303,6 +303,10 @@ void THCState_setDeviceAllocator(THCState* state, THCDeviceAllocator* allocator)
state->cudaDeviceAllocator = allocator;
}
+int THCState_isCachingAllocatorEnabled(THCState* state) {
+ return state->cudaHostAllocator == &THCCachingHostAllocator;
+}
+
int THCState_getNumDevices(THCState *state)
{
return state->numDevices;
diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in
index 7af3d79..06930cf 100644
--- a/lib/THC/THCGeneral.h.in
+++ b/lib/THC/THCGeneral.h.in
@@ -141,6 +141,7 @@ THC_API THAllocator* THCState_getCudaHostAllocator(THCState* state);
THC_API THAllocator* THCState_getCudaUVAAllocator(THCState* state);
THC_API THCDeviceAllocator* THCState_getDeviceAllocator(THCState* state);
THC_API void THCState_setDeviceAllocator(THCState* state, THCDeviceAllocator* allocator);
+THC_API int THCState_isCachingAllocatorEnabled(THCState* state);
THC_API void THCMagma_init(THCState *state);