Welcome to mirror list, hosted at ThFree Co, Russian Federation.

reduce_scatter.cu « src - github.com/marian-nmt/nccl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: e1860c54b9f73e245876b389f91a3a4c8693e4ac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
/*************************************************************************
 * Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *  * Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *  * Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *  * Neither the name of NVIDIA CORPORATION nor the names of its
 *    contributors may be used to endorse or promote products derived
 *    from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 ************************************************************************/

#include <cassert>

#include "core.h"
#include "common_kernel.h"
#include "copy_kernel.h"
#include "enqueue.h"
#include "reduce_kernel.h"

/* HIERARCHY
 *
 * The data is split into CHUNKS, and each CHUNK is split into NUM_SUBCHUNKS
 * SUBCHUNKS, where each SUBCHUNK is an independent, complete reduction. Each
 * GPU has a buffer that can fit an entire CHUNK, so that all SUBCHUNKS can be
 * processed without checking that the buffer on the receiving GPU is empty. A
 * SUBCHUNK is split into NUM_GPUS SLICES and each GPU works on a different
 * SLICE at the same time. Before moving on the the next SLICE in the reduction
 * algorithm, the GPU has to check whether it has received the data from the
 * previous GPU it needs for this SLICE. To hide the latency of this
 * communication, each GPU processes all the SLICES of all the SUBCHUNKS in
 * sequence before moving on to the next SLICE. Each SLICE is split into a
 * certain number of UNROLLS (determined by the buffer size) and each thread
 * performs UNROLL_COUNT single-data-element operations inside an UNROLL. As the
 * name suggests, the UNROLL_COUNT operations within an UNROLL are unrolled.
*/

// Number of threads used to perform copies, etc. Must be multiple of 32.
// An additional thread is used to handle threadfences, so the CUDA blocks
// have dimension NUM_THREADS+1.
#define NUM_THREADS     256

// Each thread unrolls the innermost loop of the copy or reduction operations
// to this many single-data-element instructions
#define UNROLL_COUNT    8

#define UNROLL_SIZE     (UNROLL_COUNT * NUM_THREADS)

// To hide the latency associated with the synchronization between different
// subchunks, we interleave the independent subchunks so that more data can be
// transferred while the sync is in progress. This is the number of subchunks
// that are active at the same time
#define NUM_SUBCHUNKS   2

/*
 * numGPUs BLOCKs consisting of recvcount words each
 * BLOCK is split up into NumChunks CHUNKs
 * CHUNK is split up into NUM_SUBCHUNKS SUBCHUNKs
 * SUBCHUNK consists of exactly one SLICE
 * SLICE is most efficiently processed in multiples of UNROLL_SIZE
 *
 * The algorithm has numGPUs steps and each step processes a SLICE (i.e.
 * SUBCHUNK) of a different BLOCK. Only data of the BLOCKs not resident on the
 * GPU need to be communicated, hence (numGPUs - 1) BLOCKs. So the buffer needs
 * to have room for (numGPUs - 1) SLICEs.
 */


// do not encode the subchunk number into the flag, because there is a separate
// flag for each subchunk

// If this is called with STEP, it means that we just finished processing the
// data for step STEP on this GPU, which is the data required on the next GPU
// for step STEP + 1, so we signal the next GPU that its data for step STEP + 1
// is available. This is called by one particular consumer warp and so we select
// the first thread in the warp to set the flag.
#define SIGNAL_NEW_DATA_AVAILABLE(chunk, subchunk, step)                        \
    do {                                                                        \
      args.NextNewDataAvailableFlag[0] =                                        \
          2*((chunk) * args.NumGPUs + (step)) + subchunk + 1;                   \
    } while (0)

// This is called by all producer threads, but only thread 0 spins on the flag,
// all threads synchronize after thread 0 is done spinning.
#define WAIT_FOR_NEW_DATA(chunk, subchunk, step)                                \
    do {                                                                        \
      if (tid == 0) {                                                           \
        Wait([=] {                                                              \
          return ((volatile int *)args.ThisNewDataAvailableFlag)[0] >=          \
              2*((chunk) * args.NumGPUs + (step)) + subchunk - 1;               \
        });                                                                     \
      }                                                                         \
      BAR(sync, 1, NUM_THREADS);                                                \
    } while (0)

