Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/init.c
diff options
context:
space:
mode:
authorDominik Grewe <dominikg@google.com>2014-11-19 19:49:01 +0300
committerDominik Grewe <dominikg@google.com>2014-11-19 20:07:52 +0300
commitd33f309b0f4e8e2205998b146940fe70954df432 (patch)
tree3d4232f85166c24eb1f9e2ac7b823bf5f5f917d5 /init.c
parent2fee05a20efc83943251dc53f2a6d63ea1d7f452 (diff)
Reset RNG state after device reset.
A device reset destroys the state of the RNG, so we have to re-initialize it after each reset.
Diffstat (limited to 'init.c')
-rw-r--r--init.c19
1 files changed, 10 insertions, 9 deletions
diff --git a/init.c b/init.c
index 365b200..c3accd5 100644
--- a/init.c
+++ b/init.c
@@ -7,6 +7,15 @@ extern void cutorch_CudaTensor_init(lua_State* L);
extern void cutorch_CudaTensorMath_init(lua_State* L);
extern void cutorch_CudaTensorOperator_init(lua_State* L);
+static THCudaState* getState(lua_State *L)
+{
+ lua_getglobal(L, "cutorch");
+ lua_getfield(L, -1, "_state");
+ THCudaState *state = lua_touserdata(L, -1);
+ lua_pop(L, 2);
+ return state;
+}
+
static int cutorch_synchronize(lua_State *L)
{
cudaDeviceSynchronize();
@@ -25,6 +34,7 @@ static int cutorch_getDevice(lua_State *L)
static int cutorch_deviceReset(lua_State *L)
{
THCudaCheck(cudaDeviceReset());
+ THCRandom_resetGenerator(getState(L)->rngState);
return 0;
}
@@ -36,15 +46,6 @@ static int cutorch_getDeviceCount(lua_State *L)
return 1;
}
-static THCudaState* getState(lua_State *L)
-{
- lua_getglobal(L, "cutorch");
- lua_getfield(L, -1, "_state");
- THCudaState *state = lua_touserdata(L, -1);
- lua_pop(L, 2);
- return state;
-}
-
static int cutorch_setDevice(lua_State *L)
{
int device = (int)luaL_checknumber(L, 1)-1;