Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/collectives/device/prims_ll128.h')
-rw-r--r--src/collectives/device/prims_ll128.h10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h
index f97b25c..999d0d5 100644
--- a/src/collectives/device/prims_ll128.h
+++ b/src/collectives/device/prims_ll128.h
@@ -211,14 +211,14 @@ class ncclLL128Primitives {
/************************ Send **************************/
if (SEND) {
for (int i=1; i<NSEND && i<nsend; i++) {
- int flag = sendFlag(i);
+ uint64_t flag = sendFlag(i);
uint64_t* ptr = sendPtr(i)+ll128Offset;
#pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]);
}
}
- int flag = sendFlag(0);
+ uint64_t flag = sendFlag(0);
uint64_t* ptr = sendPtr(0)+ll128Offset;
#pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
@@ -318,10 +318,10 @@ class ncclLL128Primitives {
sendConnHeadPtr = sendConn->head;
sendConnHeadCache = *sendConnHeadPtr;
sendConnHead = sendConn->step;
- sendConnFifoPtr = sendConn->fifo;
+ sendConnFifoPtr = sendConn->sizesFifo;
}
if (tid >= nthreads-WARP_SIZE && wid<nsend) {
- if (sendConn->fifo) {
+ if (sendConn->sizesFifo) {
sendConnTailPtr = sendConn->tail;
sendConnTail = sendConn->step;
}
@@ -345,7 +345,7 @@ class ncclLL128Primitives {
public:
__device__ __forceinline__
ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm)
- : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
+ : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem->data+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) {
// Make sure step is updated before we read it.
barrier();