diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-11-24 03:40:03 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-24 03:40:03 +0300 |
commit | e9e131e6181f9a01dc2367ac250fe6758bcb6056 (patch) | |
tree | 3d0dc3d3253c37a3b46355eb40bb8d6723d832d2 | |
parent | f8d05d267e3be9a67055284a22b073fc7611ac9c (diff) |
Revert "Lazily initialize CUDA devices"revert-610-lazy
-rw-r--r-- | init.c | 10 | ||||
-rw-r--r-- | lib/THC/CMakeLists.txt | 5 | ||||
-rw-r--r-- | lib/THC/THCGeneral.c | 211 | ||||
-rw-r--r-- | lib/THC/THCGeneral.h.in | 6 | ||||
-rw-r--r-- | lib/THC/THCTensorRandom.cpp | 133 | ||||
-rw-r--r-- | lib/THC/THCTensorRandom.cu | 124 | ||||
-rw-r--r-- | lib/THC/THCTensorRandom.h | 12 |
7 files changed, 255 insertions, 246 deletions
@@ -776,35 +776,35 @@ static int cutorch_getDeviceProperties(lua_State *L) static int cutorch_seed(lua_State *L) { - unsigned long long seed = THCRandom_seed(cutorch_getstate(L)); + unsigned long seed = THCRandom_seed(cutorch_getstate(L)); lua_pushnumber(L, seed); return 1; } static int cutorch_seedAll(lua_State *L) { - unsigned long long seed = THCRandom_seedAll(cutorch_getstate(L)); + unsigned long seed = THCRandom_seedAll(cutorch_getstate(L)); lua_pushnumber(L, seed); return 1; } static int cutorch_initialSeed(lua_State *L) { - unsigned long long seed = THCRandom_initialSeed(cutorch_getstate(L)); + unsigned long seed = THCRandom_initialSeed(cutorch_getstate(L)); lua_pushnumber(L, seed); return 1; } static int cutorch_manualSeed(lua_State *L) { - unsigned long long seed = luaL_checknumber(L, 1); + unsigned long seed = luaL_checknumber(L, 1); THCRandom_manualSeed(cutorch_getstate(L), seed); return 0; } static int cutorch_manualSeedAll(lua_State* L) { - unsigned long long seed = luaL_checknumber(L, 1); + unsigned long seed = luaL_checknumber(L, 1); THCRandom_manualSeedAll(cutorch_getstate(L), seed); return 0; } diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt index 08e6fbc..c8916d3 100644 --- a/lib/THC/CMakeLists.txt +++ b/lib/THC/CMakeLists.txt @@ -25,10 +25,10 @@ endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.7" OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "4.7" ) # add c++11 flag - set_source_files_properties(THCTensorRandom.cpp THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++11) + set_source_files_properties(THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++11) else() # add c++0x flag - set_source_files_properties(THCTensorRandom.cpp THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++0x) + set_source_files_properties(THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++0x) endif() else() SET(CMAKE_CXX_STANDARD 11) @@ -130,7 +130,6 @@ SET(src THCStream.c THCTensor.c THCTensorCopy.c - THCTensorRandom.cpp THCThreadLocal.c ) diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c index 2d4043d..547e060 100644 --- a/lib/THC/THCGeneral.c +++ b/lib/THC/THCGeneral.c @@ -81,21 +81,8 @@ void THCudaInit(THCState* state) state->cudaUVAAllocator = (THAllocator*)malloc(sizeof(THAllocator)); THCUVAAllocator_init(state->cudaUVAAllocator); - // By default, all direct p2p kernel access (besides copy) is disallowed, - // since direct access without knowing whether or not a certain operation - // should be cross-GPU leads to synchronization errors. The user can choose - // to disable this functionality, however. - state->p2pKernelAccessEnabled = 0; - - // p2pAccessEnabled records if p2p copies are allowed between pairs of - // devices. Values include "1" (copy allowed), "0" (copy not allowed), and - // "-1" (unknown). - state->p2pAccessEnabled = (int**) malloc(sizeof(int*) * numDevices); - for (int i = 0; i < numDevices; ++i) { - state->p2pAccessEnabled[i] = (int*) malloc(sizeof(int) * numDevices); - memset(state->p2pAccessEnabled[i], -1, sizeof(int) * numDevices); - state->p2pAccessEnabled[i][i] = 1; - } + /* Enable P2P access between all pairs, if possible */ + THCudaEnablePeerToPeerAccess(state); for (int i = 0; i < numDevices; ++i) { THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, i); @@ -111,15 +98,22 @@ void THCudaInit(THCState* state) int numSM = state->deviceProperties[i].multiProcessorCount; size_t sizePerStream = numSM * GLOBAL_SCRATCH_SPACE_PER_SM_STREAM; res->scratchSpacePerStream = sizePerStream; + + /* Allocate scratch space for each stream */ + res->devScratchSpacePerStream = (void**) malloc(sizeof(void*)); + THCudaCheck(THCudaMalloc(state, &res->devScratchSpacePerStream[0], + sizePerStream)); } /* Restore to previous device */ THCudaCheck(cudaSetDevice(device)); - // Unlike CUDA streams, there is no NULL cuBLAS handle. The default THC - // cuBLAS handle is the first user BLAS handle. Note that the actual BLAS - // handles are created lazily. - state->numUserBlasHandles = 1; + /* There is no such thing as a default cublas handle. + To maintain consistency with streams API, handle 0 is always NULL and we + start counting at 1. If currentPerDeviceBlasHandle is 0 (the default + thread-local value), then we assume it means 1. + */ + THCState_reserveBlasHandles(state, 1); state->heapSoftmax = 3e8; // 300MB, adjusted upward dynamically state->heapDelta = 0; @@ -153,9 +147,10 @@ void THCudaShutdown(THCState* state) for (int i = 1; i <= state->numUserStreams; ++i) { THCStream_free(res->streams[i]); } - /* Free user defined BLAS handles */ - for (int i = 0; i < res->numBlasHandles; ++i) { - THCublasCheck(cublasDestroy(res->blasHandles[i])); + /* Free Torch-defined handles (0 is NULL for consistency with streams API) */ + for (int handle = 1; handle <= state->numUserBlasHandles; ++handle) { + THCublasCheck(cublasDestroy( + THCState_getDeviceBlasHandle(state, dev, handle))); } /* Free per-stream scratch space; starts at 0 because there is space for the default stream as well*/ @@ -179,36 +174,79 @@ void THCudaShutdown(THCState* state) THCudaCheck(cudaSetDevice(prevDev)); } -int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess) +void THCudaEnablePeerToPeerAccess(THCState* state) { - if (dev < 0 || dev >= state->numDevices) { - THError("%d is not a device", dev); - } - if (devToAccess < 0 || devToAccess >= state->numDevices) { - THError("%d is not a device", devToAccess); + /* By default, all direct p2p kernel access (besides copy) is disallowed, */ + /* since direct access without knowing whether or not a certain operation */ + /* should be cross-GPU leads to synchronization errors. The user can choose */ + /* to disable this functionality, however. */ + state->p2pKernelAccessEnabled = 0; + + int prevDev = -1; + THCudaCheck(cudaGetDevice(&prevDev)); + + int numDevices = -1; + THCudaCheck(cudaGetDeviceCount(&numDevices)); + + state->p2pAccessEnabled = (int**) malloc(sizeof(int*) * numDevices); + for (int i = 0; i < numDevices; ++i) { + state->p2pAccessEnabled[i] = (int*) malloc(sizeof(int) * numDevices); } - if (state->p2pAccessEnabled[dev][devToAccess] == -1) { - int prevDev = 0; - THCudaCheck(cudaGetDevice(&prevDev)); - THCudaCheck(cudaSetDevice(dev)); - int access = 0; - THCudaCheck(cudaDeviceCanAccessPeer(&access, dev, devToAccess)); - if (access) { - cudaError_t err = cudaDeviceEnablePeerAccess(devToAccess, 0); - if (err == cudaErrorPeerAccessAlreadyEnabled) { - // ignore and clear the error if access was already enabled - cudaGetLastError(); + /* Build a table of all allowed p2p accesses, to avoid checking the p2p + status at runtime. */ + for (int i = 0; i < numDevices; ++i) { + THCudaCheck(cudaSetDevice(i)); + + for (int j = 0; j < numDevices; ++j) { + /* Presume no access by default */ + state->p2pAccessEnabled[i][j] = 0; + + if (i == j) { + /* A GPU can access itself */ + state->p2pAccessEnabled[i][j] = 1; } else { - THCudaCheck(err); + int access = 0; + THCudaCheck(cudaDeviceCanAccessPeer(&access, i, j)); + + if (access) { + cudaError_t err = cudaDeviceEnablePeerAccess(j, 0); + if (err == cudaErrorPeerAccessAlreadyEnabled) { + /* It is possible that another thread has already enabled access. */ + /* Any future call to cudaGetLastError will now return an error, */ + /* even though we've already dealt with this specific error here. */ + /* Call cudaGetLastError once to reset the last error state. */ + cudaGetLastError(); + + /* The above should have cleared status */ + THCudaCheck(cudaGetLastError()); + } else { + /* In case there are other unhandled errors returned from the */ + /* above */ + THCudaCheck(err); + } + + /* Access could be enabled, or was already enabled */ + state->p2pAccessEnabled[i][j] = 1; + } } - state->p2pAccessEnabled[dev][devToAccess] = 1; - } else { - state->p2pAccessEnabled[dev][devToAccess] = 0; } + } - THCudaCheck(cudaSetDevice(prevDev)); + /* Restore previous device before continuing */ + THCudaCheck(cudaSetDevice(prevDev)); +} + +int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess) +{ + if (dev < 0 || dev >= state->numDevices) { + THError("%d is not a device", dev); + } + + if (devToAccess < 0 || dev >= state->numDevices) { + THError("%d is not a device", devToAccess); } + return state->p2pAccessEnabled[dev][devToAccess]; } @@ -289,20 +327,6 @@ int THCState_getNumDevices(THCState *state) return state->numDevices; } -static void THCState_initializeScratchSpace(THCState* state, int dev) -{ - THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev); - if (res->devScratchSpacePerStream) { - return; - } - size_t size = (state->numUserStreams + 1) * sizeof(void*); - void** scratch = (void**)malloc(size); - for (int i = 0; i <= state->numUserStreams; ++i) { - THCudaCheck(THCudaMalloc(state, &scratch[i], res->scratchSpacePerStream)); - } - res->devScratchSpacePerStream = scratch; -} - void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking) { if (numStreams <= state->numUserStreams) @@ -322,7 +346,6 @@ void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking) THCStream** newStreams = realloc(res->streams, (numStreams + 1) * sizeof(THCStream*)); THAssert(newStreams); - THCState_initializeScratchSpace(state, dev); void** newScratchSpace = realloc(res->devScratchSpacePerStream, (numStreams + 1) * sizeof(void*)); THAssert(newScratchSpace); @@ -346,37 +369,45 @@ void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking) THCudaCheck(cudaSetDevice(prevDev)); } -void THCState_reserveDeviceBlasHandles(THCState* state, int device, int numBlasHandles) +void THCState_reserveBlasHandles(THCState* state, int numBlasHandles) { - int prevDev = -1; - THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device); - if (numBlasHandles <= res->numBlasHandles) { + if (numBlasHandles <= state->numUserBlasHandles) + { return; } + int prevDev = -1; THCudaCheck(cudaGetDevice(&prevDev)); - THCudaCheck(cudaSetDevice(device)); - size_t size = numBlasHandles * sizeof(cublasHandle_t); - cublasHandle_t* handles = (cublasHandle_t*) realloc(res->blasHandles, size); - for (int i = res->numBlasHandles; i < numBlasHandles; ++i) { - handles[i] = NULL; - THCublasCheck(cublasCreate(&handles[i])); - } - res->blasHandles = handles; - res->numBlasHandles = numBlasHandles; + /* Otherwise, we have to allocate a new set of blasHandles */ + for (int dev = 0; dev < state->numDevices; ++dev) { + THCudaCheck(cudaSetDevice(dev)); - THCudaCheck(cudaSetDevice(prevDev)); -} + /* +1 to be consistent with stream API, blas handle 0 is NULL and unused */ + cublasHandle_t* newBlasHandles = + (cublasHandle_t*) malloc((numBlasHandles + 1) * sizeof(cublasHandle_t)); -void THCState_reserveBlasHandles(THCState* state, int numBlasHandles) -{ - // cuBLAS handles are created lazily from THCState_getDeviceBlasHandle - // to avoid initializing unused devices - if (numBlasHandles > state->numUserBlasHandles) - { - state->numUserBlasHandles = numBlasHandles; + /* Copy over old blasHandles + (0 is NULL, 1 ... numUserBlasHandles are rest) */ + newBlasHandles[0] = NULL; + for (int hndl = 1; hndl <= state->numUserBlasHandles; ++hndl) { + newBlasHandles[hndl] = THCState_getDeviceBlasHandle(state, dev, hndl); + } + + /* Allocate new handles */ + for (int hndl = state->numUserBlasHandles + 1; hndl <= numBlasHandles; ++hndl) { + newBlasHandles[hndl] = NULL; + THCublasCheck(cublasCreate(newBlasHandles + hndl)); + } + + THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev); + free(res->blasHandles); + res->blasHandles = newBlasHandles; } + + state->numUserBlasHandles = numBlasHandles; + + THCudaCheck(cudaSetDevice(prevDev)); } int THCState_getNumStreams(THCState* state) @@ -414,13 +445,12 @@ cudaStream_t THCState_getDeviceStream(THCState *state, int device, int streamInd cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int handle) { - if (handle <= 0 || handle > state->numUserBlasHandles) { + if (handle <= 0 || handle > state->numUserBlasHandles) + { THError("%d is not a valid handle, valid range is: (1, %d)", handle, state->numUserBlasHandles); } - THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device); - THCState_reserveDeviceBlasHandles(state, device, handle); - return res->blasHandles[handle - 1]; + return THCState_getDeviceResourcePtr(state, device)->blasHandles[handle]; } static THCStream* THCState_getStreamOnDevice(THCState* state, int device) @@ -562,13 +592,16 @@ void* THCState_getCurrentDeviceScratchSpace(THCState* state) return THCState_getDeviceScratchSpace(state, device, stream); } -void* THCState_getDeviceScratchSpace(THCState* state, int dev, int stream) +void* THCState_getDeviceScratchSpace(THCState* state, int device, int stream) { - THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev); - if (stream > state->numUserStreams || stream < 0) { + THCCudaResourcesPerDevice* res = + THCState_getDeviceResourcePtr(state, device); + + if (stream > state->numUserStreams || stream < 0) + { THError("%d is not a stream", stream); } - THCState_initializeScratchSpace(state, dev); + return res->devScratchSpacePerStream[stream]; } diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in index 1a37390..c685d37 100644 --- a/lib/THC/THCGeneral.h.in +++ b/lib/THC/THCGeneral.h.in @@ -55,14 +55,11 @@ typedef struct _THCDeviceAllocator { typedef struct _THCCudaResourcesPerDevice { THCStream** streams; - /* Number of materialized cuBLAS handles */ - int numBlasHandles; - /* cuBLAS handes are lazily initialized */ cublasHandle_t* blasHandles; /* Size of scratch space per each stream on this device available */ size_t scratchSpacePerStream; /* Device-resident scratch space per stream, used for global memory - reduction kernels. Lazily initialized. */ + reduction kernels. */ void** devScratchSpacePerStream; } THCCudaResourcesPerDevice; @@ -118,6 +115,7 @@ THC_API void THCState_free(THCState* state); THC_API void THCudaInit(THCState* state); THC_API void THCudaShutdown(THCState* state); +THC_API void THCudaEnablePeerToPeerAccess(THCState* state); /* If device `dev` can access allocations on device `devToAccess`, this will return */ /* 1; otherwise, 0. */ diff --git a/lib/THC/THCTensorRandom.cpp b/lib/THC/THCTensorRandom.cpp deleted file mode 100644 index d7690b5..0000000 --- a/lib/THC/THCTensorRandom.cpp +++ /dev/null @@ -1,133 +0,0 @@ -#include "THCTensorRandom.h" - -#include <random> -#include <curand.h> - - -void initializeGenerator(THCState *state, Generator* gen); -void createGeneratorState(Generator* gen, unsigned long long seed); - - -/* Frees memory allocated during setup. */ -void destroyGenerator(THCState *state, Generator* gen) -{ - if (gen->gen_states) - { - THCudaCheck(THCudaFree(state, gen->gen_states)); - gen->gen_states = NULL; - } - if (gen->kernel_params) - { - THCudaCheck(THCudaFree(state, gen->kernel_params)); - gen->kernel_params = NULL; - } -} - -static unsigned long long createSeed(std::random_device& rd) -{ - // limit to 53 bits to ensure unique representation in double - unsigned long long seed = (((unsigned long long)rd()) << 32) + rd(); - return seed & 0x1FFFFFFFFFFFFF; -} - -/* Initialize generator array (must be called before any other function) */ -void THCRandom_init(THCState* state, int devices, int current_device) -{ - THCRNGState* rng_state = THCState_getRngState(state); - rng_state->num_devices = devices; - rng_state->gen = (Generator*)malloc(rng_state->num_devices * sizeof(Generator)); - std::random_device rd; - for (int i = 0; i < rng_state->num_devices; ++i) - { - rng_state->gen[i].initf = 0; - rng_state->gen[i].initial_seed = createSeed(rd); - rng_state->gen[i].gen_states = NULL; - rng_state->gen[i].kernel_params = NULL; - } -} - -/* Destroy generators and free memory */ -void THCRandom_shutdown(THCState* state) -{ - THCRNGState* rng_state = THCState_getRngState(state); - if (rng_state->gen == NULL) return; - for (int i = 0; i < rng_state->num_devices; ++i) - { - destroyGenerator(state, &rng_state->gen[i]); - } - free(rng_state->gen); - rng_state->gen = NULL; -} - -/* Get the generator for the current device, but does not initialize the state */ -static Generator* THCRandom_rawGenerator(THCState* state) -{ - THCRNGState* rng_state = THCState_getRngState(state); - int device; - THCudaCheck(cudaGetDevice(&device)); - if (device >= rng_state->num_devices) THError("Invalid device index."); - return &rng_state->gen[device]; -} - -/* Get the generator for the current device and initializes it if necessary */ -Generator* THCRandom_getGenerator(THCState* state) -{ - Generator* gen = THCRandom_rawGenerator(state); - if (gen->initf == 0) - { - initializeGenerator(state, gen); - createGeneratorState(gen, gen->initial_seed); - gen->initf = 1; - } - return gen; -} - -struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state) -{ - return THCRandom_getGenerator(state)->gen_states; -} - -/* Random seed */ -unsigned long long THCRandom_seed(THCState* state) -{ - std::random_device rd; - unsigned long long s = createSeed(rd); - THCRandom_manualSeed(state, s); - return s; -} - -unsigned long long THCRandom_seedAll(THCState* state) -{ - std::random_device rd; - unsigned long long s = createSeed(rd); - THCRandom_manualSeedAll(state, s); - return s; -} - -/* Manually set the seed */ -void THCRandom_manualSeed(THCState* state, unsigned long long seed) -{ - Generator* gen = THCRandom_rawGenerator(state); - gen->initial_seed = seed; - if (gen->initf) { - createGeneratorState(gen, seed); - } -} - -void THCRandom_manualSeedAll(THCState* state, unsigned long long seed) -{ - THCRNGState* rng_state = THCState_getRngState(state); - int currentDevice; - THCudaCheck(cudaGetDevice(¤tDevice)); - for (int i = 0; i < rng_state->num_devices; ++i) { - THCudaCheck(cudaSetDevice(i)); - THCRandom_manualSeed(state, seed); - } - THCudaCheck(cudaSetDevice(currentDevice)); -} - -/* Get the initial seed */ -unsigned long long THCRandom_initialSeed(THCState* state) -{ - return THCRandom_getGenerator(state)->initial_seed; -} diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu index e05cf82..08efc0a 100644 --- a/lib/THC/THCTensorRandom.cu +++ b/lib/THC/THCTensorRandom.cu @@ -15,9 +15,6 @@ #define MAX_NUM_BLOCKS 64 #define BLOCK_SIZE 256 - -Generator* THCRandom_getGenerator(THCState* state); - /* Sets up generator. Allocates but does not create the generator states. */ __host__ void initializeGenerator(THCState *state, Generator* gen) { @@ -25,8 +22,23 @@ __host__ void initializeGenerator(THCState *state, Generator* gen) THCudaCheck(THCudaMalloc(state, (void**)&gen->kernel_params, sizeof(mtgp32_kernel_params))); } +/* Frees memory allocated during setup. */ +__host__ void destroyGenerator(THCState *state, Generator* gen) +{ + if (gen->gen_states) + { + THCudaCheck(THCudaFree(state, gen->gen_states)); + gen->gen_states = NULL; + } + if (gen->kernel_params) + { + THCudaCheck(THCudaFree(state, gen->kernel_params)); + gen->kernel_params = NULL; + } +} + /* Creates a new generator state given the seed. */ -__host__ void createGeneratorState(Generator* gen, unsigned long long seed) +__host__ void createGeneratorState(Generator* gen, unsigned long seed) { if (curandMakeMTGP32Constants(mtgp32dc_params_fast_11213, gen->kernel_params) != CURAND_STATUS_SUCCESS) { @@ -39,13 +51,112 @@ __host__ void createGeneratorState(Generator* gen, unsigned long long seed) } } +/* Initialize generator array (must be called before any other function) */ +__host__ void THCRandom_init(THCState* state, int devices, int current_device) +{ + THCRNGState* rng_state = THCState_getRngState(state); + rng_state->num_devices = devices; + rng_state->gen = (Generator*)malloc(rng_state->num_devices * sizeof(Generator)); + for (int i = 0; i < rng_state->num_devices; ++i) + { + rng_state->gen[i].initf = 0; + rng_state->gen[i].initial_seed = 0; + rng_state->gen[i].gen_states = NULL; + rng_state->gen[i].kernel_params = NULL; + } +} + +/* Destroy generators and free memory */ +__host__ void THCRandom_shutdown(THCState* state) +{ + THCRNGState* rng_state = THCState_getRngState(state); + if (rng_state->gen == NULL) return; + for (int i = 0; i < rng_state->num_devices; ++i) + { + destroyGenerator(state, &rng_state->gen[i]); + } + free(rng_state->gen); + rng_state->gen = NULL; +} + +/* Manually set the generator seed */ +__host__ static void THCRandom_manualSeedGen(Generator* gen, unsigned long seed) +{ + gen->initial_seed = seed; + createGeneratorState(gen, seed); + gen->initf = 1; +} + +/* Get the generator for the current device */ +__host__ Generator* THCRandom_getGenerator(THCState* state) +{ + THCRNGState* rng_state = THCState_getRngState(state); + + int device; + THCudaCheck(cudaGetDevice(&device)); + if (device >= rng_state->num_devices) THError("Invalid device index."); + + Generator* gen = &rng_state->gen[device]; + if (gen->initf == 0) + { + initializeGenerator(state, gen); + THCRandom_manualSeedGen(gen, (unsigned long)time(0)); + } + return gen; +} + +__host__ struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state) +{ + return THCRandom_getGenerator(state)->gen_states; +} + +/* Random seed */ +__host__ unsigned long THCRandom_seed(THCState* state) +{ + unsigned long s = (unsigned long)time(0); + THCRandom_manualSeed(state, s); + return s; +} + +__host__ unsigned long THCRandom_seedAll(THCState* state) +{ + unsigned long s = (unsigned long)time(0); + THCRandom_manualSeedAll(state, s); + return s; +} + +/* Manually set the seed */ +__host__ void THCRandom_manualSeed(THCState* state, unsigned long seed) +{ + Generator* gen = THCRandom_getGenerator(state); + THCRandom_manualSeedGen(gen, seed); +} + +__host__ void THCRandom_manualSeedAll(THCState* state, unsigned long seed) +{ + THCRNGState* rng_state = THCState_getRngState(state); + int currentDevice; + THCudaCheck(cudaGetDevice(¤tDevice)); + for (int i = 0; i < rng_state->num_devices; ++i) { + THCudaCheck(cudaSetDevice(i)); + THCRandom_manualSeed(state, seed); + } + THCudaCheck(cudaSetDevice(currentDevice)); +} + +/* Get the initial seed */ +__host__ unsigned long THCRandom_initialSeed(THCState* state) +{ + return THCRandom_getGenerator(state)->initial_seed; +} + __host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_state) { Generator* gen = THCRandom_getGenerator(state); // The RNG state comprises the MTPG32 states and the seed. static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); - static const size_t seed_size = sizeof(gen->initial_seed); + static const size_t seed_size = sizeof(unsigned long); static const size_t total_size = states_size + seed_size; THByteTensor_resize1d(rng_state, total_size); THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size"); @@ -65,7 +176,7 @@ __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state) Generator* gen = THCRandom_getGenerator(state); static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); - static const size_t seed_size = sizeof(gen->initial_seed); + static const size_t seed_size = sizeof(unsigned long); static const size_t total_size = states_size + seed_size; THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size"); THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous"); @@ -129,3 +240,4 @@ GENERATE_KERNEL2(generate_cauchy, half, double median, double sigma, float, cura #undef GENERATE_KERNEL1 #undef GENERATE_KERNEL2 + diff --git a/lib/THC/THCTensorRandom.h b/lib/THC/THCTensorRandom.h index 197a53c..12128cd 100644 --- a/lib/THC/THCTensorRandom.h +++ b/lib/THC/THCTensorRandom.h @@ -11,7 +11,7 @@ typedef struct _Generator { struct curandStateMtgp32* gen_states; struct mtgp32_kernel_params *kernel_params; int initf; - unsigned long long initial_seed; + unsigned long initial_seed; } Generator; typedef struct THCRNGState { @@ -24,11 +24,11 @@ struct THCState; THC_API void THCRandom_init(struct THCState *state, int num_devices, int current_device); THC_API void THCRandom_shutdown(struct THCState *state); -THC_API unsigned long long THCRandom_seed(struct THCState *state); -THC_API unsigned long long THCRandom_seedAll(struct THCState *state); -THC_API void THCRandom_manualSeed(struct THCState *state, unsigned long long the_seed_); -THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long long the_seed_); -THC_API unsigned long long THCRandom_initialSeed(struct THCState *state); +THC_API unsigned long THCRandom_seed(struct THCState *state); +THC_API unsigned long THCRandom_seedAll(struct THCState *state); +THC_API void THCRandom_manualSeed(struct THCState *state, unsigned long the_seed_); +THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long the_seed_); +THC_API unsigned long THCRandom_initialSeed(struct THCState *state); THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state); THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state); |