Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBoris Fomitchev <borisfom@users.noreply.github.com>2016-11-24 01:38:57 +0300
committerSoumith Chintala <soumith@gmail.com>2016-11-24 01:38:57 +0300
commitf5932241e86087821a4c61dbde2c39a03d7c9883 (patch)
treee36533a796dc81b1c3cf75a7bc04ea41006891c0
parent2d75d411560df62f4ac291143f6b0f2e15378031 (diff)
Implemented cudaMemGetInfo for caching allocator (#600)
* Implemented cudaMemGetInfo for caching allocator
-rw-r--r--init.c5
-rw-r--r--lib/THC/THCCachingAllocator.cpp29
-rw-r--r--lib/THC/THCGeneral.c28
-rw-r--r--lib/THC/THCGeneral.h.in2
-rw-r--r--test/test_shutdown.lua61
5 files changed, 118 insertions, 7 deletions
diff --git a/init.c b/init.c
index 69f5583..124be5c 100644
--- a/init.c
+++ b/init.c
@@ -698,13 +698,14 @@ static int cutorch_getMemoryUsage(lua_State *L) {
size_t totalBytes = 0;
int curDevice;
THCudaCheck(cudaGetDevice(&curDevice));
+ THCState *state = cutorch_getstate(L);
int device = luaL_optint(L, 1, -10);
if (device == -10) { /* no argument passed, current device mem usage */
- THCudaCheck(cudaMemGetInfo(&freeBytes, &totalBytes));
+ THCudaCheck(THCudaMemGetInfo(state, &freeBytes, &totalBytes));
} else { /* argument was given, particular device's memory usage */
THCudaCheck(cudaSetDevice(device-1)); /* zero indexed */
- THCudaCheck(cudaMemGetInfo(&freeBytes, &totalBytes));
+ THCudaCheck(THCudaMemGetInfo(state, &freeBytes, &totalBytes));
THCudaCheck(cudaSetDevice(curDevice));
}
lua_pushnumber(L, freeBytes);
diff --git a/lib/THC/THCCachingAllocator.cpp b/lib/THC/THCCachingAllocator.cpp
index e2fc8d8..6a42ff0 100644
--- a/lib/THC/THCCachingAllocator.cpp
+++ b/lib/THC/THCCachingAllocator.cpp
@@ -205,6 +205,27 @@ struct THCCachingAllocator
return basePtr;
}
+ // Accumulates sizes of all memory blocks for given device in given free list
+ void cacheInfoAux(FreeBlocks& blocks, int dev_id, size_t* total, size_t* largest)
+ {
+ Block search_key(dev_id, 0, 0);
+ auto it = blocks.lower_bound(&search_key);
+ for (;it != blocks.end() && *it && (*it)->device == dev_id; ++it) {
+ size_t blocksize = (*it)->size;
+ *total += blocksize;
+ if (blocksize > *largest)
+ *largest = blocksize;
+ }
+ }
+
+ void cacheInfo(int dev_id, size_t* total, size_t* largest)
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+ cacheInfoAux(large_blocks, dev_id, total, largest);
+ cacheInfoAux(small_blocks, dev_id, total, largest);
+ }
+
+
/** combine previously split blocks */
void try_merge_blocks(Block* dst, Block* src, FreeBlocks& free_blocks)
{
@@ -327,12 +348,20 @@ static cudaError_t THCCachingAllocator_emptyCache(void* ctx)
return a->emptyCache();
}
+static cudaError_t THCCachingAllocator_cacheInfo(void* ctx, int dev_id, size_t* cachedAndFree, size_t* largestBlock)
+{
+ THCCachingAllocator* a = (THCCachingAllocator*) ctx;
+ a->cacheInfo(dev_id, cachedAndFree, largestBlock);
+ return cudaSuccess;
+}
+
static THCCachingAllocator caching_allocator;
static THCDeviceAllocator device_allocator = {
&THCCachingAllocator_malloc,
NULL,
&THCCachingAllocator_free,
&THCCachingAllocator_emptyCache,
+ &THCCachingAllocator_cacheInfo,
&caching_allocator
};
diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c
index 13f62be..547e060 100644
--- a/lib/THC/THCGeneral.c
+++ b/lib/THC/THCGeneral.c
@@ -41,6 +41,7 @@ static THCDeviceAllocator defaultDeviceAllocator = {
NULL,
&cudaFreeWrapper,
NULL,
+ NULL,
NULL
};
@@ -710,6 +711,33 @@ cudaError_t THCudaFree(THCState *state, void *ptr)
return allocator->free(allocator->state, ptr);
}
+cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes)
+{
+ size_t cachedBytes = 0;
+ size_t largestBlock = 0;
+ THCDeviceAllocator* allocator = state->cudaDeviceAllocator;
+
+ /* get info from CUDA first */
+ cudaError_t ret = cudaMemGetInfo(freeBytes, totalBytes);
+ if (ret!= cudaSuccess)
+ return ret;
+
+ int device;
+ ret = cudaGetDevice(&device);
+ if (ret!= cudaSuccess)
+ return ret;
+
+ /* not always true - our optimistic guess here */
+ largestBlock = *freeBytes;
+
+ if (allocator->cacheInfo != NULL)
+ allocator->cacheInfo(allocator->state, device, &cachedBytes, &largestBlock);
+
+ /* Adjust resulting free bytes number. largesBlock unused for now */
+ *freeBytes += cachedBytes;
+ return cudaSuccess;
+}
+
static ptrdiff_t applyHeapDelta(THCState *state) {
ptrdiff_t newHeapSize = THAtomicAddPtrdiff(&heapSize, state->heapDelta) + state->heapDelta;
state->heapDelta = 0;
diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in
index 8f55cf3..c685d37 100644
--- a/lib/THC/THCGeneral.h.in
+++ b/lib/THC/THCGeneral.h.in
@@ -49,6 +49,7 @@ typedef struct _THCDeviceAllocator {
cudaError_t (*realloc)(void*, void**, size_t, size_t, cudaStream_t);
cudaError_t (*free)(void*, void*);
cudaError_t (*emptyCache)(void*);
+ cudaError_t (*cacheInfo)(void*, int, size_t*, size_t*);
void* state;
} THCDeviceAllocator;
@@ -177,6 +178,7 @@ THC_API void __THCublasCheck(cublasStatus_t status, const char *file, const int
THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size);
THC_API cudaError_t THCudaFree(THCState *state, void *ptr);
+THC_API cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes);
THC_API void THCSetGCHandler(THCState *state,
void (*torchGCHandlerFunction)(void *data),
void *data );
diff --git a/test/test_shutdown.lua b/test/test_shutdown.lua
index e78a51e..750df06 100644
--- a/test/test_shutdown.lua
+++ b/test/test_shutdown.lua
@@ -1,13 +1,64 @@
local Threads = require 'threads'
require 'cutorch'
-print ("cutorch.hasHalf is ", cutorch.hasHalf)
+local function test_cudaEvent()
+ cutorch.reserveStreams(2)
+ cutorch.setStream(1)
+
+ local t1 = torch.CudaTensor(10000000):zero()
+ local t2 = torch.CudaTensor(1):zero()
+
+ local t1View = t1:narrow(1, 10000000, 1)
+ t1:fill(1)
+
+ -- Event is created here
+ local event = cutorch.Event()
+
+ cutorch.setStream(2)
+
+ -- assert below will fail without this
+ event:waitOn()
+ t2:copy(t1View)
+ -- revert to default stream
+ cutorch.setStream(0)
+end
+
+local Gig = 1024*1024*1024
+
+local function test_getMemInfo()
+ local sz = Gig*0.1
+ local t1 = torch.CudaTensor(sz):zero()
+ print('Memory usage after 1st allocation [free memory], [total memory]')
+ local free, total = cutorch.getMemoryUsage()
+ print(free/Gig, total/Gig)
+ local t2 = torch.CudaTensor(sz*1.3):zero()
+ print('Memory usage after 2nd allocation [free memory], [total memory]')
+ local free, total = cutorch.getMemoryUsage()
+ print(free/Gig, total/Gig)
+ t1 = nil
+ collectgarbage()
+ print('Memory usage after 1st deallocation [free memory], [total memory]')
+ local free, total = cutorch.getMemoryUsage()
+ print(free/Gig, total/Gig)
+ t2 = nil
+ collectgarbage()
+ print('Memory usage after 2nd deallocation [free memory], [total memory]')
+ free, total = cutorch.getMemoryUsage()
+ print(free/Gig, total/Gig)
+end
+
+print ("cutorch.hasHalf is ", cutorch.hasHalf)
print('Memory usage before intialization of threads [free memory], [total memory]')
-print(cutorch.getMemoryUsage())
-threads = Threads(100, function() require 'cutorch' end)
+local free, total = cutorch.getMemoryUsage()
+print(free/Gig, total/Gig)
+threads = Threads(20, function() require 'cutorch'; test_getMemInfo(); test_cudaEvent(); end)
print('Memory usage after intialization of threads [free memory], [total memory]')
-print(cutorch.getMemoryUsage())
+free, total = cutorch.getMemoryUsage()
+print(free/Gig, total/Gig)
threads:terminate()
+collectgarbage()
print('Memory usage after termination of threads [free memory], [total memory]')
-print(cutorch.getMemoryUsage())
+free, total = cutorch.getMemoryUsage()
+print(free/Gig, total/Gig)
+