Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/misc/argcheck.cc')
-rw-r--r--src/misc/argcheck.cc29
1 files changed, 18 insertions, 11 deletions
diff --git a/src/misc/argcheck.cc b/src/misc/argcheck.cc
index 67931f8..c262f8c 100644
--- a/src/misc/argcheck.cc
+++ b/src/misc/argcheck.cc
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -34,7 +34,6 @@ ncclResult_t PtrCheck(void* ptr, const char* opname, const char* ptrname) {
}
ncclResult_t ArgsCheck(struct ncclInfo* info) {
- NCCLCHECK(PtrCheck(info->comm, info->opName, "comm"));
// First, the easy ones
if (info->root < 0 || info->root >= info->comm->nRanks) {
WARN("%s : invalid root %d (root should be in the 0..%d range)", info->opName, info->root, info->comm->nRanks);
@@ -44,13 +43,13 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) {
WARN("%s : invalid type %d", info->opName, info->datatype);
return ncclInvalidArgument;
}
- // Type is OK, compute nbytes. Convert Allgather/Broadcast calls to chars.
+ // Type is OK, compute nbytes. Convert Allgather/Broadcast/P2P calls to chars.
info->nBytes = info->count * ncclTypeSize(info->datatype);
- if (info->coll == ncclCollAllGather || info->coll == ncclCollBroadcast) {
+ if (info->coll == ncclFuncAllGather || info->coll == ncclFuncBroadcast) {
info->count = info->nBytes;
info->datatype = ncclInt8;
}
- if (info->coll == ncclCollAllGather || info->coll == ncclCollReduceScatter) info->nBytes *= info->comm->nRanks; // count is per rank
+ if (info->coll == ncclFuncAllGather || info->coll == ncclFuncReduceScatter) info->nBytes *= info->comm->nRanks; // count is per rank
if (info->op < 0 || info->op >= ncclNumOps) {
WARN("%s : invalid reduction operation %d", info->opName, info->op);
@@ -58,12 +57,20 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) {
}
if (info->comm->checkPointers) {
- // Check CUDA device pointers
- if (info->coll != ncclCollBroadcast || info->comm->rank == info->root) {
- NCCLCHECK(CudaPtrCheck(info->sendbuff, info->comm, "sendbuff", info->opName));
- }
- if (info->coll != ncclCollReduce || info->comm->rank == info->root) {
- NCCLCHECK(CudaPtrCheck(info->recvbuff, info->comm, "recvbuff", info->opName));
+ if (info->coll == ncclFuncSendRecv) {
+ if (strcmp(info->opName, "Send") == 0) {
+ NCCLCHECK(CudaPtrCheck(info->sendbuff, info->comm, "sendbuff", "Send"));
+ } else {
+ NCCLCHECK(CudaPtrCheck(info->recvbuff, info->comm, "recvbuff", "Recv"));
+ }
+ } else {
+ // Check CUDA device pointers
+ if (info->coll != ncclFuncBroadcast || info->comm->rank == info->root) {
+ NCCLCHECK(CudaPtrCheck(info->sendbuff, info->comm, "sendbuff", info->opName));
+ }
+ if (info->coll != ncclFuncReduce || info->comm->rank == info->root) {
+ NCCLCHECK(CudaPtrCheck(info->recvbuff, info->comm, "recvbuff", info->opName));
+ }
}
}
return ncclSuccess;