diff options
Diffstat (limited to 'src/misc/argcheck.cc')
-rw-r--r-- | src/misc/argcheck.cc | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/src/misc/argcheck.cc b/src/misc/argcheck.cc new file mode 100644 index 0000000..364f041 --- /dev/null +++ b/src/misc/argcheck.cc @@ -0,0 +1,69 @@ +/************************************************************************* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "argcheck.h" + +static ncclResult_t CudaPtrCheck(const void* pointer, struct ncclComm* comm, const char* ptrname, const char* opname) { + cudaPointerAttributes attr; + cudaError_t err = cudaPointerGetAttributes(&attr, pointer); + if (err != cudaSuccess || attr.devicePointer == NULL) { + WARN("%s : %s is not a valid pointer", opname, ptrname); + return ncclInvalidArgument; + } +#if CUDART_VERSION >= 10000 + if (attr.type == cudaMemoryTypeDevice && attr.device != comm->cudaDev) { +#else + if (attr.memoryType == cudaMemoryTypeDevice && attr.device != comm->cudaDev) { +#endif + WARN("%s : %s allocated on device %d mismatchs with NCCL device %d", opname, ptrname, attr.device, comm->cudaDev); + return ncclInvalidArgument; + } + return ncclSuccess; +} + +ncclResult_t PtrCheck(void* ptr, const char* opname, const char* ptrname) { + if (ptr == NULL) { + WARN("%s : %s argument is NULL", opname, ptrname); + return ncclInvalidArgument; + } + return ncclSuccess; +} + +ncclResult_t ArgsCheck(struct ncclInfo* info) { + NCCLCHECK(PtrCheck(info->comm, info->opName, "comm")); + // First, the easy ones + if (info->root < 0 || info->root >= info->comm->nRanks) { + WARN("%s : invalid root %d (root should be in the 0..%d range)", info->opName, info->root, info->comm->nRanks); + return ncclInvalidArgument; + } + if (info->datatype < 0 || info->datatype >= ncclNumTypes) { + WARN("%s : invalid type %d", info->opName, info->datatype); + return ncclInvalidArgument; + } + // Type is OK, compute nbytes. Convert Allgather/Broadcast calls to chars. + info->nBytes = info->count * ncclTypeSize(info->datatype); + if (info->coll == ncclCollAllGather || info->coll == ncclCollBroadcast) { + info->count = info->nBytes; + info->datatype = ncclInt8; + } + if (info->coll == ncclCollAllGather || info->coll == ncclCollReduceScatter) info->nBytes *= info->comm->nRanks; // count is per rank + + if (info->op < 0 || info->op >= ncclNumOps) { + WARN("%s : invalid reduction operation %d", info->opName, info->op); + return ncclInvalidArgument; + } + + if (info->comm->checkPointers) { + // Check CUDA device pointers + if (info->coll != ncclCollBroadcast || info->comm->rank == info->root) { + NCCLCHECK(CudaPtrCheck(info->sendbuff, info->comm, "sendbuff", info->opName)); + } + if (info->coll != ncclCollReduce || info->comm->rank == info->root) { + NCCLCHECK(CudaPtrCheck(info->recvbuff, info->comm, "recvbuff", info->opName)); + } + } + return ncclSuccess; +} |