diff options
Diffstat (limited to 'src/collectives/device/common_kernel.h')
-rw-r--r-- | src/collectives/device/common_kernel.h | 174 |
1 files changed, 104 insertions, 70 deletions
diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h index aa1e936..ff466a0 100644 --- a/src/collectives/device/common_kernel.h +++ b/src/collectives/device/common_kernel.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -16,6 +16,12 @@ // Define min for ssize_t static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; } +template <typename T> +inline __device__ void loadPtr(void** ptr, T* &v) { + asm volatile("ld.volatile.global.u64 %0, [%1];" + : "=l"(v) : "l"(ptr)); +} + typedef uint64_t PackType; // unpack x and y to elements of type T and apply FUNC to each element @@ -245,28 +251,57 @@ inline __device__ void Store128(Pack128* p, Pack128& v) { asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory"); } -template<class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS> -__device__ __forceinline__ void ReduceCopyMulti(const int tid, const int nthreads, - int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS], - const int offset, const int N) { - for (int idx = offset+tid; idx < offset+N; idx += nthreads) { - T val = vFetch(srcs[0]+idx); +template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS> +__device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t, + int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem) { + const int inc = nw * UNROLL * WARP_SIZE; + int offset = w * UNROLL * WARP_SIZE + t; + + const T* srcs[MAXSRCS]; + for (int i=0; i<MAXSRCS; i++) srcs[i] = s[i]+elemOffset+offset; + T* dsts[MAXDSTS]; + for (int i=0; i<MAXDSTS; i++) dsts[i] = d[i]+elemOffset+offset; + + while (offset < Nelem) { + T vals[UNROLL]; + // Load and reduce + for (int u = 0; u < UNROLL; ++u) vals[u] = vFetch(srcs[0]+u*WARP_SIZE); + + #pragma unroll + for (int i=1; i<MINSRCS; i++) { + T vals2[UNROLL]; + for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE); + for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]); + } #pragma unroll - for (int i=1; i<MINSRCS; i++) val = FUNC()(val, vFetch(srcs[i]+idx)); - #pragma unroll 1 - for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) val = FUNC()(val, vFetch(srcs[i]+idx)); + for (int i=MINSRCS; i<MAXSRCS; i++) { + if (i<nsrcs) { + T vals2[UNROLL]; + for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE); + for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]); + } + } + // Store #pragma unroll - for (int i=0; i<MINDSTS; i++) vStore(dsts[i]+idx, val); - #pragma unroll 1 - for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) vStore(dsts[i]+idx, val); + for (int i = 0; i < MINDSTS; i++) { + for (int u = 0; u < UNROLL; ++u) vStore(dsts[i]+u*WARP_SIZE, vals[u]); + } + #pragma unroll + for (int i=MINDSTS; i<MAXDSTS; i++) { + if (i<ndsts) { + for (int u = 0; u < UNROLL; ++u) vStore(dsts[i]+u*WARP_SIZE, vals[u]); + } + } + for (int i=0; i<MAXSRCS; i++) srcs[i] += inc; + for (int i=0; i<MAXDSTS; i++) dsts[i] += inc; + offset += inc; } } template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS> -__device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw, const int t, - int nsrcs, const T* s[MAXSRCS], int ndsts, T* d[MAXDSTS], - const int elemOffset, const int Npack) { +__device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, const int t, + int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack) { const int inc = nw * UNROLL * WARP_SIZE; int offset = w * UNROLL * WARP_SIZE + t; @@ -280,25 +315,31 @@ __device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw, // Load and reduce for (int u = 0; u < UNROLL; ++u) Fetch128(vals[u], srcs[0]+u*WARP_SIZE); + #pragma unroll for (int i=1; i<MINSRCS; i++) { Pack128 vals2[UNROLL]; for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE); for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]); } - #pragma unroll 1 - for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) { - Pack128 vals2[UNROLL]; - for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE); - for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]); + #pragma unroll + for (int i=MINSRCS; i<MAXSRCS; i++) { + if (i<nsrcs) { + Pack128 vals2[UNROLL]; + for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE); + for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]); + } } // Store + #pragma unroll for (int i = 0; i < MINDSTS; i++) { for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]); } - #pragma unroll 1 - for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) { - for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]); + #pragma unroll + for (int i=MINDSTS; i<MAXDSTS; i++) { + if (i<ndsts) { + for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]); + } } for (int i=0; i<MAXSRCS; i++) srcs[i] += inc; for (int i=0; i<MAXDSTS; i++) dsts[i] += inc; @@ -309,72 +350,65 @@ __device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw, template <typename T> __device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(Pack128); } -// Try to limit consecutive load/stores to 8. -// Use UNROLL 8 when we have a single source and a single destination, 4 otherwise -#define AUTOUNROLL (UNROLL*(4/(MINDSTS+MINSRCS))) +#define PACKELEMS (sizeof(Pack128) / sizeof(T)) template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS> __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads, - int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS], + int nsrcs, const T** srcs, int ndsts, T** dsts, int N) { int Nrem = N; if (Nrem <= 0) return; - int alignDiff = 0; - int align = ptrAlign128(srcs[0]); - #pragma unroll - for (int i=1; i<MINSRCS; i++) alignDiff |= (align ^ ptrAlign128(srcs[i])); - for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) alignDiff |= (align ^ ptrAlign128(srcs[i])); - #pragma unroll - for (int i=0; i<MINDSTS; i++) alignDiff |= (align ^ ptrAlign128(dsts[i])); - for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) alignDiff |= (align ^ ptrAlign128(dsts[i])); - - int Npreamble = alignDiff ? Nrem : - N < alignof(Pack128) ? N : - (alignof(Pack128) - align) % alignof(Pack128); - - // stage 1: preamble: handle any elements up to the point of everything coming - // into alignment - if (Npreamble) { - ReduceCopyMulti<FUNC, T, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(tid, nthreads, nsrcs, srcs, ndsts, dsts, 0, Npreamble); - Nrem -= Npreamble; - if (Nrem == 0) return; - } - int offset = Npreamble; - - // stage 2: fast path: use 128b loads/stores to do the bulk of the work, - // assuming the pointers we have are all 128-bit alignable. int w = tid / WARP_SIZE; // Warp number int nw = nthreads / WARP_SIZE; // Number of warps int t = tid % WARP_SIZE; // Thread (inside the warp) - const int packFactor = sizeof(Pack128) / sizeof(T); + // Check that all is 16B aligned. If not don't use 16B load/stores. + int align = 0; + #pragma unroll + for (int i=0; i<MINSRCS; i++) align |= ptrAlign128(srcs[i]); + for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) align |= ptrAlign128(srcs[i]); + #pragma unroll + for (int i=0; i<MINDSTS; i++) align |= ptrAlign128(dsts[i]); + for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) align |= ptrAlign128(dsts[i]); - // stage 2a: main loop - int Npack2a = (Nrem / (packFactor * AUTOUNROLL * WARP_SIZE)) - * (AUTOUNROLL * WARP_SIZE); // round down - int Nelem2a = Npack2a * packFactor; + int offset = 0; + if (align == 0) { + // fast path: use 128b loads/stores to do the bulk of the work, + // assuming the pointers we have are all 128-bit aligned. - ReduceCopy128bMulti<FUNC, T, AUTOUNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack2a); + // main loop + int Npack = (Nrem / (PACKELEMS*UNROLL*WARP_SIZE)) * (UNROLL*WARP_SIZE); // round down + int Nelem = Npack * PACKELEMS; - Nrem -= Nelem2a; - if (Nrem == 0) return; - offset += Nelem2a; + ReduceCopy128bMulti<FUNC, T, UNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack); - // stage 2b: slightly less optimized for section when we don't have full - // unrolling + Nrem -= Nelem; + if (Nrem == 0) return; + offset += Nelem; + + // slightly less optimized for section when we don't have full unrolling + Npack = Nrem / PACKELEMS; + Nelem = Npack * PACKELEMS; + + ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack); + + Nrem -= Nelem; + if (Nrem == 0) return; + offset += Nelem; + } - int Npack2b = Nrem / packFactor; - int Nelem2b = Npack2b * packFactor; + // unrolled, by-type (mostly for unaligned buffers) + int Nelem = (Nrem / (UNROLL*PACKELEMS/2*WARP_SIZE)) * (UNROLL*PACKELEMS/2*WARP_SIZE); // round down - ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack2b); + ReduceCopyMulti<FUNC, T, UNROLL*PACKELEMS/2, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nelem); - Nrem -= Nelem2b; + Nrem -= Nelem; if (Nrem == 0) return; - offset += Nelem2b; + offset += Nelem; - // stage 2c: tail - ReduceCopyMulti<FUNC, T, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(tid, nthreads, nsrcs, srcs, ndsts, dsts, offset, Nrem); + // no unroll, by type. Should finish what's remaining. + ReduceCopyMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nrem); } #endif // COMMON_KERNEL_H_ |