Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/net.h')
-rw-r--r--src/include/net.h53
1 files changed, 30 insertions, 23 deletions
diff --git a/src/include/net.h b/src/include/net.h
index 3d37c8c..bc81965 100644
--- a/src/include/net.h
+++ b/src/include/net.h
@@ -16,7 +16,7 @@ typedef char ncclNetHandle_t[NCCL_NET_HANDLE_MAXSIZE];
// Translation to external API
static const char* ncclNetName() { return ncclNet->name; }
static ncclResult_t ncclNetDevices(int* ndev) { NCCLCHECK(ncclNet->devices(ndev)); return ncclSuccess; }
-static ncclResult_t ncclNetPciPath(int dev, char** path) { NCCLCHECK(ncclNet->pciPath(dev, path)); return ncclSuccess; }
+static ncclResult_t ncclNetGetProperties(int dev, ncclNetProperties_t* props) { NCCLCHECK(ncclNet->getProperties(dev, props)); return ncclSuccess; }
static ncclResult_t ncclNetListen(int dev, void* handle, void** listenComm) { NCCLCHECK(ncclNet->listen(dev, handle, listenComm)); return ncclSuccess; }
static ncclResult_t ncclNetConnect(int dev, void* handle, void** sendComm) { NCCLCHECK(ncclNet->connect(dev, handle, sendComm)); return ncclSuccess; }
static ncclResult_t ncclNetAccept(void* listenComm, void** recvComm) { NCCLCHECK(ncclNet->accept(listenComm, recvComm)); return ncclSuccess; }
@@ -30,33 +30,40 @@ static ncclResult_t ncclNetCloseSend(void* sendComm) { NCCLCHECK(ncclNet->closeS
static ncclResult_t ncclNetCloseRecv(void* recvComm) { NCCLCHECK(ncclNet->closeRecv(recvComm)); return ncclSuccess; }
static ncclResult_t ncclNetCloseListen(void* listenComm) { NCCLCHECK(ncclNet->closeListen(listenComm)); return ncclSuccess; }
+// Test whether the current GPU support GPU Direct RDMA.
#define GPU_BUF_SIZE (2*1024*1024)
-static ncclResult_t ncclNetPtrSupport(int dev, int* supportedTypes) {
- int support;
- NCCLCHECK(ncclNet->ptrSupport(dev, &support));
- *supportedTypes = support & ~NCCL_PTR_CUDA;
- // The network supports GPU Direct RDMA ; verify the GPU supports it as well.
- if (support & NCCL_PTR_CUDA) {
+static ncclResult_t ncclGpuGdrSupport(int* gdrSupport) {
+ int netDevs;
+ NCCLCHECK(ncclNetDevices(&netDevs));
+ *gdrSupport = 0;
+ for (int dev=0; dev<netDevs; dev++) {
+ // Find a net device which is GDR-capable
+ ncclNetProperties_t props;
+ NCCLCHECK(ncclNet->getProperties(dev, &props));
+ if ((props.ptrSupport & NCCL_PTR_CUDA) == 0) continue;
+
+ // Allocate memory on the GPU and try to register it on the NIC.
void *lComm = NULL, *sComm = NULL, *rComm = NULL;
ncclNetHandle_t handle;
void* gpuPtr = NULL;
void* mHandle = NULL;
- ncclResult_t res;
- NCCLCHECKGOTO(ncclNetListen(dev, &handle, &lComm), res, cleanup);
- NCCLCHECKGOTO(ncclNetConnect(dev, &handle, &sComm), res, cleanup);
- NCCLCHECKGOTO(ncclNetAccept(lComm, &rComm), res, cleanup);
- CUDACHECKGOTO(cudaMalloc(&gpuPtr, GPU_BUF_SIZE), res, cleanup);
- NOWARN(ncclNetRegMr(sComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle), res);
- if (res != ncclSuccess) goto cleanup;
- NCCLCHECKGOTO(ncclNetDeregMr(sComm, mHandle), res, cleanup);
- NCCLCHECKGOTO(ncclNetRegMr(rComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle), res, cleanup);
- NCCLCHECKGOTO(ncclNetDeregMr(rComm, mHandle), res, cleanup);
- *supportedTypes |= NCCL_PTR_CUDA;
-cleanup:
- if (gpuPtr) cudaFree(gpuPtr);
- if (rComm) ncclNetCloseRecv(rComm);
- if (sComm) ncclNetCloseSend(sComm);
- if (lComm) ncclNetCloseListen(lComm);
+ NCCLCHECK(ncclNetListen(dev, &handle, &lComm));
+ NCCLCHECK(ncclNetConnect(dev, &handle, &sComm));
+ NCCLCHECK(ncclNetAccept(lComm, &rComm));
+ CUDACHECK(cudaMalloc(&gpuPtr, GPU_BUF_SIZE));
+ ncclDebugNoWarn = NCCL_NET;
+ if (ncclNetRegMr(sComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle) == ncclSuccess) {
+ NCCLCHECK(ncclNetDeregMr(sComm, mHandle));
+ NCCLCHECK(ncclNetRegMr(rComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle));
+ NCCLCHECK(ncclNetDeregMr(rComm, mHandle));
+ *gdrSupport = 1;
+ }
+ ncclDebugNoWarn = 0;
+ CUDACHECK(cudaFree(gpuPtr));
+ NCCLCHECK(ncclNetCloseRecv(rComm));
+ NCCLCHECK(ncclNetCloseSend(sComm));
+ NCCLCHECK(ncclNetCloseListen(lComm));
+ break;
}
return ncclSuccess;
}