diff options
Diffstat (limited to 'src/graph/xml.cc')
-rw-r--r-- | src/graph/xml.cc | 44 |
1 files changed, 33 insertions, 11 deletions
diff --git a/src/graph/xml.cc b/src/graph/xml.cc index cc91b92..b2232c2 100644 --- a/src/graph/xml.cc +++ b/src/graph/xml.cc @@ -559,7 +559,6 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm NCCLCHECK(xmlGetAttrIndex(gpuNode, "dev", &index)); if (index == -1) { if (nvmlDev == NULL) { - WARN("No NVML, trying to use CUDA instead"); const char* busId; NCCLCHECK(xmlGetAttr(pciNode, "busid", &busId)); if (busId == NULL || cudaDeviceGetByPCIBusId(&dev, busId) != cudaSuccess) dev = -1; @@ -647,6 +646,7 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm char* path; NCCLCHECK(getPciPath(busId, &path)); NCCLCHECK(ncclTopoSetAttrFromSys(sub, path, "class", "tclass")); + free(path); } } } @@ -658,10 +658,14 @@ ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct nccl struct ncclXmlNode* node; NCCLCHECK(ncclTopoGetPciNode(xml, busId, &node)); NCCLCHECK(ncclTopoGetXmlFromSys(node, xml)); - NCCLCHECK(wrapNvmlSymbols()); - NCCLCHECK(wrapNvmlInit()); - nvmlDevice_t nvmlDev; - if (wrapNvmlDeviceGetHandleByPciBusId(busId, &nvmlDev) != ncclSuccess) nvmlDev = NULL; + nvmlDevice_t nvmlDev = NULL; + static int nvmlInit = 0; + if (nvmlInit == 0) { + nvmlInit = (wrapNvmlSymbols() != ncclSuccess || wrapNvmlInit() != ncclSuccess) ? 2 : 1; + } + if (nvmlInit == 1) { + if (wrapNvmlDeviceGetHandleByPciBusId(busId, &nvmlDev) != ncclSuccess) nvmlDev = NULL; + } NCCLCHECK(ncclTopoGetXmlFromGpu(node, nvmlDev, xml, gpuNode)); return ncclSuccess; } @@ -704,12 +708,8 @@ ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const cha for (offset=strlen(pciSysPath)-1; pciSysPath[offset] != '/'; offset--); char busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; strcpy(busId, pciSysPath+offset+1); - NCCLCHECK(xmlFindTagKv(xml, "pci", &parent, "busid", busId)); - if (parent == NULL) { - NCCLCHECK(xmlAddNode(xml, NULL, "pci", &parent)); - NCCLCHECK(xmlSetAttr(parent, "busid", busId)); - NCCLCHECK(ncclTopoGetXmlFromSys(parent, xml)); - } + NCCLCHECK(ncclTopoGetPciNode(xml, busId, &parent)); + NCCLCHECK(ncclTopoGetXmlFromSys(parent, xml)); } else { // Virtual NIC, no PCI device, attach to first CPU NCCLCHECK(xmlFindTag(xml, "cpu", &parent)); @@ -728,6 +728,28 @@ ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const cha return ncclSuccess; } +ncclResult_t ncclTopoTrimXmlRec(struct ncclXmlNode* node) { + const char* str; + NCCLCHECK(xmlGetAttr(node, "keep", &str)); + if (str && strcmp(str, "1") == 0) { + NCCLCHECK(xmlUnsetAttr(node, "keep")); + } else { + // Copy nSubs and subs as they could change as we trim recursively. + struct ncclXmlNode* subs[MAX_SUBS]; + int nSubs = node->nSubs; + memcpy(subs, node->subs, node->nSubs*sizeof(struct ncclXmlNode*)); + for (int s=0; s<nSubs; s++) { + NCCLCHECK(ncclTopoTrimXmlRec(subs[s])); + } + if (node->nSubs == 0) NCCLCHECK(xmlRemoveNode(node)); + } + return ncclSuccess; +} +ncclResult_t ncclTopoTrimXml(struct ncclXml* xml) { + NCCLCHECK(ncclTopoTrimXmlRec(xml->nodes)); + return ncclSuccess; +} + /**************************************************/ /* Parser rules for the user-defined graph search */ /**************************************************/ |