Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <sgross@fb.com>2016-12-01 23:45:22 +0300
committerSam Gross <sgross@fb.com>2016-12-02 00:35:12 +0300
commiteb3fe1904d18934f177972ad89c973be92b97dc2 (patch)
tree5dd5a7723562f21d4b3bd4e65715876fa8a7ab2f
parent9d8e13d4e55d68a32d269c7f88c85d2e9b4e232f (diff)
Add caching allocator for pinned (host) memory
Adds a caching allocator for CUDA pinned (page-locked) memory. This avoid synchronization due to cudaFreeHost or cudaHostUnregister at the expense of potentially higher host memory usage. Correctness is preserved by recording CUDA events after each cudaMemcpyAsync involving the pinned memory. The pinned memory allocations are not reused until all events associated with it have completed.
-rw-r--r--README.md4
-rw-r--r--init.c2
-rw-r--r--lib/THC/CMakeLists.txt25
-rw-r--r--lib/THC/THC.h1
-rw-r--r--lib/THC/THCAllocator.c20
-rw-r--r--lib/THC/THCAllocator.h5
-rw-r--r--lib/THC/THCCachingAllocator.cpp8
-rw-r--r--lib/THC/THCCachingHostAllocator.cpp240
-rw-r--r--lib/THC/THCCachingHostAllocator.h30
-rw-r--r--lib/THC/THCGeneral.c34
-rw-r--r--lib/THC/THCGeneral.h.in2
-rw-r--r--lib/THC/THCTensorCopy.c1
-rw-r--r--lib/THC/generic/THCTensorCopy.c10
-rw-r--r--test/test.lua29
14 files changed, 368 insertions, 43 deletions
diff --git a/README.md b/README.md
index 7206b27..3b4a174 100644
--- a/README.md
+++ b/README.md
@@ -30,9 +30,9 @@ Most other (besides float) CPU torch tensor types now have a cutorch equivalent,
### CUDA memory allocation
Set the environment variable `THC_CACHING_ALLOCATOR=1` to enable the caching CUDA memory allocator.
-By default, cutorch calls `cudaMalloc` and `cudaFree` when CUDA tensors are allocated and freed. This is expensive because `cudaFree` synchronizes the CPU with the GPU. Setting `THC_CACHING_ALLOCATOR=1` will cause cutorch to cache and re-use CUDA allocations to avoid synchronizations.
+By default, cutorch calls `cudaMalloc` and `cudaFree` when CUDA tensors are allocated and freed. This is expensive because `cudaFree` synchronizes the CPU with the GPU. Setting `THC_CACHING_ALLOCATOR=1` will cause cutorch to cache and re-use CUDA device and pinned memory allocations to avoid synchronizations.
-With the caching memory allocator, allocations and frees should logically be considered "usages" of the memory segment associated with streams, just like kernel launches. The programmer must insert the proper synchronization if memory segments are used from multiple streams.
+With the caching memory allocator, device allocations and frees should logically be considered "usages" of the memory segment associated with streams, just like kernel launches. The programmer must insert the proper synchronization if memory segments are used from multiple streams.
###`cutorch.*` API
- `cutorch.synchronize()` : All of the CUDA API is asynchronous (barring a few functions), which means that you can queue up operations. To wait for the operations to finish, you can issue `cutorch.synchronize()` in your code, when the code waits for all GPU operations on the current GPU to finish. WARNING: synchronizes the CPU host with respect to the current device (as per `cutorch.getDevice()`) only.
diff --git a/init.c b/init.c
index eb52bc7..7be0823 100644
--- a/init.c
+++ b/init.c
@@ -2,6 +2,7 @@
#include "luaT.h"
#include "THCGeneral.h"
#include "THCCachingAllocator.h"
+#include "THCCachingHostAllocator.h"
#include "THCSleep.h"
#include "THCTensorRandom.h"
#include "THCHalf.h" // for CUDA_HALF_TENSOR
@@ -1005,6 +1006,7 @@ int luaopen_libcutorch(lua_State *L)
char* thc_caching_allocator = getenv("THC_CACHING_ALLOCATOR");
if (thc_caching_allocator && strcmp(thc_caching_allocator, "1") == 0) {
THCState_setDeviceAllocator(state, THCCachingAllocator_get());
+ state->cudaHostAllocator = &THCCachingHostAllocator;
}
THCudaInit(state);
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt
index 25a793c..34a4550 100644
--- a/lib/THC/CMakeLists.txt
+++ b/lib/THC/CMakeLists.txt
@@ -22,17 +22,20 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
endif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.9.3")
endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
-if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
- if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.7" OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "4.7" )
- # add c++11 flag
- set_source_files_properties(THCTensorRandom.cpp THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++11)
- else()
- # add c++0x flag
- set_source_files_properties(THCTensorRandom.cpp THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++0x)
- endif()
-else()
+IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ IF(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.7" OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "4.7" )
+ SET(CXX_VERSION "c++11")
+ ELSE()
+ SET(CXX_VERSION "c++0x")
+ ENDIF()
+ SET_SOURCE_FILES_PROPERTIES(
+ THCTensorRandom.cpp
+ THCCachingAllocator.cpp
+ THCCachingHostAllocator.cpp
+ PROPERTIES COMPILE_FLAGS -std=${CXX_VERSION})
+ELSE()
SET(CMAKE_CXX_STANDARD 11)
-endif()
+ENDIF()
INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS})
@@ -125,6 +128,7 @@ ENDIF()
SET(src
THCCachingAllocator.cpp
+ THCCachingHostAllocator.cpp
THCGeneral.c
THCStorageCopy.c
THCStream.c
@@ -221,6 +225,7 @@ INSTALL(FILES
THCSortUtils.cuh
THCAllocator.h
THCCachingAllocator.h
+ THCCachingHostAllocator.h
THCDeviceUtils.cuh
THCDeviceTensor.cuh
THCDeviceTensor-inl.cuh
diff --git a/lib/THC/THC.h b/lib/THC/THC.h
index b9e9885..e3840dc 100644
--- a/lib/THC/THC.h
+++ b/lib/THC/THC.h
@@ -5,6 +5,7 @@
#include "THCAllocator.h"
#include "THCBlas.h"
#include "THCCachingAllocator.h"
+#include "THCCachingHostAllocator.h"
#include "THCSleep.h"
#include "THCStorage.h"
#include "THCStorageCopy.h"
diff --git a/lib/THC/THCAllocator.c b/lib/THC/THCAllocator.c
index 263d8d5..9ff447d 100644
--- a/lib/THC/THCAllocator.c
+++ b/lib/THC/THCAllocator.c
@@ -18,11 +18,11 @@ static void THCudaHostAllocator_free(void* ctx, void* ptr) {
THCudaCheck(cudaFreeHost(ptr));
}
-void THCAllocator_init(THCState *state) {
- state->cudaHostAllocator->malloc = &THCudaHostAllocator_malloc;
- state->cudaHostAllocator->realloc = NULL;
- state->cudaHostAllocator->free = &THCudaHostAllocator_free;
-}
+THAllocator THCudaHostAllocator = {
+ &THCudaHostAllocator_malloc,
+ NULL,
+ &THCudaHostAllocator_free
+};
static cudaError_t THCIpcAllocator_malloc(void* ctx, void** devPtr, size_t size, cudaStream_t stream)
{
@@ -60,8 +60,8 @@ static void THCUVAAllocator_free(void* ctx, void* ptr) {
THCudaCheck(cudaFree(ptr));
}
-void THCUVAAllocator_init(THAllocator *cudaUVAAllocator) {
- cudaUVAAllocator->malloc = &THCUVAAllocator_alloc;
- cudaUVAAllocator->realloc = NULL;
- cudaUVAAllocator->free = &THCUVAAllocator_free;
-}
+THAllocator THCUVAAllocator = {
+ &THCUVAAllocator_alloc,
+ NULL,
+ &THCUVAAllocator_free
+};
diff --git a/lib/THC/THCAllocator.h b/lib/THC/THCAllocator.h
index e6f2f28..d6a0a9b 100644
--- a/lib/THC/THCAllocator.h
+++ b/lib/THC/THCAllocator.h
@@ -3,9 +3,8 @@
#include "THCGeneral.h"
-THC_API void THCAllocator_init(THCState *state);
-THC_API void THCUVAAllocator_init(THAllocator *state);
-
+extern THAllocator THCudaHostAllocator;
+extern THAllocator THCUVAAllocator;
extern THCDeviceAllocator THCIpcAllocator;
#endif
diff --git a/lib/THC/THCCachingAllocator.cpp b/lib/THC/THCCachingAllocator.cpp
index 6a42ff0..85cafd4 100644
--- a/lib/THC/THCCachingAllocator.cpp
+++ b/lib/THC/THCCachingAllocator.cpp
@@ -213,11 +213,12 @@ struct THCCachingAllocator
for (;it != blocks.end() && *it && (*it)->device == dev_id; ++it) {
size_t blocksize = (*it)->size;
*total += blocksize;
- if (blocksize > *largest)
- *largest = blocksize;
+ if (blocksize > *largest) {
+ *largest = blocksize;
+ }
}
}
-
+
void cacheInfo(int dev_id, size_t* total, size_t* largest)
{
std::lock_guard<std::mutex> lock(mutex);
@@ -225,7 +226,6 @@ struct THCCachingAllocator
cacheInfoAux(small_blocks, dev_id, total, largest);
}
-
/** combine previously split blocks */
void try_merge_blocks(Block* dst, Block* src, FreeBlocks& free_blocks)
{
diff --git a/lib/THC/THCCachingHostAllocator.cpp b/lib/THC/THCCachingHostAllocator.cpp
new file mode 100644
index 0000000..6d1b870
--- /dev/null
+++ b/lib/THC/THCCachingHostAllocator.cpp
@@ -0,0 +1,240 @@
+#include "THCCachingHostAllocator.h"
+
+#include <cuda_runtime_api.h>
+#include <deque>
+#include <mutex>
+#include <set>
+#include <stdint.h>
+#include <unordered_map>
+#include <utility>
+
+
+namespace {
+
+struct BlockSize
+{
+ size_t size; // allocation size
+ void* ptr; // host memory pointer
+
+ BlockSize(size_t size, void* ptr=NULL) : size(size), ptr(ptr) {}
+};
+
+struct Block : public BlockSize
+{
+ bool allocated; // true if the block is currently allocated
+ int event_count; // number of outstanding cuda events
+
+ Block(size_t size, void* ptr, bool allocated) :
+ BlockSize(size, ptr), allocated(allocated), event_count(0) { }
+};
+
+static bool BlockComparator(const BlockSize& a, const BlockSize& b)
+{
+ // sort by size, break ties with pointer
+ if (a.size != b.size) {
+ return a.size < b.size;
+ }
+ return (uintptr_t)a.ptr < (uintptr_t)b.ptr;
+}
+
+struct HostAllocator
+{
+ typedef bool (*Comparison)(const BlockSize&, const BlockSize&);
+
+ // lock around all operations
+ std::mutex mutex;
+
+ // blocks by pointer
+ std::unordered_map<void*, Block> blocks;
+
+ // pointers that are ready to be allocated (event_count=0)
+ std::set<BlockSize, Comparison> available;
+
+ // outstanding cuda events
+ std::deque<std::pair<cudaEvent_t, void*>> cuda_events;
+
+ HostAllocator() : available(BlockComparator) {}
+
+ cudaError_t malloc(void** ptr, size_t size)
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+
+ // process outstanding cuda events which may have occurred
+ cudaError_t err = processEvents();
+ if (err != cudaSuccess) {
+ return err;
+ }
+
+ // search for the smallest block which can hold this allocation
+ BlockSize search_key(size);
+ auto it = available.lower_bound(search_key);
+ if (it != available.end()) {
+ Block& block = blocks.at(it->ptr);
+ THAssert(!block.allocated && block.event_count == 0);
+ block.allocated = true;
+ *ptr = block.ptr;
+ available.erase(it);
+ return cudaSuccess;
+ }
+
+ // allocate a new block if no cached allocation is found
+ err = cudaHostAlloc(ptr, size, cudaHostAllocDefault);
+ if (err != cudaSuccess) {
+ return err;
+ }
+
+ blocks.insert({*ptr, Block(size, *ptr, true)});
+ return cudaSuccess;
+ }
+
+ cudaError_t free(void* ptr)
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+
+ if (!ptr) {
+ return cudaSuccess;
+ }
+
+ auto it = blocks.find(ptr);
+ THAssert(it != blocks.end());
+
+ Block& block = it->second;
+ THAssert(block.allocated);
+
+ block.allocated = false;
+ if (block.event_count == 0) {
+ // the block can be re-used if there are no outstanding cuda events
+ available.insert(block);
+ }
+ return cudaSuccess;
+ }
+
+ cudaError_t recordEvent(void* ptr, cudaStream_t stream)
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+ cudaError_t err;
+
+ auto it = blocks.find(ptr);
+ if (it == blocks.end()) {
+ // ignore events for untracked pointers
+ return cudaSuccess;
+ }
+
+ Block& block = it->second;
+ THAssert(block.allocated);
+
+ // create and record an event in the given stream
+ cudaEvent_t event;
+ err = cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
+ if (err != cudaSuccess) {
+ return err;
+ }
+ err = cudaEventRecord(event, stream);
+ if (err != cudaSuccess) {
+ return err;
+ }
+
+ // the block will not be re-used until all associated events have occured
+ block.event_count++;
+ cuda_events.emplace_back(event, ptr);
+ return cudaSuccess;
+ }
+
+ cudaError_t processEvents()
+ {
+ // Process outstanding cudaEvents. Events that are completed are removed
+ // from the queue, and the 'event_count' for the corresponding allocation
+ // is decremented. Stops at the first event which has not been completed.
+ // Since events on different devices or streams may occur out of order,
+ // the processing of some events may be delayed.
+ while (!cuda_events.empty()) {
+ auto& e = cuda_events.front();
+ cudaEvent_t event = e.first;
+
+ cudaError_t err = cudaEventQuery(event);
+ if (err == cudaErrorNotReady) {
+ break;
+ } else if (err != cudaSuccess) {
+ return err;
+ }
+ err = cudaEventDestroy(event);
+ if (err != cudaSuccess) {
+ return err;
+ }
+
+ Block& block = blocks.at(e.second);
+ block.event_count--;
+ if (block.event_count == 0 && !block.allocated) {
+ available.insert(block);
+ }
+ cuda_events.pop_front();
+ }
+ return cudaSuccess;
+ }
+
+ void emptyCache()
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+
+ // remove events for freed blocks
+ std::deque<std::pair<cudaEvent_t, void*>> new_events;
+ for (auto it = cuda_events.begin(); it != cuda_events.end(); ++it) {
+ cudaEvent_t event = it->first;
+ Block& block = blocks.at(it->second);
+ if (!block.allocated) {
+ THCudaCheckWarn(cudaEventDestroy(event));
+ block.event_count--;
+ } else {
+ new_events.push_back(*it);
+ }
+ }
+ cuda_events.swap(new_events);
+
+ // clear list of available blocks
+ available.clear();
+
+ // free and erase non-allocated blocks
+ for (auto it = blocks.begin(); it != blocks.end();) {
+ Block& block = it->second;
+ if (!block.allocated) {
+ THCudaCheckWarn(cudaFreeHost(block.ptr));
+ it = blocks.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+};
+
+} // namespace
+
+static HostAllocator allocator;
+
+static void* THCCachingHostAllocator_malloc(void* ctx, ptrdiff_t size)
+{
+ THAssert(size >= 0);
+ void *ptr;
+ THCudaCheck(allocator.malloc(&ptr, size));
+ return ptr;
+}
+
+static void THCCachingHostAllocator_free(void* ctx, void* ptr)
+{
+ allocator.free(ptr);
+}
+
+cudaError_t THCCachingHostAllocator_recordEvent(void *ptr, cudaStream_t stream)
+{
+ return allocator.recordEvent(ptr, stream);
+}
+
+void THCCachingHostAllocator_emptyCache()
+{
+ allocator.emptyCache();
+}
+
+THAllocator THCCachingHostAllocator = {
+ &THCCachingHostAllocator_malloc,
+ NULL,
+ &THCCachingHostAllocator_free,
+};
diff --git a/lib/THC/THCCachingHostAllocator.h b/lib/THC/THCCachingHostAllocator.h
new file mode 100644
index 0000000..2f2adc1
--- /dev/null
+++ b/lib/THC/THCCachingHostAllocator.h
@@ -0,0 +1,30 @@
+#ifndef THC_CACHING_HOST_ALLOCATOR_INC
+#define THC_CACHING_HOST_ALLOCATOR_INC
+
+#include "THCGeneral.h"
+
+//
+// A caching allocator for CUDA host allocations (pinned memory).
+//
+// This provides a drop-in replacement for THCudaHostAllocator, which re-uses
+// freed pinned (page-locked) memory allocations. This avoids device
+// synchronizations due to cudaFreeHost calls.
+//
+// To ensure correct behavior, THCCachingHostAllocator_recordEvent must be
+// called anytime a pointer from this allocator is used in a cudaMemcpyAsync
+// call between host and device. The THC library implements this for storages
+// and tensors in THCTensor_(copyAsyncCPU) and THCTensor_(copyAsyncCuda).
+//
+// Note that this allocator does not split larger allocations into smaller
+// blocks, unlike the caching device allocator.
+//
+extern THAllocator THCCachingHostAllocator;
+
+// Records an event in the specified stream. The allocation 'ptr' will not be
+// re-used until the event has occured.
+THC_API cudaError_t THCCachingHostAllocator_recordEvent(void *ptr, cudaStream_t stream);
+
+// Releases cached pinned memory allocations via cudaHostFree
+THC_API void THCCachingHostAllocator_emptyCache();
+
+#endif
diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c
index 3677432..153d41d 100644
--- a/lib/THC/THCGeneral.c
+++ b/lib/THC/THCGeneral.c
@@ -1,10 +1,11 @@
#include "THCGeneral.h"
#include "TH.h"
-#include "THCTensorRandom.h"
-#include "THCBlas.h"
#include "THCAllocator.h"
-#include "THCThreadLocal.h"
+#include "THCBlas.h"
+#include "THCCachingHostAllocator.h"
#include "THCStream.h"
+#include "THCThreadLocal.h"
+#include "THCTensorRandom.h"
#include <stdlib.h>
#include <stdint.h>
@@ -50,6 +51,12 @@ void THCudaInit(THCState* state)
if (!state->cudaDeviceAllocator) {
state->cudaDeviceAllocator = &defaultDeviceAllocator;
}
+ if (!state->cudaHostAllocator) {
+ state->cudaHostAllocator = &THCudaHostAllocator;
+ }
+ if (!state->cudaUVAAllocator) {
+ state->cudaUVAAllocator = &THCUVAAllocator;
+ }
int numDevices = 0;
THCudaCheck(cudaGetDeviceCount(&numDevices));
@@ -75,12 +82,6 @@ void THCudaInit(THCState* state)
state->rngState = (THCRNGState*)malloc(sizeof(THCRNGState));
THCRandom_init(state, numDevices, device);
- state->cudaHostAllocator = (THAllocator*)malloc(sizeof(THAllocator));
- THCAllocator_init(state);
-
- state->cudaUVAAllocator = (THAllocator*)malloc(sizeof(THAllocator));
- THCUVAAllocator_init(state->cudaUVAAllocator);
-
// By default, all direct p2p kernel access (besides copy) is disallowed,
// since direct access without knowing whether or not a certain operation
// should be cross-GPU leads to synchronization errors. The user can choose
@@ -130,8 +131,6 @@ void THCudaShutdown(THCState* state)
THCRandom_shutdown(state);
free(state->rngState);
- free(state->cudaHostAllocator);
- free(state->cudaUVAAllocator);
free(state->deviceProperties);
int deviceCount = 0;
@@ -175,6 +174,9 @@ void THCudaShutdown(THCState* state)
if (state->cudaDeviceAllocator->emptyCache) {
state->cudaDeviceAllocator->emptyCache(state->cudaDeviceAllocator->state);
}
+ if (state->cudaHostAllocator == &THCCachingHostAllocator) {
+ THCCachingHostAllocator_emptyCache();
+ }
free(state->currentStreams);
THCThreadLocal_free(state->currentPerDeviceBlasHandle);
@@ -603,6 +605,14 @@ void __THCudaCheck(cudaError_t err, const char *file, const int line)
}
}
+void __THCudaCheckWarn(cudaError_t err, const char *file, const int line)
+{
+ if(err != cudaSuccess)
+ {
+ fprintf(stderr, "THCudaCheckWarn FAIL file=%s line=%i error=%i : %s\n", file, line, err, cudaGetErrorString(err));
+ }
+}
+
void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
{
if(status != CUBLAS_STATUS_SUCCESS)
@@ -698,7 +708,7 @@ cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalB
/* not always true - our optimistic guess here */
largestBlock = *freeBytes;
-
+
if (allocator->cacheInfo != NULL)
allocator->cacheInfo(allocator->state, device, &cachedBytes, &largestBlock);
diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in
index 1a37390..a88bd7d 100644
--- a/lib/THC/THCGeneral.h.in
+++ b/lib/THC/THCGeneral.h.in
@@ -173,9 +173,11 @@ THC_API size_t THCState_getCurrentDeviceScratchSpaceSize(THCState* state);
THC_API size_t THCState_getDeviceScratchSpaceSize(THCState* state, int device);
#define THCudaCheck(err) __THCudaCheck(err, __FILE__, __LINE__)
+#define THCudaCheckWarn(err) __THCudaCheckWarn(err, __FILE__, __LINE__)
#define THCublasCheck(err) __THCublasCheck(err, __FILE__, __LINE__)
THC_API void __THCudaCheck(cudaError_t err, const char *file, const int line);
+THC_API void __THCudaCheckWarn(cudaError_t err, const char *file, const int line);
THC_API void __THCublasCheck(cublasStatus_t status, const char *file, const int line);
THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size);
diff --git a/lib/THC/THCTensorCopy.c b/lib/THC/THCTensorCopy.c
index 1bf8980..59c0934 100644
--- a/lib/THC/THCTensorCopy.c
+++ b/lib/THC/THCTensorCopy.c
@@ -1,6 +1,7 @@
#include "THCTensorCopy.h"
#include "THCGeneral.h"
#include "THCTensor.h"
+#include "THCCachingHostAllocator.h"
#include "THCHalf.h"
diff --git a/lib/THC/generic/THCTensorCopy.c b/lib/THC/generic/THCTensorCopy.c
index 64f8364..f461c0e 100644
--- a/lib/THC/generic/THCTensorCopy.c
+++ b/lib/THC/generic/THCTensorCopy.c
@@ -149,11 +149,14 @@ void THCTensor_(copyAsyncCPU)(THCState *state, THCTensor *self, struct THTensor
THCudaCheck(cudaSetDevice(tensorDevice));
}
+ cudaStream_t stream = THCState_getCurrentStream(state);
THCudaCheck(cudaMemcpyAsync(THCTensor_(data)(state, self),
THTensor_(data)(src),
THTensor_(nElement)(src) * sizeof(real),
cudaMemcpyHostToDevice,
- THCState_getCurrentStream(state)));
+ stream));
+
+ THCudaCheck(THCCachingHostAllocator_recordEvent(src->storage->data, stream));
if (currentDevice != tensorDevice) {
THCudaCheck(cudaSetDevice(currentDevice));
@@ -177,11 +180,14 @@ void THTensor_(copyAsyncCuda)(THCState *state, THTensor *self, struct THCTensor
THCudaCheck(cudaSetDevice(tensorDevice));
}
+ cudaStream_t stream = THCState_getCurrentStream(state);
THCudaCheck(cudaMemcpyAsync(THTensor_(data)(self),
THCTensor_(data)(state, src),
THCTensor_(nElement)(state, src) * sizeof(real),
cudaMemcpyDeviceToHost,
- THCState_getCurrentStream(state)));
+ stream));
+
+ THCudaCheck(THCCachingHostAllocator_recordEvent(src->storage->data, stream));
if (currentDevice != tensorDevice) {
THCudaCheck(cudaSetDevice(currentDevice));
diff --git a/test/test.lua b/test/test.lua
index 93e1419..bce8109 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -3923,6 +3923,35 @@ function test.kernelP2PAccess()
end
end
+if os.getenv('THC_CACHING_ALLOCATOR') == '1' then
+ local function getCyclesPerMs()
+ cutorch.synchronize()
+ local t = torch.Timer()
+ cutorch._sleep(1e6)
+ cutorch.synchronize()
+ return 1e6 / (t:time().real * 1000)
+ end
+
+ function test.cachedPinnedMemory()
+ local cyclesPerMs = getCyclesPerMs()
+
+ -- check that allocations are re-used after deletion
+ t = cutorch.createCudaHostTensor({1})
+ ptr = t:data()
+ t = nil; collectgarbage()
+ t = cutorch.createCudaHostTensor({1})
+ tester:asserteq(t:data(), ptr, 'allocation not reused')
+
+ -- check that the allocation is not re-used if it's in-use by a copy
+ gpuTensor = torch.CudaTensor({0})
+ cutorch._sleep(50 * cyclesPerMs) -- delay the copy
+ gpuTensor:copyAsync(t)
+ t = nil; collectgarbage()
+ t = cutorch.createCudaHostTensor({1})
+ tester:assertne(t:data(), ptr, 'allocation re-used too soon')
+ end
+end
+
-- unfortunately, torch.Tester() forgot setUp and tearDown functions.
-- It would be nice to fix torch.Tester() eventually.
local function setUp()