// If this is called with CHUNK, it means that this GPU has just finished
// processing the chunk CHUNK and so the previous GPU can start with CHUNK + 1
#define SIGNAL_CHUNK_DONE(chunk, subchunk)                                      \
    do {                                                                        \
      args.PrevChunkDoneFlag[0] = 2*(chunk) + subchunk + 1;                     \
    } while (0)

// This is called by all producer threads, but only thread 0 spins on the flag,
// all threads synchronize after thread 0 is done spinning.
#define WAIT_FOR_CHUNK(chunk, subchunk)                                       \
    do {                                                                      \
      if (tid == 0) {                                                         \
        Wait([=] {                                                            \
          return ((volatile int *)args.ThisChunkDoneFlag)[0] >=               \
              2*(chunk) + subchunk - 1;                                       \
        });                                                                   \
      }                                                                       \
      BAR(sync, 1, NUM_THREADS);                                              \
    } while (0)


__device__ inline void getSliceSizeAndChunkSize(int *sliceSize, int slice,
    int numSlices, int numBigSlices, int numSmallSlices, int bigSliceN,
    int smallSliceN, int lastSliceN) {
  if (slice < numBigSlices) {
    *sliceSize = bigSliceN;
  } else {
    *sliceSize = (slice < numBigSlices + numSmallSlices) ? smallSliceN
        : ((slice == numSlices - 1) ? lastSliceN : 0);
  }

/*  if (threadIdx.x == 0)
    printf("[sliceSize=%d] slice=%d numSlices=%d "
        "numBigSlices=%d numSmallSlices=%d bigSliceN=%d smallSliceN=%d "
        "lastSliceN=%d\n", *sliceSize, slice, numSlices, numBigSlices,
        numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
*/
}

template<typename T>
struct ReduceScatterKernelArgs {
  // general parameters
  int ThisId;
  int NumGPUs;
  int N;
  int * UserFromRing;

  // some pre-computed sizes
  int SliceSize;
  int ChunkSize;
  int NumChunks;

  int BufferSliceStride;
  int BufferMisalignedN;

  // local and remote input, output, and buffer
  const T * __restrict__ ThisInput;
  volatile T * __restrict__ ThisOutput;
  volatile T * __restrict__ ThisBuffer;
  volatile T * __restrict__ NextBuffer;

  // local and remote flags
  volatile int * __restrict__ ThisNewDataAvailableFlag;
  volatile int * __restrict__ NextNewDataAvailableFlag;
  volatile int * __restrict__ ThisChunkDoneFlag;
  volatile int * __restrict__ PrevChunkDoneFlag;
};

__device__ inline int GetBlock(const int index, const int step,
    const int * const userFromRing, const int numGPUs) {
  return userFromRing[(numGPUs + index - 1 - step) % numGPUs];
}

