Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-11-24 03:40:03 +0300
committerGitHub <noreply@github.com>2016-11-24 03:40:03 +0300
commite9e131e6181f9a01dc2367ac250fe6758bcb6056 (patch)
tree3d0dc3d3253c37a3b46355eb40bb8d6723d832d2
parentf8d05d267e3be9a67055284a22b073fc7611ac9c (diff)
Revert "Lazily initialize CUDA devices"revert-610-lazy
-rw-r--r--init.c10
-rw-r--r--lib/THC/CMakeLists.txt5
-rw-r--r--lib/THC/THCGeneral.c211
-rw-r--r--lib/THC/THCGeneral.h.in6
-rw-r--r--lib/THC/THCTensorRandom.cpp133
-rw-r--r--lib/THC/THCTensorRandom.cu124
-rw-r--r--lib/THC/THCTensorRandom.h12
7 files changed, 255 insertions, 246 deletions
diff --git a/init.c b/init.c
index 02960cd..124be5c 100644
--- a/init.c
+++ b/init.c
@@ -776,35 +776,35 @@ static int cutorch_getDeviceProperties(lua_State *L)
static int cutorch_seed(lua_State *L)
{
- unsigned long long seed = THCRandom_seed(cutorch_getstate(L));
+ unsigned long seed = THCRandom_seed(cutorch_getstate(L));
lua_pushnumber(L, seed);
return 1;
}
static int cutorch_seedAll(lua_State *L)
{
- unsigned long long seed = THCRandom_seedAll(cutorch_getstate(L));
+ unsigned long seed = THCRandom_seedAll(cutorch_getstate(L));
lua_pushnumber(L, seed);
return 1;
}
static int cutorch_initialSeed(lua_State *L)
{
- unsigned long long seed = THCRandom_initialSeed(cutorch_getstate(L));
+ unsigned long seed = THCRandom_initialSeed(cutorch_getstate(L));
lua_pushnumber(L, seed);
return 1;
}
static int cutorch_manualSeed(lua_State *L)
{
- unsigned long long seed = luaL_checknumber(L, 1);
+ unsigned long seed = luaL_checknumber(L, 1);
THCRandom_manualSeed(cutorch_getstate(L), seed);
return 0;
}
static int cutorch_manualSeedAll(lua_State* L)
{
- unsigned long long seed = luaL_checknumber(L, 1);
+ unsigned long seed = luaL_checknumber(L, 1);
THCRandom_manualSeedAll(cutorch_getstate(L), seed);
return 0;
}
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt
index 08e6fbc..c8916d3 100644
--- a/lib/THC/CMakeLists.txt
+++ b/lib/THC/CMakeLists.txt
@@ -25,10 +25,10 @@ endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.7" OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "4.7" )
# add c++11 flag
- set_source_files_properties(THCTensorRandom.cpp THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++11)
+ set_source_files_properties(THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++11)
else()
# add c++0x flag
- set_source_files_properties(THCTensorRandom.cpp THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++0x)
+ set_source_files_properties(THCCachingAllocator.cpp PROPERTIES COMPILE_FLAGS -std=c++0x)
endif()
else()
SET(CMAKE_CXX_STANDARD 11)
@@ -130,7 +130,6 @@ SET(src
THCStream.c
THCTensor.c
THCTensorCopy.c
- THCTensorRandom.cpp
THCThreadLocal.c
)
diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c
index 2d4043d..547e060 100644
--- a/lib/THC/THCGeneral.c
+++ b/lib/THC/THCGeneral.c
@@ -81,21 +81,8 @@ void THCudaInit(THCState* state)
state->cudaUVAAllocator = (THAllocator*)malloc(sizeof(THAllocator));
THCUVAAllocator_init(state->cudaUVAAllocator);
- // By default, all direct p2p kernel access (besides copy) is disallowed,
- // since direct access without knowing whether or not a certain operation
- // should be cross-GPU leads to synchronization errors. The user can choose
- // to disable this functionality, however.
- state->p2pKernelAccessEnabled = 0;
-
- // p2pAccessEnabled records if p2p copies are allowed between pairs of
- // devices. Values include "1" (copy allowed), "0" (copy not allowed), and
- // "-1" (unknown).
- state->p2pAccessEnabled = (int**) malloc(sizeof(int*) * numDevices);
- for (int i = 0; i < numDevices; ++i) {
- state->p2pAccessEnabled[i] = (int*) malloc(sizeof(int) * numDevices);
- memset(state->p2pAccessEnabled[i], -1, sizeof(int) * numDevices);
- state->p2pAccessEnabled[i][i] = 1;
- }
+ /* Enable P2P access between all pairs, if possible */
+ THCudaEnablePeerToPeerAccess(state);
for (int i = 0; i < numDevices; ++i) {
THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, i);
@@ -111,15 +98,22 @@ void THCudaInit(THCState* state)
int numSM = state->deviceProperties[i].multiProcessorCount;
size_t sizePerStream = numSM * GLOBAL_SCRATCH_SPACE_PER_SM_STREAM;
res->scratchSpacePerStream = sizePerStream;
+
+ /* Allocate scratch space for each stream */
+ res->devScratchSpacePerStream = (void**) malloc(sizeof(void*));
+ THCudaCheck(THCudaMalloc(state, &res->devScratchSpacePerStream[0],
+ sizePerStream));
}
/* Restore to previous device */
THCudaCheck(cudaSetDevice(device));
- // Unlike CUDA streams, there is no NULL cuBLAS handle. The default THC
- // cuBLAS handle is the first user BLAS handle. Note that the actual BLAS
- // handles are created lazily.
- state->numUserBlasHandles = 1;
+ /* There is no such thing as a default cublas handle.
+ To maintain consistency with streams API, handle 0 is always NULL and we
+ start counting at 1. If currentPerDeviceBlasHandle is 0 (the default
+ thread-local value), then we assume it means 1.
+ */
+ THCState_reserveBlasHandles(state, 1);
state->heapSoftmax = 3e8; // 300MB, adjusted upward dynamically
state->heapDelta = 0;
@@ -153,9 +147,10 @@ void THCudaShutdown(THCState* state)
for (int i = 1; i <= state->numUserStreams; ++i) {
THCStream_free(res->streams[i]);
}
- /* Free user defined BLAS handles */
- for (int i = 0; i < res->numBlasHandles; ++i) {
- THCublasCheck(cublasDestroy(res->blasHandles[i]));
+ /* Free Torch-defined handles (0 is NULL for consistency with streams API) */
+ for (int handle = 1; handle <= state->numUserBlasHandles; ++handle) {
+ THCublasCheck(cublasDestroy(
+ THCState_getDeviceBlasHandle(state, dev, handle)));
}
/* Free per-stream scratch space; starts at 0 because there is space for
the default stream as well*/
@@ -179,36 +174,79 @@ void THCudaShutdown(THCState* state)
THCudaCheck(cudaSetDevice(prevDev));
}
-int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess)
+void THCudaEnablePeerToPeerAccess(THCState* state)
{
- if (dev < 0 || dev >= state->numDevices) {
- THError("%d is not a device", dev);
- }
- if (devToAccess < 0 || devToAccess >= state->numDevices) {
- THError("%d is not a device", devToAccess);
+ /* By default, all direct p2p kernel access (besides copy) is disallowed, */
+ /* since direct access without knowing whether or not a certain operation */
+ /* should be cross-GPU leads to synchronization errors. The user can choose */
+ /* to disable this functionality, however. */
+ state->p2pKernelAccessEnabled = 0;
+
+ int prevDev = -1;
+ THCudaCheck(cudaGetDevice(&prevDev));
+
+ int numDevices = -1;
+ THCudaCheck(cudaGetDeviceCount(&numDevices));
+
+ state->p2pAccessEnabled = (int**) malloc(sizeof(int*) * numDevices);
+ for (int i = 0; i < numDevices; ++i) {
+ state->p2pAccessEnabled[i] = (int*) malloc(sizeof(int) * numDevices);
}
- if (state->p2pAccessEnabled[dev][devToAccess] == -1) {
- int prevDev = 0;
- THCudaCheck(cudaGetDevice(&prevDev));
- THCudaCheck(cudaSetDevice(dev));
- int access = 0;
- THCudaCheck(cudaDeviceCanAccessPeer(&access, dev, devToAccess));
- if (access) {
- cudaError_t err = cudaDeviceEnablePeerAccess(devToAccess, 0);
- if (err == cudaErrorPeerAccessAlreadyEnabled) {
- // ignore and clear the error if access was already enabled
- cudaGetLastError();
+ /* Build a table of all allowed p2p accesses, to avoid checking the p2p
+ status at runtime. */
+ for (int i = 0; i < numDevices; ++i) {
+ THCudaCheck(cudaSetDevice(i));
+
+ for (int j = 0; j < numDevices; ++j) {
+ /* Presume no access by default */
+ state->p2pAccessEnabled[i][j] = 0;
+
+ if (i == j) {
+ /* A GPU can access itself */
+ state->p2pAccessEnabled[i][j] = 1;
} else {
- THCudaCheck(err);
+ int access = 0;
+ THCudaCheck(cudaDeviceCanAccessPeer(&access, i, j));
+
+ if (access) {
+ cudaError_t err = cudaDeviceEnablePeerAccess(j, 0);
+ if (err == cudaErrorPeerAccessAlreadyEnabled) {
+ /* It is possible that another thread has already enabled access. */
+ /* Any future call to cudaGetLastError will now return an error, */
+ /* even though we've already dealt with this specific error here. */
+ /* Call cudaGetLastError once to reset the last error state. */
+ cudaGetLastError();
+
+ /* The above should have cleared status */
+ THCudaCheck(cudaGetLastError());
+ } else {
+ /* In case there are other unhandled errors returned from the */
+ /* above */
+ THCudaCheck(err);
+ }
+
+ /* Access could be enabled, or was already enabled */
+ state->p2pAccessEnabled[i][j] = 1;
+ }
}
- state->p2pAccessEnabled[dev][devToAccess] = 1;
- } else {
- state->p2pAccessEnabled[dev][devToAccess] = 0;
}
+ }
- THCudaCheck(cudaSetDevice(prevDev));
+ /* Restore previous device before continuing */
+ THCudaCheck(cudaSetDevice(prevDev));
+}
+
+int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess)
+{
+ if (dev < 0 || dev >= state->numDevices) {
+ THError("%d is not a device", dev);
+ }
+
+ if (devToAccess < 0 || dev >= state->numDevices) {
+ THError("%d is not a device", devToAccess);
}
+
return state->p2pAccessEnabled[dev][devToAccess];
}
@@ -289,20 +327,6 @@ int THCState_getNumDevices(THCState *state)
return state->numDevices;
}
-static void THCState_initializeScratchSpace(THCState* state, int dev)
-{
- THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev);
- if (res->devScratchSpacePerStream) {
- return;
- }
- size_t size = (state->numUserStreams + 1) * sizeof(void*);
- void** scratch = (void**)malloc(size);
- for (int i = 0; i <= state->numUserStreams; ++i) {
- THCudaCheck(THCudaMalloc(state, &scratch[i], res->scratchSpacePerStream));
- }
- res->devScratchSpacePerStream = scratch;
-}
-
void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking)
{
if (numStreams <= state->numUserStreams)
@@ -322,7 +346,6 @@ void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking)
THCStream** newStreams = realloc(res->streams, (numStreams + 1) * sizeof(THCStream*));
THAssert(newStreams);
- THCState_initializeScratchSpace(state, dev);
void** newScratchSpace = realloc(res->devScratchSpacePerStream, (numStreams + 1) * sizeof(void*));
THAssert(newScratchSpace);
@@ -346,37 +369,45 @@ void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking)
THCudaCheck(cudaSetDevice(prevDev));
}
-void THCState_reserveDeviceBlasHandles(THCState* state, int device, int numBlasHandles)
+void THCState_reserveBlasHandles(THCState* state, int numBlasHandles)
{
- int prevDev = -1;
- THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
- if (numBlasHandles <= res->numBlasHandles) {
+ if (numBlasHandles <= state->numUserBlasHandles)
+ {
return;
}
+ int prevDev = -1;
THCudaCheck(cudaGetDevice(&prevDev));
- THCudaCheck(cudaSetDevice(device));
- size_t size = numBlasHandles * sizeof(cublasHandle_t);
- cublasHandle_t* handles = (cublasHandle_t*) realloc(res->blasHandles, size);
- for (int i = res->numBlasHandles; i < numBlasHandles; ++i) {
- handles[i] = NULL;
- THCublasCheck(cublasCreate(&handles[i]));
- }
- res->blasHandles = handles;
- res->numBlasHandles = numBlasHandles;
+ /* Otherwise, we have to allocate a new set of blasHandles */
+ for (int dev = 0; dev < state->numDevices; ++dev) {
+ THCudaCheck(cudaSetDevice(dev));
- THCudaCheck(cudaSetDevice(prevDev));
-}
+ /* +1 to be consistent with stream API, blas handle 0 is NULL and unused */
+ cublasHandle_t* newBlasHandles =
+ (cublasHandle_t*) malloc((numBlasHandles + 1) * sizeof(cublasHandle_t));
-void THCState_reserveBlasHandles(THCState* state, int numBlasHandles)
-{
- // cuBLAS handles are created lazily from THCState_getDeviceBlasHandle
- // to avoid initializing unused devices
- if (numBlasHandles > state->numUserBlasHandles)
- {
- state->numUserBlasHandles = numBlasHandles;
+ /* Copy over old blasHandles
+ (0 is NULL, 1 ... numUserBlasHandles are rest) */
+ newBlasHandles[0] = NULL;
+ for (int hndl = 1; hndl <= state->numUserBlasHandles; ++hndl) {
+ newBlasHandles[hndl] = THCState_getDeviceBlasHandle(state, dev, hndl);
+ }
+
+ /* Allocate new handles */
+ for (int hndl = state->numUserBlasHandles + 1; hndl <= numBlasHandles; ++hndl) {
+ newBlasHandles[hndl] = NULL;
+ THCublasCheck(cublasCreate(newBlasHandles + hndl));
+ }
+
+ THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev);
+ free(res->blasHandles);
+ res->blasHandles = newBlasHandles;
}
+
+ state->numUserBlasHandles = numBlasHandles;
+
+ THCudaCheck(cudaSetDevice(prevDev));
}
int THCState_getNumStreams(THCState* state)
@@ -414,13 +445,12 @@ cudaStream_t THCState_getDeviceStream(THCState *state, int device, int streamInd
cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int handle)
{
- if (handle <= 0 || handle > state->numUserBlasHandles) {
+ if (handle <= 0 || handle > state->numUserBlasHandles)
+ {
THError("%d is not a valid handle, valid range is: (1, %d)",
handle, state->numUserBlasHandles);
}
- THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
- THCState_reserveDeviceBlasHandles(state, device, handle);
- return res->blasHandles[handle - 1];
+ return THCState_getDeviceResourcePtr(state, device)->blasHandles[handle];
}
static THCStream* THCState_getStreamOnDevice(THCState* state, int device)
@@ -562,13 +592,16 @@ void* THCState_getCurrentDeviceScratchSpace(THCState* state)
return THCState_getDeviceScratchSpace(state, device, stream);
}
-void* THCState_getDeviceScratchSpace(THCState* state, int dev, int stream)
+void* THCState_getDeviceScratchSpace(THCState* state, int device, int stream)
{
- THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev);
- if (stream > state->numUserStreams || stream < 0) {
+ THCCudaResourcesPerDevice* res =
+ THCState_getDeviceResourcePtr(state, device);
+
+ if (stream > state->numUserStreams || stream < 0)
+ {
THError("%d is not a stream", stream);
}
- THCState_initializeScratchSpace(state, dev);
+
return res->devScratchSpacePerStream[stream];
}
diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in
index 1a37390..c685d37 100644
--- a/lib/THC/THCGeneral.h.in
+++ b/lib/THC/THCGeneral.h.in
@@ -55,14 +55,11 @@ typedef struct _THCDeviceAllocator {
typedef struct _THCCudaResourcesPerDevice {
THCStream** streams;
- /* Number of materialized cuBLAS handles */
- int numBlasHandles;
- /* cuBLAS handes are lazily initialized */
cublasHandle_t* blasHandles;
/* Size of scratch space per each stream on this device available */
size_t scratchSpacePerStream;
/* Device-resident scratch space per stream, used for global memory
- reduction kernels. Lazily initialized. */
+ reduction kernels. */
void** devScratchSpacePerStream;
} THCCudaResourcesPerDevice;
@@ -118,6 +115,7 @@ THC_API void THCState_free(THCState* state);
THC_API void THCudaInit(THCState* state);
THC_API void THCudaShutdown(THCState* state);
+THC_API void THCudaEnablePeerToPeerAccess(THCState* state);
/* If device `dev` can access allocations on device `devToAccess`, this will return */
/* 1; otherwise, 0. */
diff --git a/lib/THC/THCTensorRandom.cpp b/lib/THC/THCTensorRandom.cpp
deleted file mode 100644
index d7690b5..0000000
--- a/lib/THC/THCTensorRandom.cpp
+++ /dev/null
@@ -1,133 +0,0 @@
-#include "THCTensorRandom.h"
-
-#include <random>
-#include <curand.h>
-
-
-void initializeGenerator(THCState *state, Generator* gen);
-void createGeneratorState(Generator* gen, unsigned long long seed);
-
-
-/* Frees memory allocated during setup. */
-void destroyGenerator(THCState *state, Generator* gen)
-{
- if (gen->gen_states)
- {
- THCudaCheck(THCudaFree(state, gen->gen_states));
- gen->gen_states = NULL;
- }
- if (gen->kernel_params)
- {
- THCudaCheck(THCudaFree(state, gen->kernel_params));
- gen->kernel_params = NULL;
- }
-}
-
-static unsigned long long createSeed(std::random_device& rd)
-{
- // limit to 53 bits to ensure unique representation in double
- unsigned long long seed = (((unsigned long long)rd()) << 32) + rd();
- return seed & 0x1FFFFFFFFFFFFF;
-}
-
-/* Initialize generator array (must be called before any other function) */
-void THCRandom_init(THCState* state, int devices, int current_device)
-{
- THCRNGState* rng_state = THCState_getRngState(state);
- rng_state->num_devices = devices;
- rng_state->gen = (Generator*)malloc(rng_state->num_devices * sizeof(Generator));
- std::random_device rd;
- for (int i = 0; i < rng_state->num_devices; ++i)
- {
- rng_state->gen[i].initf = 0;
- rng_state->gen[i].initial_seed = createSeed(rd);
- rng_state->gen[i].gen_states = NULL;
- rng_state->gen[i].kernel_params = NULL;
- }
-}
-
-/* Destroy generators and free memory */
-void THCRandom_shutdown(THCState* state)
-{
- THCRNGState* rng_state = THCState_getRngState(state);
- if (rng_state->gen == NULL) return;
- for (int i = 0; i < rng_state->num_devices; ++i)
- {
- destroyGenerator(state, &rng_state->gen[i]);
- }
- free(rng_state->gen);
- rng_state->gen = NULL;
-}
-
-/* Get the generator for the current device, but does not initialize the state */
-static Generator* THCRandom_rawGenerator(THCState* state)
-{
- THCRNGState* rng_state = THCState_getRngState(state);
- int device;
- THCudaCheck(cudaGetDevice(&device));
- if (device >= rng_state->num_devices) THError("Invalid device index.");
- return &rng_state->gen[device];
-}
-
-/* Get the generator for the current device and initializes it if necessary */
-Generator* THCRandom_getGenerator(THCState* state)
-{
- Generator* gen = THCRandom_rawGenerator(state);
- if (gen->initf == 0)
- {
- initializeGenerator(state, gen);
- createGeneratorState(gen, gen->initial_seed);
- gen->initf = 1;
- }
- return gen;
-}
-
-struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state)
-{
- return THCRandom_getGenerator(state)->gen_states;
-}
-
-/* Random seed */
-unsigned long long THCRandom_seed(THCState* state)
-{
- std::random_device rd;
- unsigned long long s = createSeed(rd);
- THCRandom_manualSeed(state, s);
- return s;
-}
-
-unsigned long long THCRandom_seedAll(THCState* state)
-{
- std::random_device rd;
- unsigned long long s = createSeed(rd);
- THCRandom_manualSeedAll(state, s);
- return s;
-}
-
-/* Manually set the seed */
-void THCRandom_manualSeed(THCState* state, unsigned long long seed)
-{
- Generator* gen = THCRandom_rawGenerator(state);
- gen->initial_seed = seed;
- if (gen->initf) {
- createGeneratorState(gen, seed);
- }
-}
-
-void THCRandom_manualSeedAll(THCState* state, unsigned long long seed)
-{
- THCRNGState* rng_state = THCState_getRngState(state);
- int currentDevice;
- THCudaCheck(cudaGetDevice(&currentDevice));
- for (int i = 0; i < rng_state->num_devices; ++i) {
- THCudaCheck(cudaSetDevice(i));
- THCRandom_manualSeed(state, seed);
- }
- THCudaCheck(cudaSetDevice(currentDevice));
-}
-
-/* Get the initial seed */
-unsigned long long THCRandom_initialSeed(THCState* state)
-{
- return THCRandom_getGenerator(state)->initial_seed;
-}
diff --git a/lib/THC/THCTensorRandom.cu b/lib/THC/THCTensorRandom.cu
index e05cf82..08efc0a 100644
--- a/lib/THC/THCTensorRandom.cu
+++ b/lib/THC/THCTensorRandom.cu
@@ -15,9 +15,6 @@
#define MAX_NUM_BLOCKS 64
#define BLOCK_SIZE 256
-
-Generator* THCRandom_getGenerator(THCState* state);
-
/* Sets up generator. Allocates but does not create the generator states. */
__host__ void initializeGenerator(THCState *state, Generator* gen)
{
@@ -25,8 +22,23 @@ __host__ void initializeGenerator(THCState *state, Generator* gen)
THCudaCheck(THCudaMalloc(state, (void**)&gen->kernel_params, sizeof(mtgp32_kernel_params)));
}
+/* Frees memory allocated during setup. */
+__host__ void destroyGenerator(THCState *state, Generator* gen)
+{
+ if (gen->gen_states)
+ {
+ THCudaCheck(THCudaFree(state, gen->gen_states));
+ gen->gen_states = NULL;
+ }
+ if (gen->kernel_params)
+ {
+ THCudaCheck(THCudaFree(state, gen->kernel_params));
+ gen->kernel_params = NULL;
+ }
+}
+
/* Creates a new generator state given the seed. */
-__host__ void createGeneratorState(Generator* gen, unsigned long long seed)
+__host__ void createGeneratorState(Generator* gen, unsigned long seed)
{
if (curandMakeMTGP32Constants(mtgp32dc_params_fast_11213, gen->kernel_params) != CURAND_STATUS_SUCCESS)
{
@@ -39,13 +51,112 @@ __host__ void createGeneratorState(Generator* gen, unsigned long long seed)
}
}
+/* Initialize generator array (must be called before any other function) */
+__host__ void THCRandom_init(THCState* state, int devices, int current_device)
+{
+ THCRNGState* rng_state = THCState_getRngState(state);
+ rng_state->num_devices = devices;
+ rng_state->gen = (Generator*)malloc(rng_state->num_devices * sizeof(Generator));
+ for (int i = 0; i < rng_state->num_devices; ++i)
+ {
+ rng_state->gen[i].initf = 0;
+ rng_state->gen[i].initial_seed = 0;
+ rng_state->gen[i].gen_states = NULL;
+ rng_state->gen[i].kernel_params = NULL;
+ }
+}
+
+/* Destroy generators and free memory */
+__host__ void THCRandom_shutdown(THCState* state)
+{
+ THCRNGState* rng_state = THCState_getRngState(state);
+ if (rng_state->gen == NULL) return;
+ for (int i = 0; i < rng_state->num_devices; ++i)
+ {
+ destroyGenerator(state, &rng_state->gen[i]);
+ }
+ free(rng_state->gen);
+ rng_state->gen = NULL;
+}
+
+/* Manually set the generator seed */
+__host__ static void THCRandom_manualSeedGen(Generator* gen, unsigned long seed)
+{
+ gen->initial_seed = seed;
+ createGeneratorState(gen, seed);
+ gen->initf = 1;
+}
+
+/* Get the generator for the current device */
+__host__ Generator* THCRandom_getGenerator(THCState* state)
+{
+ THCRNGState* rng_state = THCState_getRngState(state);
+
+ int device;
+ THCudaCheck(cudaGetDevice(&device));
+ if (device >= rng_state->num_devices) THError("Invalid device index.");
+
+ Generator* gen = &rng_state->gen[device];
+ if (gen->initf == 0)
+ {
+ initializeGenerator(state, gen);
+ THCRandom_manualSeedGen(gen, (unsigned long)time(0));
+ }
+ return gen;
+}
+
+__host__ struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state)
+{
+ return THCRandom_getGenerator(state)->gen_states;
+}
+
+/* Random seed */
+__host__ unsigned long THCRandom_seed(THCState* state)
+{
+ unsigned long s = (unsigned long)time(0);
+ THCRandom_manualSeed(state, s);
+ return s;
+}
+
+__host__ unsigned long THCRandom_seedAll(THCState* state)
+{
+ unsigned long s = (unsigned long)time(0);
+ THCRandom_manualSeedAll(state, s);
+ return s;
+}
+
+/* Manually set the seed */
+__host__ void THCRandom_manualSeed(THCState* state, unsigned long seed)
+{
+ Generator* gen = THCRandom_getGenerator(state);
+ THCRandom_manualSeedGen(gen, seed);
+}
+
+__host__ void THCRandom_manualSeedAll(THCState* state, unsigned long seed)
+{
+ THCRNGState* rng_state = THCState_getRngState(state);
+ int currentDevice;
+ THCudaCheck(cudaGetDevice(&currentDevice));
+ for (int i = 0; i < rng_state->num_devices; ++i) {
+ THCudaCheck(cudaSetDevice(i));
+ THCRandom_manualSeed(state, seed);
+ }
+ THCudaCheck(cudaSetDevice(currentDevice));
+}
+
+/* Get the initial seed */
+__host__ unsigned long THCRandom_initialSeed(THCState* state)
+{
+ return THCRandom_getGenerator(state)->initial_seed;
+}
+
__host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_state)
{
Generator* gen = THCRandom_getGenerator(state);
// The RNG state comprises the MTPG32 states and the seed.
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
- static const size_t seed_size = sizeof(gen->initial_seed);
+ static const size_t seed_size = sizeof(unsigned long);
static const size_t total_size = states_size + seed_size;
THByteTensor_resize1d(rng_state, total_size);
THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size");
@@ -65,7 +176,7 @@ __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state)
Generator* gen = THCRandom_getGenerator(state);
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
- static const size_t seed_size = sizeof(gen->initial_seed);
+ static const size_t seed_size = sizeof(unsigned long);
static const size_t total_size = states_size + seed_size;
THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size");
THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous");
@@ -129,3 +240,4 @@ GENERATE_KERNEL2(generate_cauchy, half, double median, double sigma, float, cura
#undef GENERATE_KERNEL1
#undef GENERATE_KERNEL2
+
diff --git a/lib/THC/THCTensorRandom.h b/lib/THC/THCTensorRandom.h
index 197a53c..12128cd 100644
--- a/lib/THC/THCTensorRandom.h
+++ b/lib/THC/THCTensorRandom.h
@@ -11,7 +11,7 @@ typedef struct _Generator {
struct curandStateMtgp32* gen_states;
struct mtgp32_kernel_params *kernel_params;
int initf;
- unsigned long long initial_seed;
+ unsigned long initial_seed;
} Generator;
typedef struct THCRNGState {
@@ -24,11 +24,11 @@ struct THCState;
THC_API void THCRandom_init(struct THCState *state, int num_devices, int current_device);
THC_API void THCRandom_shutdown(struct THCState *state);
-THC_API unsigned long long THCRandom_seed(struct THCState *state);
-THC_API unsigned long long THCRandom_seedAll(struct THCState *state);
-THC_API void THCRandom_manualSeed(struct THCState *state, unsigned long long the_seed_);
-THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long long the_seed_);
-THC_API unsigned long long THCRandom_initialSeed(struct THCState *state);
+THC_API unsigned long THCRandom_seed(struct THCState *state);
+THC_API unsigned long THCRandom_seedAll(struct THCState *state);
+THC_API void THCRandom_manualSeed(struct THCState *state, unsigned long the_seed_);
+THC_API void THCRandom_manualSeedAll(struct THCState *state, unsigned long the_seed_);
+THC_API unsigned long THCRandom_initialSeed(struct THCState *state);
THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state);
THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state);