diff options
Diffstat (limited to 'src/collectives/device/prims_ll128.h')
-rw-r--r-- | src/collectives/device/prims_ll128.h | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h index f97b25c..999d0d5 100644 --- a/src/collectives/device/prims_ll128.h +++ b/src/collectives/device/prims_ll128.h @@ -211,14 +211,14 @@ class ncclLL128Primitives { /************************ Send **************************/ if (SEND) { for (int i=1; i<NSEND && i<nsend; i++) { - int flag = sendFlag(i); + uint64_t flag = sendFlag(i); uint64_t* ptr = sendPtr(i)+ll128Offset; #pragma unroll for (int u=0; u<ELEMS_PER_THREAD; u+=2) { store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]); } } - int flag = sendFlag(0); + uint64_t flag = sendFlag(0); uint64_t* ptr = sendPtr(0)+ll128Offset; #pragma unroll for (int u=0; u<ELEMS_PER_THREAD; u+=2) { @@ -318,10 +318,10 @@ class ncclLL128Primitives { sendConnHeadPtr = sendConn->head; sendConnHeadCache = *sendConnHeadPtr; sendConnHead = sendConn->step; - sendConnFifoPtr = sendConn->fifo; + sendConnFifoPtr = sendConn->sizesFifo; } if (tid >= nthreads-WARP_SIZE && wid<nsend) { - if (sendConn->fifo) { + if (sendConn->sizesFifo) { sendConnTailPtr = sendConn->tail; sendConnTail = sendConn->step; } @@ -345,7 +345,7 @@ class ncclLL128Primitives { public: __device__ __forceinline__ ncclLL128Primitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm) - : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) { + : comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), stepSize(stepSize), shmem(ncclShmem->data+(threadIdx.x/WARP_SIZE)*NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE+2*wid) { // Make sure step is updated before we read it. barrier(); |