diff options
author | Dominik Grewe <dominikg@google.com> | 2014-11-19 19:49:01 +0300 |
---|---|---|
committer | Dominik Grewe <dominikg@google.com> | 2014-11-19 20:07:52 +0300 |
commit | d33f309b0f4e8e2205998b146940fe70954df432 (patch) | |
tree | 3d4232f85166c24eb1f9e2ac7b823bf5f5f917d5 /init.c | |
parent | 2fee05a20efc83943251dc53f2a6d63ea1d7f452 (diff) |
Reset RNG state after device reset.
A device reset destroys the state of the RNG, so we have to re-initialize
it after each reset.
Diffstat (limited to 'init.c')
-rw-r--r-- | init.c | 19 |
1 files changed, 10 insertions, 9 deletions
@@ -7,6 +7,15 @@ extern void cutorch_CudaTensor_init(lua_State* L); extern void cutorch_CudaTensorMath_init(lua_State* L); extern void cutorch_CudaTensorOperator_init(lua_State* L); +static THCudaState* getState(lua_State *L) +{ + lua_getglobal(L, "cutorch"); + lua_getfield(L, -1, "_state"); + THCudaState *state = lua_touserdata(L, -1); + lua_pop(L, 2); + return state; +} + static int cutorch_synchronize(lua_State *L) { cudaDeviceSynchronize(); @@ -25,6 +34,7 @@ static int cutorch_getDevice(lua_State *L) static int cutorch_deviceReset(lua_State *L) { THCudaCheck(cudaDeviceReset()); + THCRandom_resetGenerator(getState(L)->rngState); return 0; } @@ -36,15 +46,6 @@ static int cutorch_getDeviceCount(lua_State *L) return 1; } -static THCudaState* getState(lua_State *L) -{ - lua_getglobal(L, "cutorch"); - lua_getfield(L, -1, "_state"); - THCudaState *state = lua_touserdata(L, -1); - lua_pop(L, 2); - return state; -} - static int cutorch_setDevice(lua_State *L) { int device = (int)luaL_checknumber(L, 1)-1; |