diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-11-24 01:48:00 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-24 01:48:00 +0300 |
commit | f46ca3974ffbe65adee97111e7415bdcf9de3f4e (patch) | |
tree | 9a16520b568ae50ff9d95c0a04738983ea606e9d | |
parent | f5932241e86087821a4c61dbde2c39a03d7c9883 (diff) | |
parent | 39a13d08c252121ffeebf03717c2266133b392ea (diff) |
Merge pull request #610 from colesbury/lazy
Lazily initialize CUDA devices
-rw-r--r-- | init.c | 10 | ||||
-rw-r--r-- | lib/THC/CMakeLists.txt | 5 | ||||
-rw-r--r-- | lib/THC/THCGeneral.c | 211 | ||||
-rw-r--r-- | lib/THC/THCGeneral.h.in | 6 | ||||
-rw-r--r-- | lib/THC/THCTensorRandom.cpp | 133 | ||||
-rw-r--r-- | lib/THC/THCTensorRandom.cu | 124 | ||||
-rw-r--r-- | lib/THC/THCTensorRandom.h | 12 |
7 files changed, 246 insertions, 255 deletions
@@ -776,35 +776,35 @@ static int cutorch_getDeviceProperties(lua_State *L) static int cutorch_seed(lua_State *L) { - unsigned long seed = THCRandom_seed(cutorch_getstate(L)); + unsigned long long seed = THCRandom_seed(cutorch_getstate(L)); lua_pushnumber(L, seed); return 1; } static int cutorch_seedAll(lua_State *L) { - unsigned long seed = THCRandom_seedAll(cutorch_getstate(L)); + unsigned long long seed = THCRandom_seedAll(cutorch_getstate(L)); lua_pushnumber(L, seed); return 1; } static int cutorch_initialSeed(lua_State *L) { - unsigned long seed = THCRandom_initialSeed(cutorch_getstate(L)); + unsigned long long seed = THCRandom_initialSeed(cutorch_getstate(L)); lua_pushnumber(L, seed); return 1; } static int cutorch_manualSeed(lua_State *L) { - unsigned long seed = luaL_checknumber(L, 1); + unsigned long long seed = luaL_checknumber(L, 1); THCRandom_manualSeed(cutorch_getstate(L), seed); return 0; } static int cutorch_manualSeedAll(lua_State* L) { - unsigned long seed = luaL_checknumber(L, 1); + unsigned long long seed = luaL_checknumber(L, 1); THCRandom_manualSeedAll(cutorch_getstate(L), seed); return 0; } diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt index c8916d3..08e6fbc 100644 --- a/lib/THC/CMakeLists.txt +++ b/lib/THC/CMakeLists.txt @@ -25,10 +25,10 @@ endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.7" OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "4.7" ) # add c++11 flag - set_source_files_properties(THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++11) + set_source_files_properties(THCTensorRandom.cpp THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++11) else() # add c++0x flag - set_source_files_properties(THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++0x) + set_source_files_properties(THCTensorRandom.cpp THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++0x) endif() else() SET(CMAKE_CXX_STANDARD 11) @@ -130,6 +130,7 @@ SET(src THCStream.c THCTensor.c THCTensorCopy.c + THCTensorRandom.cpp THCThreadLocal.c ) diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c index 547e060..2d4043d 100644 --- a/lib/THC/THCGeneral.c +++ b/lib/THC/THCGeneral.c @@ -81,8 +81,21 @@ void THCudaInit(THCState* state) state->cudaUVAAllocator = (THAllocator*)malloc(sizeof(THAllocator)); THCUVAAllocator_init(state->cudaUVAAllocator); - /* Enable P2P access between all pairs, if possible */ - THCudaEnablePeerToPeerAccess(state); + // By default, all direct p2p kernel access (besides copy) is disallowed, + // since direct access without knowing whether or not a certain operation + // should be cross-GPU leads to synchronization errors. The user can choose + // to disable this functionality, however. + state->p2pKernelAccessEnabled = 0; + + // p2pAccessEnabled records if p2p copies are allowed between pairs of + // devices. Values include "1" (copy allowed), "0" (copy not allowed), and + // "-1" (unknown). + state->p2pAccessEnabled = (int**) malloc(sizeof(int*) * numDevices); + for (int i = 0; i < numDevices; ++i) { + state->p2pAccessEnabled[i] = (int*) malloc(sizeof(int) * numDevices); + memset(state->p2pAccessEnabled[i], -1, sizeof(int) * numDevices); + state->p2pAccessEnabled[i][i] = 1; + } for (int i = 0; i < numDevices; ++i) { THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, i); @@ -98,22 +111,15 @@ void THCudaInit(THCState* state) int numSM = state->deviceProperties[i].multiProcessorCount; size_t sizePerStream = numSM * GLOBAL_SCRATCH_SPACE_PER_SM_STREAM; res->scratchSpacePerStream = sizePerStream; - - /* Allocate scratch space for each stream */ - res->devScratchSpacePerStream = (void**) malloc(sizeof(void*)); - THCudaCheck(THCudaMalloc(state, &res->devScratchSpacePerStream[0], - sizePerStream)); } /* Restore to previous device */ THCudaCheck(cudaSetDevice(device)); - /* There is no such thing as a default cublas handle. - To maintain consistency with streams API, handle 0 is always NULL and we - start counting at 1. If currentPerDeviceBlasHandle is 0 (the default - thread-local value), then we assume it means 1. - */ - THCState_reserveBlasHandles(state, 1); + // Unlike CUDA streams, there is no NULL cuBLAS handle. The default THC + // cuBLAS handle is the first user BLAS handle. Note that the actual BLAS + // handles are created lazily. + state->numUserBlasHandles = 1; state->heapSoftmax = 3e8; // 300MB, adjusted upward dynamically state->heapDelta = 0; @@ -147,10 +153,9 @@ void THCudaShutdown(THCState* state) for (int i = 1; i <= state->numUserStreams; ++i) { THCStream_free(res->streams[i]); } - /* Free Torch-defined handles (0 is NULL for consistency with streams API) */ - for (int handle = 1; handle <= state->numUserBlasHandles; ++handle) { - THCublasCheck(cublasDestroy( - THCState_getDeviceBlasHandle(state, dev, handle))); + /* Free user defined BLAS handles */ + for (int i = 0; i < res->numBlasHandles; ++i) { + THCublasCheck(cublasDestroy(res->blasHandles[i])); } /* Free per-stream scratch space; starts at 0 because there is space for the default stream as well*/ @@ -174,79 +179,36 @@ void THCudaShutdown(THCState* state) THCudaCheck(cudaSetDevice(prevDev)); } -void THCudaEnablePeerToPeerAccess(THCState* state) -{ - /* By default, all direct p2p kernel access (besides copy) is disallowed, */ - /* since direct access without knowing whether or not a certain operation */ - /* should be cross-GPU leads to synchronization errors. The user can choose */ - /* to disable this functionality, however. */ - state->p2pKernelAccessEnabled = 0; - - int prevDev = -1; - THCudaCheck(cudaGetDevice(&prevDev)); - - int numDevices = -1; - THCudaCheck(cudaGetDeviceCount(&numDevices)); - - state->p2pAccessEnabled = (int**) malloc(sizeof(int*) * numDevices); - for (int i = 0; i < numDevices; ++i) { - state->p2pAccessEnabled[i] = (int*) malloc(sizeof(int) * numDevices); - } - - /* Build a table of all allowed p2p accesses, to avoid checking the p2p - status at runtime. */ - for (int i = 0; i < numDevices; ++i) { - THCudaCheck(cudaSetDevice(i)); - - for (int j = 0; j < numDevices; ++j) { - /* Presume no access by default */ - state->p2pAccessEnabled[i][j] = 0; - - if (i == j) { - /* A GPU can access itself */ - state->p2pAccessEnabled[i][j] = 1; - } else { - int access = 0; - THCudaCheck(cudaDeviceCanAccessPeer(&access, i, j)); - - if (access) { - cudaError_t err = cudaDeviceEnablePeerAccess(j, 0); - if (err == cudaErrorPeerAccessAlreadyEnabled) { - /* It is possible that another thread has already enabled access. */ - /* Any future call to cudaGetLastError will now return an error, */ - /* even though we've already dealt with this specific error here. */ - /* Call cudaGetLastError once to reset the last error state. */ - cudaGetLastError(); - - /* The above should have cleared status */ - THCudaCheck(cudaGetLastError()); - } else { - /* In case there are other unhandled errors returned from the */ - /* above */ - THCudaCheck(err); - } - - /* Access could be enabled, or was already enabled */ - state->p2pAccessEnabled[i][j] = 1; - } - } - } - } - - /* Restore previous device before continuing */ - THCudaCheck(cudaSetDevice(prevDev)); -} - int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess) { if (dev < 0 || dev >= state->numDevices) { THError("%d is not a device", dev); } - - if (devToAccess < 0 || dev >= state->numDevices) { + if (devToAccess < 0 || devToAccess >= state->numDevices) { THError("%d is not a device", devToAccess); } + if (state->p2pAccessEnabled[dev][devToAccess] == -1) { + int prevDev = 0; + THCudaCheck(cudaGetDevice(&prevDev)); + THCudaCheck(cudaSetDevice(dev)); + + int access = 0; + THCudaCheck(cudaDeviceCanAccessPeer(&access, dev, devToAccess)); + if (access) { + cudaError_t err = cudaDeviceEnablePeerAccess(devToAccess, 0); + if (err == cudaErrorPeerAccessAlreadyEnabled) { + // ignore and clear the error if access was already enabled + cudaGetLastError(); + } else { + THCudaCheck(err); + } + state->p2pAccessEnabled[dev][devToAccess] = 1; + } else { + state->p2pAccessEnabled[dev][devToAccess] = 0; + } + THCudaCheck(cudaSetDevice(prevDev)); + } return state->p2pAccessEnabled[dev][devToAccess]; } @@ -327,6 +289,20 @@ int THCState_getNumDevices(THCState *state) return state->numDevices; } +static void THCState_initializeScratchSpace(THCState* state, int dev) +{ + THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev); + if (res->devScratchSpacePerStream) { + return; + } + size_t size = (state->numUserStreams + 1) * sizeof(void*); + void** scratch = (void**)malloc(size); + for (int i = 0; i <= state->numUserStreams; ++i) { + THCudaCheck(THCudaMalloc(state, &scratch[i], res->scratchSpacePerStream)); + } + res->devScratchSpacePerStream = scratch; +} + void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking) { if (numStreams <= state->numUserStreams) @@ -346,6 +322,7 @@ void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking) THCStream** newStreams = realloc(res->streams, (numStreams + 1) * sizeof(THCStream*)); THAssert(newStreams); + THCState_initializeScratchSpace(state, dev); void** newScratchSpace = realloc(res->devScratchSpacePerStream, (numStreams + 1) * sizeof(void*)); THAssert(newScratchSpace); @@ -369,47 +346,39 @@ void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking) THCudaCheck(cudaSetDevice(prevDev)); } -void THCState_reserveBlasHandles(THCState* state, int numBlasHandles) +void THCState_reserveDeviceBlasHandles(THCState* state, int device, int numBlasHandles) { - if (numBlasHandles <= state->numUserBlasHandles) - { + int prevDev = -1; + THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device); + if (numBlasHandles <= res->numBlasHandles) { return; } - int prevDev = -1; THCudaCheck(cudaGetDevice(&prevDev)); + THCudaCheck(cudaSetDevice(device)); - /* Otherwise, we have to allocate a new set of blasHandles */ - for (int dev = 0; dev < state->numDevices; ++dev) { - THCudaCheck(cudaSetDevice(dev)); - - /* +1 to be consistent with stream API, blas handle 0 is NULL and unused */ - cublasHandle_t* newBlasHandles = - (cublasHandle_t*) malloc((numBlasHandles + 1) * sizeof(cublasHandle_t)); - - /* Copy over old blasHandles - (0 is NULL, 1 ... numUserBlasHandles are rest) */ - newBlasHandles[0] = NULL; - for (int hndl = 1; hndl <= state->numUserBlasHandles; ++hndl) { - newBlasHandles[hndl] = THCState_getDeviceBlasHandle(state, dev, hndl); - } - - /* Allocate new handles */ - for (int hndl = state->numUserBlasHandles + 1; hndl <= numBlasHandles; ++hndl) { - newBlasHandles[hndl] = NULL; - THCublasCheck(cublasCreate(newBlasHandles + hndl)); - } - - THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev); - free(res->blasHandles); - res->blasHandles = newBlasHandles; + size_t size = numBlasHandles * sizeof(cublasHandle_t); + cublasHandle_t* handles = (cublasHandle_t*) realloc(res->blasHandles, size); + for (int i = res->numBlasHandles; i < numBlasHandles; ++i) { + handles[i] = NULL; + THCublasCheck(cublasCreate(&handles[i])); } - - state->numUserBlasHandles = numBlasHandles; + res->blasHandles = handles; + res->numBlasHandles = numBlasHandles; THCudaCheck(cudaSetDevice(prevDev)); } +void THCState_reserveBlasHandles(THCState* state, int numBlasHandles) +{ + // cuBLAS handles are created lazily from THCState_getDeviceBlasHandle + // to avoid initializing unused devices + if (numBlasHandles > state->numUserBlasHandles) + { + state->numUserBlasHandles = numBlasHandles; + } +} + int THCState_getNumStreams(THCState* state) { return state->numUserStreams; @@ -445,12 +414,13 @@ cudaStream_t THCState_getDeviceStream(THCState *state, int device, int streamInd cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int handle) { - if (handle <= 0 || handle > state->numUserBlasHandles) - { + if (handle <= 0 || handle > state->numUserBlasHandles) { THError("%d is not a valid handle, valid range is: (1, %d)", handle, state->numUserBlasHandles); } - return THCState_getDeviceResourcePtr(state, device)->blasHandles[handle]; + THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device); + THCState_reserveDeviceBlasHandles(state, device, handle); + return res->blasHandles[handle - 1]; } static THCStream* THCState_getStreamOnDevice(THCState* state, int device) @@ -592,16 +562,13 @@ void* THCState_getCurrentDeviceScratchSpace(THCState* state) return THCState_getDeviceScratchSpace(state, device, stream); } -void* THCState_getDeviceScratchSpace(THCState* state, int device, int stream) +void* THCState_getDeviceScratchSpace(THCState* state, int dev, int stream) { - THCCudaResourcesPerDevice* res = - THCState_getDeviceResourcePtr(state, device); - - if (stream > state->numUserStreams || stream < 0) - { + THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev); + if (stream > state->numUserStreams || stream < 0) { THError("%d is not a stream", stream); } - + THCState_initializeScratchSpace(state, dev); return res->devScratchSpacePerStream[stream]; } diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in index c685d37..1a37390 100644 --- a/lib/THC/THCGeneral.h.in +++ b/lib/THC/THCGeneral.h.in @@ -55,11 +55,14 @@ typedef struct _THCDeviceAllocator { typedef struct _THCCudaResourcesPerDevice { THCStream** streams; + /* Number of materialized cuBLAS handles */ + int numBlasHandles; + /* cuBLAS handes are lazily initialized */ cublasHandle_t* blasHandles; /* Size of scratch space per each stream on this device available */ size_t scratchSpacePerStream; /* Device-resident scratch space per stream, used for global memory - reduction kernels. */ + reduction kernels. Lazily initialized. */ void** devScratchSpacePerStream; } THCCudaResourcesPerDevice; @@ -115,7 +118,6 @@ THC_API void THCState_free(THCState* state); THC_API void THCudaInit(THCState* state); THC_API void THCudaShutdown(THCState* state); -THC_API void THCudaEnablePeerToPeerAccess(THCState* state); /* If device `dev` can access allocations on device `devToAccess`, this will return */ /* 1; otherwise, 0. */ diff --git a/lib/THC/THCTensorRandom.cpp b/lib/THC/THCTensorRandom.cpp new file mode 100644 index 0000000..d7690b5 --- /dev/null +++ b/lib/THC/THCTensorRandom.cpp @@ -0,0 +1,133 @@ +#include "THCTensorRandom.h" + +#include <random> +#include <curand.h> + + +void initializeGenerator(THCState *state, Generator* gen); +void createGeneratorState(Generator* gen, unsigned long long seed); + + +/* Frees memory allocated during setup. */ +void destroyGenerator(THCState *state, Generator* gen) +{ + if (gen->gen_states) + { + THCudaCheck(THCudaFree(state, gen->gen_states)); + gen->gen_states = NULL; + } + if (gen->kernel_params) + { + THCudaCheck(THCudaFree(state, gen->kernel_params)); + gen->kernel_params = NULL; + } +} + +static unsigned long long createSeed(std::random_device& rd) +{ + // limit to 53 bits to ensure unique representation in double + unsigned long long seed = (((unsigned long long)rd()) << 32) + rd(); + return seed & 0x1FFFFFFFFFFFFF; +} + +/* Initialize generator array (must be called before any other function) */ +void THCRandom_init(THCState* state, int devices, int current_device) +{ + THCRNGState* rng_state = THCState_getRngState(state); + rng_state->num_devices = devices; + rng_state->gen = (Generator*)malloc(rng_state->num_devices * sizeof(Generator)); + std::random_device rd; + for (int i = 0; i < rng_state->num_devices; ++i) + { + rng_state->gen[i].initf = 0; + rng_state->gen[i].initial_seed = createSeed(rd); + rng_state->gen[i].gen_states = NULL; + rng_state->gen[i].kernel_params = NULL; + } +} + +/* Destroy generators and free memory */ +void THCRandom_shutdown(THCState* state) +{ + THCRNGState* rng_state = THCState_getRngState(state); + if (rng_state->gen == NULL) return; + for (int i = 0; i < rng_state->num_devices; ++i) + { + destroyGenerator(state, &rng_state->gen[i]); + } + free(rng_state->gen); + rng_state->gen = NULL; +} + +/* Get the generator for the current device, but does not initialize the state */ +static Generator* THCRandom_rawGenerator(THCState* state) +{ + THCRNGState* rng_state = THCState_getRngState(state); + int device; + THCudaCheck(cudaGetDevice(&device)); + if (device >= rng_state->num_devices) THError("Invalid device index."); + return &rng_state->gen[device]; +} + +/* Get the generator for the current device and initializes it if necessary */ +Generator* THCRandom_getGenerator(THCState* state) +{ + Generator* gen = THCRandom_rawGenerator(state); + if (gen->initf == 0) + { + initializeGenerator(state, gen); + createGeneratorState(gen, gen->initial_seed); + gen->initf = 1; + } + return gen; +} + +struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state) +{ + return THCRandom_getGenerator(state)->gen_states; +} + +/* Random seed */ +unsigned long long THCRandom_seed(THCState* state) +{ + std::random_device rd; + unsigned long long s = createSeed(rd); + THCRandom_manualSeed(state, s); + return s; +} + +unsigned long long THCRandom_seedAll(THCState* state) +{ + std::random_device rd; + unsigned long long s = createSeed(rd); + THCRandom_manualSeedAll(state, s); + return s; +} + +/* Manually set the seed */ +void THCRandom_manualSeed(THCState* state, unsigned long long seed) +{ + Generator* gen = THCRandom_rawGenerator(state); + gen->initial_seed = seed; + if (gen->initf) { + createGeneratorState(gen, seed); + } +} + +void THCRandom_manualSeedAll(THCState* state, unsigned long long seed) +{ + THCRNGState* rng_state = THCState_getRngState(state); + int currentDevice; + THCudaCheck(cudaGetDevice(¤tDevice)); + for (int i = 0; i < rng_state->num_devices; ++i) { + THCudaCheck(cudaSetDevice(i)); + THCRandom_manualSeed(state, seed); + } + THCudaCheck(cudaSetDevice(currentDevice)); +} + +/* Get the initial seed */ +unsigned long long THCRandom_initialSeed(THCState* state) +{ + return THCRandom_getGenerator(state)->initial_seed; +} diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu index 08efc0a..e05cf82 100644 --- a/lib/THC/THCTensorRandom.cu +++ b/lib/THC/THCTensorRandom.cu @@ -15,6 +15,9 @@ #define MAX_NUM_BLOCKS 64 #define BLOCK_SIZE 256 + +Generator* THCRandom_getGenerator(THCState* state); + /* Sets up generator. Allocates but does not create the generator states. */ __host__ void initializeGenerator(THCState *state, Generator* gen) { @@ -22,23 +25,8 @@ __host__ void initializeGenerator(THCState *state, Generator* gen) THCudaCheck(THCudaMalloc(state, (void**)&gen->kernel_params, sizeof(mtgp32_kernel_params))); } -/* Frees memory allocated during setup. */ -__host__ void destroyGenerator(THCState *state, Generator* gen) -{ - if (gen->gen_states) - { - THCudaCheck(THCudaFree(state, gen->gen_states)); - gen->gen_states = NULL; - } - if (gen->kernel_params) - { - THCudaCheck(THCudaFree(state, gen->kernel_params)); - gen->kernel_params = NULL; - } -} - /* Creates a new generator state given the seed. */ -__host__ void createGeneratorState(Generator* gen, unsigned long seed) +__host__ void createGeneratorState(Generator* gen, unsigned long long seed) { if (curandMakeMTGP32Constants(mtgp32dc_params_fast_11213, gen->kernel_params) != CURAND_STATUS_SUCCESS) { @@ -51,112 +39,13 @@ __host__ void createGeneratorState(Generator* gen, unsigned long seed) } } -/* Initialize generator array (must be called before any other function) */ -__host__ void THCRandom_init(THCState* state, int devices, int current_device) -{ - THCRNGState* rng_state = THCState_getRngState(state); - rng_state->num_devices = devices; - rng_state->gen = (Generator*)malloc(rng_state->num_devices * sizeof(Generator)); - for (int i = 0; i < rng_state->num_devices; ++i) - { - rng_state->gen[i].initf = 0; - rng_state->gen[i].initial_seed = 0; - rng_state->gen[i].gen_states = NULL; - rng_state->gen[i].kernel_params = NULL; - } -} - -/* Destroy generators and free memory */ -__host__ void THCRandom_shutdown(THCState* state) -{ - THCRNGState* rng_state = THCState_getRngState(state); - if (rng_state->gen == NULL) return; - for (int i = 0; i < rng_state->num_devices; ++i) - { - destroyGenerator(state, &rng_state->gen[i]); - } - free(rng_state->gen); - rng_state->gen = NULL; -} - -/* Manually set the generator seed */ -__host__ static void THCRandom_manualSeedGen(Generator* gen, unsigned long seed) -{ - gen->initial_seed = seed; - createGeneratorState(gen, seed); - gen->initf = 1; -} - -/* Get the generator for the current device */ -__host__ Generator* THCRandom_getGenerator(THCState* state) -{ - THCRNGState* rng_state = THCState_getRngState(state); - - int device; - THCudaCheck(cudaGetDevice(&device)); - if (device >= rng_state->num_devices) THError("Invalid device index."); - - Generator* gen = &rng_state->gen[device]; - if (gen->initf == 0) - { - initializeGenerator(state, gen); - THCRandom_manualSeedGen(gen, (unsigned long)time(0)); - } - return gen; -} - -__host__ struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state) -{ - return THCRandom_getGenerator(state)->gen_states; -} - -/* Random seed */ -__host__ unsigned long THCRandom_seed(THCState* state) -{ - unsigned long s = (unsigned long)time(0); - THCRandom_manualSeed(state, s); - return s; -} - -__host__ unsigned long THCRandom_seedAll(THCState* state) -{ - unsigned long s = (unsigned long)time(0); - THCRandom_manualSeedAll(state, s); - return s; -} - -/* Manually set the seed */ -__host__ void THCRandom_manualSeed(THCState* state, unsigned long seed) -{ - Generator* gen = THCRandom_getGenerator(state); - THCRandom_manualSeedGen(gen, seed); -} - -__host__ void THCRandom_manualSeedAll(THCState* state, unsigned long seed) -{ - THCRNGState* rng_state = THCState_getRngState(state); - int currentDevice; - THCudaCheck(cudaGetDevice(¤tDevice)); - for (int i = 0; i < rng_state->num_devices; ++i) { - THCudaCheck(cudaSetDevice(i)); - THCRandom_manualSeed(state, seed); - } - THCudaCheck(cudaSetDevice(currentDevice)); -} - -/* Get the initial seed */ -__host__ unsigned long THCRandom_initialSeed(THCState* state) -{ - return THCRandom_getGenerator(state)->initial_seed; -} - __host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_state) { Generator* gen = THCRandom_getGenerator(state); // The RNG state comprises the MTPG32 states and the seed. static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); - static const size_t seed_size = sizeof(unsigned long); + static const size_t seed_size = sizeof(gen->initial_seed); static const size_t total_size = states_size + seed_size; THByteTensor_resize1d(rng_state, total_size); THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size"); @@ -176,7 +65,7 @@ __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state) Generator* gen = THCRandom_getGenerator(state); static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); - static const size_t seed_size = sizeof(unsigned long); + static const size_t seed_size = sizeof(gen->initial_seed); static const size_t total_size = states_size + seed_size; THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size"); THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous"); @@ -240,4 +129,3 @@ GENERATE_KERNEL2(generate_cauchy, half, double median, double sigma, float, cura #undef GENERATE_KERNEL1 #undef GENERATE_KERNEL2 - diff --git a/lib/THC/THCTensorRandom.h b/lib/THC/THCTensorRandom.h index 12128cd..197a53c 100644 --- a/lib/THC/THCTensorRandom.h +++ b/lib/THC/THCTensorRandom.h @@ -11,7 +11,7 @@ typedef struct _Generator { struct curandStateMtgp32* gen_states; struct mtgp32_kernel_params *kernel_params; int initf; - unsigned long initial_seed; + unsigned long long initial_seed; } Generator; typedef struct THCRNGState { @@ -24,11 +24,11 @@ struct THCState; THC_API void THCRandom_init(struct THCState *state, int num_devices, int current_device); THC_API void THCRandom_shutdown(struct THCState *state); -THC_API unsigned long THCRandom_seed(struct THCState *state); -THC_API unsigned long THCRandom_seedAll(struct THCState *state); -THC_API void THCRandom_manualSeed(struct THCState *state, unsigned long the_seed_); -THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long the_seed_); -THC_API unsigned long THCRandom_initialSeed(struct THCState *state); +THC_API unsigned long long THCRandom_seed(struct THCState *state); +THC_API unsigned long long THCRandom_seedAll(struct THCState *state); +THC_API void THCRandom_manualSeed(struct THCState *state, unsigned long long the_seed_); +THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long long the_seed_); +THC_API unsigned long long THCRandom_initialSeed(struct THCState *state); THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state); THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state); |