Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/collectives/device/common_kernel.h')
-rw-r--r--src/collectives/device/common_kernel.h174
1 files changed, 104 insertions, 70 deletions
diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h
index aa1e936..ff466a0 100644
--- a/src/collectives/device/common_kernel.h
+++ b/src/collectives/device/common_kernel.h
@@ -1,5 +1,5 @@
/*************************************************************************
- * Copyright (c) 2015-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
@@ -16,6 +16,12 @@
// Define min for ssize_t
static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; }
+template <typename T>
+inline __device__ void loadPtr(void** ptr, T* &v) {
+ asm volatile("ld.volatile.global.u64 %0, [%1];"
+ : "=l"(v) : "l"(ptr));
+}
+
typedef uint64_t PackType;
// unpack x and y to elements of type T and apply FUNC to each element
@@ -245,28 +251,57 @@ inline __device__ void Store128(Pack128* p, Pack128& v) {
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory");
}
-template<class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
-__device__ __forceinline__ void ReduceCopyMulti(const int tid, const int nthreads,
- int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS],
- const int offset, const int N) {
- for (int idx = offset+tid; idx < offset+N; idx += nthreads) {
- T val = vFetch(srcs[0]+idx);
+template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
+__device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t,
+ int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem) {
+ const int inc = nw * UNROLL * WARP_SIZE;
+ int offset = w * UNROLL * WARP_SIZE + t;
+
+ const T* srcs[MAXSRCS];
+ for (int i=0; i<MAXSRCS; i++) srcs[i] = s[i]+elemOffset+offset;
+ T* dsts[MAXDSTS];
+ for (int i=0; i<MAXDSTS; i++) dsts[i] = d[i]+elemOffset+offset;
+
+ while (offset < Nelem) {
+ T vals[UNROLL];
+ // Load and reduce
+ for (int u = 0; u < UNROLL; ++u) vals[u] = vFetch(srcs[0]+u*WARP_SIZE);
+
+ #pragma unroll
+ for (int i=1; i<MINSRCS; i++) {
+ T vals2[UNROLL];
+ for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
+ for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]);
+ }
#pragma unroll
- for (int i=1; i<MINSRCS; i++) val = FUNC()(val, vFetch(srcs[i]+idx));
- #pragma unroll 1
- for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) val = FUNC()(val, vFetch(srcs[i]+idx));
+ for (int i=MINSRCS; i<MAXSRCS; i++) {
+ if (i<nsrcs) {
+ T vals2[UNROLL];
+ for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
+ for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]);
+ }
+ }
+ // Store
#pragma unroll
- for (int i=0; i<MINDSTS; i++) vStore(dsts[i]+idx, val);
- #pragma unroll 1
- for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) vStore(dsts[i]+idx, val);
+ for (int i = 0; i < MINDSTS; i++) {
+ for (int u = 0; u < UNROLL; ++u) vStore(dsts[i]+u*WARP_SIZE, vals[u]);
+ }
+ #pragma unroll
+ for (int i=MINDSTS; i<MAXDSTS; i++) {
+ if (i<ndsts) {
+ for (int u = 0; u < UNROLL; ++u) vStore(dsts[i]+u*WARP_SIZE, vals[u]);
+ }
+ }
+ for (int i=0; i<MAXSRCS; i++) srcs[i] += inc;
+ for (int i=0; i<MAXDSTS; i++) dsts[i] += inc;
+ offset += inc;
}
}
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
-__device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw, const int t,
- int nsrcs, const T* s[MAXSRCS], int ndsts, T* d[MAXDSTS],
- const int elemOffset, const int Npack) {
+__device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, const int t,
+ int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack) {
const int inc = nw * UNROLL * WARP_SIZE;
int offset = w * UNROLL * WARP_SIZE + t;
@@ -280,25 +315,31 @@ __device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw,
// Load and reduce
for (int u = 0; u < UNROLL; ++u) Fetch128(vals[u], srcs[0]+u*WARP_SIZE);
+ #pragma unroll
for (int i=1; i<MINSRCS; i++) {
Pack128 vals2[UNROLL];
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
}
- #pragma unroll 1
- for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) {
- Pack128 vals2[UNROLL];
- for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
- for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
+ #pragma unroll
+ for (int i=MINSRCS; i<MAXSRCS; i++) {
+ if (i<nsrcs) {
+ Pack128 vals2[UNROLL];
+ for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
+ for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
+ }
}
// Store
+ #pragma unroll
for (int i = 0; i < MINDSTS; i++) {
for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
}
- #pragma unroll 1
- for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) {
- for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
+ #pragma unroll
+ for (int i=MINDSTS; i<MAXDSTS; i++) {
+ if (i<ndsts) {
+ for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
+ }
}
for (int i=0; i<MAXSRCS; i++) srcs[i] += inc;
for (int i=0; i<MAXDSTS; i++) dsts[i] += inc;
@@ -309,72 +350,65 @@ __device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw,
template <typename T>
__device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(Pack128); }
-// Try to limit consecutive load/stores to 8.
-// Use UNROLL 8 when we have a single source and a single destination, 4 otherwise
-#define AUTOUNROLL (UNROLL*(4/(MINDSTS+MINSRCS)))
+#define PACKELEMS (sizeof(Pack128) / sizeof(T))
template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
__device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads,
- int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS],
+ int nsrcs, const T** srcs, int ndsts, T** dsts,
int N) {
int Nrem = N;
if (Nrem <= 0) return;
- int alignDiff = 0;
- int align = ptrAlign128(srcs[0]);
- #pragma unroll
- for (int i=1; i<MINSRCS; i++) alignDiff |= (align ^ ptrAlign128(srcs[i]));
- for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) alignDiff |= (align ^ ptrAlign128(srcs[i]));
- #pragma unroll
- for (int i=0; i<MINDSTS; i++) alignDiff |= (align ^ ptrAlign128(dsts[i]));
- for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) alignDiff |= (align ^ ptrAlign128(dsts[i]));
-
- int Npreamble = alignDiff ? Nrem :
- N < alignof(Pack128) ? N :
- (alignof(Pack128) - align) % alignof(Pack128);
-
- // stage 1: preamble: handle any elements up to the point of everything coming
- // into alignment
- if (Npreamble) {
- ReduceCopyMulti<FUNC, T, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(tid, nthreads, nsrcs, srcs, ndsts, dsts, 0, Npreamble);
- Nrem -= Npreamble;
- if (Nrem == 0) return;
- }
- int offset = Npreamble;
-
- // stage 2: fast path: use 128b loads/stores to do the bulk of the work,
- // assuming the pointers we have are all 128-bit alignable.
int w = tid / WARP_SIZE; // Warp number
int nw = nthreads / WARP_SIZE; // Number of warps
int t = tid % WARP_SIZE; // Thread (inside the warp)
- const int packFactor = sizeof(Pack128) / sizeof(T);
+ // Check that all is 16B aligned. If not don't use 16B load/stores.
+ int align = 0;
+ #pragma unroll
+ for (int i=0; i<MINSRCS; i++) align |= ptrAlign128(srcs[i]);
+ for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) align |= ptrAlign128(srcs[i]);
+ #pragma unroll
+ for (int i=0; i<MINDSTS; i++) align |= ptrAlign128(dsts[i]);
+ for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) align |= ptrAlign128(dsts[i]);
- // stage 2a: main loop
- int Npack2a = (Nrem / (packFactor * AUTOUNROLL * WARP_SIZE))
- * (AUTOUNROLL * WARP_SIZE); // round down
- int Nelem2a = Npack2a * packFactor;
+ int offset = 0;
+ if (align == 0) {
+ // fast path: use 128b loads/stores to do the bulk of the work,
+ // assuming the pointers we have are all 128-bit aligned.
- ReduceCopy128bMulti<FUNC, T, AUTOUNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack2a);
+ // main loop
+ int Npack = (Nrem / (PACKELEMS*UNROLL*WARP_SIZE)) * (UNROLL*WARP_SIZE); // round down
+ int Nelem = Npack * PACKELEMS;
- Nrem -= Nelem2a;
- if (Nrem == 0) return;
- offset += Nelem2a;
+ ReduceCopy128bMulti<FUNC, T, UNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack);
- // stage 2b: slightly less optimized for section when we don't have full
- // unrolling
+ Nrem -= Nelem;
+ if (Nrem == 0) return;
+ offset += Nelem;
+
+ // slightly less optimized for section when we don't have full unrolling
+ Npack = Nrem / PACKELEMS;
+ Nelem = Npack * PACKELEMS;
+
+ ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack);
+
+ Nrem -= Nelem;
+ if (Nrem == 0) return;
+ offset += Nelem;
+ }
- int Npack2b = Nrem / packFactor;
- int Nelem2b = Npack2b * packFactor;
+ // unrolled, by-type (mostly for unaligned buffers)
+ int Nelem = (Nrem / (UNROLL*PACKELEMS/2*WARP_SIZE)) * (UNROLL*PACKELEMS/2*WARP_SIZE); // round down
- ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack2b);
+ ReduceCopyMulti<FUNC, T, UNROLL*PACKELEMS/2, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nelem);
- Nrem -= Nelem2b;
+ Nrem -= Nelem;
if (Nrem == 0) return;
- offset += Nelem2b;
+ offset += Nelem;
- // stage 2c: tail
- ReduceCopyMulti<FUNC, T, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(tid, nthreads, nsrcs, srcs, ndsts, dsts, offset, Nrem);
+ // no unroll, by type. Should finish what's remaining.
+ ReduceCopyMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nrem);
}
#endif // COMMON_KERNEL_H_