template<int THREADS, int UNROLL, class FUNC, typename T>
__global__ void ReduceScatterKernel(const ReduceScatterKernelArgs<T> args) {
  if (args.N == 0) return;
  int tid = threadIdx.x;

  for (int chunk = 0; chunk < args.NumChunks; ++chunk) {
    // calculate slice size.  for all chunks except (possibly) the last one,
    // this will just be args.SliceSize. For the last one, it may be smaller
    int bigSliceN   = args.SliceSize;
    int smallSliceN = 0;
    int lastSliceN  = 0;
    int numSlices   = NUM_SUBCHUNKS;
    int numBigSlices   = numSlices;
    int numSmallSlices = 0;

    // last chunk
    if ((chunk + 1 == args.NumChunks) && (args.N % args.ChunkSize > 0))
      CalcLastChunk<THREADS, UNROLL, T>(&bigSliceN, &smallSliceN, &lastSliceN,
          &numSlices, &numBigSlices, &numSmallSlices, args.N, args.NumChunks,
          args.ChunkSize);


    // this offset is only applied to Data pointers, not to Buffer pointers,
    // since we only have one buffer per chunk
    int chunkOffset = chunk * args.ChunkSize;

    // step 0: push data to next GPU
    int step = 0;
    int block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
    int blockOffset = chunkOffset + block * args.N;
    int bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
        ((block * args.BufferMisalignedN) % alignof(PackType));
    int sliceSize;

    if (tid < NUM_THREADS) {
      for(int s=0; s<NUM_SUBCHUNKS; ++s) {
        getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
            numSmallSlices, bigSliceN, smallSliceN, lastSliceN);

        WAIT_FOR_CHUNK(chunk, s);
        Copy<UNROLL, THREADS>(
            args.NextBuffer + bufferOffset,
            args.ThisInput + blockOffset,
            sliceSize);
        __syncthreads();
        bufferOffset += sliceSize;
        blockOffset += sliceSize;
      }
    } else { // Is consumer
      for(int s=0; s<NUM_SUBCHUNKS; ++s) {
        __syncthreads();
        SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
      }
    }

    // steps j with 0 < j < k - 1, where k = number of GPUs: reduce and copy to
    // next GPU
    for (step = 1; step < args.NumGPUs - 1; ++step) {
      int block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
      int blockOffset = chunkOffset + block * args.N;
      int bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
          ((block * args.BufferMisalignedN) % alignof(PackType));

      if (tid < NUM_THREADS) {
        for(int s=0; s<NUM_SUBCHUNKS; ++s) {
            getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
                numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
          WAIT_FOR_NEW_DATA(chunk, s, step);
          Reduce<UNROLL, THREADS, FUNC>(
              args.NextBuffer + bufferOffset,
              args.ThisBuffer + bufferOffset,
              args.ThisInput + blockOffset,
              sliceSize);
          __syncthreads();
          bufferOffset += sliceSize;
          blockOffset += sliceSize;
        }
      } else {
        for(int s=0; s<NUM_SUBCHUNKS; ++s) {
          __syncthreads();
          SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);
        }
      }
    }

    // step k - 1: reduce this buffer and data, which will produce the final
    // result that we store in this data and push to the next GPU
    step = args.NumGPUs - 1;
    block = GetBlock(args.ThisId, step, args.UserFromRing, args.NumGPUs);
    blockOffset = chunkOffset + block * args.N;
    bufferOffset = block * NUM_SUBCHUNKS * args.BufferSliceStride +
        ((block * args.BufferMisalignedN) % alignof(PackType));

    if (tid < NUM_THREADS) {
      int outputOffset = 0;
      for (int s=0; s<NUM_SUBCHUNKS; ++s) {
        getSliceSizeAndChunkSize(&sliceSize, s, numSlices, numBigSlices,
            numSmallSlices, bigSliceN, smallSliceN, lastSliceN);
        WAIT_FOR_NEW_DATA(chunk, s, step);
        Reduce<UNROLL, THREADS, FUNC>(
            args.ThisOutput + (chunkOffset + outputOffset),
            args.ThisBuffer + bufferOffset,
            args.ThisInput + blockOffset,
            sliceSize);
        __syncthreads();
        outputOffset += sliceSize;
        bufferOffset += sliceSize;
        blockOffset += sliceSize;
      }
    } else {
      for (int s=0; s<NUM_SUBCHUNKS; ++s) {
        __syncthreads();
        SIGNAL_NEW_DATA_AVAILABLE(chunk, s, step);

        // signal that chunk is done if this is not the last chunk
        if (chunk + 1 < args.NumChunks) {
          SIGNAL_CHUNK_DONE(chunk, s);
        }
      }
    }
  }

  // wait for the last data to be pushed to us
  if (tid < NUM_THREADS) {
    WAIT_FOR_NEW_DATA(args.NumChunks, NUM_SUBCHUNKS-1, 0);

    if (tid == 0) {
      args.ThisNewDataAvailableFlag[tid] = 0;
      args.ThisChunkDoneFlag[tid] = 0;
    }
  }
}

