Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Luehr <nluehr@nvidia.com>2016-01-21 04:58:25 +0300
committerPrzemek Tredak <ptredak@nvidia.com>2016-01-21 21:36:03 +0300
commit130ee246e21d3f73c977eda496ac9c90c3aa520b (patch)
tree8111722778177ea8c5b686f5743cffd968d21dd1
parent90af7c73efc51ca47bd669bb1cddd6814aaf64bf (diff)
Fixed deadlock in back-to-back reduce_scatters.
Change-Id: I92d32b15e516a39710b676aee692ae9b70638937 Reviewed-on: http://git-master/r/935458 Reviewed-by: Przemek Tredak <ptredak@nvidia.com> Tested-by: Przemek Tredak <ptredak@nvidia.com>
-rw-r--r--.gitignore1
-rw-r--r--README.md16
-rw-r--r--src/all_gather.cu4
-rw-r--r--src/all_reduce_test.cu4
-rw-r--r--src/core.cu22
-rw-r--r--src/libwrap.cu8
-rw-r--r--src/mpi_test.cu4
-rw-r--r--src/nccl.h14
-rw-r--r--src/reduce.cu4
-rw-r--r--src/reduce_scatter.cu25
-rw-r--r--src/reduce_test.cu8
-rw-r--r--src/test_utilities.h4
12 files changed, 70 insertions, 44 deletions
diff --git a/.gitignore b/.gitignore
index 796b96d..34a07c2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
+# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
/build
diff --git a/README.md b/README.md
index 289b46e..e65e1ba 100644
--- a/README.md
+++ b/README.md
@@ -88,7 +88,7 @@ typedef struct {
int size;
cudaStream_t stream;
} PerThreadData;
-
+
int main(int argc, char* argv[])
{
int nGPUs;
@@ -96,20 +96,26 @@ int main(int argc, char* argv[])
ncclComm_t* comms = (ncclComm_t*)malloc(sizeof(ncclComm_t)*nGPUs);
ncclCommInitAll(comms, nGPUs); // initialize communicator
// One communicator per process
-
+
PerThreadData* data;
-
+
... // Allocate data and issue work to each GPU's
// perDevStream to populate the sendBuffs.
-
+
for(int i=0; i<nGPUs; ++i) {
cudaSetDevice(i); // Correct device must be set
// prior to each collective call.
ncclAllReduce(data[i].sendBuff, data[i].recvBuff, size,
ncclDouble, ncclSum, comms[i], data[i].stream);
}
-
+
... // Issue work into data[*].stream to consume buffers, etc.
}
```
+## Copyright and License
+
+NCCL is provided under the [BSD licence](LICENSE.txt). All source code and
+accompanying documentation is copyright (c) 2015-2016, NVIDIA CORPORATION. All
+rights reserved.
+
diff --git a/src/all_gather.cu b/src/all_gather.cu
index 0f90efd..a034bb1 100644
--- a/src/all_gather.cu
+++ b/src/all_gather.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -218,7 +218,7 @@ __global__ void AllGatherKernel(const AllGatherKernelArgs<T> args) {
if (!PUSHRECV)
WAIT_FOR_PREV_CHUNK(chunk, s);
-
+
if (PUSHRECV) {
DoubleCopy<UNROLL, THREADS>(
args.ThisOutput + outputOffset,
diff --git a/src/all_reduce_test.cu b/src/all_reduce_test.cu
index f46bd48..a2fcb3d 100644
--- a/src/all_reduce_test.cu
+++ b/src/all_reduce_test.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -50,7 +50,7 @@ void RunTest(T** sendbuff, T** recvbuff, const int N, const ncclDataType_t type,
int nDev = 0;
ncclCommCount(comms[0], &nDev);
cudaStream_t* s = (cudaStream_t*)malloc(sizeof(cudaStream_t)*nDev);
-
+
for (int i = 0; i < nDev; ++i) {
CUDACHECK(cudaSetDevice(dList[i]));
CUDACHECK(cudaStreamCreate(s+i));
diff --git a/src/core.cu b/src/core.cu
index 4e357bb..95482c7 100644
--- a/src/core.cu
+++ b/src/core.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -139,7 +139,7 @@ typedef struct {
RankEntry ranks[1];
} RankGather;
-static ncclResult_t initGather(RankGather** gather, ncclUniqueId commId,
+static ncclResult_t initGather(RankGather** gather, ncclUniqueId commId,
int ndev, int rank, RankEntry myInfo) {
size_t bytes = offsetof(RankGather, ranks) + ndev*sizeof(RankEntry);
RankGather* tmp = NULL;
@@ -164,7 +164,7 @@ static ncclResult_t initGather(RankGather** gather, ncclUniqueId commId,
shmUnmap(tmp, bytes);
return res;
}
-
+
orderRanks(tmp->ranks, ndev);
}
swapped = __sync_bool_compare_and_swap(&tmp->bar, bar_tmp, bar_tmp+1);
@@ -264,7 +264,7 @@ static ncclResult_t populateRankInfo(RankEntry* info, int rank, ncclComm_t comm)
return ncclUnhandledCudaError;
}
INFO("rank %d using device %d (%s)", rank, comm->cudaDev, busId);
-
+
if (wrapNvmlDeviceGetHandleByPciBusId(busId, &nvmlHandle) != ncclSuccess) {
WARN("rank %d failed to get nvml handle for device %s", rank, busId);
return ncclUnhandledCudaError;
@@ -306,7 +306,7 @@ static ncclResult_t commClearMaps(ncclComm_t comm) {
case CLEANUP_CUIPC:
res = wrapCuIpcCloseMemHandle((CUdeviceptr)comm->cleanup[d].handle);
if (res != ncclSuccess) {
- WARN("rank %d failed to close IPC handle to rank %d",
+ WARN("rank %d failed to close IPC handle to rank %d",
comm->userFromRing[comm->ncclId], comm->userFromRing[d]);
retval = (retval == ncclSuccess) ? res : retval;
}
@@ -382,7 +382,7 @@ static ncclResult_t commBuildMaps(ncclComm_t comm, ncclUniqueId* commId, int ran
return ncclInvalidRank;
}
comm->ncclId = myId;
-
+
int myDev = ranks[myId].cudaDev;
pid_t myPid = ranks[myId].pid;
comm->useRemoteRecv = 1; // Assume we directly write to result ptrs.
@@ -407,7 +407,7 @@ static ncclResult_t commBuildMaps(ncclComm_t comm, ncclUniqueId* commId, int ran
} else if (err != cudaSuccess) {
INFO("peer access failed between rank %d (dev %d) and rank %d (dev %d)\n",
rank, myDev, iRank, iDev);
-
+
canpeer = 0;
}
}
@@ -609,7 +609,7 @@ static ncclResult_t commUnlinkHostMem(ncclComm_t comm, ncclUniqueId commId, int
extern "C" DSOGLOBAL
ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank) {
- if (strlen(commId.internal) < 1 ||
+ if (strlen(commId.internal) < 1 ||
strlen(commId.internal) >= NCCL_UNIQUE_ID_BYTES) {
WARN("rank %d invalid commId", myrank);
return ncclInvalidArgument;
@@ -675,7 +675,7 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId
if (commUnlinkHostMem(*newcomm, commId, myrank) != ncclSuccess)
INFO("rank %d failed to unlink host mem shm segment", myrank);
}
-
+
if (wrapNvmlShutdown() != ncclSuccess)
INFO("rank %d did not shutdown nvml properly", myrank);
return res;
@@ -739,8 +739,8 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, int* devlist) {
INFO("rank %d failed to set affinity", rank);
goto skipaffinity;
}
- affinity_set = 1;
- skipaffinity:
+ affinity_set = 1;
+ skipaffinity:
res = commAlloc(&comm, ndev, NULL, rank);
if (res != ncclSuccess) {
diff --git a/src/libwrap.cu b/src/libwrap.cu
index 67a699b..1b3eb75 100644
--- a/src/libwrap.cu
+++ b/src/libwrap.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -51,7 +51,7 @@ ncclResult_t wrapSymbols(void) {
if (symbolsLoaded)
return ncclSuccess;
-
+
static void* nvmlhandle = NULL;
static void* cuhandle = NULL;
void* tmp;
@@ -91,7 +91,7 @@ ncclResult_t wrapSymbols(void) {
LOAD_SYM(cuhandle, "cuIpcGetMemHandle", cuInternalIpcGetMemHandle);
LOAD_SYM(cuhandle, "cuIpcOpenMemHandle", cuInternalIpcOpenMemHandle);
LOAD_SYM(cuhandle, "cuIpcCloseMemHandle", cuInternalIpcCloseMemHandle);
-
+
symbolsLoaded = 1;
return ncclSuccess;
@@ -102,7 +102,7 @@ ncclResult_t wrapSymbols(void) {
nvmlInternalDeviceGetIndex = NULL;
nvmlInternalDeviceSetCpuAffinity = NULL;
nvmlInternalDeviceClearCpuAffinity = NULL;
-
+
cuInternalGetErrorString = NULL;
cuInternalIpcGetMemHandle = NULL;
cuInternalIpcOpenMemHandle = NULL;
diff --git a/src/mpi_test.cu b/src/mpi_test.cu
index 5768a20..600228c 100644
--- a/src/mpi_test.cu
+++ b/src/mpi_test.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -72,7 +72,7 @@ int main(int argc, char *argv[]) {
CUDACHECK(cudaMemcpy(dptr, &val, sizeof(int), cudaMemcpyHostToDevice));
ncclAllReduce((const void*)dptr, (void*)(dptr+1024), 1024, ncclInt, ncclSum, comm, cudaStreamDefault);
-
+
CUDACHECK(cudaMemcpy(&val, (dptr+1024), sizeof(int), cudaMemcpyDeviceToHost));
printf("Sum is %d\n", val);
CUDACHECK(cudaFree(dptr));
diff --git a/src/nccl.h b/src/nccl.h
index 5173b13..a0a71fc 100644
--- a/src/nccl.h
+++ b/src/nccl.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -124,7 +124,7 @@ typedef enum { ncclChar = 0,
/* Reduces data arrays of length count in sendbuff into recvbuf using op operation.
* recvbuf may be NULL on all calls except for root device.
* On the root device, sendbuff and recvbuff are assumed to reside on
- * the same device.
+ * the same device.
* Must be called separately for each communicator in communicator clique.
*/
ncclResult_t ncclReduce(const void* sendbuff, void* recvbuf, int count, ncclDataType_t datatype,
@@ -137,11 +137,11 @@ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuf, int count, ncclData
ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, int count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream);
-/* Reduces data in sendbuff using op operation and leaves reduced result scattered
- * over the devices so that recvbuff on the i-th GPU will contain the i-th block of
- * the result. Sendbuff and recvbuff are assumed to reside on same device. Assumes
- * sendbuff has size at least ndev*recvcount elements, where ndev is number of
- * communicators in communicator clique
+/* Reduces data in sendbuff using op operation and leaves reduced result scattered
+ * over the devices so that recvbuff on the i-th GPU will contain the i-th block of
+ * the result. Sendbuff and recvbuff are assumed to reside on same device. Assumes
+ * sendbuff has size at least ndev*recvcount elements, where ndev is number of
+ * communicators in communicator clique
* Must be called separately for each communicator in communicator clique.*/
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff,
int recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
diff --git a/src/reduce.cu b/src/reduce.cu
index 6752d24..37acb33 100644
--- a/src/reduce.cu
+++ b/src/reduce.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -271,7 +271,7 @@ __global__ void ReduceKernel(const ReduceKernelArgs<T> args) {
}
template<class FUNC, typename T>
-ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
+ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff,
const int count, const int root, ncclComm* comm, cudaStream_t stream) {
if (count == 0)
return ncclSuccess;
diff --git a/src/reduce_scatter.cu b/src/reduce_scatter.cu
index e1860c5..039e95f 100644
--- a/src/reduce_scatter.cu
+++ b/src/reduce_scatter.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -164,6 +164,9 @@ struct ReduceScatterKernelArgs {
int BufferSliceStride;
int BufferMisalignedN;
+ T ** ThisPtrToNextOutput;
+ T ** PrevPtrToThisOutput;
+
// local and remote input, output, and buffer
const T * __restrict__ ThisInput;
volatile T * __restrict__ ThisOutput;
@@ -187,6 +190,20 @@ __global__ void ReduceScatterKernel(const ReduceScatterKernelArgs<T> args) {
if (args.N == 0) return;
int tid = threadIdx.x;
+ // First wait for args.PrevPtrToThisOutput to become nullptr to ensure that
+ // the previous GPU is done with a previous collective operation.
+ if (tid == 0) {
+ Wait([=] {
+ return *((T * volatile *)args.PrevPtrToThisOutput) == nullptr; // Wait for previous processor to be done
+ });
+
+ *((T * volatile *)args.PrevPtrToThisOutput) = (T*)args.ThisOutput; // Tell Previous I'm starting
+ Wait([=] {
+ return *((T * volatile *)args.ThisPtrToNextOutput) != nullptr; // Wait till I've been told next started
+ });
+ }
+ __syncthreads();
+
for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
// calculate slice size. for all chunks except (possibly) the last one,
// this will just be args.SliceSize. For the last one, it may be smaller
@@ -311,6 +328,7 @@ __global__ void ReduceScatterKernel(const ReduceScatterKernelArgs<T> args) {
if (tid == 0) {
args.ThisNewDataAvailableFlag[tid] = 0;
args.ThisChunkDoneFlag[tid] = 0;
+ *args.ThisPtrToNextOutput = nullptr;
}
}
}
@@ -410,7 +428,8 @@ ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff,
args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
}
-// printf("sliceSize = %i, chunkSize = %i, numChunks = %i, sliceStride = %i, misalignedN = %i\n", args.SliceSize, args.ChunkSize, args.NumChunks, args.BufferSliceStride, args.BufferMisalignedN);
+ args.ThisPtrToNextOutput = (T**)&(comm->local[nextId]->recvPtrs[0]);
+ args.PrevPtrToThisOutput = (T**)&(comm->remote[prevId]->recvPtrs[0]);
args.ThisInput = (const T*)sendbuff;
args.ThisOutput = (volatile T*)recvbuff;
@@ -426,7 +445,7 @@ ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff,
args.PrevChunkDoneFlag = comm->remote[prevId]->flags + 1;
ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T>
- <<<1, NUM_THREADS + NUM_SUBCHUNKS * WARP_SIZE, 0, stream>>>(args);
+ <<<1, NUM_THREADS + 1, 0, stream>>>(args);
return ncclSuccess;
}
diff --git a/src/reduce_test.cu b/src/reduce_test.cu
index ce17e32..42b1e9b 100644
--- a/src/reduce_test.cu
+++ b/src/reduce_test.cu
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -52,7 +52,7 @@ void RunTest(T** sendbuff, T** recvbuff, const int N, const ncclDataType_t type,
int nDev = 0;
ncclCommCount(comms[0], &nDev);
cudaStream_t* s = (cudaStream_t*)malloc(sizeof(cudaStream_t)*nDev);
-
+
for (int i = 0; i < nDev; ++i) {
CUDACHECK(cudaSetDevice(dList[i]));
CUDACHECK(cudaStreamCreate(s+i));
@@ -68,7 +68,7 @@ void RunTest(T** sendbuff, T** recvbuff, const int N, const ncclDataType_t type,
// warm up GPU
for (int i = 0; i < nDev; ++i) {
CUDACHECK(cudaSetDevice(dList[i]));
- ncclReduce((const void*)sendbuff[i], (void*)recvbuff[i], std::min(N, 1024 * 1024),
+ ncclReduce((const void*)sendbuff[i], (void*)recvbuff[i], std::min(N, 1024 * 1024),
type, op, root, comms[i], s[i]);
}
@@ -270,7 +270,7 @@ int main(int argc, char* argv[]) {
printf("\n");
printf("# %10s %12s %6s %6s %4s out-of-place in-place\n", "", "", "", "", "");
- printf("# %10s %12s %6s %6s %4s %7s %5s %5s %7s %7s %5s %5s %7s\n",
+ printf("# %10s %12s %6s %6s %4s %7s %5s %5s %7s %7s %5s %5s %7s\n",
"bytes", "N", "type", "op", "root",
"time", "algbw", "busbw", "res", "time", "algbw", "busbw", "res");
}
diff --git a/src/test_utilities.h b/src/test_utilities.h
index a5d3661..c929a9e 100644
--- a/src/test_utilities.h
+++ b/src/test_utilities.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
@@ -250,7 +250,7 @@ void deltaKern(const T* A, const T* B, int N, double* max) {
int tid = threadIdx.x;
double locmax = 0.0;
for(int i=tid; i<N; i+=blockDim.x) {
-
+
double delta = absDiff(A[i], B[i]);
if( delta > locmax )
locmax = delta;