Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Jeaugey <sjeaugey@nvidia.com>2020-05-13 00:40:18 +0300
committerSylvain Jeaugey <sjeaugey@nvidia.com>2020-06-08 19:31:44 +0300
commit5949d96f36d050e59d05872f8bbffd2549318e95 (patch)
treee56476c71668bbd1ce4ddbc189b1be7d037b065c /src/collectives/device/prims_ll128.h
parentf36540f55a15683a121b6c330657af442b85c796 (diff)
2.7.3-1
Add support for A100 GPU and related platforms. Add support for CUDA 11. Add support for send/receive operations (beta).
Diffstat (limited to 'src/collectives/device/prims_ll128.h')
-rw-r--r--src/collectives/device/prims_ll128.h19
1 files changed, 10 insertions, 9 deletions
diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h
index 40a8cff..f445e0d 100644
--- a/src/collectives/device/prims_ll128.h
+++ b/src/collectives/device/prims_ll128.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -14,6 +14,7 @@ class ncclLL128Primitives {
const int tid;
const int nthreads;
const int wid;
+ const int stepSize;
const int warp;
const bool flagThread;
int nrecv = 0;
@@ -38,8 +39,8 @@ class ncclLL128Primitives {
volatile uint64_t* shmem;
- inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*NCCL_LL128_SLICE_ELEMS; }
- inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*NCCL_LL128_SLICE_ELEMS; }
+ inline __device__ int recvOffset(int i) { return (recvStep[i]%NCCL_STEPS)*stepSize; }
+ inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; }
inline __device__ uint64_t* recvPtr(int i) { return recvBuff[i]+recvOffset(i); }
inline __device__ uint64_t* sendPtr(int i) { return sendBuff[i]+sendOffset(i); }
inline __device__ uint64_t recvFlag(int i) { return recvStep[i]+1; }
@@ -47,9 +48,9 @@ class ncclLL128Primitives {
inline __device__ void barrier() {
if (NSEND>NRECV) {
- asm volatile ("bar.sync 2, %0;" :: "r"(nthreads));
+ asm volatile ("bar.sync 1, %0;" :: "r"(nthreads));
} else {
- asm volatile ("bar.sync 3, %0;" :: "r"(nthreads));
+ asm volatile ("bar.sync 2, %0;" :: "r"(nthreads));
}
}
@@ -309,7 +310,7 @@ class ncclLL128Primitives {
}
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i) {
- recvBuff[i] = conn->ll128Buff;
+ recvBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128];
recvStep[i] = conn->step;
if (wid == i) recvConn = conn;
nrecv++;
@@ -324,7 +325,7 @@ class ncclLL128Primitives {
}
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
- sendBuff[i] = conn->ll128Buff;
+ sendBuff[i] = (uint64_t*)conn->buffs[NCCL_PROTO_LL128];
sendStep[i] = conn->step;
if (wid == i) sendConn = conn;
nsend++;
@@ -363,8 +364,8 @@ class ncclLL128Primitives {
public:
__device__ __forceinline__
- ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
- : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), opCount(opCount), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
+ ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm, const uint64_t opCount)
+ : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), opCount(opCount), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
// Make sure step is updated before we read it.
barrier();