Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/xml.h')
-rw-r--r--src/graph/xml.h237
1 files changed, 237 insertions, 0 deletions
diff --git a/src/graph/xml.h b/src/graph/xml.h
new file mode 100644
index 0000000..fa04527
--- /dev/null
+++ b/src/graph/xml.h
@@ -0,0 +1,237 @@
+/*************************************************************************
+ * Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * See LICENSE.txt for license information
+ ************************************************************************/
+
+#ifndef XML_H_
+#define XML_H_
+
+// A few constraints to make the implementation easy
+#define MAX_STR_LEN 256
+#define MAX_ATTR_COUNT 16
+#define MAX_SUBS 32
+#define MAX_NODES 1024
+
+#define NODE_TYPE_NONE 0
+#define NODE_TYPE_OPEN 1
+#define NODE_TYPE_CLOSE 2
+#define NODE_TYPE_SINGLE 3
+
+struct ncclXmlNode {
+ char name[MAX_STR_LEN];
+ struct {
+ char key[MAX_STR_LEN];
+ char value[MAX_STR_LEN];
+ } attrs[MAX_ATTR_COUNT+1]; // Need an extra one to consume extra params
+ int nAttrs;
+ int type;
+ struct ncclXmlNode* parent;
+ struct ncclXmlNode* subs[MAX_SUBS];
+ int nSubs;
+};
+
+struct ncclXml {
+ struct ncclXmlNode nodes[MAX_NODES];
+ int maxIndex;
+};
+
+/* File functions */
+#define NCCL_TOPO_XML_VERSION 1
+ncclResult_t ncclTopoGetXmlFromFile(const char* xmlTopoFile, struct ncclXml* xml);
+ncclResult_t ncclTopoDumpXmlToFile(const char* xmlTopoFile, struct ncclXml* xml);
+#define NCCL_GRAPH_XML_VERSION 1
+ncclResult_t ncclTopoGetXmlGraphFromFile(const char* xmlGraphFile, struct ncclXml* xml);
+
+/* Auto-detect functions */
+ncclResult_t ncclTopoFillGpu(struct ncclXml* xml, const char* busId, struct ncclXmlNode** gpuNode);
+ncclResult_t ncclTopoFillNet(struct ncclXml* xml, const char* pciPath, const char* netName, struct ncclXmlNode** netNode);
+
+/**************/
+/* XML Struct */
+/* Functions */
+/**************/
+
+static ncclResult_t xmlGetAttrIndex(struct ncclXmlNode* node, const char* attrName, int* index) {
+ *index = -1;
+ const int nAttrs = node->nAttrs;
+ for (int a=0; a<nAttrs; a++) {
+ if (strncmp(node->attrs[a].key, attrName, MAX_STR_LEN-1) == 0) {
+ *index = a;
+ return ncclSuccess;
+ }
+ }
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlGetAttr(struct ncclXmlNode* node, const char* attrName, const char** value) {
+ int index;
+ NCCLCHECK(xmlGetAttrIndex(node, attrName, &index));
+ *value = index == -1 ? NULL : node->attrs[index].value;
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlGetAttrStr(struct ncclXmlNode* node, const char* attrName, const char** value) {
+ NCCLCHECK(xmlGetAttr(node, attrName, value));
+ if (*value == NULL) {
+ WARN("Attribute %s of node %s not found", attrName, node->name);
+ return ncclInternalError;
+ }
+ return ncclSuccess;
+}
+static ncclResult_t xmlGetAttrInt(struct ncclXmlNode* node, const char* attrName, int* value) {
+ const char* str;
+ NCCLCHECK(xmlGetAttrStr(node, attrName, &str));
+ *value = strtol(str, NULL, 0);
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlGetAttrFloat(struct ncclXmlNode* node, const char* attrName, float* value) {
+ const char* str;
+ NCCLCHECK(xmlGetAttrStr(node, attrName, &str));
+ *value = strtof(str, NULL);
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlFindTag(struct ncclXml* xml, const char* tagName, struct ncclXmlNode** node) {
+ *node = NULL;
+ for (int i=0; i<xml->maxIndex; i++) {
+ struct ncclXmlNode* n = xml->nodes+i;
+ if (strcmp(n->name, tagName) == 0) {
+ *node = n;
+ return ncclSuccess;
+ }
+ }
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlFindTagKv(struct ncclXml* xml, const char* tagName, struct ncclXmlNode** node, const char* attrName, const char* attrValue) {
+ *node = NULL;
+ for (int i=0; i<xml->maxIndex; i++) {
+ struct ncclXmlNode* n = xml->nodes+i;
+ if (strcmp(n->name, tagName) == 0) {
+ const char* value;
+ NCCLCHECK(xmlGetAttr(n, attrName, &value));
+ if (value && strcmp(value, attrValue) == 0) {
+ *node = n;
+ return ncclSuccess;
+ }
+ }
+ }
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlSetAttr(struct ncclXmlNode* node, const char* attrName, const char* value) {
+ int index;
+ NCCLCHECK(xmlGetAttrIndex(node, attrName, &index));
+ if (index == -1) {
+ index = node->nAttrs++;
+ strncpy(node->attrs[index].key, attrName, MAX_STR_LEN);
+ }
+ strncpy(node->attrs[index].value, value, MAX_STR_LEN);
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlSetAttrInt(struct ncclXmlNode* node, const char* attrName, const int value) {
+ int index;
+ NCCLCHECK(xmlGetAttrIndex(node, attrName, &index));
+ if (index == -1) {
+ index = node->nAttrs++;
+ strncpy(node->attrs[index].key, attrName, MAX_STR_LEN);
+ }
+ snprintf(node->attrs[index].value, MAX_STR_LEN, "%d", value);
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlSetAttrFloat(struct ncclXmlNode* node, const char* attrName, const float value) {
+ int index;
+ NCCLCHECK(xmlGetAttrIndex(node, attrName, &index));
+ if (index == -1) {
+ index = node->nAttrs++;
+ strncpy(node->attrs[index].key, attrName, MAX_STR_LEN);
+ }
+ snprintf(node->attrs[index].value, MAX_STR_LEN, "%g", value);
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlGetSub(struct ncclXmlNode* node, const char* subName, struct ncclXmlNode** sub) {
+ *sub = NULL;
+ for (int s=0; s<node->nSubs; s++) {
+ if (strcmp(node->subs[s]->name, subName) == 0) {
+ *sub = node->subs[s];
+ return ncclSuccess;
+ }
+ }
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlGetSubKv(struct ncclXmlNode* node, const char* subName, struct ncclXmlNode** sub, const char* attrName, const char* attrValue) {
+ *sub = NULL;
+ for (int s=0; s<node->nSubs; s++) {
+ struct ncclXmlNode* subNode = node->subs[s];
+ if (strcmp(subNode->name, subName) == 0) {
+ const char* value;
+ NCCLCHECK(xmlGetAttr(subNode, attrName, &value));
+ if (value && strcmp(value, attrValue) == 0) {
+ *sub = node->subs[s];
+ return ncclSuccess;
+ }
+ }
+ }
+ return ncclSuccess;
+}
+static ncclResult_t xmlGetSubKvInt(struct ncclXmlNode* node, const char* subName, struct ncclXmlNode** sub, const char* attrName, const int attrValue) {
+ char strValue[10];
+ snprintf(strValue, 10, "%d", attrValue);
+ NCCLCHECK(xmlGetSubKv(node, subName, sub, attrName, strValue));
+ return ncclSuccess;
+}
+
+static ncclResult_t xmlAddNode(struct ncclXml* xml, struct ncclXmlNode* parent, const char* subName, struct ncclXmlNode** sub) {
+ if (xml->maxIndex == MAX_NODES) {
+ WARN("Error : too many XML nodes (max %d)", MAX_NODES);
+ return ncclInternalError;
+ }
+ struct ncclXmlNode* s = xml->nodes+xml->maxIndex++;
+ s->nSubs = 0;
+ s->nAttrs = 0;
+ *sub = s;
+ s->parent = parent;
+ if (parent) parent->subs[parent->nSubs++] = s;
+ strncpy(s->name, subName, MAX_STR_LEN);
+ return ncclSuccess;
+}
+
+// Dictionary for STR -> INT conversions. No dictionary size information,
+// there needs to be a last element with str == NULL.
+struct kvDict {
+ const char* str;
+ int value;
+};
+
+static ncclResult_t kvConvertToInt(const char* str, int* value, struct kvDict* dict) {
+ struct kvDict* d = dict;
+ while (d->str) {
+ if (strncmp(str, d->str, strlen(d->str)) == 0) {
+ *value = d->value;
+ return ncclSuccess;
+ }
+ d++;
+ }
+ WARN("KV Convert to int : could not find value of '%s' in dictionary", str);
+ return ncclInternalError;
+}
+static ncclResult_t kvConvertToStr(int value, const char** str, struct kvDict* dict) {
+ struct kvDict* d = dict;
+ while (d->str) {
+ if (value == d->value) {
+ *str = d->str;
+ return ncclSuccess;
+ }
+ d++;
+ }
+ WARN("KV Convert to str : could not find value %d in dictionary", value);
+ return ncclInternalError;
+}
+
+#endif