diff options
author | Nathan Luehr <nluehr@nvidia.com> | 2016-01-21 04:58:25 +0300 |
---|---|---|
committer | Przemek Tredak <ptredak@nvidia.com> | 2016-01-21 21:36:03 +0300 |
commit | 130ee246e21d3f73c977eda496ac9c90c3aa520b (patch) | |
tree | 8111722778177ea8c5b686f5743cffd968d21dd1 | |
parent | 90af7c73efc51ca47bd669bb1cddd6814aaf64bf (diff) |
Fixed deadlock in back-to-back reduce_scatters.
Change-Id: I92d32b15e516a39710b676aee692ae9b70638937
Reviewed-on: http://git-master/r/935458
Reviewed-by: Przemek Tredak <ptredak@nvidia.com>
Tested-by: Przemek Tredak <ptredak@nvidia.com>
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | README.md | 16 | ||||
-rw-r--r-- | src/all_gather.cu | 4 | ||||
-rw-r--r-- | src/all_reduce_test.cu | 4 | ||||
-rw-r--r-- | src/core.cu | 22 | ||||
-rw-r--r-- | src/libwrap.cu | 8 | ||||
-rw-r--r-- | src/mpi_test.cu | 4 | ||||
-rw-r--r-- | src/nccl.h | 14 | ||||
-rw-r--r-- | src/reduce.cu | 4 | ||||
-rw-r--r-- | src/reduce_scatter.cu | 25 | ||||
-rw-r--r-- | src/reduce_test.cu | 8 | ||||
-rw-r--r-- | src/test_utilities.h | 4 |
12 files changed, 70 insertions, 44 deletions
@@ -1 +1,2 @@ +# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. /build @@ -88,7 +88,7 @@ typedef struct { int size; cudaStream_t stream; } PerThreadData; - + int main(int argc, char* argv[]) { int nGPUs; @@ -96,20 +96,26 @@ int main(int argc, char* argv[]) ncclComm_t* comms = (ncclComm_t*)malloc(sizeof(ncclComm_t)*nGPUs); ncclCommInitAll(comms, nGPUs); // initialize communicator // One communicator per process - + PerThreadData* data; - + ... // Allocate data and issue work to each GPU's // perDevStream to populate the sendBuffs. - + for(int i=0; i<nGPUs; ++i) { cudaSetDevice(i); // Correct device must be set // prior to each collective call. ncclAllReduce(data[i].sendBuff, data[i].recvBuff, size, ncclDouble, ncclSum, comms[i], data[i].stream); } - + ... // Issue work into data[*].stream to consume buffers, etc. } ``` +## Copyright and License + +NCCL is provided under the [BSD licence](LICENSE.txt). All source code and +accompanying documentation is copyright (c) 2015-2016, NVIDIA CORPORATION. All +rights reserved. + diff --git a/src/all_gather.cu b/src/all_gather.cu index 0f90efd..a034bb1 100644 --- a/src/all_gather.cu +++ b/src/all_gather.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -218,7 +218,7 @@ __global__ void AllGatherKernel(const AllGatherKernelArgs<T> args) { if (!PUSHRECV) WAIT_FOR_PREV_CHUNK(chunk, s); - + if (PUSHRECV) { DoubleCopy<UNROLL, THREADS>( args.ThisOutput + outputOffset, diff --git a/src/all_reduce_test.cu b/src/all_reduce_test.cu index f46bd48..a2fcb3d 100644 --- a/src/all_reduce_test.cu +++ b/src/all_reduce_test.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -50,7 +50,7 @@ void RunTest(T** sendbuff, T** recvbuff, const int N, const ncclDataType_t type, int nDev = 0; ncclCommCount(comms[0], &nDev); cudaStream_t* s = (cudaStream_t*)malloc(sizeof(cudaStream_t)*nDev); - + for (int i = 0; i < nDev; ++i) { CUDACHECK(cudaSetDevice(dList[i])); CUDACHECK(cudaStreamCreate(s+i)); diff --git a/src/core.cu b/src/core.cu index 4e357bb..95482c7 100644 --- a/src/core.cu +++ b/src/core.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -139,7 +139,7 @@ typedef struct { RankEntry ranks[1]; } RankGather; -static ncclResult_t initGather(RankGather** gather, ncclUniqueId commId, +static ncclResult_t initGather(RankGather** gather, ncclUniqueId commId, int ndev, int rank, RankEntry myInfo) { size_t bytes = offsetof(RankGather, ranks) + ndev*sizeof(RankEntry); RankGather* tmp = NULL; @@ -164,7 +164,7 @@ static ncclResult_t initGather(RankGather** gather, ncclUniqueId commId, shmUnmap(tmp, bytes); return res; } - + orderRanks(tmp->ranks, ndev); } swapped = __sync_bool_compare_and_swap(&tmp->bar, bar_tmp, bar_tmp+1); @@ -264,7 +264,7 @@ static ncclResult_t populateRankInfo(RankEntry* info, int rank, ncclComm_t comm) return ncclUnhandledCudaError; } INFO("rank %d using device %d (%s)", rank, comm->cudaDev, busId); - + if (wrapNvmlDeviceGetHandleByPciBusId(busId, &nvmlHandle) != ncclSuccess) { WARN("rank %d failed to get nvml handle for device %s", rank, busId); return ncclUnhandledCudaError; @@ -306,7 +306,7 @@ static ncclResult_t commClearMaps(ncclComm_t comm) { case CLEANUP_CUIPC: res = wrapCuIpcCloseMemHandle((CUdeviceptr)comm->cleanup[d].handle); if (res != ncclSuccess) { - WARN("rank %d failed to close IPC handle to rank %d", + WARN("rank %d failed to close IPC handle to rank %d", comm->userFromRing[comm->ncclId], comm->userFromRing[d]); retval = (retval == ncclSuccess) ? res : retval; } @@ -382,7 +382,7 @@ static ncclResult_t commBuildMaps(ncclComm_t comm, ncclUniqueId* commId, int ran return ncclInvalidRank; } comm->ncclId = myId; - + int myDev = ranks[myId].cudaDev; pid_t myPid = ranks[myId].pid; comm->useRemoteRecv = 1; // Assume we directly write to result ptrs. @@ -407,7 +407,7 @@ static ncclResult_t commBuildMaps(ncclComm_t comm, ncclUniqueId* commId, int ran } else if (err != cudaSuccess) { INFO("peer access failed between rank %d (dev %d) and rank %d (dev %d)\n", rank, myDev, iRank, iDev); - + canpeer = 0; } } @@ -609,7 +609,7 @@ static ncclResult_t commUnlinkHostMem(ncclComm_t comm, ncclUniqueId commId, int extern "C" DSOGLOBAL ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId, int myrank) { - if (strlen(commId.internal) < 1 || + if (strlen(commId.internal) < 1 || strlen(commId.internal) >= NCCL_UNIQUE_ID_BYTES) { WARN("rank %d invalid commId", myrank); return ncclInvalidArgument; @@ -675,7 +675,7 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int ndev, ncclUniqueId commId if (commUnlinkHostMem(*newcomm, commId, myrank) != ncclSuccess) INFO("rank %d failed to unlink host mem shm segment", myrank); } - + if (wrapNvmlShutdown() != ncclSuccess) INFO("rank %d did not shutdown nvml properly", myrank); return res; @@ -739,8 +739,8 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, int* devlist) { INFO("rank %d failed to set affinity", rank); goto skipaffinity; } - affinity_set = 1; - skipaffinity: + affinity_set = 1; + skipaffinity: res = commAlloc(&comm, ndev, NULL, rank); if (res != ncclSuccess) { diff --git a/src/libwrap.cu b/src/libwrap.cu index 67a699b..1b3eb75 100644 --- a/src/libwrap.cu +++ b/src/libwrap.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -51,7 +51,7 @@ ncclResult_t wrapSymbols(void) { if (symbolsLoaded) return ncclSuccess; - + static void* nvmlhandle = NULL; static void* cuhandle = NULL; void* tmp; @@ -91,7 +91,7 @@ ncclResult_t wrapSymbols(void) { LOAD_SYM(cuhandle, "cuIpcGetMemHandle", cuInternalIpcGetMemHandle); LOAD_SYM(cuhandle, "cuIpcOpenMemHandle", cuInternalIpcOpenMemHandle); LOAD_SYM(cuhandle, "cuIpcCloseMemHandle", cuInternalIpcCloseMemHandle); - + symbolsLoaded = 1; return ncclSuccess; @@ -102,7 +102,7 @@ ncclResult_t wrapSymbols(void) { nvmlInternalDeviceGetIndex = NULL; nvmlInternalDeviceSetCpuAffinity = NULL; nvmlInternalDeviceClearCpuAffinity = NULL; - + cuInternalGetErrorString = NULL; cuInternalIpcGetMemHandle = NULL; cuInternalIpcOpenMemHandle = NULL; diff --git a/src/mpi_test.cu b/src/mpi_test.cu index 5768a20..600228c 100644 --- a/src/mpi_test.cu +++ b/src/mpi_test.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -72,7 +72,7 @@ int main(int argc, char *argv[]) { CUDACHECK(cudaMemcpy(dptr, &val, sizeof(int), cudaMemcpyHostToDevice)); ncclAllReduce((const void*)dptr, (void*)(dptr+1024), 1024, ncclInt, ncclSum, comm, cudaStreamDefault); - + CUDACHECK(cudaMemcpy(&val, (dptr+1024), sizeof(int), cudaMemcpyDeviceToHost)); printf("Sum is %d\n", val); CUDACHECK(cudaFree(dptr)); @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -124,7 +124,7 @@ typedef enum { ncclChar = 0, /* Reduces data arrays of length count in sendbuff into recvbuf using op operation. * recvbuf may be NULL on all calls except for root device. * On the root device, sendbuff and recvbuff are assumed to reside on - * the same device. + * the same device. * Must be called separately for each communicator in communicator clique. */ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuf, int count, ncclDataType_t datatype, @@ -137,11 +137,11 @@ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuf, int count, ncclData ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, int count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream); -/* Reduces data in sendbuff using op operation and leaves reduced result scattered - * over the devices so that recvbuff on the i-th GPU will contain the i-th block of - * the result. Sendbuff and recvbuff are assumed to reside on same device. Assumes - * sendbuff has size at least ndev*recvcount elements, where ndev is number of - * communicators in communicator clique +/* Reduces data in sendbuff using op operation and leaves reduced result scattered + * over the devices so that recvbuff on the i-th GPU will contain the i-th block of + * the result. Sendbuff and recvbuff are assumed to reside on same device. Assumes + * sendbuff has size at least ndev*recvcount elements, where ndev is number of + * communicators in communicator clique * Must be called separately for each communicator in communicator clique.*/ ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, int recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, diff --git a/src/reduce.cu b/src/reduce.cu index 6752d24..37acb33 100644 --- a/src/reduce.cu +++ b/src/reduce.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -271,7 +271,7 @@ __global__ void ReduceKernel(const ReduceKernelArgs<T> args) { } template<class FUNC, typename T> -ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff, +ncclResult_t ncclReduceWithTypeAndFunc(const void* sendbuff, void* recvbuff, const int count, const int root, ncclComm* comm, cudaStream_t stream) { if (count == 0) return ncclSuccess; diff --git a/src/reduce_scatter.cu b/src/reduce_scatter.cu index e1860c5..039e95f 100644 --- a/src/reduce_scatter.cu +++ b/src/reduce_scatter.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -164,6 +164,9 @@ struct ReduceScatterKernelArgs { int BufferSliceStride; int BufferMisalignedN; + T ** ThisPtrToNextOutput; + T ** PrevPtrToThisOutput; + // local and remote input, output, and buffer const T * __restrict__ ThisInput; volatile T * __restrict__ ThisOutput; @@ -187,6 +190,20 @@ __global__ void ReduceScatterKernel(const ReduceScatterKernelArgs<T> args) { if (args.N == 0) return; int tid = threadIdx.x; + // First wait for args.PrevPtrToThisOutput to become nullptr to ensure that + // the previous GPU is done with a previous collective operation. + if (tid == 0) { + Wait([=] { + return *((T * volatile *)args.PrevPtrToThisOutput) == nullptr; // Wait for previous processor to be done + }); + + *((T * volatile *)args.PrevPtrToThisOutput) = (T*)args.ThisOutput; // Tell Previous I'm starting + Wait([=] { + return *((T * volatile *)args.ThisPtrToNextOutput) != nullptr; // Wait till I've been told next started + }); + } + __syncthreads(); + for (int chunk = 0; chunk < args.NumChunks; ++chunk) { // calculate slice size. for all chunks except (possibly) the last one, // this will just be args.SliceSize. For the last one, it may be smaller @@ -311,6 +328,7 @@ __global__ void ReduceScatterKernel(const ReduceScatterKernelArgs<T> args) { if (tid == 0) { args.ThisNewDataAvailableFlag[tid] = 0; args.ThisChunkDoneFlag[tid] = 0; + *args.ThisPtrToNextOutput = nullptr; } } } @@ -410,7 +428,8 @@ ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff, args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize; } -// printf("sliceSize = %i, chunkSize = %i, numChunks = %i, sliceStride = %i, misalignedN = %i\n", args.SliceSize, args.ChunkSize, args.NumChunks, args.BufferSliceStride, args.BufferMisalignedN); + args.ThisPtrToNextOutput = (T**)&(comm->local[nextId]->recvPtrs[0]); + args.PrevPtrToThisOutput = (T**)&(comm->remote[prevId]->recvPtrs[0]); args.ThisInput = (const T*)sendbuff; args.ThisOutput = (volatile T*)recvbuff; @@ -426,7 +445,7 @@ ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff, args.PrevChunkDoneFlag = comm->remote[prevId]->flags + 1; ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T> - <<<1, NUM_THREADS + NUM_SUBCHUNKS * WARP_SIZE, 0, stream>>>(args); + <<<1, NUM_THREADS + 1, 0, stream>>>(args); return ncclSuccess; } diff --git a/src/reduce_test.cu b/src/reduce_test.cu index ce17e32..42b1e9b 100644 --- a/src/reduce_test.cu +++ b/src/reduce_test.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -52,7 +52,7 @@ void RunTest(T** sendbuff, T** recvbuff, const int N, const ncclDataType_t type, int nDev = 0; ncclCommCount(comms[0], &nDev); cudaStream_t* s = (cudaStream_t*)malloc(sizeof(cudaStream_t)*nDev); - + for (int i = 0; i < nDev; ++i) { CUDACHECK(cudaSetDevice(dList[i])); CUDACHECK(cudaStreamCreate(s+i)); @@ -68,7 +68,7 @@ void RunTest(T** sendbuff, T** recvbuff, const int N, const ncclDataType_t type, // warm up GPU for (int i = 0; i < nDev; ++i) { CUDACHECK(cudaSetDevice(dList[i])); - ncclReduce((const void*)sendbuff[i], (void*)recvbuff[i], std::min(N, 1024 * 1024), + ncclReduce((const void*)sendbuff[i], (void*)recvbuff[i], std::min(N, 1024 * 1024), type, op, root, comms[i], s[i]); } @@ -270,7 +270,7 @@ int main(int argc, char* argv[]) { printf("\n"); printf("# %10s %12s %6s %6s %4s out-of-place in-place\n", "", "", "", "", ""); - printf("# %10s %12s %6s %6s %4s %7s %5s %5s %7s %7s %5s %5s %7s\n", + printf("# %10s %12s %6s %6s %4s %7s %5s %5s %7s %7s %5s %5s %7s\n", "bytes", "N", "type", "op", "root", "time", "algbw", "busbw", "res", "time", "algbw", "busbw", "res"); } diff --git a/src/test_utilities.h b/src/test_utilities.h index a5d3661..c929a9e 100644 --- a/src/test_utilities.h +++ b/src/test_utilities.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -250,7 +250,7 @@ void deltaKern(const T* A, const T* B, int N, double* max) { int tid = threadIdx.x; double locmax = 0.0; for(int i=tid; i<N; i+=blockDim.x) { - + double delta = absDiff(A[i], B[i]); if( delta > locmax ) locmax = delta; |