diff options
author | Sylvain Jeaugey <sjeaugey@nvidia.com> | 2018-09-25 02:06:59 +0300 |
---|---|---|
committer | Sylvain Jeaugey <sjeaugey@nvidia.com> | 2018-09-26 00:12:01 +0300 |
commit | f93fe9bfd94884cec2ba711897222e0df5569a53 (patch) | |
tree | 78b91eed1abfbaa3346b85bffe0c0ef9d9fb32bf /src/collectives/device | |
parent | 286916a1a37ca1fe8cd43e280f5c42ec29569fc5 (diff) |
2.3.5-5v2.3.5-5
Add support for inter-node communication using sockets and InfiniBand/RoCE.
Improve latency.
Add support for aggregation.
Improve LL/regular tuning.
Remove tests as those are now at github.com/nvidia/nccl-tests .
Diffstat (limited to 'src/collectives/device')
-rw-r--r-- | src/collectives/device/Makefile | 86 | ||||
-rw-r--r-- | src/collectives/device/all_gather.cu | 15 | ||||
-rw-r--r-- | src/collectives/device/all_gather.h | 269 | ||||
-rw-r--r-- | src/collectives/device/all_reduce.cu | 21 | ||||
-rw-r--r-- | src/collectives/device/all_reduce.h | 332 | ||||
-rw-r--r-- | src/collectives/device/broadcast.cu | 15 | ||||
-rw-r--r-- | src/collectives/device/broadcast.h | 228 | ||||
-rw-r--r-- | src/collectives/device/common.h | 90 | ||||
-rw-r--r-- | src/collectives/device/common_kernel.h | 372 | ||||
-rw-r--r-- | src/collectives/device/functions.cu | 64 | ||||
-rw-r--r-- | src/collectives/device/ll_kernel.h | 154 | ||||
-rw-r--r-- | src/collectives/device/primitives.h | 226 | ||||
-rw-r--r-- | src/collectives/device/reduce.cu | 21 | ||||
-rw-r--r-- | src/collectives/device/reduce.h | 190 | ||||
-rw-r--r-- | src/collectives/device/reduce_kernel.h | 364 | ||||
-rw-r--r-- | src/collectives/device/reduce_scatter.cu | 21 | ||||
-rw-r--r-- | src/collectives/device/reduce_scatter.h | 217 |
17 files changed, 2685 insertions, 0 deletions
diff --git a/src/collectives/device/Makefile b/src/collectives/device/Makefile new file mode 100644 index 0000000..ccea8f5 --- /dev/null +++ b/src/collectives/device/Makefile @@ -0,0 +1,86 @@ +# +# Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. +# +# See LICENSE.txt for license information +# + +include ../../../makefiles/common.mk +include ../../../makefiles/version.mk + +BUILDDIR ?= $(abspath ../../../build) +OBJDIR := $(BUILDDIR)/obj/collectives/device + +LIBSRCFILES := all_reduce.cu broadcast.cu reduce.cu all_gather.cu reduce_scatter.cu + +LIBOBJ := $(patsubst %.cu,$(OBJDIR)/%_sum.o, $(LIBSRCFILES)) \ + $(patsubst %.cu,$(OBJDIR)/%_prod.o, $(LIBSRCFILES)) \ + $(patsubst %.cu,$(OBJDIR)/%_min.o, $(LIBSRCFILES)) \ + $(patsubst %.cu,$(OBJDIR)/%_max.o, $(LIBSRCFILES)) \ + $(OBJDIR)/functions.o + +LIBSRCFILES += functions.cu + +DEPFILES := $(patsubst %.cu, $(OBJDIR)/%.d, $(LIBSRCFILES)) +DEPENDFILES := $(DEPFILES:%.d=%.dep) +STATICLIB := $(OBJDIR)/colldevice.a +DEVOBJ := $(OBJDIR)/devlink.o + +NVCUFLAGS += -I. -I.. -I../.. -I../../include --compiler-options "-fPIC -fvisibility=hidden" + + +all: $(STATICLIB) + +# Dummy rule so that the extra dependency (%.dep) files are preserved by make +all_deps: $(DEPENDFILES) + +-include $(DEPFILES) + +$(STATICLIB): $(LIBOBJ) $(DEVOBJ) + @printf "Archiving %-35s > %s\n" objects $@ + ar cr $@ $^ + +# We do not want make to build *.d when running make clean. +# So we only provide targets for .dep which will produce .dep and .d, +# with only .d being included, and .dep keeping track of what needs to +# be regenerated. +$(OBJDIR)/%.dep : %.cu + @mkdir -p $(OBJDIR) + @$(NVCC) $(NVCUFLAGS) -M $< -o $@.tmp + @sed "0,/^.*:/s//$(subst /,\/,$@):/" $@.tmp > $@ + @sed -e 's/.*://' -e 's/\\$$//' < $@.tmp | fmt -1 | \ + sed -e 's/^ *//' -e 's/$$/:/' >> $@ + @rm -f $@.tmp + @cp $@ $(@:.dep=.d) + +# Compiled kernels and collectives with relocatable device code ... +$(OBJDIR)/functions.o : functions.cu $(OBJDIR)/functions.dep + @printf "Compiling %-35s > %s\n" $< $@ + mkdir -p `dirname $@` + $(NVCC) $(NVCUFLAGS) -dc $< -o $@ + +$(OBJDIR)/%_sum.o : %.cu $(OBJDIR)/%.dep + @printf "Compiling %-35s > %s\n" $< $@ + mkdir -p `dirname $@` + $(NVCC) -DNCCL_OP=0 $(NVCUFLAGS) -dc $< -o $@ + +$(OBJDIR)/%_prod.o : %.cu $(OBJDIR)/%.dep + @printf "Compiling %-35s > %s\n" $< $@ + mkdir -p `dirname $@` + $(NVCC) -DNCCL_OP=1 $(NVCUFLAGS) -dc $< -o $@ + +$(OBJDIR)/%_min.o : %.cu $(OBJDIR)/%.dep + @printf "Compiling %-35s > %s\n" $< $@ + mkdir -p `dirname $@` + $(NVCC) -DNCCL_OP=2 $(NVCUFLAGS) -dc $< -o $@ + +$(OBJDIR)/%_max.o : %.cu $(OBJDIR)/%.dep + @printf "Compiling %-35s > %s\n" $< $@ + mkdir -p `dirname $@` + $(NVCC) -DNCCL_OP=3 $(NVCUFLAGS) -dc $< -o $@ + +# ... and create the device-side linked object with all those. +$(DEVOBJ) : $(LIBOBJ) + $(NVCC) $(NVCUFLAGS) -dlink $^ -o $@ + +clean: + rm -f $(LIBOBJ) $(DEVOBJ) $(DEPFILES) $(DEPENDFILES) $(STATICLIB) test diff --git a/src/collectives/device/all_gather.cu b/src/collectives/device/all_gather.cu new file mode 100644 index 0000000..0f572ce --- /dev/null +++ b/src/collectives/device/all_gather.cu @@ -0,0 +1,15 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "common.h" +#include "all_gather.h" +#include "collectives.h" + +#define UNROLL 4 + +#if NCCL_OP == 0 +IMPL_COLL3(ncclAllGather, copy, FuncSum, i8, int8_t, ncclCollAllGather, ncclSum, ncclInt8); +#endif diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h new file mode 100644 index 0000000..a30e575 --- /dev/null +++ b/src/collectives/device/all_gather.h @@ -0,0 +1,269 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "core.h" +#include "primitives.h" +#include "collectives.h" + +// Increase Step and poffset/noffset for buffer sync +#define NEXT_STEP \ + step++; \ + poffset = noffset; \ + noffset += sliceSize; \ + if (noffset == buffSize) noffset = 0; + +template<int UNROLL, class FUNC, typename T> +__device__ void ncclAllGatherKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int nthreads = blockDim.x - 1; + const int bid = args->bid; + __shared__ T* sharedNextOutput; + struct ncclComm* comm = args->comm; + struct ncclRing* ring = comm->rings+blockIdx.x; + int prevdirect = ring->recv.conn.direct; + int nextdirect = ring->send.conn.direct; + + WaitFlag waitDoneFromNext(ring->send.conn.head, ALLGATHER_BUFCHUNKS*ALLGATHER_SUBSTEPS); + WaitFlag waitReadyFromPrev(ring->recv.conn.tail, ALLGATHER_SUBSTEPS); + PostFlag postDoneToPrev(ring->recv.conn.head, ALLGATHER_SUBSTEPS, NULL, 0); + PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, ALLGATHER_BUFCHUNKS*ALLGATHER_SUBSTEPS); + + typedef Primitives<UNROLL, ALLGATHER_SUBSTEPS, T> Prims; + + const ssize_t size = args->N; + const int nranks = comm->nRanks; + const int buffSize = ring->buffSize / sizeof(T); + const int sliceSize = buffSize / ALLGATHER_BUFCHUNKS; + const ssize_t loopSize = args->nRings*(ssize_t)sliceSize; + + if (tid == 0) { + // Update in case we skipped some collectives + *ring->recv.conn.opCount = args->opCount; + // Wait for next to be ready + WaitFlag waitOpCountNext(ring->send.conn.opCount, 0); + waitOpCountNext.wait(args->opCount); + if (prevdirect) { + *ring->recv.conn.ptrExchange = args->ThisOutput; + } + if (nextdirect) { + void* volatile* ptr = &(ring->devMemSend->ptrExchange); + while (*ptr == nullptr); + sharedNextOutput = (T*)*ptr; + *ptr = nullptr; + } + } + __syncthreads(); + + uint64_t step = 0ULL; + int poffset, noffset = 0; + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + T * __restrict__ prevInput = (T*)ring->recv.conn.buff; + T * __restrict__ nextOutput = (T*)ring->send.conn.buff; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + int chunkSize = min(sliceSize, DIVUP(size-gridOffset,args->nRings)); + ALIGN_SIZE(chunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); + ssize_t chunkOffset = gridOffset + bid*chunkSize; + + /////////////// begin AllGather steps /////////////// + ssize_t offset; + int maxOffset = min(chunkSize, size-chunkOffset); + int rankDest; + + // step 0: push data to next GPU + rankDest = ring->devUserRanks[0]; + offset = chunkOffset + rankDest * size; + + if (thisInput + chunkOffset == thisOutput + offset) { // In place + Prims::Copy(tid, nthreads, + thisInput + chunkOffset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset), + sliceSize, maxOffset, + step, + waitDoneFromNext, + postReadyToNext); + } else { + Prims::DoubleCopy(tid, nthreads, + thisInput + chunkOffset, + thisOutput + offset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset), + sliceSize, maxOffset, + step, + waitDoneFromNext, + postReadyToNext); + } + + NEXT_STEP; // Increases step, poffset, noffset + + // k-2 steps: copy to next GPU + if (prevdirect) { + for (int j=1; j<nranks-1; ++j) { + rankDest = ring->devUserRanks[nranks-j]; + offset = chunkOffset + rankDest * size; + + Prims::Copy(tid, nthreads, + thisOutput + offset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset), + sliceSize, maxOffset, + step, + waitDoneFromNext, waitReadyFromPrev, + postReadyToNext, postDoneToPrev); + + NEXT_STEP; + } + Prims::Copy(tid, nthreads, + NULL, + NULL, + 0, 0, + step, + waitReadyFromPrev, + postDoneToPrev); + } else { + for (int j=1; j<nranks-1; ++j) { + rankDest = ring->devUserRanks[nranks-j]; + offset = chunkOffset + rankDest * size; + + Prims::DoubleCopy(tid, nthreads, + prevInput + poffset, + thisOutput + offset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset), + sliceSize, maxOffset, + step, + waitDoneFromNext, waitReadyFromPrev, + postReadyToNext, postDoneToPrev); + + NEXT_STEP; + } + + // Make final copy from buffer to dest. + rankDest = ring->devUserRanks[1]; + offset = chunkOffset + rankDest * size; + + // Here we need to copy from buffer to this output. + Prims::Copy(tid, nthreads, + prevInput + poffset, + thisOutput + offset, + sliceSize, maxOffset, + step, + waitReadyFromPrev, + postDoneToPrev); + } + } + + if (tid == 0) { + waitDoneFromNext.wait(ALLGATHER_SUBSTEPS*(step + ALLGATHER_BUFCHUNKS)); + *ring->send.conn.head = 0ULL; + *ring->recv.conn.tail = 0ULL; + __threadfence_system(); + *ring->recv.conn.opCount = args->opCount+1; + } +} + +#include "ll_kernel.h" + +#define NEXT_STEP_LL \ + poffset = noffset; \ + pflag = nflag; \ + noffset += NCCL_LL_SLICE_LINES; \ + if (noffset == NCCL_LL_BUFF_LINES) { noffset = 0; } \ + nflag++; \ + step++; + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclAllGatherLLKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int llNthreads = args->nThreads; + struct ncclComm* comm = args->comm; + struct ncclRing* ring = comm->rings+blockIdx.x; + volatile uint64_t * recvHeadPtr = ring->recv.conn.llHead; + volatile uint64_t * sendHeadPtr = ring->send.conn.llHead; + volatile int * sizesFifo = ring->send.conn.llFifo; + uint64_t sendHead = sendHeadPtr[0]; + + typedef LLPrimitives<T, FUNC> LL; + + const ssize_t size = args->N; + //const int rank = comm->rank; + const int nranks = comm->nRanks; + ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const ssize_t loopSize = args->nRings*chunkSize; + + uint64_t step = ring->send.conn.llStep; + uint32_t pflag, nflag = step + 1; + int poffset, noffset = NCCL_LL_SLICE_LINES * STEP_TO_SLOT(step); + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + union ncclLLFifoLine * prevInput = (union ncclLLFifoLine *)ring->recv.conn.llBuff; + union ncclLLFifoLine * nextOutput = (union ncclLLFifoLine *)ring->send.conn.llBuff; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + if (size-gridOffset < loopSize) { + chunkSize = args->lastChunkSize; + } + ssize_t chunkOffset = gridOffset + bid*chunkSize; + + /////////////// begin AllGather steps /////////////// + ssize_t offset; + int maxOffset = min(chunkSize, size-chunkOffset); + int rankDest; + + // step 0: push data to next GPU + rankDest = ring->devUserRanks[0]; + offset = chunkOffset + rankDest * size; + + WAIT_NEXT; + if (thisInput + chunkOffset == thisOutput + offset) { // In place + LL::ReduceCopy( + thisInput + chunkOffset, + nextOutput + noffset, + maxOffset, nflag, llNthreads); + } else { + LL::ReduceCopy( + thisInput + chunkOffset, + thisOutput + offset, + nextOutput + noffset, + maxOffset, nflag, llNthreads); + } + POST_SIZE; + + NEXT_STEP_LL; + + // k-2 steps: copy to next GPU + for (int j=1; j<nranks-1; ++j) { + rankDest = ring->devUserRanks[nranks-j]; + offset = chunkOffset + rankDest * size; + + WAIT_NEXT; + LL::ReduceCopy( + prevInput + poffset, + thisOutput + offset, + nextOutput + noffset, + maxOffset, pflag, nflag, llNthreads); + POST_SIZE; + ACK_PREV; + + NEXT_STEP_LL; + } + + // step k-1: final store + rankDest = ring->devUserRanks[1]; + offset = chunkOffset + rankDest * size; + + LL::ReduceCopy( + prevInput + poffset, + thisOutput + offset, + maxOffset, pflag, llNthreads); + ACK_PREV; + } + + FIFO_CLEANING_AND_SAVE_STEP(nflag); +} diff --git a/src/collectives/device/all_reduce.cu b/src/collectives/device/all_reduce.cu new file mode 100644 index 0000000..caa1479 --- /dev/null +++ b/src/collectives/device/all_reduce.cu @@ -0,0 +1,21 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "common.h" +#include "all_reduce.h" +#include "collectives.h" + +#define UNROLL 4 + +#if NCCL_OP == 0 +IMPL_COLL2(ncclAllReduce, sum, FuncSum, ncclCollAllReduce, ncclSum); +#elif NCCL_OP == 1 +IMPL_COLL2(ncclAllReduce, prod, FuncProd, ncclCollAllReduce, ncclProd); +#elif NCCL_OP == 2 +IMPL_COLL2(ncclAllReduce, min, FuncMin, ncclCollAllReduce, ncclMin); +#elif NCCL_OP == 3 +IMPL_COLL2(ncclAllReduce, max, FuncMax, ncclCollAllReduce, ncclMax); +#endif diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h new file mode 100644 index 0000000..d7abc64 --- /dev/null +++ b/src/collectives/device/all_reduce.h @@ -0,0 +1,332 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "core.h" +#include "primitives.h" +#include "collectives.h" + +// Increase Step and poffset/noffset for buffer sync +#define NEXT_STEP \ + step++; \ + poffset = noffset; \ + noffset += sliceSize; \ + if (noffset == buffSize) noffset = 0; + +template<int UNROLL, class FUNC, typename T> +__device__ void ncclAllReduceKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int nthreads = blockDim.x - 1; + const int bid = args->bid; + __shared__ T* sharedNextOutput; + struct ncclComm* comm = args->comm; + struct ncclRing* ring = comm->rings+blockIdx.x; + int prevdirect = ring->recv.conn.direct; + int nextdirect = ring->send.conn.direct; + + WaitFlag waitDoneFromNext(ring->send.conn.head, ALLREDUCE_BUFCHUNKS*ALLREDUCE_SUBSTEPS); + WaitFlag waitReadyFromPrev(ring->recv.conn.tail, ALLREDUCE_SUBSTEPS); + PostFlag postDoneToPrev(ring->recv.conn.head, ALLREDUCE_SUBSTEPS, NULL, 0); + PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, ALLREDUCE_BUFCHUNKS*ALLREDUCE_SUBSTEPS); + + typedef Primitives<UNROLL, ALLREDUCE_SUBSTEPS, T, FUNC> Prims; + + const ssize_t size = args->N; + //const int rank = comm->rank; + const int nranks = comm->nRanks; + const int buffSize = ring->buffSize / sizeof(T); + const int sliceSize = buffSize / ALLREDUCE_BUFCHUNKS; + const ssize_t loopSize = args->nRings*(ssize_t)sliceSize; + + if (tid == 0) { + // Update in case we skipped some collectives + *ring->recv.conn.opCount = args->opCount; + // Wait for next to be ready + WaitFlag waitOpCountNext(ring->send.conn.opCount, 0); + waitOpCountNext.wait(args->opCount); + if (prevdirect) { + *ring->recv.conn.ptrExchange = args->ThisOutput; + } + if (nextdirect) { + void* volatile* ptr = &(ring->devMemSend->ptrExchange); + while (*ptr == nullptr); + sharedNextOutput = (T*)*ptr; + *ptr = nullptr; + } + } + __syncthreads(); + + uint64_t step = 0ULL; + int poffset, noffset = 0; + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + T * __restrict__ prevInput = (T*)ring->recv.conn.buff; + T * __restrict__ nextOutput = (T*)ring->send.conn.buff; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += nranks*loopSize) { + int chunkSize = min(sliceSize, DIVUP(size-gridOffset,nranks*args->nRings)); + ALIGN_SIZE(chunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); + ssize_t chunkOffset = gridOffset + bid*nranks*chunkSize; + + /////////////// begin AllReduce steps /////////////// + ssize_t offset; + int maxOffset; + int slice; + + // step 0: push data to next GPU + slice = ring->devUserRanks[nranks-1]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + Prims::Copy(tid, nthreads, + thisInput + offset, + nextOutput + noffset, + sliceSize, maxOffset, + step, + waitDoneFromNext, + postReadyToNext); + + NEXT_STEP; // Increases step, poffset, noffset + + // k-2 steps: reduce and copy to next GPU + for (int j=2; j<nranks; ++j) { + slice = ring->devUserRanks[nranks-j]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + Prims::Reduce(tid, nthreads, + prevInput + poffset, + thisInput + offset, + nextOutput + noffset, + sliceSize, maxOffset, + step, + waitDoneFromNext, waitReadyFromPrev, + postReadyToNext, postDoneToPrev); + + NEXT_STEP; + } + + // step k-1: reduce this buffer and data, which will produce the final + // result that we store in this data and push to the next GPU + slice = ring->devUserRanks[0]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + Prims::ReduceCopy(tid, nthreads, + prevInput + poffset, + thisInput + offset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset), + thisOutput + offset, + sliceSize, maxOffset, + step, + waitDoneFromNext, waitReadyFromPrev, + postReadyToNext, postDoneToPrev); + + NEXT_STEP; + + // k-2 steps: copy to next GPU + if (prevdirect) { + for (int j=1; j<nranks-1; ++j) { + slice = ring->devUserRanks[nranks - j]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + Prims::Copy(tid, nthreads, + thisOutput + offset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset), + sliceSize, maxOffset, + step, + waitDoneFromNext, waitReadyFromPrev, + postReadyToNext, postDoneToPrev); + + NEXT_STEP; + } + Prims::Copy(tid, nthreads, + NULL, + NULL, + 0, 0, + step, + waitReadyFromPrev, + postDoneToPrev); + } else { + for (int j=1; j<nranks-1; ++j) { + slice = ring->devUserRanks[nranks - j]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + Prims::DoubleCopy(tid, nthreads, + prevInput + poffset, + thisOutput + offset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + noffset), + sliceSize, maxOffset, + step, + waitDoneFromNext, waitReadyFromPrev, + postReadyToNext, postDoneToPrev); + + NEXT_STEP; + } + + // Make final copy from buffer to dest. + slice = ring->devUserRanks[1]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + // Here we need to copy from buffer to this output. + Prims::Copy(tid, nthreads, + prevInput + poffset, + thisOutput + offset, + sliceSize, maxOffset, + step, + waitReadyFromPrev, + postDoneToPrev); + } + } + + if (tid == 0) { + // Wait for next to have consumed all data before we reset the flag + waitDoneFromNext.wait(ALLREDUCE_SUBSTEPS*(step + ALLREDUCE_BUFCHUNKS)); + *ring->send.conn.head = 0ULL; + *ring->recv.conn.tail = 0ULL; + __threadfence_system(); + *ring->recv.conn.opCount = args->opCount+1; + } +} + +#include "ll_kernel.h" + +#define NEXT_STEP_LL \ + poffset = noffset; \ + pflag = nflag; \ + noffset += NCCL_LL_SLICE_LINES; \ + if (noffset == NCCL_LL_BUFF_LINES) { noffset = 0; } \ + nflag++; \ + step++; + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclAllReduceLLKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int llNthreads = args->nThreads; + struct ncclComm* comm = args->comm; + struct ncclRing* ring = comm->rings+blockIdx.x; + volatile uint64_t * recvHeadPtr = ring->recv.conn.llHead; + volatile uint64_t * sendHeadPtr = ring->send.conn.llHead; + volatile int * sizesFifo = ring->send.conn.llFifo; + uint64_t sendHead = sendHeadPtr[0]; + + typedef LLPrimitives<T, FUNC> LL; + + const ssize_t size = args->N; + //const int rank = comm->rank; + const int nranks = comm->nRanks; + ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const ssize_t loopSize = args->nRings*nranks*chunkSize; + + uint64_t step = ring->send.conn.llStep; + uint32_t pflag, nflag = step + 1; + int poffset, noffset = NCCL_LL_SLICE_LINES * STEP_TO_SLOT(step); + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + union ncclLLFifoLine * prevInput = (union ncclLLFifoLine *)ring->recv.conn.llBuff; + union ncclLLFifoLine * nextOutput = (union ncclLLFifoLine *)ring->send.conn.llBuff; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + if (size-gridOffset < loopSize) { + chunkSize = args->lastChunkSize; + } + ssize_t chunkOffset = gridOffset + bid*nranks*chunkSize; + + /////////////// begin AllReduce steps /////////////// + ssize_t offset; + int maxOffset; + int slice; + + // step 0: push data to next GPU + slice = ring->devUserRanks[nranks-1]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + WAIT_NEXT; + LL::ReduceCopy( + thisInput + offset, + nextOutput + noffset, + maxOffset, nflag, llNthreads); + POST_SIZE; + + NEXT_STEP_LL; + + // k-2 steps: reduce and copy to next GPU + for (int j=2; j<nranks; ++j) { + slice = ring->devUserRanks[nranks-j]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + WAIT_NEXT; + LL::ReduceCopy( + thisInput + offset, + prevInput + poffset, + nextOutput + noffset, + maxOffset, pflag, nflag, llNthreads); + POST_SIZE; + ACK_PREV; + + NEXT_STEP_LL; + } + + // step k-1: reduce this buffer and data, which will produce the final + // result that we store in this data and push to the next GPU + slice = ring->devUserRanks[0]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + WAIT_NEXT; + LL::ReduceCopy( + thisInput + offset, + prevInput + poffset, + thisOutput + offset, + nextOutput + noffset, + maxOffset, pflag, nflag, llNthreads); + POST_SIZE; + ACK_PREV; + + NEXT_STEP_LL; + + // k-2 steps: copy to next GPU + for (int j=1; j<nranks-1; ++j) { + slice = ring->devUserRanks[nranks - j]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + WAIT_NEXT; + LL::ReduceCopy( + prevInput + poffset, + thisOutput + offset, + nextOutput + noffset, + maxOffset, pflag, nflag, llNthreads); + POST_SIZE; + ACK_PREV; + + NEXT_STEP_LL; + } + + // Make final copy from buffer to dest. + slice = ring->devUserRanks[1]; + offset = chunkOffset + slice * chunkSize; + maxOffset = min(chunkSize, size-offset); + + // Here we need to copy from buffer to this output. + LL::ReduceCopy( + prevInput + poffset, + thisOutput + offset, + maxOffset, pflag, llNthreads); + ACK_PREV; + } + + FIFO_CLEANING_AND_SAVE_STEP(nflag); +} diff --git a/src/collectives/device/broadcast.cu b/src/collectives/device/broadcast.cu new file mode 100644 index 0000000..4125de4 --- /dev/null +++ b/src/collectives/device/broadcast.cu @@ -0,0 +1,15 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "common.h" +#include "broadcast.h" +#include "collectives.h" + +#define UNROLL 4 + +#if NCCL_OP == 0 +IMPL_COLL3(ncclBroadcast, copy, FuncSum, i8, int8_t, ncclCollBroadcast, ncclSum, ncclInt8); +#endif diff --git a/src/collectives/device/broadcast.h b/src/collectives/device/broadcast.h new file mode 100644 index 0000000..c2f6d00 --- /dev/null +++ b/src/collectives/device/broadcast.h @@ -0,0 +1,228 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "core.h" +#include "primitives.h" +#include "collectives.h" + +// Increase Step and boffset for buffer sync +#define NEXT_STEP \ + step++; \ + boffset += sliceSize; \ + if (boffset == buffSize) boffset = 0; + +template<int UNROLL, class FUNC, typename T> +__device__ void ncclBroadcastKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int nthreads = blockDim.x - 1; + const int bid = args->bid; + __shared__ T* sharedNextOutput; + struct ncclComm* comm = args->comm; + struct ncclRing* ring = comm->rings+blockIdx.x; + int prevdirect = ring->recv.conn.direct; + int nextdirect = ring->send.conn.direct; + + WaitFlag waitDoneFromNext(ring->send.conn.head, (BROADCAST_BUFCHUNKS-1)*BROADCAST_SUBSTEPS); + WaitFlag waitReadyFromPrev(ring->recv.conn.tail, 0); + PostFlag postDoneToPrev(ring->recv.conn.head, 0, NULL, 0); + PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, BROADCAST_BUFCHUNKS*BROADCAST_SUBSTEPS); + + typedef Primitives<UNROLL, BROADCAST_SUBSTEPS, T> Prims; + + const ssize_t size = args->N; + const int buffSize = ring->buffSize / sizeof(T); + const int sliceSize = buffSize / BROADCAST_BUFCHUNKS; + const ssize_t loopSize = args->nRings*(ssize_t)sliceSize; + const int rank = ring->devUserRanks[0]; + const int nextRank = ring->devUserRanks[1]; + const int root = args->root; + + if (tid == 0) { + // Update in case we skipped some collectives + *ring->recv.conn.opCount = args->opCount; + if (nextRank != root) { + // Wait for next to be ready + WaitFlag waitOpCountNext(ring->send.conn.opCount, 0); + waitOpCountNext.wait(args->opCount); + } + if (rank != root && prevdirect) { + *ring->recv.conn.ptrExchange = args->ThisOutput; + } + if (nextRank != root && nextdirect) { + void* volatile* ptr = &(ring->devMemSend->ptrExchange); + while (*ptr == nullptr); + sharedNextOutput = (T*)*ptr; + *ptr = nullptr; + } + } + __syncthreads(); + + uint64_t step = 0ULL; + int boffset = 0; + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + T * __restrict__ prevInput = (T*)ring->recv.conn.buff; + T * __restrict__ nextOutput = (T*)ring->send.conn.buff; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + int chunkSize = min(sliceSize, DIVUP(size-gridOffset,args->nRings)); + ALIGN_SIZE(chunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); + ssize_t offset = gridOffset + bid*chunkSize; + int maxOffset = min(chunkSize, size-offset); + + if (rank == root) { + if (thisInput == thisOutput) { + Prims::Copy(tid, nthreads, + thisInput + offset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + boffset), + sliceSize, maxOffset, + step, + waitDoneFromNext, + postReadyToNext); + } else { + Prims::DoubleCopy(tid, nthreads, + thisInput + offset, + thisOutput + offset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + boffset), + sliceSize, maxOffset, + step, + waitDoneFromNext, + postReadyToNext); + } + } else if (nextRank == root) { + if (prevdirect) maxOffset = 0; // Only wait for signals + Prims::Copy(tid, nthreads, + prevInput + boffset, + thisOutput + offset, + sliceSize, maxOffset, + step, + waitReadyFromPrev, + postDoneToPrev); + } else { + if (prevdirect) { + Prims::Copy(tid, nthreads, + thisOutput + offset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + boffset), + sliceSize, maxOffset, + step, + waitDoneFromNext, waitReadyFromPrev, + postReadyToNext, postDoneToPrev); + } else { + Prims::DoubleCopy(tid, nthreads, + prevInput + boffset, + thisOutput + offset, + nextdirect ? (sharedNextOutput + offset) : (nextOutput + boffset), + sliceSize, maxOffset, + step, + waitDoneFromNext, waitReadyFromPrev, + postReadyToNext, postDoneToPrev); + } + } + NEXT_STEP; // Increases step, boffset + } + + if (tid == 0) { + if (nextRank != root) { + // Wait for next to have consumed data before resetting the flag + waitDoneFromNext.wait(BROADCAST_SUBSTEPS*(step + BROADCAST_BUFCHUNKS - 1)); + *ring->send.conn.head = 0ULL; + } + *ring->recv.conn.tail = 0ULL; + __threadfence_system(); + *ring->recv.conn.opCount = args->opCount+1; + } +} + +#include "ll_kernel.h" + +#define NEXT_STEP_LL \ + boffset += NCCL_LL_SLICE_LINES; \ + if (boffset == NCCL_LL_BUFF_LINES) boffset = 0; \ + flag++; \ + step++; + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclBroadcastLLKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int llNthreads = args->nThreads; + struct ncclComm* comm = args->comm; + struct ncclRing* ring = comm->rings+blockIdx.x; + volatile uint64_t * recvHeadPtr = ring->recv.conn.llHead; + volatile uint64_t * sendHeadPtr = ring->send.conn.llHead; + volatile int * sizesFifo = ring->send.conn.llFifo; + uint64_t sendHead = sendHeadPtr[0]; + const int rank = comm->rank; + const int nextRank = ring->devUserRanks[1]; + const int root = args->root; + + typedef LLPrimitives<T, FUNC> LL; + + const ssize_t size = args->N; + ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const ssize_t loopSize = args->nRings*chunkSize; + + uint64_t step = ring->send.conn.llStep; + uint32_t flag = step + 1; + int boffset = NCCL_LL_SLICE_LINES * STEP_TO_SLOT(step); + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + union ncclLLFifoLine * prevInput = (union ncclLLFifoLine *)ring->recv.conn.llBuff; + union ncclLLFifoLine * nextOutput = (union ncclLLFifoLine *)ring->send.conn.llBuff; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + if (size-gridOffset < loopSize) { + chunkSize = args->lastChunkSize; + } + ssize_t offset = gridOffset + bid*chunkSize; + + int maxOffset = min(chunkSize, size-offset); + if (rank == root) { + WAIT_NEXT; + if (thisInput == thisOutput) { + LL::ReduceCopy( + thisInput + offset, + nextOutput + boffset, + maxOffset, flag, llNthreads); + } else { + LL::ReduceCopy( + thisInput + offset, + thisOutput + offset, + nextOutput + boffset, + maxOffset, flag, llNthreads); + } + POST_SIZE; + NEXT_STEP_LL; + } else if (nextRank == root) { + LL::ReduceCopy( + prevInput + boffset, + thisOutput + offset, + maxOffset, flag, llNthreads); + NEXT_STEP_LL; + ACK_PREV; + } else { + WAIT_NEXT; + LL::ReduceCopy( + prevInput + boffset, + thisOutput + offset, + nextOutput + boffset, + maxOffset, flag, flag, llNthreads); + POST_SIZE; + NEXT_STEP_LL; + ACK_PREV; + } + } + + // We need everyone to acknowledge data even if they didn't receive anything + // so that the next collective can start right away. + ACK_PREV; + + FIFO_CLEANING_AND_SAVE_STEP(flag); +} diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h new file mode 100644 index 0000000..c988913 --- /dev/null +++ b/src/collectives/device/common.h @@ -0,0 +1,90 @@ +/************************************************************************* + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef NCCL_DEVICE_COMMON_H_ +#define NCCL_DEVICE_COMMON_H_ + +#include "../collectives.h" +#include "core.h" +#include "nccl.h" + +typedef void(*ncclKern_t)(struct CollectiveArgs* args); +extern __device__ ncclKern_t ncclFuncs[]; + +static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) { + int* d = (int*)dst; + int* s = (int*)src; + __syncthreads(); + for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o]; + __syncthreads(); +} +static __device__ void load_coll(struct ncclColl* localColl, struct ncclColl* hostColl, int tid) { + load_parallel(localColl, hostColl, sizeof(struct ncclColl), tid); + if (tid == 0) hostColl->active = 0; +} + +/* Functions for aggregation case */ +#define IMPL_COLL4(coll, op, ncclFunc, dtype, ctype) \ +__device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args) { \ + coll##Kernel<UNROLL, ncclFunc<ctype>, ctype>(args); \ +} +/* Kernels with the first operation inlined */ +#define IMPL_COLL4K(coll, op, ncclFunc, dtype, ctype, fIndex) \ +__launch_bounds__(MAXTHREADS+WARP_SIZE, 1) \ +__global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ + int tid = threadIdx.x; \ + int bid = blockIdx.x; \ + __shared__ struct ncclColl localColl; \ + \ + struct ncclComm* comm = firstColl.args.comm; \ + struct ncclRing* ring = comm->rings+bid; \ + struct ncclColl* c; \ + if (bid == 0) { \ + /* To optimize for latency, (only) the first operation is passed as argument.*/ \ + c = &firstColl; \ + } else { \ + c = &localColl; \ + load_coll(c, ring->devCollectives+ring->collFifoHead, tid); \ + } \ + while (1) { \ + if (tid < c->nThreads) { \ + if (c->funcIndex == fIndex) { \ + coll##Kernel<UNROLL, ncclFunc<ctype>, ctype>(&c->args); \ + } else { \ + ncclFuncs[c->funcIndex](&c->args); \ + } \ + } \ + int nextIndex = c->nextIndex; \ + if (tid == 0) ring->collFifoHead = nextIndex; \ + \ + if (c->active == 2) { \ + return; \ + } \ + \ + /* Load next collective operation*/ \ + c = &localColl; /* for bid 0 */ \ + load_coll(c, ring->devCollectives+nextIndex, tid); \ + } \ +} + +#define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \ + IMPL_COLL4(coll##LL, op, ncclFunc, dtype, ctype) \ + IMPL_COLL4K(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1)) \ + IMPL_COLL4(coll, op, ncclFunc, dtype, ctype) \ + IMPL_COLL4K(coll, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 0)) \ + +#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ + IMPL_COLL3(coll, op, ncclFunc, i8, int8_t, ncclColl, ncclOp, ncclInt8) \ + IMPL_COLL3(coll, op, ncclFunc, u8, uint8_t, ncclColl, ncclOp, ncclUint8) \ + IMPL_COLL3(coll, op, ncclFunc, i32, int32_t, ncclColl, ncclOp, ncclInt32) \ + IMPL_COLL3(coll, op, ncclFunc, u32, uint32_t, ncclColl, ncclOp, ncclUint32) \ + IMPL_COLL3(coll, op, ncclFunc, i64, int64_t, ncclColl, ncclOp, ncclInt64) \ + IMPL_COLL3(coll, op, ncclFunc, u64, uint64_t, ncclColl, ncclOp, ncclUint64) \ + IMPL_COLL3(coll, op, ncclFunc, f16, half, ncclColl, ncclOp, ncclFloat16) \ + IMPL_COLL3(coll, op, ncclFunc, f32, float, ncclColl, ncclOp, ncclFloat32) \ + IMPL_COLL3(coll, op, ncclFunc, f64, double, ncclColl, ncclOp, ncclFloat64) + +#endif diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h new file mode 100644 index 0000000..0eaa061 --- /dev/null +++ b/src/collectives/device/common_kernel.h @@ -0,0 +1,372 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef NCCL_COMMON_KERNEL_H_ +#define NCCL_COMMON_KERNEL_H_ + +#include "core.h" +#include <cstdio> +#include <cstdint> + +#include <cuda_runtime.h> + +// Define min for ssize_t +static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; } + +typedef uint64_t PackType; + +// unpack x and y to elements of type T and apply FUNC to each element +template<class FUNC, typename T> +struct MULTI { + __device__ PackType operator()(const PackType x, const PackType y) const; +}; + +template<class FUNC> +struct MULTI<FUNC, int8_t> { + static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), + "PackType must be twice the size of uint32_t."); + union converter { + PackType storage; + struct { + uint32_t a, b; + }; + }; + + __device__ PackType operator()(const PackType x, const PackType y) const { + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + + // for char, we do these as vector ops + cr.a = FUNC()(cx.a, cy.a); + cr.b = FUNC()(cx.b, cy.b); + + return cr.storage; + } +}; + +template<class FUNC> +struct MULTI<FUNC, uint8_t> { + static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), + "PackType must be twice the size of uint32_t."); + union converter { + PackType storage; + struct { + uint32_t a, b; + }; + }; + + __device__ PackType operator()(const PackType x, const PackType y) const { + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + + // for char, we do these as vector ops + cr.a = FUNC()(cx.a, cy.a); + cr.b = FUNC()(cx.b, cy.b); + + return cr.storage; + } +}; + +template<class FUNC> +struct MULTI<FUNC, int32_t> { + static_assert(sizeof(PackType) == 2 * sizeof(int32_t), + "PackType must be twice the size of int."); + union converter { + PackType storage; + struct { + int32_t a, b; + }; + }; + + __device__ PackType operator()(const PackType x, const PackType y) const { + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + + cr.a = FUNC()(cx.a, cy.a); + cr.b = FUNC()(cx.b, cy.b); + + return cr.storage; + } +}; + +template<class FUNC> +struct MULTI<FUNC, uint32_t> { + static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), + "PackType must be twice the size of int."); + union converter { + PackType storage; + struct { + uint32_t a, b; + }; + }; + + __device__ PackType operator()(const PackType x, const PackType y) const { + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + + cr.a = FUNC()(cx.a, cy.a); + cr.b = FUNC()(cx.b, cy.b); + + return cr.storage; + } +}; + +template<class FUNC> +struct MULTI<FUNC, half> { + static_assert(sizeof(PackType) == 4 * sizeof(half), + "PackType must be four times the size of half."); + + struct PackHalf2 { + half2 a, b; + }; + + __device__ PackType operator()(const PackType x, const PackType y) const { + struct PackHalf2 cx, cy, cr; + cx = *(reinterpret_cast<const struct PackHalf2*>(&x)); + cy = *(reinterpret_cast<const struct PackHalf2*>(&y)); + + cr.a = FUNC()(cx.a, cy.a); + cr.b = FUNC()(cx.b, cy.b); + + return *(reinterpret_cast<PackType*>(&cr)); + } +}; + +template<class FUNC> +struct MULTI<FUNC, float> { + static_assert(sizeof(PackType) == 2 * sizeof(float), + "PackType must be twice the size of float."); + union converter { + PackType storage; + struct { + float a, b; + }; + }; + + __device__ PackType operator()(const PackType x, const PackType y) const { + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + + cr.a = FUNC()(cx.a, cy.a); + cr.b = FUNC()(cx.b, cy.b); + + return cr.storage; + } +}; + +template<class FUNC> +struct MULTI<FUNC, double> { + static_assert(sizeof(PackType) == sizeof(double), + "PackType must be the same size as double."); + __device__ PackType operator()(const PackType x, const PackType y) const { + double rv = FUNC()(__longlong_as_double(x), __longlong_as_double(y)); + return __double_as_longlong(rv); + } +}; + +template<class FUNC> +struct MULTI<FUNC, uint64_t> { + static_assert(sizeof(PackType) == sizeof(uint64_t), + "PackType must be the same size as uint64_t."); + __device__ PackType operator()(const PackType x, const PackType y) const { + uint64_t rv = FUNC()(x, y); + return rv; + } +}; + +template<class FUNC> +struct MULTI<FUNC, int64_t> { + static_assert(sizeof(PackType) == sizeof(int64_t), + "PackType must be the same size as int64_t."); + __device__ PackType operator()(const PackType x, const PackType y) const { + int64_t rv = FUNC()((int64_t)x, (int64_t)y); + return rv; + } +}; + +#define ALIGNUP(x, a) ((((x)-1) & ~((a)-1)) + (a)) + +template<typename T> +__device__ inline volatile T* AlignUp(volatile T * ptr, size_t align) { + size_t ptrval = reinterpret_cast<size_t>(ptr); + return reinterpret_cast<volatile T*>(ALIGNUP(ptrval, align)); +} + +template<typename T> inline __device__ +T vFetch(const volatile T* ptr) { + return *ptr; +} + +template<typename T> inline __device__ +void vStore(volatile T* ptr, const T val) { + *ptr = val; +} + +#if CUDART_VERSION < 9000 +template<> inline __device__ +half vFetch<half>(const volatile half* ptr) { + half r; + r.x = ptr->x; + return r; +} + +template<> inline __device__ +void vStore<half>(volatile half* ptr, const half val) { + ptr->x = val.x; +} +#else +template<> inline __device__ +half vFetch<half>(const volatile half* ptr) { + half r; + r = ((half*)ptr)[0]; + return r; +} + +template<> inline __device__ +void vStore<half>(volatile half* ptr, const half val) { + ((half*)ptr)[0] = val; +} +#endif + +template<class FUNC, typename T, bool TWO_INPUTS, bool TWO_OUTPUTS> +__device__ inline void ReduceCopy( + const int tid, const int nthreads, + const volatile T * __restrict__ const src0, + const volatile T * __restrict__ const src1, + volatile T * __restrict__ const dest0, + volatile T * __restrict__ const dest1, const int N) { + for (int idx = tid; idx < N; idx += nthreads) { + T val = vFetch(src0+idx); + if (TWO_INPUTS) { + val = FUNC()(val, vFetch(src1+idx)); + } + vStore(dest0+idx, val); + if (TWO_OUTPUTS) { + vStore(dest1+idx, val); + } + } +} + +typedef ulong2 Pack128; + +template<class FUNC, typename T> +struct MULTI128 { + __device__ void operator()(Pack128& x, Pack128& y) { + x.x = MULTI<FUNC, T>()(x.x, y.x); + x.y = MULTI<FUNC, T>()(x.y, y.y); + } +}; + +inline __device__ void Fetch128(Pack128& v, Pack128* p) { + asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory"); +} +inline __device__ void Store128(Pack128* p, Pack128& v) { + asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory"); +} + +#define WARP_SIZE 32 +template<class FUNC, typename T, bool TWO_INPUTS, bool TWO_OUTPUTS, int UNROLL> +__device__ inline void ReduceCopy128b( const int w, const int nw, const int t, + Pack128 * src0, Pack128 * src1, Pack128 * dest0, Pack128 * dest1, + const int N) { + Pack128 t0[UNROLL]; + Pack128 t1[UNROLL]; + const Pack128* src0_end = src0 + N; + const int inc = nw * UNROLL * WARP_SIZE; + const int offset = w * UNROLL * WARP_SIZE + t; + src0 += offset; if (TWO_INPUTS) src1 += offset; + dest0 += offset; if (TWO_OUTPUTS) dest1 += offset; + + while (src0 < src0_end) { +#pragma unroll + for (int u = 0; u < UNROLL; ++u) { + Fetch128(t0[u], src0+u*WARP_SIZE); + if (TWO_INPUTS) Fetch128(t1[u], src1+u*WARP_SIZE); + } +#pragma unroll + for (int u = 0; u < UNROLL; ++u) { + if (TWO_INPUTS) MULTI128<FUNC, T>()(t0[u], t1[u]); + Store128(dest0+u*WARP_SIZE, t0[u]); + if (TWO_OUTPUTS) Store128(dest1+u*WARP_SIZE, t0[u]); + } + src0 += inc; if (TWO_INPUTS) src1 += inc; + dest0 += inc; if (TWO_OUTPUTS) dest1 += inc; + } +} + +template<int UNROLL, class FUNC, typename T, bool HAS_DEST1, bool HAS_SRC1> +__device__ inline void ReduceOrCopy(const int tid, const int nthreads, + volatile T * __restrict__ dest0, volatile T * __restrict__ dest1, + const volatile T * __restrict__ src0, const volatile T * __restrict__ src1, + int N) { + int Nrem = N; + if (Nrem <= 0) return; + + int Npreamble = (Nrem<alignof(Pack128)) ? Nrem : AlignUp(dest0, alignof(Pack128)) - dest0; + + // stage 0: check if we'll be able to use the fast, 128-bit aligned path. + // If not, we'll just use the slow preamble path for the whole operation + bool alignable = (((AlignUp(src0, alignof(Pack128)) == src0 + Npreamble)) && + (!HAS_DEST1 || (AlignUp(dest1, alignof(Pack128)) == dest1 + Npreamble)) && + (!HAS_SRC1 || (AlignUp(src1, alignof(Pack128)) == src1 + Npreamble))); + + if (!alignable) { + Npreamble = Nrem; + } + + // stage 1: preamble: handle any elements up to the point of everything coming + // into alignment + ReduceCopy<FUNC, T, HAS_SRC1, HAS_DEST1>(tid, nthreads, src0, src1, dest0, dest1, Npreamble); + + Nrem -= Npreamble; + if (Nrem == 0) return; + + dest0 += Npreamble; if (HAS_DEST1) { dest1 += Npreamble; } + src0 += Npreamble; if (HAS_SRC1) { src1 += Npreamble; } + + // stage 2: fast path: use 128b loads/stores to do the bulk of the work, + // assuming the pointers we have are all 128-bit alignable. + int w = tid / WARP_SIZE; // Warp number + int nw = nthreads / WARP_SIZE; // Number of warps + int t = tid % WARP_SIZE; // Thread (inside the warp) + + const int PackFactor = sizeof(Pack128) / sizeof(T); + + // stage 2a: main loop + int Nalign2a = (Nrem / (PackFactor * UNROLL * nthreads)) + * (UNROLL * nthreads); // round down + + ReduceCopy128b<FUNC, T, HAS_SRC1, HAS_DEST1, UNROLL>(w, nw, t, (Pack128*)src0, (Pack128*)src1, (Pack128*)dest0, (Pack128*)dest1, Nalign2a); + + int Ndone2a = Nalign2a * PackFactor; + Nrem -= Ndone2a; + if (Nrem == 0) return; + dest0 += Ndone2a; if (HAS_DEST1) { dest1 += Ndone2a; } + src0 += Ndone2a; if (HAS_SRC1) { src1 += Ndone2a; } + + // stage 2b: slightly less optimized for section when we don't have full + // UNROLLs + + int Nalign2b = Nrem / PackFactor; + + ReduceCopy128b<FUNC, T, HAS_SRC1, HAS_DEST1, 1>(w, nw, t, (Pack128*)src0, (Pack128*)src1, (Pack128*)dest0, (Pack128*)dest1, Nalign2b); + + int Ndone2b = Nalign2b * PackFactor; + Nrem -= Ndone2b; + if (Nrem == 0) return; + dest0 += Ndone2b; if (HAS_DEST1) { dest1 += Ndone2b; } + src0 += Ndone2b; if (HAS_SRC1) { src1 += Ndone2b; } + + // stage 2c: tail + ReduceCopy<FUNC, T, HAS_SRC1, HAS_DEST1>(tid, nthreads, src0, src1, dest0, dest1, Nrem); +} + +#endif // COMMON_KERNEL_H_ diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu new file mode 100644 index 0000000..16f1865 --- /dev/null +++ b/src/collectives/device/functions.cu @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "core.h" +#include "collectives.h" +#include "common.h" + +#define NCCL_FUNC4(coll, op, dtype) \ + NCCL_COLL_NAME(coll, op, dtype), \ + NCCL_COLL_NAME(coll##LL, op, dtype) \ + +// Must be consistent with ncclDataType_t +#define NCCL_FUNCS3A(coll, op) \ + NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, u8), \ + NCCL_FUNC4(coll, op, i32), \ + NCCL_FUNC4(coll, op, u32), \ + NCCL_FUNC4(coll, op, i64), \ + NCCL_FUNC4(coll, op, u64), \ + NCCL_FUNC4(coll, op, f16), \ + NCCL_FUNC4(coll, op, f32), \ + NCCL_FUNC4(coll, op, f64) +#define NCCL_FUNCS3B(coll, op) \ + NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8), \ + NCCL_FUNC4(coll, op, i8) + +// Must be consistent with ncclRedOp_t +#define NCCL_FUNCS2A(coll) \ + NCCL_FUNCS3A(coll, sum ), \ + NCCL_FUNCS3A(coll, prod), \ + NCCL_FUNCS3A(coll, max ), \ + NCCL_FUNCS3A(coll, min ) +#define NCCL_FUNCS2B(coll) \ + NCCL_FUNCS3B(coll, copy), \ + NCCL_FUNCS3B(coll, copy), \ + NCCL_FUNCS3B(coll, copy), \ + NCCL_FUNCS3B(coll, copy) + +// Must be consistent with ncclColl_t +#define NCCL_FUNCS() { \ + NCCL_FUNCS2B(ncclBroadcast), \ + NCCL_FUNCS2A(ncclReduce), \ + NCCL_FUNCS2B(ncclAllGather), \ + NCCL_FUNCS2A(ncclReduceScatter), \ + NCCL_FUNCS2A(ncclAllReduce) } + +// Must be consistent with the ncclFuncSet enum +__device__ ncclKern_t ncclFuncs[ncclCollCount*ncclNumOps*ncclNumTypes*2] = { + NCCL_FUNCS2B(ncclBroadcast), + NCCL_FUNCS2A(ncclReduce), + NCCL_FUNCS2B(ncclAllGather), + NCCL_FUNCS2A(ncclReduceScatter), + NCCL_FUNCS2A(ncclAllReduce) +}; diff --git a/src/collectives/device/ll_kernel.h b/src/collectives/device/ll_kernel.h new file mode 100644 index 0000000..5ec3c9a --- /dev/null +++ b/src/collectives/device/ll_kernel.h @@ -0,0 +1,154 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef NCCL_LL_KERNEL_H_ +#define NCCL_LL_KERNEL_H_ + +static __device__ uint64_t readLL(union ncclLLFifoLine* src, uint32_t flag) { + uint32_t data1, flag1, data2, flag2; + do { + asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(data1), "=r"(flag1), "=r"(data2), "=r"(flag2) : "l"(&src->i4)); + } while ((flag1 != flag) || (flag2 != flag)); + uint64_t val64 = data1 + (((uint64_t)data2) << 32); + return val64; +} + +static __device__ void storeLL(union ncclLLFifoLine* dst, uint64_t val, uint32_t flag) { + asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(&dst->i4), "r"((uint32_t)val), "r"(flag), "r"((uint32_t)(val >> 32)), "r"(flag)); +} + +// Using memcpy handles misaligned pointers. +static __device__ uint64_t readAL(uint64_t* src) { + uint64_t val; + memcpy((char*)&val, (char*)src, sizeof(uint64_t)); + return val; +} +static __device__ void storeAL(uint64_t* dst, uint64_t val) { + memcpy((char*)dst, (char*)&val, sizeof(uint64_t)); +} + +template <typename T, class FUNC> +class LLPrimitives { + private: + template <int HAS_SRC1, int HAS_SRC2, int HAS_DST1, int HAS_DST2> + static __device__ void ReduceCopyGeneric(const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) { + if (size <= 0) return; + size_t size64 = size * sizeof(T) / sizeof(uint64_t); + uint64_t* src1A = (uint64_t*)src1; + uint64_t* dst1A = (uint64_t*)dst1; + int offset = threadIdx.x; + // Do multiples of 64 bits +#pragma unroll 1 + for (; offset < size64; offset += nthreads) { + uint64_t val; + if (HAS_SRC1) { + val = readAL(src1A+offset); + if (HAS_SRC2) val = MULTI<FUNC, T>()(readLL(src2+offset, iflag), val); + } else if (HAS_SRC2) { + val = readLL(src2+offset, iflag); + } + if (HAS_DST1) storeAL(dst1A+offset, val); + if (HAS_DST2) storeLL(dst2+offset, val, oflag); + } + // Finish last word + int sizeDone = size64*(sizeof(uint64_t)/sizeof(T)); + int sizeRem = size - sizeDone; + if (threadIdx.x == 0 && sizeRem) { + const T* src1B = src1 + sizeDone; + T* dst1B = dst1 + sizeDone; + + uint64_t lastVal; + T* vals = (T*)&lastVal; + + if (HAS_SRC2) { + uint64_t lastVal2 = readLL(src2+size64, iflag); + T* src2B = (T*)&lastVal2; + for (int offset = 0; offset < sizeRem; offset++) { + vals[offset] = HAS_SRC1 ? FUNC()(src2B[offset], src1B[offset]) : src2B[offset]; + } + } else if (HAS_SRC1) { + for (int offset = 0; offset < sizeRem; offset++) { + vals[offset] = src1B[offset]; + } + } + if (HAS_DST2) storeLL(dst2+size64, lastVal, oflag); + if (HAS_DST1) { + for (int offset = 0; offset < sizeRem; offset++) { + dst1B[offset] = vals[offset]; + } + } + } + } + public: + static __device__ void ReduceCopy(const T* src, union ncclLLFifoLine* dst, int size, uint32_t oflag, int nthreads) { + return ReduceCopyGeneric<1, 0, 0, 1>(src, NULL, NULL, dst, size, 0, oflag, nthreads); + } + + static __device__ void ReduceCopy(union ncclLLFifoLine* src, T* dst, int size, uint32_t iflag, int nthreads) { + return ReduceCopyGeneric<0, 1, 1, 0>(NULL, src, dst, NULL, size, iflag, 0, nthreads); + } + + static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, union ncclLLFifoLine* dst, int size, uint32_t iflag, uint32_t oflag, int nthreads) { + return ReduceCopyGeneric<1, 1, 0, 1>(src1, src2, NULL, dst, size, iflag, oflag, nthreads); + } + + static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, T* dst, int size, uint32_t iflag, int nthreads) { + return ReduceCopyGeneric<1, 1, 1, 0>(src1, src2, dst, NULL, size, iflag, 0, nthreads); + } + + static __device__ void ReduceCopy(const T* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t oflag, int nthreads) { + return ReduceCopyGeneric<1, 0, 1, 1>(src, NULL, dst1, dst2, size, 0, oflag, nthreads); + } + + static __device__ void ReduceCopy(union ncclLLFifoLine* src, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) { + return ReduceCopyGeneric<0, 1, 1, 1>(NULL, src, dst1, dst2, size, iflag, oflag, nthreads); + } + + static __device__ void ReduceCopy(const T* src1, union ncclLLFifoLine* src2, T* dst1, union ncclLLFifoLine* dst2, int size, uint32_t iflag, uint32_t oflag, int nthreads) { + return ReduceCopyGeneric<1, 1, 1, 1>(src1, src2, dst1, dst2, size, iflag, oflag, nthreads); + } +}; + +// Common macros + +#define STEP_TO_SLOT(step) \ + (step % NCCL_LL_CHUNKS) + +#define WAIT_NEXT \ + if (tid == 0) { \ + while (sendHead + NCCL_LL_CHUNKS <= step) { \ + sendHead = sendHeadPtr[0]; \ + } \ + } \ + asm volatile ("bar.sync 1, %0;" :: "r"(llNthreads)); + +#define POST_SIZE \ + if (tid == 0 && sizesFifo) sizesFifo[step % NCCL_LL_CHUNKS] = (maxOffset <= 0) ? -1 : (maxOffset*2*(int)sizeof(T)); + +#define ACK_PREV \ + asm volatile ("bar.sync 1, %0;" :: "r"(llNthreads)); \ + if (tid == 0) recvHeadPtr[0] = step; + +#define FIFO_CLEANING_AND_SAVE_STEP(flag) do { \ + if (step > ring->send.conn.llLastCleaning + NCCL_LL_CLEAN_FREQ) { \ + /* Reset all flags */ \ + static_assert((NCCL_LL_BUFF_SIZE % NCCL_LL_MAX_NTHREADS) == 0, "NCCL_LL_BUFF_SIZE must be a multiple of THREADS"); \ + static_assert(NCCL_LL_BUFF_SIZE/(sizeof(union ncclLLFifoLine)*NCCL_LL_MAX_NTHREADS) > 0, "NCCL_LL_BUFF_SIZE is less than 16 bytes*THREADS"); \ + const union ncclLLFifoLine resetLine = { 0, flag, 0, flag }; \ + for (int i=0; i<NCCL_LL_BUFF_SIZE/(sizeof(union ncclLLFifoLine)*llNthreads); i++) { \ + prevInput[tid+i*llNthreads].i4 = resetLine.i4; \ + } \ + __threadfence_system(); \ + /* Restart from the same slot, only make sure sender waits for data to be reset */ \ + step += NCCL_LL_CHUNKS; \ + ACK_PREV; \ + while (sendHeadPtr[0] < step); \ + if (tid == 0) ring->send.conn.llLastCleaning = step; \ + } \ + ring->send.conn.llStep = step; \ +} while (0); + +#endif diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h new file mode 100644 index 0000000..8df152e --- /dev/null +++ b/src/collectives/device/primitives.h @@ -0,0 +1,226 @@ +/************************************************************************* + * Copyright (c) 2016-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef NCCL_PRIMITIVES_H_ +#define NCCL_PRIMITIVES_H_ + +#include <type_traits> +#include "reduce_kernel.h" // for reduction funcs + + +/* Defines primitive operations: Copy, Reduce, DoubleCopy, and ReduceCopy. + * + * In order to reduce the reptetion of template arguments, the operations + * are bundled as static methods of the Primitives class. + * + * Each primitive operation copies/reduces a contiguous buffer and syncs + * an optional set of flags against a sub-step counter. The sync value is + * based on the step parameter. Sync flags must be of type WaitFlag or + * PostFlag. The primitive routines wait for all WaitFlag args to attain + * at least a value of SUBSTEPS*(step-1)+substep+1 (i.e. completion of + * corresponding substep by previous step) before executing the transfer. + * After each substep is transfered, all PostFlag arguments get updated to + * the value SUBSTEPS*step+substep+1. + */ + + +class WaitFlag { + volatile uint64_t * const flag; + const int shift; + public: + __device__ __forceinline__ + WaitFlag(volatile uint64_t * const flag, const int shift) : flag(flag), shift(shift) { } + __device__ __forceinline__ + void wait(uint64_t val) { while ((*flag + shift) < val) /*SPIN*/; } +}; + + +class PostFlag { + volatile uint64_t * const flag; + const int shift; + volatile int * const fifo; + const int fifo_size; + public: + __device__ __forceinline__ + PostFlag(volatile uint64_t* const flag, const int shift, volatile int* const fifo, const int fifo_size) : flag(flag), shift(shift), fifo(fifo), fifo_size(fifo_size) { } + __device__ __forceinline__ + void post(uint64_t val) { *flag = (val - shift); } + __device__ __forceinline__ + void postSize(uint64_t step, int size) { if (fifo != NULL) fifo[step%fifo_size] = size; }; +}; + + +// Helper to check if any argument is of type T. +// e.g. AnyAre<WaitFlag>(Flag1, Flag2, ...) +template<typename T> __device__ __forceinline__ +bool AnyAre() { return false; } + +template<typename T, typename FIRST_T, typename... TAIL_Ts> +__device__ __forceinline__ +bool AnyAre(FIRST_T first, TAIL_Ts... tail) { + return std::is_same<T, FIRST_T>::value || AnyAre<T>(tail...); +} + + +// Wait on all WaitFlags, ignore PostFlags +__device__ __forceinline__ +void WaitOnFlags(uint64_t val) { } + +template <typename... TAIL_Ts> __device__ __forceinline__ +void WaitOnFlags(uint64_t val, WaitFlag flag, TAIL_Ts... tail) { + flag.wait(val); + WaitOnFlags(val, tail...); +} + +template <typename... TAIL_Ts> __device__ __forceinline__ +void WaitOnFlags(uint64_t val, PostFlag, TAIL_Ts... tail) { + WaitOnFlags(val, tail...); +} + + +// Post all PostFlags, ignore WaitFlags +__device__ __forceinline__ +void PostToFlags(uint64_t val) { } + +template <typename... TAIL_Ts> __device__ __forceinline__ +void PostToFlags(uint64_t val, WaitFlag flag, TAIL_Ts... tail) { + PostToFlags(val, tail...); +} + +template <typename... TAIL_Ts> __device__ __forceinline__ +void PostToFlags(uint64_t val, PostFlag flag, TAIL_Ts... tail) { + flag.post(val); + PostToFlags(val, tail...); +} + + +// Post sizes for PostFlags, ignore WaitFlags +__device__ __forceinline__ +void PostSizeToFlags(uint64_t step, int size) { } + +template <typename... TAIL_Ts> __device__ __forceinline__ +void PostSizeToFlags(uint64_t step, int size, WaitFlag flag, TAIL_Ts... tail) { + PostSizeToFlags(step, size, tail...); +} + +template <typename... TAIL_Ts> __device__ __forceinline__ +void PostSizeToFlags(uint64_t step, int size, PostFlag flag, TAIL_Ts... tail) { + flag.postSize(step, size); + PostSizeToFlags(step, size, tail...); +} + + +// Create pointer arithmetic syntax that doesn't break for nullptr_t +template <typename Tptr> __device__ __forceinline__ +Tptr ptradd(Tptr ptr, int i) { + return ptr + i; +} + +__device__ __forceinline__ +nullptr_t ptradd(nullptr_t ptr, int i) { + return nullptr; +} + + +// Implementation of primitive types +template <int UNROLL, int SUBSTEPS, typename T, typename REDOP=FuncSum<T> > +class Primitives { + private: + template <typename SRC2_T, // either T* or nullptr_t + typename DST2_T, // either T* or nullptr_t + typename... SYNC_Ts> // either WaitFunc or PostFunc + static __device__ __forceinline__ void + GenericOp(const int tid, const int nthreads, + const T* src1, + const SRC2_T src2, + T* dst1, + DST2_T dst2, + int len, int maxoffset, uint64_t step, SYNC_Ts... flags) { + + enum { noSrc2 = std::is_same<SRC2_T, nullptr_t>::value }; + enum { noDst2 = std::is_same<DST2_T, nullptr_t>::value }; + static_assert(noSrc2 || std::is_same<SRC2_T, const T*>::value, + "src2 must be of type T* or nullptr_t"); + static_assert(noDst2 || std::is_same<DST2_T, T*>::value, + "dst2 must be of type T* or nullptr_t"); + + using OpType = typename std::conditional<noSrc2, FuncSum<T>, REDOP>::type; + + int sliceSize = len / SUBSTEPS; + int sliceOffset = 0; + +#pragma unroll 1 + for (int sub=0; sub<SUBSTEPS; ++sub) { + int realSize = max(0, min(sliceSize, maxoffset-sliceOffset)); + if (tid < nthreads) { + if (AnyAre<WaitFlag>(flags...)) { + if (tid == 0) { + WaitOnFlags(SUBSTEPS*step + sub + 1, flags...); + } + asm volatile ("bar.sync 1, %0;" :: "r"(nthreads)); + } + ReduceOrCopy + < + UNROLL, + OpType, + T, + !std::is_same<DST2_T, nullptr_t>::value, // HAS_DEST1 + !std::is_same<SRC2_T, nullptr_t>::value // HAS_SRC1 + > + ( + tid, nthreads, + ptradd(dst1, sliceOffset), + ptradd(dst2, sliceOffset), + ptradd(src1, sliceOffset), + ptradd(src2, sliceOffset), + realSize + ); + if (AnyAre<PostFlag>(flags...)) { + __syncthreads(); + } + } else { + if (AnyAre<PostFlag>(flags...)) { + __syncthreads(); + PostSizeToFlags(SUBSTEPS*step+sub, realSize*sizeof(T), flags...); + __threadfence_system(); + PostToFlags(SUBSTEPS*step + sub + 1, flags...); + } + } + sliceOffset += sliceSize; + } + } + + public: + template <typename... SYNC_Ts> + static __device__ __forceinline__ void + Copy(const int tid, const int nthreads, const T* src, T* dst, + int len, int maxOffset, uint64_t step, SYNC_Ts... flags) { + GenericOp(tid, nthreads, src, nullptr, dst, nullptr, len, maxOffset, step, flags...); + } + + template <typename... SYNC_Ts> + static __device__ __forceinline__ void + DoubleCopy(const int tid, const int nthreads, const T* src, T* dst1, T* dst2, + int len, int maxOffset, uint64_t step, SYNC_Ts... flags) { + GenericOp(tid, nthreads, src, nullptr, dst1, dst2, len, maxOffset, step, flags...); + } + + template <typename... SYNC_Ts> + static __device__ __forceinline__ void + Reduce(const int tid, const int nthreads, const T* src1, const T* src2, T* dst, + int len, int maxOffset, uint64_t step, SYNC_Ts... flags) { + GenericOp(tid, nthreads, src1, src2, dst, nullptr, len, maxOffset, step, flags...); + } + + template <typename... SYNC_Ts> + static __device__ __forceinline__ void + ReduceCopy(const int tid, const int nthreads, const T* src1, const T* src2, T* dst1, T* dst2, + int len, int maxOffset, uint64_t step, SYNC_Ts... flags) { + GenericOp(tid, nthreads, src1, src2, dst1, dst2, len, maxOffset, step, flags...); + } +}; + +#endif // end include guard diff --git a/src/collectives/device/reduce.cu b/src/collectives/device/reduce.cu new file mode 100644 index 0000000..bd1d23c --- /dev/null +++ b/src/collectives/device/reduce.cu @@ -0,0 +1,21 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "common.h" +#include "reduce.h" +#include "collectives.h" + +#define UNROLL 4 + +#if NCCL_OP == 0 +IMPL_COLL2(ncclReduce, sum, FuncSum, ncclCollReduce, ncclSum); +#elif NCCL_OP == 1 +IMPL_COLL2(ncclReduce, prod, FuncProd, ncclCollReduce, ncclProd); +#elif NCCL_OP == 2 +IMPL_COLL2(ncclReduce, min, FuncMin, ncclCollReduce, ncclMin); +#elif NCCL_OP == 3 +IMPL_COLL2(ncclReduce, max, FuncMax, ncclCollReduce, ncclMax); +#endif diff --git a/src/collectives/device/reduce.h b/src/collectives/device/reduce.h new file mode 100644 index 0000000..f5694b1 --- /dev/null +++ b/src/collectives/device/reduce.h @@ -0,0 +1,190 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "core.h" +#include "primitives.h" +#include "collectives.h" + +// Increase Step and boffset for buffer sync +#define NEXT_STEP \ + step++; \ + boffset += sliceSize; \ + if (boffset == buffSize) boffset = 0; + +template<int UNROLL, class FUNC, typename T> +__device__ void ncclReduceKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int nthreads = blockDim.x - 1; + const int bid = args->bid; + struct ncclComm* comm = args->comm; + struct ncclRing* ring = comm->rings+blockIdx.x; + + WaitFlag waitDoneFromNext(ring->send.conn.head, (REDUCE_BUFCHUNKS-1)*REDUCE_SUBSTEPS); + WaitFlag waitReadyFromPrev(ring->recv.conn.tail, 0); + PostFlag postDoneToPrev(ring->recv.conn.head, 0, NULL, 0); + PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, REDUCE_BUFCHUNKS*REDUCE_SUBSTEPS); + + typedef Primitives<UNROLL, REDUCE_SUBSTEPS, T, FUNC> Prims; + + const ssize_t size = args->N; + const int nranks = comm->nRanks; + const int buffSize = ring->buffSize / sizeof(T); + const int sliceSize = buffSize / REDUCE_BUFCHUNKS; + const ssize_t loopSize = args->nRings*(ssize_t)sliceSize; + const int rank = ring->devUserRanks[0]; + const int prevRank = ring->devUserRanks[nranks-1]; + const int root = args->root; + + if (tid == 0) { + // Update in case we skipped some collectives + *ring->recv.conn.opCount = args->opCount; + + if (rank != root) { + // Wait for next to be ready + WaitFlag waitOpCountNext(ring->send.conn.opCount, 0); + waitOpCountNext.wait(args->opCount); + } + } + __syncthreads(); + + uint64_t step = 0ULL; + int boffset = 0; + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + T * __restrict__ prevInput = (T*)ring->recv.conn.buff; + T * __restrict__ nextOutput = (T*)ring->send.conn.buff; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + int chunkSize = min(sliceSize, DIVUP(size-gridOffset,args->nRings)); + ALIGN_SIZE(chunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); + ssize_t offset = gridOffset + bid*chunkSize; + int maxOffset = min(chunkSize, size-offset); + if (prevRank == root) { + Prims::Copy(tid, nthreads, + thisInput + offset, + nextOutput + boffset, + sliceSize, maxOffset, + step, + waitDoneFromNext, + postReadyToNext); + } else if (rank == root) { + Prims::Reduce(tid, nthreads, + prevInput + boffset, + thisInput + offset, + thisOutput + offset, + sliceSize, maxOffset, + step, + waitReadyFromPrev, + postDoneToPrev); + } else { + Prims::Reduce(tid, nthreads, + prevInput + boffset, + thisInput + offset, + nextOutput + boffset, + sliceSize, maxOffset, + step, + waitDoneFromNext, waitReadyFromPrev, + postReadyToNext, postDoneToPrev); + } + NEXT_STEP; // Increases step, boffset + } + + if (tid == 0) { + if (rank != root) { + // Wait for next to have consumed data before resetting the flag + waitDoneFromNext.wait(REDUCE_SUBSTEPS*(step + REDUCE_BUFCHUNKS - 1)); + *ring->send.conn.head = 0ULL; + } + *ring->recv.conn.tail = 0ULL; + __threadfence_system(); + *ring->recv.conn.opCount = args->opCount+1; + } +} + +#include "ll_kernel.h" + +#define NEXT_STEP_LL \ + boffset += NCCL_LL_SLICE_LINES; \ + if (boffset == NCCL_LL_BUFF_LINES) boffset = 0; \ + flag++; \ + step++; + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclReduceLLKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int llNthreads = args->nThreads; + struct ncclComm* comm = args->comm; + struct ncclRing* ring = comm->rings+blockIdx.x; + volatile uint64_t * recvHeadPtr = ring->recv.conn.llHead; + volatile uint64_t * sendHeadPtr = ring->send.conn.llHead; + volatile int * sizesFifo = ring->send.conn.llFifo; + uint64_t sendHead = sendHeadPtr[0]; + const int nranks = comm->nRanks; + const int rank = comm->rank; + const int prevRank = ring->devUserRanks[nranks-1]; + const int root = args->root; + + typedef LLPrimitives<T, FUNC> LL; + + const ssize_t size = args->N; + ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const ssize_t loopSize = args->nRings*chunkSize; + + uint64_t step = ring->send.conn.llStep; + uint32_t flag = step + 1; + int boffset = NCCL_LL_SLICE_LINES * STEP_TO_SLOT(step); + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + union ncclLLFifoLine * prevInput = (union ncclLLFifoLine *)ring->recv.conn.llBuff; + union ncclLLFifoLine * nextOutput = (union ncclLLFifoLine *)ring->send.conn.llBuff; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + if (size-gridOffset < loopSize) { + chunkSize = args->lastChunkSize; + } + ssize_t offset = gridOffset + bid*chunkSize; + + int maxOffset = min(chunkSize, size-offset); + if (prevRank == root) { + WAIT_NEXT; + LL::ReduceCopy( + thisInput + offset, + nextOutput + boffset, + maxOffset, flag, llNthreads); + POST_SIZE; + NEXT_STEP_LL; + } else if (rank == root) { + LL::ReduceCopy( + thisInput + offset, + prevInput + boffset, + thisOutput + offset, + maxOffset, flag, llNthreads); + NEXT_STEP_LL; + ACK_PREV; + } else { + WAIT_NEXT; + LL::ReduceCopy( + thisInput + offset, + prevInput + boffset, + nextOutput + boffset, + maxOffset, flag, flag, llNthreads); + POST_SIZE; + NEXT_STEP_LL; + ACK_PREV; + } + } + + // We need everyone to acknowledge data even if they didn't receive anything + // so that the next collective can start right away. + ACK_PREV; + + FIFO_CLEANING_AND_SAVE_STEP(flag); +} diff --git a/src/collectives/device/reduce_kernel.h b/src/collectives/device/reduce_kernel.h new file mode 100644 index 0000000..0cb8f13 --- /dev/null +++ b/src/collectives/device/reduce_kernel.h @@ -0,0 +1,364 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + + +#ifndef NCCL_REDUCE_KERNEL_H_ +#define NCCL_REDUCE_KERNEL_H_ + +#include "common_kernel.h" +#include <limits> + +template<typename T> +struct FuncNull { + __device__ T operator()(const T x, const T y) const { + return 0; + } +}; + +template<typename T> +struct FuncSum { + __device__ T operator()(const T x, const T y) const { + return x + y; + } +}; + +template<typename T> +struct FuncProd { + __device__ T operator()(const T x, const T y) const { + return x * y; + } +}; + +template<typename T> +struct FuncMax { + __device__ T operator()(const T x, const T y) const { + return (x < y) ? y : x; + } +}; + +template<typename T> +struct FuncMin { + __device__ T operator()(const T x, const T y) const { + return (x < y) ? x : y; + } +}; + +template<> +struct FuncSum<int8_t> { + union converter { uint32_t storage; char4 a; }; + __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { +#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) + int32_t rv, z=0; + asm("vadd4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); + return rv; +#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700) + int32_t rv; + asm("vadd.s32.s32.s32 %0, %1.b0, %2.b0; \n\t" + "vadd.s32.s32.s32 %0.b1, %1.b1, %2.b1, %0;\n\t" + "vadd.s32.s32.s32 %0.b2, %1.b2, %2.b2, %0;\n\t" + "vadd.s32.s32.s32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y)); + return rv; +#else + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + cr.a.x = cx.a.x + cy.a.x; + cr.a.y = cx.a.y + cy.a.y; + cr.a.z = cx.a.z + cy.a.z; + cr.a.w = cx.a.w + cy.a.w; + return cr.storage; +#endif + } + __device__ int8_t operator()(const int8_t x, const int8_t y) const { + return x+y; + } +}; +template<> +struct FuncSum<uint8_t> { + union converter { uint32_t storage; uchar4 a; }; + __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { +#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) + int32_t rv, z=0; + asm("vadd4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); + return rv; +#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700) + int32_t rv; + asm("vadd.u32.u32.u32 %0, %1.b0, %2.b0; \n\t" + "vadd.u32.u32.u32 %0.b1, %1.b1, %2.b1, %0;\n\t" + "vadd.u32.u32.u32 %0.b2, %1.b2, %2.b2, %0;\n\t" + "vadd.u32.u32.u32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y)); + return rv; +#else + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + cr.a.x = cx.a.x + cy.a.x; + cr.a.y = cx.a.y + cy.a.y; + cr.a.z = cx.a.z + cy.a.z; + cr.a.w = cx.a.w + cy.a.w; + return cr.storage; +#endif + } + __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { + return x+y; + } +}; + +static __device__ uint32_t mulChar4(const uint32_t x, const uint32_t y) { + /* This can be used both for signed and unsigned 8-bit multiplication */ +#if (__CUDA_ARCH__ >= 300) + uint32_t rv; + asm("{ .reg .u32 t0, t1, t2, t3;\n\t" + " vmad.u32.u32.u32 t3, %1.b3, %2.b3, 0;\n\t" + " vmad.u32.u32.u32 t2, %1.b2, %2.b2, 0;\n\t" + " shl.b32 t3, t3, 16;\n\t" + " shl.b32 t2, t2, 16;\n\t" + " vmad.u32.u32.u32 t1, %1.b1, %2.b1, t3;\n\t" + " shl.b32 t1, t1, 8;\n\t" + " vmad.u32.u32.u32 t0, %1.b0, %2.b0, t2;\n\t" + " and.b32 t1, t1, 0xff00ff00;\n\t" + " and.b32 t0, t0, 0x00ff00ff;\n\t" + " or.b32 %0, t0, t1;\n\t" + "}" : "=r"(rv) : "r"(x), "r"(y)); + return rv; +#else + union converter { uint32_t storage; char4 a; }; + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + cr.a.x = cx.a.x * cy.a.x; + cr.a.y = cx.a.y * cy.a.y; + cr.a.z = cx.a.z * cy.a.z; + cr.a.w = cx.a.w * cy.a.w; + return cr.storage; +#endif +} + +template<> +struct FuncProd<int8_t> { + __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { + return mulChar4(x, y); + } + __device__ int8_t operator()(const int8_t x, const int8_t y) const { + return x*y; + } +}; +template<> +struct FuncProd<uint8_t> { + __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { + return mulChar4(x, y); + } + __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { + return x*y; + } +}; + +template<> +struct FuncMax<int8_t> { + union converter { uint32_t storage; char4 a; }; + __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { +#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) + int32_t rv, z=0; + asm("vmax4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); + return rv; +#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700) + int32_t rv; + asm("vmax.s32.s32.s32 %0, %1.b0, %2.b0; \n\t" + "vmax.s32.s32.s32 %0.b1, %1.b1, %2.b1, %0;\n\t" + "vmax.s32.s32.s32 %0.b2, %1.b2, %2.b2, %0;\n\t" + "vmax.s32.s32.s32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y)); + return rv; +#else + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + cr.a.x = max(cx.a.x, cy.a.x); + cr.a.y = max(cx.a.y, cy.a.y); + cr.a.z = max(cx.a.z, cy.a.z); + cr.a.w = max(cx.a.w, cy.a.w); + return cr.storage; +#endif + } + __device__ int8_t operator()(const int8_t x, const int8_t y) const { + return (x>y) ? x : y; + } +}; +template<> +struct FuncMax<uint8_t> { + union converter { uint32_t storage; uchar4 a; }; + __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { +#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) + int32_t rv, z=0; + asm("vmax4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); + return rv; +#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700) + int32_t rv; + asm("vmax.u32.u32.u32 %0, %1.b0, %2.b0; \n\t" + "vmax.u32.u32.u32 %0.b1, %1.b1, %2.b1, %0;\n\t" + "vmax.u32.u32.u32 %0.b2, %1.b2, %2.b2, %0;\n\t" + "vmax.u32.u32.u32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y)); + return rv; +#else + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + cr.a.x = max(cx.a.x, cy.a.x); + cr.a.y = max(cx.a.y, cy.a.y); + cr.a.z = max(cx.a.z, cy.a.z); + cr.a.w = max(cx.a.w, cy.a.w); + return cr.storage; +#endif + } + __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { + return (x>y) ? x : y; + } +}; + +template<> +struct FuncMin<int8_t> { + union converter { uint32_t storage; char4 a; }; + __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { +#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) + int32_t rv, z=0; + asm("vmin4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); + return rv; +#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700) + int32_t rv; + asm("vmin.s32.s32.s32 %0, %1.b0, %2.b0; \n\t" + "vmin.s32.s32.s32 %0.b1, %1.b1, %2.b1, %0;\n\t" + "vmin.s32.s32.s32 %0.b2, %1.b2, %2.b2, %0;\n\t" + "vmin.s32.s32.s32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y)); + return rv; +#else + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + cr.a.x = min(cx.a.x, cy.a.x); + cr.a.y = min(cx.a.y, cy.a.y); + cr.a.z = min(cx.a.z, cy.a.z); + cr.a.w = min(cx.a.w, cy.a.w); + return cr.storage; +#endif + } + __device__ int8_t operator()(const int8_t x, const int8_t y) const { + return (x<y) ? x : y; + } +}; +template<> +struct FuncMin<uint8_t> { + union converter { uint32_t storage; uchar4 a; }; + __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { +#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) + int32_t rv, z=0; + asm("vmin4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); + return rv; +#elif (__CUDA_ARCH__ >= 500) && (__CUDA_ARCH__ < 700) + int32_t rv; + asm("vmin.u32.u32.u32 %0, %1.b0, %2.b0; \n\t" + "vmin.u32.u32.u32 %0.b1, %1.b1, %2.b1, %0;\n\t" + "vmin.u32.u32.u32 %0.b2, %1.b2, %2.b2, %0;\n\t" + "vmin.u32.u32.u32 %0.b3, %1.b3, %2.b3, %0;" : "=r"(rv) : "r"(x), "r"(y)); + return rv; +#else + converter cx, cy, cr; + cx.storage = x; + cy.storage = y; + cr.a.x = min(cx.a.x, cy.a.x); + cr.a.y = min(cx.a.y, cy.a.y); + cr.a.z = min(cx.a.z, cy.a.z); + cr.a.w = min(cx.a.w, cy.a.w); + return cr.storage; +#endif + } + __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { + return (x<y) ? x : y; + } +}; + +template<> +struct FuncSum<half> { + __device__ half2 operator()(const half2 x, const half2 y) const { +#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 + return __hadd2(x, y); +#else + float2 fx, fy, fr; + fx = __half22float2(x); + fy = __half22float2(y); + fr.x = fx.x + fy.x; + fr.y = fx.y + fy.y; + return __float22half2_rn(fr); +#endif + } + __device__ half operator()(const half x, const half y) const { +#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 + return __hadd(x, y); +#else + return __float2half( __half2float(x) + __half2float(y) ); +#endif + } +}; + +template<> +struct FuncProd<half> { + __device__ half2 operator()(const half2 x, const half2 y) const { +#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 + return __hmul2(x, y); +#else + float2 fx, fy, fr; + fx = __half22float2(x); + fy = __half22float2(y); + fr.x = fx.x * fy.x; + fr.y = fx.y * fy.y; + return __float22half2_rn(fr); +#endif + } + __device__ half operator()(const half x, const half y) const { +#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 + return __hmul(x, y); +#else + return __float2half( __half2float(x) * __half2float(y) ); +#endif + } +}; + +template<> +struct FuncMax<half> { + __device__ half2 operator()(const half2 x, const half2 y) const { + float2 fx, fy, fr; + fx = __half22float2(x); + fy = __half22float2(y); + fr.x = fmaxf(fx.x, fy.x); + fr.y = fmaxf(fx.y, fy.y); + return __float22half2_rn(fr); + } + __device__ half operator()(const half x, const half y) const { + float fx, fy, fm; + fx = __half2float(x); + fy = __half2float(y); + fm = fmaxf(fx, fy); + return __float2half(fm); + } +}; + +template<> +struct FuncMin<half> { + __device__ half2 operator()(const half2 x, const half2 y) const { + float2 fx, fy, fr; + fx = __half22float2(x); + fy = __half22float2(y); + fr.x = fminf(fx.x, fy.x); + fr.y = fminf(fx.y, fy.y); + return __float22half2_rn(fr); + } + __device__ half operator()(const half x, const half y) const { + float fx, fy, fm; + fx = __half2float(x); + fy = __half2float(y); + fm = fminf(fx, fy); + return __float2half(fm); + } +}; +#endif // REDUCE_KERNEL_H_ diff --git a/src/collectives/device/reduce_scatter.cu b/src/collectives/device/reduce_scatter.cu new file mode 100644 index 0000000..b16053c --- /dev/null +++ b/src/collectives/device/reduce_scatter.cu @@ -0,0 +1,21 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "common.h" +#include "reduce_scatter.h" +#include "collectives.h" + +#define UNROLL 4 + +#if NCCL_OP == 0 +IMPL_COLL2(ncclReduceScatter, sum, FuncSum, ncclCollReduceScatter, ncclSum); +#elif NCCL_OP == 1 +IMPL_COLL2(ncclReduceScatter, prod, FuncProd, ncclCollReduceScatter, ncclProd); +#elif NCCL_OP == 2 +IMPL_COLL2(ncclReduceScatter, min, FuncMin, ncclCollReduceScatter, ncclMin); +#elif NCCL_OP == 3 +IMPL_COLL2(ncclReduceScatter, max, FuncMax, ncclCollReduceScatter, ncclMax); +#endif diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h new file mode 100644 index 0000000..cad011b --- /dev/null +++ b/src/collectives/device/reduce_scatter.h @@ -0,0 +1,217 @@ +/************************************************************************* + * Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "core.h" +#include "primitives.h" +#include "collectives.h" + +// Increase Step and poffset/noffset for buffer sync +#define NEXT_STEP \ + step++; \ + poffset = noffset; \ + noffset += sliceSize; \ + if (noffset == buffSize) noffset = 0; + +template<int UNROLL, class FUNC, typename T> +__device__ void ncclReduceScatterKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int nthreads = blockDim.x - 1; + const int bid = args->bid; + struct ncclComm* comm = args->comm; + struct ncclRing* ring = comm->rings+blockIdx.x; + + WaitFlag waitDoneFromNext(ring->send.conn.head, REDUCESCATTER_BUFCHUNKS*REDUCESCATTER_SUBSTEPS); + WaitFlag waitReadyFromPrev(ring->recv.conn.tail, REDUCESCATTER_SUBSTEPS); + PostFlag postDoneToPrev(ring->recv.conn.head, REDUCESCATTER_SUBSTEPS, NULL, 0); + PostFlag postReadyToNext(ring->send.conn.tail, 0, ring->send.conn.fifo, REDUCESCATTER_BUFCHUNKS*REDUCESCATTER_SUBSTEPS); + + typedef Primitives<UNROLL, REDUCESCATTER_SUBSTEPS, T, FUNC> Prims; + + const ssize_t size = args->N; + const int nranks = comm->nRanks; + const int buffSize = ring->buffSize / sizeof(T); + const int sliceSize = buffSize / REDUCESCATTER_BUFCHUNKS; + const ssize_t loopSize = args->nRings*(ssize_t)sliceSize; + + if (tid == 0) { + // Update in case we skipped some collectives + *ring->recv.conn.opCount = args->opCount; + // Wait for next to be ready + WaitFlag waitOpCountNext(ring->send.conn.opCount, 0); + waitOpCountNext.wait(args->opCount); + } + __syncthreads(); + + uint64_t step = 0ULL; + int poffset, noffset = 0; + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + T * __restrict__ prevInput = (T*)ring->recv.conn.buff; + T * __restrict__ nextOutput = (T*)ring->send.conn.buff; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + int chunkSize = min(sliceSize, DIVUP(size-gridOffset,args->nRings)); + ALIGN_SIZE(chunkSize, nthreads*sizeof(uint64_t)/sizeof(T)); + ssize_t chunkOffset = gridOffset + bid*chunkSize; + + /////////////// begin ReduceScatter steps /////////////// + ssize_t offset; + int maxOffset = min(chunkSize, size-chunkOffset); + int rankDest; + + // step 0: push data to next GPU + rankDest = ring->devUserRanks[nranks-1]; + offset = chunkOffset + rankDest * size; + + Prims::Copy(tid, nthreads, + thisInput + offset, + nextOutput + noffset, + sliceSize, maxOffset, + step, + waitDoneFromNext, + postReadyToNext); + + NEXT_STEP; // Increases step, poffset, noffset + + // k-2 steps: reduce and copy to next GPU + for (int j=2; j<nranks; ++j) { + rankDest = ring->devUserRanks[nranks-j]; + offset = chunkOffset + rankDest * size; + + Prims::Reduce(tid, nthreads, + prevInput + poffset, + thisInput + offset, + nextOutput + noffset, + sliceSize, maxOffset, + step, + waitDoneFromNext, waitReadyFromPrev, + postReadyToNext, postDoneToPrev); + + NEXT_STEP; + } + + // step k-1: reduce this buffer and data, which will produce the final + // result that we store in this data and push to the next GPU + rankDest = ring->devUserRanks[0]; + offset = chunkOffset + rankDest * size; + + Prims::Reduce(tid, nthreads, + prevInput + poffset, + thisInput + offset, + thisOutput + chunkOffset, + sliceSize, maxOffset, + step, + waitReadyFromPrev, + postDoneToPrev); + } + + if (tid == 0) { + waitDoneFromNext.wait(REDUCESCATTER_SUBSTEPS*(step + REDUCESCATTER_BUFCHUNKS)); + *ring->send.conn.head = 0ULL; + *ring->recv.conn.tail = 0ULL; + __threadfence_system(); + *ring->recv.conn.opCount = args->opCount+1; + } +} + +#include "ll_kernel.h" + +#define NEXT_STEP_LL \ + poffset = noffset; \ + pflag = nflag; \ + noffset += NCCL_LL_SLICE_LINES; \ + if (noffset == NCCL_LL_BUFF_LINES) { noffset = 0; } \ + nflag++; \ + step++; + +template<int UNUSED, class FUNC, typename T> +__device__ void ncclReduceScatterLLKernel(struct CollectiveArgs* args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int llNthreads = args->nThreads; + struct ncclComm* comm = args->comm; + struct ncclRing* ring = comm->rings+blockIdx.x; + volatile uint64_t * recvHeadPtr = ring->recv.conn.llHead; + volatile uint64_t * sendHeadPtr = ring->send.conn.llHead; + volatile int * sizesFifo = ring->send.conn.llFifo; + uint64_t sendHead = sendHeadPtr[0]; + + typedef LLPrimitives<T, FUNC> LL; + + const ssize_t size = args->N; + //const int rank = comm->rank; + const int nranks = comm->nRanks; + ssize_t chunkSize = NCCL_LL_SLICE_LINES * sizeof(uint64_t) / sizeof(T); + const ssize_t loopSize = args->nRings*chunkSize; + + uint64_t step = ring->send.conn.llStep; + uint32_t pflag, nflag = step + 1; + int poffset, noffset = NCCL_LL_SLICE_LINES * STEP_TO_SLOT(step); + + // Compute pointers + const T * __restrict__ thisInput = (const T*)args->ThisInput; + T * __restrict__ thisOutput = (T*)args->ThisOutput; + union ncclLLFifoLine * prevInput = (union ncclLLFifoLine *)ring->recv.conn.llBuff; + union ncclLLFifoLine * nextOutput = (union ncclLLFifoLine *)ring->send.conn.llBuff; + + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + if (size-gridOffset < loopSize) { + chunkSize = args->lastChunkSize; + } + ssize_t chunkOffset = gridOffset + bid*chunkSize; + + /////////////// begin ReduceScatter steps /////////////// + ssize_t offset; + int maxOffset = min(chunkSize, size-chunkOffset); + int rankDest; + + // step 0: push data to next GPU + rankDest = ring->devUserRanks[nranks-1]; + offset = chunkOffset + rankDest * size; + + WAIT_NEXT; + LL::ReduceCopy( + thisInput + offset, + nextOutput + noffset, + maxOffset, nflag, llNthreads); + POST_SIZE; + + NEXT_STEP_LL; + + // k-2 steps: reduce and copy to next GPU + for (int j=2; j<nranks; ++j) { + rankDest = ring->devUserRanks[nranks-j]; + offset = chunkOffset + rankDest * size; + + WAIT_NEXT; + LL::ReduceCopy( + thisInput + offset, + prevInput + poffset, + nextOutput + noffset, + maxOffset, pflag, nflag, llNthreads); + POST_SIZE; + ACK_PREV; + + NEXT_STEP_LL; + } + + // step k-1: reduce this buffer and data, which will produce the final + // result that we store in this data + rankDest = ring->devUserRanks[0]; + offset = chunkOffset + rankDest * size; + + LL::ReduceCopy( + thisInput + offset, + prevInput + poffset, + thisOutput + chunkOffset, + maxOffset, pflag, llNthreads); + ACK_PREV; + } + + FIFO_CLEANING_AND_SAVE_STEP(nflag); +} |