template<class FUNC, typename T>
ncclResult_t ncclReduceScatterWithTypeAndFunc(const void* sendbuff,
    void* recvbuff, const int recvcount, ncclComm* comm, cudaStream_t stream) {
  if (recvcount == 0) {
    return ncclSuccess;
  }
  int index = comm->ncclId;

  int blockSizeInBytes = recvcount * sizeof(T);
  int misalignedBytes = blockSizeInBytes % alignof(uint64_t);

  assert((int)((misalignedBytes / sizeof(T)) * sizeof(T)) == misalignedBytes);

  int misalignedN = misalignedBytes / sizeof(T);
  assert(misalignedN < (int)(sizeof(uint64_t) / sizeof(T)));

  int paddingN = (misalignedN > 0) ? sizeof(uint64_t) / sizeof(T) : 0;

  // There is one slice per GPU, so a slice can be at most bufferN / numGPUs,
  // where bufferN is the number of elements of type T that fit into the buffer.
  // For efficiency, we want the slice size to be a multiple of UNROLL_SIZE
  int bufferN = comm->buffSize / sizeof(T);
  // we only need buffer for k slices and k*k paddings (we need k paddings per
  // block and we have k blocks)
  int bufferNPerSlice = (bufferN - NUM_SUBCHUNKS * comm->nDev * paddingN) /
      (NUM_SUBCHUNKS * comm->nDev);
  int sliceSize = (bufferNPerSlice / UNROLL_SIZE) * UNROLL_SIZE;

  int nextId = (index + 1) % comm->nDev;
  int prevId = (index + comm->nDev - 1) % comm->nDev;

  ReduceScatterKernelArgs<T> args;

  args.ThisId = index;
  args.NumGPUs = comm->nDev;
  args.N = recvcount;

  /* Block j must end up in recvbuff[j], which lives on device with logical
   * index comm->ringFromUser[j]. But the block ordering does not necessarily
   * follow the ring ordering. Hence the order in which a particular GPU
   * processes the different blocks (the correspondence between the step in
   * the reduction algorithm and the block on which a GPU operates in that
   * particular step) is not the same as the ring order.
   *
   * Say we have 4 GPUs and comm->userFromRing = { 1, 2, 0, 3 }. Then there are 4
   * step in the reduction algorithm and block 0 needs to end up device 2,
   * block 1 on device 0, block 2 on device 1, and block 3 needs to end up on
   * device 3. In the last step of the algorithm, each GPU must be processing
   * the block that will end up on that GPU. The blocks that a GPU has to
   * process in the previous steps is determined by the next step because each
   * GPU only hands off data to the next GPU in the ring.
   *
   * In the above example, we get the following table of which block is
   * processed by each GPU in a given step. The columns correspond to the
   * different GPUs while the rows are the steps in the algorithm.
   *
   *      GPU 0   1   2   3
   * step
   *    0     3   1   2   0
   *    1     0   3   1   2
   *    2     2   0   3   1
   *    3     1   2   0   3
   *
   * We note the the rows in the above table are just comm->userFromRing in the last
   * step and the list is cyclicly permuted to the left for each previous
   * step. The columns, which are what the individual GPUs need to know, are
   * comm->userFromRing traversed backwards and starting at index k-1 for GPU k.
   * These columns are what we put into args.BlockVsStep to tell the GPU which
   * block it needs to be processing at a particular step. */
  args.UserFromRing = comm->devUserFromRing;

  args.SliceSize = sliceSize;
  args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;

  // don't reduce this if we cut the slice size in half below, because if that
  // happens, the last chunk will be larger than the other chunks, and we will
  // need the extra buffer space
  args.BufferSliceStride = args.SliceSize + paddingN;

  args.BufferMisalignedN = misalignedN;

  // avoid a case where we have one or more big chunks and one tiny one
  int remainder = args.N % args.ChunkSize;
  if ((args.N > args.ChunkSize) && (remainder > 0) &&
      (args.N < 5 * args.ChunkSize) && (2 * remainder < args.ChunkSize)) {
    args.SliceSize /= 2;
    args.ChunkSize = NUM_SUBCHUNKS * args.SliceSize;

    // round down so we end up with a big last chunk
    args.NumChunks = args.N / args.ChunkSize;
  } else {
    // round up
    args.NumChunks = (args.N + args.ChunkSize - 1) / args.ChunkSize;
  }

//  printf("sliceSize = %i, chunkSize = %i, numChunks = %i, sliceStride = %i, misalignedN = %i\n", args.SliceSize, args.ChunkSize, args.NumChunks, args.BufferSliceStride, args.BufferMisalignedN);

  args.ThisInput = (const T*)sendbuff;
  args.ThisOutput = (volatile T*)recvbuff;
  args.ThisBuffer = (volatile T*)comm->local[prevId]->buff;
  args.NextBuffer = (volatile T*)comm->remote[nextId]->buff;

  // we need 2 * NUM_SUBCHUNKS flags, so use the first NUM_SUBCHUNKS flags
  // to signal the next GPU that new data is available and the following
  // NUM_SUBCHUNKS to signal the previous GPU that a chunk is finished
  args.ThisNewDataAvailableFlag = comm->local[prevId]->flags;
  args.NextNewDataAvailableFlag = comm->remote[nextId]->flags;
  args.ThisChunkDoneFlag = comm->local[nextId]->flags + 1;
  args.PrevChunkDoneFlag = comm->remote[prevId]->flags + 1;

  ReduceScatterKernel<NUM_THREADS, UNROLL_COUNT, FUNC, T>
      <<<1, NUM_THREADS + NUM_SUBCHUNKS * WARP_SIZE, 0, stream>>>(args);
  return ncclSuccess;
}

