Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-11-24 01:48:00 +0300
committerGitHub <noreply@github.com>2016-11-24 01:48:00 +0300
commitf46ca3974ffbe65adee97111e7415bdcf9de3f4e (patch)
tree9a16520b568ae50ff9d95c0a04738983ea606e9d
parentf5932241e86087821a4c61dbde2c39a03d7c9883 (diff)
parent39a13d08c252121ffeebf03717c2266133b392ea (diff)
Merge pull request #610 from colesbury/lazy
Lazily initialize CUDA devices
-rw-r--r--init.c10
-rw-r--r--lib/THC/CMakeLists.txt5
-rw-r--r--lib/THC/THCGeneral.c211
-rw-r--r--lib/THC/THCGeneral.h.in6
-rw-r--r--lib/THC/THCTensorRandom.cpp133
-rw-r--r--lib/THC/THCTensorRandom.cu124
-rw-r--r--lib/THC/THCTensorRandom.h12
7 files changed, 246 insertions, 255 deletions
diff --git a/init.c b/init.c
index 124be5c..02960cd 100644
--- a/init.c
+++ b/init.c
@@ -776,35 +776,35 @@ static int cutorch_getDeviceProperties(lua_State *L)
static int cutorch_seed(lua_State *L)
{
- unsigned long seed = THCRandom_seed(cutorch_getstate(L));
+ unsigned long long seed = THCRandom_seed(cutorch_getstate(L));
lua_pushnumber(L, seed);
return 1;
}
static int cutorch_seedAll(lua_State *L)
{
- unsigned long seed = THCRandom_seedAll(cutorch_getstate(L));
+ unsigned long long seed = THCRandom_seedAll(cutorch_getstate(L));
lua_pushnumber(L, seed);
return 1;
}
static int cutorch_initialSeed(lua_State *L)
{
- unsigned long seed = THCRandom_initialSeed(cutorch_getstate(L));
+ unsigned long long seed = THCRandom_initialSeed(cutorch_getstate(L));
lua_pushnumber(L, seed);
return 1;
}
static int cutorch_manualSeed(lua_State *L)
{
- unsigned long seed = luaL_checknumber(L, 1);
+ unsigned long long seed = luaL_checknumber(L, 1);
THCRandom_manualSeed(cutorch_getstate(L), seed);
return 0;
}
static int cutorch_manualSeedAll(lua_State* L)
{
- unsigned long seed = luaL_checknumber(L, 1);
+ unsigned long long seed = luaL_checknumber(L, 1);
THCRandom_manualSeedAll(cutorch_getstate(L), seed);
return 0;
}
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt
index c8916d3..08e6fbc 100644
--- a/lib/THC/CMakeLists.txt
+++ b/lib/THC/CMakeLists.txt
@@ -25,10 +25,10 @@ endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.7" OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "4.7" )
# add c++11 flag
- set_source_files_properties(THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++11)
+ set_source_files_properties(THCTensorRandom.cpp THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++11)
else()
# add c++0x flag
- set_source_files_properties(THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++0x)
+ set_source_files_properties(THCTensorRandom.cpp THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++0x)
endif()
else()
SET(CMAKE_CXX_STANDARD 11)
@@ -130,6 +130,7 @@ SET(src
THCStream.c
THCTensor.c
THCTensorCopy.c
+ THCTensorRandom.cpp
THCThreadLocal.c
)
diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c
index 547e060..2d4043d 100644
--- a/lib/THC/THCGeneral.c
+++ b/lib/THC/THCGeneral.c
@@ -81,8 +81,21 @@ void THCudaInit(THCState* state)
state->cudaUVAAllocator = (THAllocator*)malloc(sizeof(THAllocator));
THCUVAAllocator_init(state->cudaUVAAllocator);
- /* Enable P2P access between all pairs, if possible */
- THCudaEnablePeerToPeerAccess(state);
+ // By default, all direct p2p kernel access (besides copy) is disallowed,
+ // since direct access without knowing whether or not a certain operation
+ // should be cross-GPU leads to synchronization errors. The user can choose
+ // to disable this functionality, however.
+ state->p2pKernelAccessEnabled = 0;
+
+ // p2pAccessEnabled records if p2p copies are allowed between pairs of
+ // devices. Values include "1" (copy allowed), "0" (copy not allowed), and
+ // "-1" (unknown).
+ state->p2pAccessEnabled = (int**) malloc(sizeof(int*) * numDevices);
+ for (int i = 0; i < numDevices; ++i) {
+ state->p2pAccessEnabled[i] = (int*) malloc(sizeof(int) * numDevices);
+ memset(state->p2pAccessEnabled[i], -1, sizeof(int) * numDevices);
+ state->p2pAccessEnabled[i][i] = 1;
+ }
for (int i = 0; i < numDevices; ++i) {
THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, i);
@@ -98,22 +111,15 @@ void THCudaInit(THCState* state)
int numSM = state->deviceProperties[i].multiProcessorCount;
size_t sizePerStream = numSM * GLOBAL_SCRATCH_SPACE_PER_SM_STREAM;
res->scratchSpacePerStream = sizePerStream;
-
- /* Allocate scratch space for each stream */
- res->devScratchSpacePerStream = (void**) malloc(sizeof(void*));
- THCudaCheck(THCudaMalloc(state, &res->devScratchSpacePerStream[0],
- sizePerStream));
}
/* Restore to previous device */
THCudaCheck(cudaSetDevice(device));
- /* There is no such thing as a default cublas handle.
- To maintain consistency with streams API, handle 0 is always NULL and we
- start counting at 1. If currentPerDeviceBlasHandle is 0 (the default
- thread-local value), then we assume it means 1.
- */
- THCState_reserveBlasHandles(state, 1);
+ // Unlike CUDA streams, there is no NULL cuBLAS handle. The default THC
+ // cuBLAS handle is the first user BLAS handle. Note that the actual BLAS
+ // handles are created lazily.
+ state->numUserBlasHandles = 1;
state->heapSoftmax = 3e8; // 300MB, adjusted upward dynamically
state->heapDelta = 0;
@@ -147,10 +153,9 @@ void THCudaShutdown(THCState* state)
for (int i = 1; i <= state->numUserStreams; ++i) {
THCStream_free(res->streams[i]);
}
- /* Free Torch-defined handles (0 is NULL for consistency with streams API) */
- for (int handle = 1; handle <= state->numUserBlasHandles; ++handle) {
- THCublasCheck(cublasDestroy(
- THCState_getDeviceBlasHandle(state, dev, handle)));
+ /* Free user defined BLAS handles */
+ for (int i = 0; i < res->numBlasHandles; ++i) {
+ THCublasCheck(cublasDestroy(res->blasHandles[i]));
}
/* Free per-stream scratch space; starts at 0 because there is space for
the default stream as well*/
@@ -174,79 +179,36 @@ void THCudaShutdown(THCState* state)
THCudaCheck(cudaSetDevice(prevDev));
}
-void THCudaEnablePeerToPeerAccess(THCState* state)
-{
- /* By default, all direct p2p kernel access (besides copy) is disallowed, */
- /* since direct access without knowing whether or not a certain operation */
- /* should be cross-GPU leads to synchronization errors. The user can choose */
- /* to disable this functionality, however. */
- state->p2pKernelAccessEnabled = 0;
-
- int prevDev = -1;
- THCudaCheck(cudaGetDevice(&prevDev));
-
- int numDevices = -1;
- THCudaCheck(cudaGetDeviceCount(&numDevices));
-
- state->p2pAccessEnabled = (int**) malloc(sizeof(int*) * numDevices);
- for (int i = 0; i < numDevices; ++i) {
- state->p2pAccessEnabled[i] = (int*) malloc(sizeof(int) * numDevices);
- }
-
- /* Build a table of all allowed p2p accesses, to avoid checking the p2p
- status at runtime. */
- for (int i = 0; i < numDevices; ++i) {
- THCudaCheck(cudaSetDevice(i));
-
- for (int j = 0; j < numDevices; ++j) {
- /* Presume no access by default */
- state->p2pAccessEnabled[i][j] = 0;
-
- if (i == j) {
- /* A GPU can access itself */
- state->p2pAccessEnabled[i][j] = 1;
- } else {
- int access = 0;
- THCudaCheck(cudaDeviceCanAccessPeer(&access, i, j));
-
- if (access) {
- cudaError_t err = cudaDeviceEnablePeerAccess(j, 0);
- if (err == cudaErrorPeerAccessAlreadyEnabled) {
- /* It is possible that another thread has already enabled access. */
- /* Any future call to cudaGetLastError will now return an error, */
- /* even though we've already dealt with this specific error here. */
- /* Call cudaGetLastError once to reset the last error state. */
- cudaGetLastError();
-
- /* The above should have cleared status */
- THCudaCheck(cudaGetLastError());
- } else {
- /* In case there are other unhandled errors returned from the */
- /* above */
- THCudaCheck(err);
- }
-
- /* Access could be enabled, or was already enabled */
- state->p2pAccessEnabled[i][j] = 1;
- }
- }
- }
- }
-
- /* Restore previous device before continuing */
- THCudaCheck(cudaSetDevice(prevDev));
-}
-
int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess)
{
if (dev < 0 || dev >= state->numDevices) {
THError("%d is not a device", dev);
}
-
- if (devToAccess < 0 || dev >= state->numDevices) {
+ if (devToAccess < 0 || devToAccess >= state->numDevices) {
THError("%d is not a device", devToAccess);
}
+ if (state->p2pAccessEnabled[dev][devToAccess] == -1) {
+ int prevDev = 0;
+ THCudaCheck(cudaGetDevice(&prevDev));
+ THCudaCheck(cudaSetDevice(dev));
+
+ int access = 0;
+ THCudaCheck(cudaDeviceCanAccessPeer(&access, dev, devToAccess));
+ if (access) {
+ cudaError_t err = cudaDeviceEnablePeerAccess(devToAccess, 0);
+ if (err == cudaErrorPeerAccessAlreadyEnabled) {
+ // ignore and clear the error if access was already enabled
+ cudaGetLastError();
+ } else {
+ THCudaCheck(err);
+ }
+ state->p2pAccessEnabled[dev][devToAccess] = 1;
+ } else {
+ state->p2pAccessEnabled[dev][devToAccess] = 0;
+ }
+ THCudaCheck(cudaSetDevice(prevDev));
+ }
return state->p2pAccessEnabled[dev][devToAccess];
}
@@ -327,6 +289,20 @@ int THCState_getNumDevices(THCState *state)
return state->numDevices;
}
+static void THCState_initializeScratchSpace(THCState* state, int dev)
+{
+ THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev);
+ if (res->devScratchSpacePerStream) {
+ return;
+ }
+ size_t size = (state->numUserStreams + 1) * sizeof(void*);
+ void** scratch = (void**)malloc(size);
+ for (int i = 0; i <= state->numUserStreams; ++i) {
+ THCudaCheck(THCudaMalloc(state, &scratch[i], res->scratchSpacePerStream));
+ }
+ res->devScratchSpacePerStream = scratch;
+}
+
void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking)
{
if (numStreams <= state->numUserStreams)
@@ -346,6 +322,7 @@ void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking)
THCStream** newStreams = realloc(res->streams, (numStreams + 1) * sizeof(THCStream*));
THAssert(newStreams);
+ THCState_initializeScratchSpace(state, dev);
void** newScratchSpace = realloc(res->devScratchSpacePerStream, (numStreams + 1) * sizeof(void*));
THAssert(newScratchSpace);
@@ -369,47 +346,39 @@ void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking)
THCudaCheck(cudaSetDevice(prevDev));
}
-void THCState_reserveBlasHandles(THCState* state, int numBlasHandles)
+void THCState_reserveDeviceBlasHandles(THCState* state, int device, int numBlasHandles)
{
- if (numBlasHandles <= state->numUserBlasHandles)
- {
+ int prevDev = -1;
+ THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
+ if (numBlasHandles <= res->numBlasHandles) {
return;
}
- int prevDev = -1;
THCudaCheck(cudaGetDevice(&prevDev));
+ THCudaCheck(cudaSetDevice(device));
- /* Otherwise, we have to allocate a new set of blasHandles */
- for (int dev = 0; dev < state->numDevices; ++dev) {
- THCudaCheck(cudaSetDevice(dev));
-
- /* +1 to be consistent with stream API, blas handle 0 is NULL and unused */
- cublasHandle_t* newBlasHandles =
- (cublasHandle_t*) malloc((numBlasHandles + 1) * sizeof(cublasHandle_t));
-
- /* Copy over old blasHandles
- (0 is NULL, 1 ... numUserBlasHandles are rest) */
- newBlasHandles[0] = NULL;
- for (int hndl = 1; hndl <= state->numUserBlasHandles; ++hndl) {
- newBlasHandles[hndl] = THCState_getDeviceBlasHandle(state, dev, hndl);
- }
-
- /* Allocate new handles */
- for (int hndl = state->numUserBlasHandles + 1; hndl <= numBlasHandles; ++hndl) {
- newBlasHandles[hndl] = NULL;
- THCublasCheck(cublasCreate(newBlasHandles + hndl));
- }
-
- THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev);
- free(res->blasHandles);
- res->blasHandles = newBlasHandles;
+ size_t size = numBlasHandles * sizeof(cublasHandle_t);
+ cublasHandle_t* handles = (cublasHandle_t*) realloc(res->blasHandles, size);
+ for (int i = res->numBlasHandles; i < numBlasHandles; ++i) {
+ handles[i] = NULL;
+ THCublasCheck(cublasCreate(&handles[i]));
}
-
- state->numUserBlasHandles = numBlasHandles;
+ res->blasHandles = handles;
+ res->numBlasHandles = numBlasHandles;
THCudaCheck(cudaSetDevice(prevDev));
}
+void THCState_reserveBlasHandles(THCState* state, int numBlasHandles)
+{
+ // cuBLAS handles are created lazily from THCState_getDeviceBlasHandle
+ // to avoid initializing unused devices
+ if (numBlasHandles > state->numUserBlasHandles)
+ {
+ state->numUserBlasHandles = numBlasHandles;
+ }
+}
+
int THCState_getNumStreams(THCState* state)
{
return state->numUserStreams;
@@ -445,12 +414,13 @@ cudaStream_t THCState_getDeviceStream(THCState *state, int device, int streamInd
cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int handle)
{
- if (handle <= 0 || handle > state->numUserBlasHandles)
- {
+ if (handle <= 0 || handle > state->numUserBlasHandles) {
THError("%d is not a valid handle, valid range is: (1, %d)",
handle, state->numUserBlasHandles);
}
- return THCState_getDeviceResourcePtr(state, device)->blasHandles[handle];
+ THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
+ THCState_reserveDeviceBlasHandles(state, device, handle);
+ return res->blasHandles[handle - 1];
}
static THCStream* THCState_getStreamOnDevice(THCState* state, int device)
@@ -592,16 +562,13 @@ void* THCState_getCurrentDeviceScratchSpace(THCState* state)
return THCState_getDeviceScratchSpace(state, device, stream);
}
-void* THCState_getDeviceScratchSpace(THCState* state, int device, int stream)
+void* THCState_getDeviceScratchSpace(THCState* state, int dev, int stream)
{
- THCCudaResourcesPerDevice* res =
- THCState_getDeviceResourcePtr(state, device);
-
- if (stream > state->numUserStreams || stream < 0)
- {
+ THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev);
+ if (stream > state->numUserStreams || stream < 0) {
THError("%d is not a stream", stream);
}
-
+ THCState_initializeScratchSpace(state, dev);
return res->devScratchSpacePerStream[stream];
}
diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in
index c685d37..1a37390 100644
--- a/lib/THC/THCGeneral.h.in
+++ b/lib/THC/THCGeneral.h.in
@@ -55,11 +55,14 @@ typedef struct _THCDeviceAllocator {
typedef struct _THCCudaResourcesPerDevice {
THCStream** streams;
+ /* Number of materialized cuBLAS handles */
+ int numBlasHandles;
+ /* cuBLAS handes are lazily initialized */
cublasHandle_t* blasHandles;
/* Size of scratch space per each stream on this device available */
size_t scratchSpacePerStream;
/* Device-resident scratch space per stream, used for global memory
- reduction kernels. */
+ reduction kernels. Lazily initialized. */
void** devScratchSpacePerStream;
} THCCudaResourcesPerDevice;
@@ -115,7 +118,6 @@ THC_API void THCState_free(THCState* state);
THC_API void THCudaInit(THCState* state);
THC_API void THCudaShutdown(THCState* state);
-THC_API void THCudaEnablePeerToPeerAccess(THCState* state);
/* If device `dev` can access allocations on device `devToAccess`, this will return */
/* 1; otherwise, 0. */
diff --git a/lib/THC/THCTensorRandom.cpp b/lib/THC/THCTensorRandom.cpp
new file mode 100644
index 0000000..d7690b5
--- /dev/null
+++ b/lib/THC/THCTensorRandom.cpp
@@ -0,0 +1,133 @@
+#include "THCTensorRandom.h"
+
+#include <random>
+#include <curand.h>
+
+
+void initializeGenerator(THCState *state, Generator* gen);
+void createGeneratorState(Generator* gen, unsigned long long seed);
+
+
+/* Frees memory allocated during setup. */
+void destroyGenerator(THCState *state, Generator* gen)
+{
+ if (gen->gen_states)
+ {
+ THCudaCheck(THCudaFree(state, gen->gen_states));
+ gen->gen_states = NULL;
+ }
+ if (gen->kernel_params)
+ {
+ THCudaCheck(THCudaFree(state, gen->kernel_params));
+ gen->kernel_params = NULL;
+ }
+}
+
+static unsigned long long createSeed(std::random_device& rd)
+{
+ // limit to 53 bits to ensure unique representation in double
+ unsigned long long seed = (((unsigned long long)rd()) << 32) + rd();
+ return seed & 0x1FFFFFFFFFFFFF;
+}
+
+/* Initialize generator array (must be called before any other function) */
+void THCRandom_init(THCState* state, int devices, int current_device)
+{
+ THCRNGState* rng_state = THCState_getRngState(state);
+ rng_state->num_devices = devices;
+ rng_state->gen = (Generator*)malloc(rng_state->num_devices * sizeof(Generator));
+ std::random_device rd;
+ for (int i = 0; i < rng_state->num_devices; ++i)
+ {
+ rng_state->gen[i].initf = 0;
+ rng_state->gen[i].initial_seed = createSeed(rd);
+ rng_state->gen[i].gen_states = NULL;
+ rng_state->gen[i].kernel_params = NULL;
+ }
+}
+
+/* Destroy generators and free memory */
+void THCRandom_shutdown(THCState* state)
+{
+ THCRNGState* rng_state = THCState_getRngState(state);
+ if (rng_state->gen == NULL) return;
+ for (int i = 0; i < rng_state->num_devices; ++i)
+ {
+ destroyGenerator(state, &rng_state->gen[i]);
+ }
+ free(rng_state->gen);
+ rng_state->gen = NULL;
+}
+
+/* Get the generator for the current device, but does not initialize the state */
+static Generator* THCRandom_rawGenerator(THCState* state)
+{
+ THCRNGState* rng_state = THCState_getRngState(state);
+ int device;
+ THCudaCheck(cudaGetDevice(&device));
+ if (device >= rng_state->num_devices) THError("Invalid device index.");
+ return &rng_state->gen[device];
+}
+
+/* Get the generator for the current device and initializes it if necessary */
+Generator* THCRandom_getGenerator(THCState* state)
+{
+ Generator* gen = THCRandom_rawGenerator(state);
+ if (gen->initf == 0)
+ {
+ initializeGenerator(state, gen);
+ createGeneratorState(gen, gen->initial_seed);
+ gen->initf = 1;
+ }
+ return gen;
+}
+
+struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state)
+{
+ return THCRandom_getGenerator(state)->gen_states;
+}
+
+/* Random seed */
+unsigned long long THCRandom_seed(THCState* state)
+{
+ std::random_device rd;
+ unsigned long long s = createSeed(rd);
+ THCRandom_manualSeed(state, s);
+ return s;
+}
+
+unsigned long long THCRandom_seedAll(THCState* state)
+{
+ std::random_device rd;
+ unsigned long long s = createSeed(rd);
+ THCRandom_manualSeedAll(state, s);
+ return s;
+}
+
+/* Manually set the seed */
+void THCRandom_manualSeed(THCState* state, unsigned long long seed)
+{
+ Generator* gen = THCRandom_rawGenerator(state);
+ gen->initial_seed = seed;
+ if (gen->initf) {
+ createGeneratorState(gen, seed);
+ }
+}
+
+void THCRandom_manualSeedAll(THCState* state, unsigned long long seed)
+{
+ THCRNGState* rng_state = THCState_getRngState(state);
+ int currentDevice;
+ THCudaCheck(cudaGetDevice(&currentDevice));
+ for (int i = 0; i < rng_state->num_devices; ++i) {
+ THCudaCheck(cudaSetDevice(i));
+ THCRandom_manualSeed(state, seed);
+ }
+ THCudaCheck(cudaSetDevice(currentDevice));
+}
+
+/* Get the initial seed */
+unsigned long long THCRandom_initialSeed(THCState* state)
+{
+ return THCRandom_getGenerator(state)->initial_seed;
+}
diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu
index 08efc0a..e05cf82 100644
--- a/lib/THC/THCTensorRandom.cu
+++ b/lib/THC/THCTensorRandom.cu
@@ -15,6 +15,9 @@
#define MAX_NUM_BLOCKS 64
#define BLOCK_SIZE 256
+
+Generator* THCRandom_getGenerator(THCState* state);
+
/* Sets up generator. Allocates but does not create the generator states. */
__host__ void initializeGenerator(THCState *state, Generator* gen)
{
@@ -22,23 +25,8 @@ __host__ void initializeGenerator(THCState *state, Generator* gen)
THCudaCheck(THCudaMalloc(state, (void**)&gen->kernel_params, sizeof(mtgp32_kernel_params)));
}
-/* Frees memory allocated during setup. */
-__host__ void destroyGenerator(THCState *state, Generator* gen)
-{
- if (gen->gen_states)
- {
- THCudaCheck(THCudaFree(state, gen->gen_states));
- gen->gen_states = NULL;
- }
- if (gen->kernel_params)
- {
- THCudaCheck(THCudaFree(state, gen->kernel_params));
- gen->kernel_params = NULL;
- }
-}
-
/* Creates a new generator state given the seed. */
-__host__ void createGeneratorState(Generator* gen, unsigned long seed)
+__host__ void createGeneratorState(Generator* gen, unsigned long long seed)
{
if (curandMakeMTGP32Constants(mtgp32dc_params_fast_11213, gen->kernel_params) != CURAND_STATUS_SUCCESS)
{
@@ -51,112 +39,13 @@ __host__ void createGeneratorState(Generator* gen, unsigned long seed)
}
}
-/* Initialize generator array (must be called before any other function) */
-__host__ void THCRandom_init(THCState* state, int devices, int current_device)
-{
- THCRNGState* rng_state = THCState_getRngState(state);
- rng_state->num_devices = devices;
- rng_state->gen = (Generator*)malloc(rng_state->num_devices * sizeof(Generator));
- for (int i = 0; i < rng_state->num_devices; ++i)
- {
- rng_state->gen[i].initf = 0;
- rng_state->gen[i].initial_seed = 0;
- rng_state->gen[i].gen_states = NULL;
- rng_state->gen[i].kernel_params = NULL;
- }
-}
-
-/* Destroy generators and free memory */
-__host__ void THCRandom_shutdown(THCState* state)
-{
- THCRNGState* rng_state = THCState_getRngState(state);
- if (rng_state->gen == NULL) return;
- for (int i = 0; i < rng_state->num_devices; ++i)
- {
- destroyGenerator(state, &rng_state->gen[i]);
- }
- free(rng_state->gen);
- rng_state->gen = NULL;
-}
-
-/* Manually set the generator seed */
-__host__ static void THCRandom_manualSeedGen(Generator* gen, unsigned long seed)
-{
- gen->initial_seed = seed;
- createGeneratorState(gen, seed);
- gen->initf = 1;
-}
-
-/* Get the generator for the current device */
-__host__ Generator* THCRandom_getGenerator(THCState* state)
-{
- THCRNGState* rng_state = THCState_getRngState(state);
-
- int device;
- THCudaCheck(cudaGetDevice(&device));
- if (device >= rng_state->num_devices) THError("Invalid device index.");
-
- Generator* gen = &rng_state->gen[device];
- if (gen->initf == 0)
- {
- initializeGenerator(state, gen);
- THCRandom_manualSeedGen(gen, (unsigned long)time(0));
- }
- return gen;
-}
-
-__host__ struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state)
-{
- return THCRandom_getGenerator(state)->gen_states;
-}
-
-/* Random seed */
-__host__ unsigned long THCRandom_seed(THCState* state)
-{
- unsigned long s = (unsigned long)time(0);
- THCRandom_manualSeed(state, s);
- return s;
-}
-
-__host__ unsigned long THCRandom_seedAll(THCState* state)
-{
- unsigned long s = (unsigned long)time(0);
- THCRandom_manualSeedAll(state, s);
- return s;
-}
-
-/* Manually set the seed */
-__host__ void THCRandom_manualSeed(THCState* state, unsigned long seed)
-{
- Generator* gen = THCRandom_getGenerator(state);
- THCRandom_manualSeedGen(gen, seed);
-}
-
-__host__ void THCRandom_manualSeedAll(THCState* state, unsigned long seed)
-{
- THCRNGState* rng_state = THCState_getRngState(state);
- int currentDevice;
- THCudaCheck(cudaGetDevice(&currentDevice));
- for (int i = 0; i < rng_state->num_devices; ++i) {
- THCudaCheck(cudaSetDevice(i));
- THCRandom_manualSeed(state, seed);
- }
- THCudaCheck(cudaSetDevice(currentDevice));
-}
-
-/* Get the initial seed */
-__host__ unsigned long THCRandom_initialSeed(THCState* state)
-{
- return THCRandom_getGenerator(state)->initial_seed;
-}
-
__host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_state)
{
Generator* gen = THCRandom_getGenerator(state);
// The RNG state comprises the MTPG32 states and the seed.
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
- static const size_t seed_size = sizeof(unsigned long);
+ static const size_t seed_size = sizeof(gen->initial_seed);
static const size_t total_size = states_size + seed_size;
THByteTensor_resize1d(rng_state, total_size);
THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size");
@@ -176,7 +65,7 @@ __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state)
Generator* gen = THCRandom_getGenerator(state);
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
- static const size_t seed_size = sizeof(unsigned long);
+ static const size_t seed_size = sizeof(gen->initial_seed);
static const size_t total_size = states_size + seed_size;
THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size");
THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous");
@@ -240,4 +129,3 @@ GENERATE_KERNEL2(generate_cauchy, half, double median, double sigma, float, cura
#undef GENERATE_KERNEL1
#undef GENERATE_KERNEL2
-
diff --git a/lib/THC/THCTensorRandom.h b/lib/THC/THCTensorRandom.h
index 12128cd..197a53c 100644
--- a/lib/THC/THCTensorRandom.h
+++ b/lib/THC/THCTensorRandom.h
@@ -11,7 +11,7 @@ typedef struct _Generator {
struct curandStateMtgp32* gen_states;
struct mtgp32_kernel_params *kernel_params;
int initf;
- unsigned long initial_seed;
+ unsigned long long initial_seed;
} Generator;
typedef struct THCRNGState {
@@ -24,11 +24,11 @@ struct THCState;
THC_API void THCRandom_init(struct THCState *state, int num_devices, int current_device);
THC_API void THCRandom_shutdown(struct THCState *state);
-THC_API unsigned long THCRandom_seed(struct THCState *state);
-THC_API unsigned long THCRandom_seedAll(struct THCState *state);
-THC_API void THCRandom_manualSeed(struct THCState *state, unsigned long the_seed_);
-THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long the_seed_);
-THC_API unsigned long THCRandom_initialSeed(struct THCState *state);
+THC_API unsigned long long THCRandom_seed(struct THCState *state);
+THC_API unsigned long long THCRandom_seedAll(struct THCState *state);
+THC_API void THCRandom_manualSeed(struct THCState *state, unsigned long long the_seed_);
+THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long long the_seed_);
+THC_API unsigned long long THCRandom_initialSeed(struct THCState *state);
THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state);
THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state);