Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/collectives/all_reduce.cu')
-rw-r--r--src/collectives/all_reduce.cu32
1 files changed, 32 insertions, 0 deletions
diff --git a/src/collectives/all_reduce.cu b/src/collectives/all_reduce.cu
new file mode 100644
index 0000000..cca9886
--- /dev/null
+++ b/src/collectives/all_reduce.cu
@@ -0,0 +1,32 @@
+/*************************************************************************
+ * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * See LICENSE.txt for license information
+ ************************************************************************/
+
+#include "core.h"
+#include "common_coll.h"
+#include "enqueue.h"
+#include "collectives.h"
+
+ncclResult_t ncclAllReduceFunc(const void* sendbuff, void* recvbuff, size_t count,
+ ncclDataType_t datatype, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
+ size_t nbytes = count*ncclTypeSize(datatype);
+ INFO(COLL,"opCount %lx sendbuff %p recvbuff %p count %zi size %zi datatype %d op %d comm %p [nranks=%d] stream %p", comm->opCount, sendbuff, recvbuff, count, nbytes, datatype, op, comm, comm->nRanks, stream);
+ if (comm->nRanks == 1) {
+ if (sendbuff != recvbuff)
+ CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, nbytes, cudaMemcpyDeviceToDevice, stream));
+ } else {
+ NCCLCHECK(transportSaveProxies(ALLREDUCE_SUBSTEPS, ALLREDUCE_BUFCHUNKS, (comm->nRanks)*2-2, comm->nRanks, nbytes, proxyPatternRing, comm));
+ NCCLCHECK(saveKernel(ncclCollAllReduce, sendbuff, recvbuff, count, datatype, op, root, comm, stream, nbytes, comm->nRanks));
+ }
+ return ncclSuccess;
+}
+
+NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count,
+ ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream);
+ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
+ ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
+ return ncclEnqueueCheck(ncclAllReduceFunc, "AllReduce", sendbuff, recvbuff, count, datatype,
+ op, 0, comm, stream);
+}