template<typename T>
ncclResult_t ncclReduceScatterWithType(const void* sendbuff, void* recvbuff,
    int recvcount, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream) {
  switch (op) {
  case ncclSum:
    return ncclReduceScatterWithTypeAndFunc<FuncSum<T>, T>(
        sendbuff, recvbuff, recvcount, comm, stream);
  case ncclProd:
    return ncclReduceScatterWithTypeAndFunc<FuncProd<T>, T>(
        sendbuff, recvbuff, recvcount, comm, stream);
  case ncclMax:
    return ncclReduceScatterWithTypeAndFunc<FuncMax<T>, T>(
        sendbuff, recvbuff, recvcount, comm, stream);
  case ncclMin:
    return ncclReduceScatterWithTypeAndFunc<FuncMin<T>, T>(
        sendbuff, recvbuff, recvcount, comm, stream);
  }
  return ncclInvalidOperation;
}

class ReduceScatterFunctor {
public:
  ncclResult_t operator()(const void* sendbuff, void* recvbuff,
      int recvcount, ncclDataType_t datatype, ncclRedOp_t op, int /*root*/,
      ncclComm* comm, cudaStream_t stream) {

    switch (datatype) {
    case ncclChar:
      return ncclReduceScatterWithType<char>(sendbuff, recvbuff, recvcount,
          op, comm, stream);
    case ncclInt:
      return ncclReduceScatterWithType<int>(sendbuff, recvbuff, recvcount,
          op, comm, stream);
#ifdef CUDA_HAS_HALF
    case ncclHalf:
      return ncclReduceScatterWithType<half>(sendbuff, recvbuff, recvcount,
          op, comm, stream);
#endif
    case ncclFloat:
      return ncclReduceScatterWithType<float>(sendbuff, recvbuff, recvcount,
          op, comm, stream);
    case ncclDouble:
      return ncclReduceScatterWithType<double>(sendbuff, recvbuff, recvcount,
          op, comm, stream);
    case ncclInt64:
      return ncclReduceScatterWithType<long long>(sendbuff, recvbuff, recvcount,
          op, comm, stream);
    case ncclUint64:
      return ncclReduceScatterWithType<unsigned long long>(sendbuff, recvbuff, recvcount,
          op, comm, stream);
    }
    return ncclInvalidType;
  }
};

extern "C" DSOGLOBAL
ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff,
    int recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm,
    cudaStream_t stream) {
  return enqueue(ReduceScatterFunctor(), sendbuff, recvbuff, recvcount,
      datatype, op, 0, comm, stream);
}