Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/init.c
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-08-21 23:03:31 +0300
committerSoumith Chintala <soumith@gmail.com>2015-08-21 23:03:31 +0300
commit8fad17bbff237ad9191114088873ae83535dfed3 (patch)
tree3bf7657c52dc6276f80ab54e9d77bac0601c7787 /init.c
parentd220328e9e2b42f3df6f40f15932cf2719bb0e17 (diff)
parentb6c0b9d97f2a0f9c637ba3ba67b17df24c12a4b5 (diff)
Merge pull request #222 from torch/streamfixes
stream event fixes
Diffstat (limited to 'init.c')
-rw-r--r--init.c69
1 files changed, 37 insertions, 32 deletions
diff --git a/init.c b/init.c
index e724b28..495af43 100644
--- a/init.c
+++ b/init.c
@@ -81,26 +81,28 @@ void checkAndCountListOfGPUStreamPairs(lua_State *L, THCState *state, int arg,
lua_pop(L, 1);
}
-void createSingleDeviceEvent(lua_State *L, THCState *state, int arg,
+int createSingleDeviceEvents(lua_State *L, THCState *state, int arg,
int device, cudaEvent_t* event)
{
- THCudaCheck(cudaEventCreateWithFlags(event, cudaEventDisableTiming));
/* Push table to top */
lua_pushvalue(L, arg);
/* Record events */
lua_pushnil(L);
+ int i = 0;
while (lua_next(L, -2)) {
int streamId = (int) lua_tonumber(L, -1);
cudaStream_t streamWaitingOn =
THCState_getDeviceStream(state, device, streamId);
- THCudaCheck(cudaEventRecord(*event, streamWaitingOn));
+ THCudaCheck(cudaEventCreateWithFlags(&event[i], cudaEventDisableTiming));
+ THCudaCheck(cudaEventRecord(event[i], streamWaitingOn));
lua_pop(L, 1);
+ i++;
}
-
/* Pop table from top */
lua_pop(L, 1);
+ return i;
}
void createMultiDeviceEvents(lua_State *L, THCState *state, int arg,
@@ -115,7 +117,7 @@ void createMultiDeviceEvents(lua_State *L, THCState *state, int arg,
while (lua_next(L, -2)) {
int device = (int) lua_tonumber(L, -2) - 1;
THCudaCheck(cudaSetDevice(device));
- createSingleDeviceEvent(L, state, -1, device, &events[gpu]);
+ events += createSingleDeviceEvents(L, state, -1, device, events);
++gpu;
lua_pop(L, 1);
@@ -125,8 +127,8 @@ void createMultiDeviceEvents(lua_State *L, THCState *state, int arg,
lua_pop(L, 1);
}
-void waitSingleDeviceEvent(lua_State *L, THCState *state, int arg,
- int device, cudaEvent_t event)
+void waitSingleDeviceEvents(lua_State *L, THCState *state, int arg,
+ int device, cudaEvent_t * event, int numEvents)
{
/* Push table to top */
lua_pushvalue(L, arg);
@@ -136,9 +138,11 @@ void waitSingleDeviceEvent(lua_State *L, THCState *state, int arg,
lua_pushnil(L);
while (lua_next(L, -2)) {
int streamId = (int) lua_tonumber(L, -1);
- cudaStream_t stream =
+ cudaStream_t stream =
THCState_getDeviceStream(state, device, streamId);
- THCudaCheck(cudaStreamWaitEvent(stream, event, 0));
+ for (int i = 0; i < numEvents; i++) {
+ THCudaCheck(cudaStreamWaitEvent(stream, event[i], 0));
+ }
lua_pop(L, 1);
}
@@ -146,8 +150,9 @@ void waitSingleDeviceEvent(lua_State *L, THCState *state, int arg,
lua_pop(L, 1);
}
+
void waitMultiDeviceEvents(lua_State *L, THCState *state, int arg,
- cudaEvent_t* events, int gpus)
+ cudaEvent_t* events, int streams)
{
/* Push {gpu={streams...}} table */
lua_pushvalue(L, arg);
@@ -169,7 +174,7 @@ void waitMultiDeviceEvents(lua_State *L, THCState *state, int arg,
THCState_getDeviceStream(state, device, streamId);
/* Each stream waits on all events */
- for (int i = 0; i < gpus; ++i) {
+ for (int i = 0; i < streams; ++i) {
THCudaCheck(cudaStreamWaitEvent(stream, events[i], 0));
}
@@ -367,16 +372,16 @@ static int cutorch_streamWaitFor(lua_State *L)
/* nothing to synchronize */
return 0;
}
-
/* One-way dependency; streamWaiting will wait for the list of streams to
wait on to complete execution of pending scheduled kernels/events */
- cudaEvent_t event;
- createSingleDeviceEvent(L, state, 2, curDev, &event);
-
+ cudaEvent_t * events = (cudaEvent_t*)malloc(sizeof(cudaEvent_t) * streams);
+ createSingleDeviceEvents(L, state, 2, curDev, events);
/* Then, wait on them */
- THCudaCheck(cudaStreamWaitEvent(streamWaiting, event, 0));
- THCudaCheck(cudaEventDestroy(event));
-
+ for (int i = 0; i < streams; i++) {
+ THCudaCheck(cudaStreamWaitEvent(streamWaiting, events[i], 0));
+ THCudaCheck(cudaEventDestroy(events[i]));
+ }
+ free(events);
return 0;
}
@@ -411,18 +416,18 @@ static int cutorch_streamWaitForMultiDevice(lua_State *L)
int streams = 0;
checkAndCountListOfGPUStreamPairs(L, state, 3, &gpus, &streams);
- if (streams < 2) {
+ if (streams < 1) {
/* nothing to synchronize together */
return 0;
}
/*
Events can only be recorded on the same device on which they are created.
- -For each GPU, create an event, and record that event on each stream given
+ -For each GPU, create and record event per each stream given
for that GPU.
-For (gpuWaiter, streamWaiter), wait on all of the above events.
*/
- cudaEvent_t* events = (cudaEvent_t*) malloc(sizeof(cudaEvent_t) * gpus);
+ cudaEvent_t* events = (cudaEvent_t*) malloc(sizeof(cudaEvent_t) * streams);
/* First, create an event per GPU and record events for the specified stream
on that GPU */
@@ -430,12 +435,12 @@ static int cutorch_streamWaitForMultiDevice(lua_State *L)
/* Then, wait on the events */
THCudaCheck(cudaSetDevice(gpuWaiter));
- for (int i = 0; i < gpus; ++i) {
+ for (int i = 0; i < streams; ++i) {
THCudaCheck(cudaStreamWaitEvent(streamWaiting, events[i], 0));
}
/* Clean up events */
- for (int i = 0; i < gpus; ++i) {
+ for (int i = 0; i < streams; ++i) {
THCudaCheck(cudaEventDestroy(events[i]));
}
free(events);
@@ -463,19 +468,19 @@ static int cutorch_streamBarrier(lua_State *L)
/* nothing to synchronize together */
return 0;
}
-
/* Multi-way dependency (barrier); all streams must complete execution
of pending scheduled kernels/events */
- cudaEvent_t event;
-
+ cudaEvent_t * events = (cudaEvent_t*)malloc(sizeof(cudaEvent_t) * streams);
/* First, create an event and record them for all streams */
- createSingleDeviceEvent(L, state, 1, curDev, &event);
+ int eventsCreated = createSingleDeviceEvents(L, state, 1, curDev, events);
/* Then, wait on the event. Each stream is actually waiting on itself here
too, but that's harmless and isn't worth weeding out. */
- waitSingleDeviceEvent(L, state, 1, curDev, event);
- THCudaCheck(cudaEventDestroy(event));
+ waitSingleDeviceEvents(L, state, 1, curDev, events, eventsCreated);
+ for (int i = 0; i < eventsCreated; i++)
+ THCudaCheck(cudaEventDestroy(events[i]));
+ free(events);
return 0;
}
@@ -513,7 +518,7 @@ static int cutorch_streamBarrierMultiDevice(lua_State *L)
-For each GPU, for each stream, wait on the event created by each other
GPU.
*/
- cudaEvent_t* events = (cudaEvent_t*) malloc(sizeof(cudaEvent_t) * gpus);
+ cudaEvent_t* events = (cudaEvent_t*) malloc(sizeof(cudaEvent_t) * streams);
/* First, create an event per GPU and record events for the specified stream
on that GPU */
@@ -521,10 +526,10 @@ static int cutorch_streamBarrierMultiDevice(lua_State *L)
/* Then, wait on the events. Each stream is actually waiting on itself here
too, but that's harmless and isn't worth weeding out. */
- waitMultiDeviceEvents(L, state, 1, events, gpus);
+ waitMultiDeviceEvents(L, state, 1, events, streams);
/* Clean up events */
- for (int i = 0; i < gpus; ++i) {
+ for (int i = 0; i < streams; ++i) {
THCudaCheck(cudaEventDestroy(events[i]));
}
free(events);