diff options
Diffstat (limited to 'src/misc/argcheck.cc')
-rw-r--r-- | src/misc/argcheck.cc | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/src/misc/argcheck.cc b/src/misc/argcheck.cc index 67931f8..c262f8c 100644 --- a/src/misc/argcheck.cc +++ b/src/misc/argcheck.cc @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -34,7 +34,6 @@ ncclResult_t PtrCheck(void* ptr, const char* opname, const char* ptrname) { } ncclResult_t ArgsCheck(struct ncclInfo* info) { - NCCLCHECK(PtrCheck(info->comm, info->opName, "comm")); // First, the easy ones if (info->root < 0 || info->root >= info->comm->nRanks) { WARN("%s : invalid root %d (root should be in the 0..%d range)", info->opName, info->root, info->comm->nRanks); @@ -44,13 +43,13 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) { WARN("%s : invalid type %d", info->opName, info->datatype); return ncclInvalidArgument; } - // Type is OK, compute nbytes. Convert Allgather/Broadcast calls to chars. + // Type is OK, compute nbytes. Convert Allgather/Broadcast/P2P calls to chars. info->nBytes = info->count * ncclTypeSize(info->datatype); - if (info->coll == ncclCollAllGather || info->coll == ncclCollBroadcast) { + if (info->coll == ncclFuncAllGather || info->coll == ncclFuncBroadcast) { info->count = info->nBytes; info->datatype = ncclInt8; } - if (info->coll == ncclCollAllGather || info->coll == ncclCollReduceScatter) info->nBytes *= info->comm->nRanks; // count is per rank + if (info->coll == ncclFuncAllGather || info->coll == ncclFuncReduceScatter) info->nBytes *= info->comm->nRanks; // count is per rank if (info->op < 0 || info->op >= ncclNumOps) { WARN("%s : invalid reduction operation %d", info->opName, info->op); @@ -58,12 +57,20 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) { } if (info->comm->checkPointers) { - // Check CUDA device pointers - if (info->coll != ncclCollBroadcast || info->comm->rank == info->root) { - NCCLCHECK(CudaPtrCheck(info->sendbuff, info->comm, "sendbuff", info->opName)); - } - if (info->coll != ncclCollReduce || info->comm->rank == info->root) { - NCCLCHECK(CudaPtrCheck(info->recvbuff, info->comm, "recvbuff", info->opName)); + if (info->coll == ncclFuncSendRecv) { + if (strcmp(info->opName, "Send") == 0) { + NCCLCHECK(CudaPtrCheck(info->sendbuff, info->comm, "sendbuff", "Send")); + } else { + NCCLCHECK(CudaPtrCheck(info->recvbuff, info->comm, "recvbuff", "Recv")); + } + } else { + // Check CUDA device pointers + if (info->coll != ncclFuncBroadcast || info->comm->rank == info->root) { + NCCLCHECK(CudaPtrCheck(info->sendbuff, info->comm, "sendbuff", info->opName)); + } + if (info->coll != ncclFuncReduce || info->comm->rank == info->root) { + NCCLCHECK(CudaPtrCheck(info->recvbuff, info->comm, "recvbuff", info->opName)); + } } } return ncclSuccess; |