Welcome to mirror list, hosted at ThFree Co, Russian Federation.

THCReduce.cuh « THC « lib - github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: b7df49b9c13d88de6a98929e405eda88e7b1b165 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
#ifndef THC_REDUCE_INC
#define THC_REDUCE_INC

//
// This file contains dimension reduction operation functions and
// kernels that work on both contiguous and non-contiguous tensor
// arguments of arbitrary (up to MAX_CUTORCH_DIMS) dimensioned
// arguments without copying or temporary storage.
//

#include "THCTensorTypeUtils.cuh"
#include "THCReduceApplyUtils.cuh"

// Threads per thread block
#define THC_NONCONTIG_REDUCE_BLOCK_SIZE 32 * 16

template <typename IndexType>
__device__ __forceinline__ IndexType getReduceNoncontigDimSliceIndex() {
  // Each thread handles one slice
  return getLinearBlockId<IndexType>() * THC_NONCONTIG_REDUCE_BLOCK_SIZE + threadIdx.x;
}

// Kernel that handles an entire reduction of a slice of a tensor per each thread
template <typename ModifyOp,
          typename ReduceOp,
          typename T,
          typename IndexType,
          int ADims, int BDims>
#if __CUDA_ARCH__ >= 350
__launch_bounds__(32 * 16, 4)
#endif
__global__ void
kernelReduceNoncontigDim_shared(TensorInfo<T, IndexType> out,
                         TensorInfo<T, IndexType> in,
                         IndexType reductionStride,
                         IndexType reductionSize,
                         IndexType totalSlices,
                         T init,
                         ModifyOp modifyOp,
                         ReduceOp reduceOp) {

  IndexType sliceIndex  = blockIdx.x * blockDim.x + threadIdx.x;
  IndexType sliceStride = gridDim.x * blockDim.x;

  __shared__ T local_reduce[THC_NONCONTIG_REDUCE_BLOCK_SIZE];
  T* shmem = &local_reduce[threadIdx.x + threadIdx.y * blockDim.x];
  T load_reg[4];
  T local_reg;

  for(;sliceIndex<totalSlices; sliceIndex+=sliceStride){
    local_reg = init;

    const IndexType outOffset =
      IndexToOffset<T, IndexType, ADims>::get(sliceIndex, out);
    const IndexType inOffset =
      IndexToOffset<T, IndexType, BDims>::get(sliceIndex, in);

    //Unroll this loop
    //for(IndexType i=threadIdx.y; i<reductionSize; i+=blockDim.y){
    //  local_reg += in[inOffset + i * reductionStride];
    //}
    for(IndexType i=threadIdx.y; i<reductionSize; i+=blockDim.y*4){
      if(i + blockDim.y * 3 < reductionSize){
        load_reg[0] = modifyOp(in.data[inOffset + (i + blockDim.y * 0) * reductionStride]);
        load_reg[1] = modifyOp(in.data[inOffset + (i + blockDim.y * 1) * reductionStride]);
        load_reg[2] = modifyOp(in.data[inOffset + (i + blockDim.y * 2) * reductionStride]);
        load_reg[3] = modifyOp(in.data[inOffset + (i + blockDim.y * 3) * reductionStride]);

        local_reg = reduceOp(local_reg,
                             reduceOp(
                                      reduceOp(load_reg[0], load_reg[1]),
                                      reduceOp(load_reg[2], load_reg[3])
                                      )
                             );

      }else if(i + blockDim.y * 2 < reductionSize){
        load_reg[0] = modifyOp(in.data[inOffset + (i + blockDim.y * 0) * reductionStride]);
        load_reg[1] = modifyOp(in.data[inOffset + (i + blockDim.y * 1) * reductionStride]);
        load_reg[2] = modifyOp(in.data[inOffset + (i + blockDim.y * 2) * reductionStride]);

        local_reg = reduceOp(
                             reduceOp(load_reg[0], load_reg[1]),
                             reduceOp(load_reg[2], local_reg)
                             );

        }else if( (i + blockDim.y) < reductionSize){
        load_reg[0] = modifyOp(in.data[inOffset + (i + blockDim.y * 0) * reductionStride]);
        load_reg[1] = modifyOp(in.data[inOffset + (i + blockDim.y * 1) * reductionStride]);
        local_reg = reduceOp(
                             local_reg, reduceOp(load_reg[0], load_reg[1])
                             );

      }else if(i + blockDim.y * 0 < reductionSize){
        local_reg = reduceOp(local_reg, modifyOp(in.data[inOffset + i * reductionStride]));
      }
    }

    *shmem = local_reg;
    int dimy = blockDim.y;
    while(dimy > 1){
      __syncthreads();
      if( threadIdx.y == 0 && (dimy%2 != 0) ){
        *shmem = reduceOp(*shmem, *(shmem + (dimy-1) * blockDim.x) );
      }
      if(threadIdx.y < dimy/2){
        *shmem = reduceOp(*shmem, *(shmem + (dimy/2)*blockDim.x) );
      }
      dimy /= 2;
    }
    if(threadIdx.y == 0)
      out.data[outOffset] = *shmem;
  }
}


// Kernel that handles an entire reduction of a slice of a tensor per each thread
template <typename ModifyOp,
          typename ReduceOp,
          typename T,
          typename IndexType,
          int ADims, int BDims>
#if __CUDA_ARCH__ >= 350
__launch_bounds__(32 * 16, 4)
#endif
__global__ void
kernelReduceNoncontigDim(TensorInfo<T, IndexType> out,
                         TensorInfo<T, IndexType> in,
                         IndexType reductionStride,
                         IndexType reductionSize,
                         IndexType totalSlices,
                         T init,
                         ModifyOp modifyOp,
                         ReduceOp reduceOp) {
  const IndexType sliceIndex = getReduceNoncontigDimSliceIndex<IndexType>();

  if (sliceIndex >= totalSlices) {
    return;
  }

  // Each thread picks a point in `out` and `in` for which it is
  // producing the reduction
  const IndexType outOffset =
    IndexToOffset<T, IndexType, ADims>::get(sliceIndex, out);
  const IndexType inBaseOffset =
    IndexToOffset<T, IndexType, BDims>::get(sliceIndex, in);

  // For each point in reductionSize, reduce into `r`
  IndexType inOffset = inBaseOffset;
  T r = init;

  for (IndexType i = 0; i < reductionSize; ++i) {
    r = reduceOp(r, modifyOp(in.data[inOffset]));
    inOffset += reductionStride;
  }

  // Write out reduced value
  out.data[outOffset] = r;
}

template <typename IndexType>
__device__ __forceinline__ IndexType getReduceContigDimSliceIndex() {
  // Each block handles one slice
  return getLinearBlockId<IndexType>();
}

// Kernel that handles an entire reduction of a slice of a tensor per
// each block
template <typename ModifyOp,
          typename ReduceOp,
          typename T,
          typename IndexType,
          int ADims, int BDims>
__global__ void
kernelReduceContigDim(TensorInfo<T, IndexType> out,
                      TensorInfo<T, IndexType> in,
                      IndexType reductionSize,
                      IndexType totalSlices,
                      T init,
                      ModifyOp modifyOp,
                      ReduceOp reduceOp) {
  const IndexType sliceIndex = getReduceContigDimSliceIndex<IndexType>();

  if (sliceIndex >= totalSlices) {
    return;
  }

  // Get the offset in `out` for the reduction
  const IndexType outOffset =
    IndexToOffset<T, IndexType, ADims>::get(sliceIndex, out);

  // Get the base offset in `in` for this block's reduction
  const IndexType inBaseOffset =
    IndexToOffset<T, IndexType, BDims>::get(sliceIndex, in);

  // Each thread in the block will reduce some subset of elements in
  // the slice. The elements are guaranteed contiguous starting at
  // `inBaseOffset`.
  T r = init;
  for (IndexType i = threadIdx.x; i < reductionSize; i += blockDim.x) {
    r = reduceOp(r, modifyOp(in.data[inBaseOffset + i]));
  }

  // Reduce within the block
  // FIXME: extern name
  extern __shared__ char smemChar[];
  T* smem = (T*) smemChar;
  r = reduceBlock<T, ReduceOp>(smem, blockDim.x, r, reduceOp, init);

  if (threadIdx.x == 0) {
    // Write out reduced value
    out.data[outOffset] = r;
  }
}

inline dim3 getNoncontigReduceBlock() {
  return dim3(THC_NONCONTIG_REDUCE_BLOCK_SIZE);
}

inline dim3 getContigReduceBlock(ptrdiff_t numSlices, long reductionSize) {
  // If the number of slices is low but the reduction dimension size
  // is high, then we should increase block size for greater parallelism.
  // Aim for at least 32 warps per SM (assume 15 SMs; don't bother
  // inquiring the real number for now).
  int maxWarps = 4; // better occupancy if many blocks are around
  // For numSlices > 15 * 8, there are > 32 warps active per SM.
  if (numSlices < 15 * 8) {
    maxWarps = 8;
    if (numSlices < 15 * 4) {
      maxWarps = 16;
      if (numSlices < 15 * 2) {
        maxWarps = 32;
      }
    }
  }

  // Scale up block size based on the reduction dimension size
  long warpsInReductionSize = THCCeilDiv(reductionSize, 32L);
  int numWarps = warpsInReductionSize > (long) maxWarps ?
    maxWarps : (int) warpsInReductionSize;

  return dim3(numWarps * 32);
}

inline bool getNoncontigReduceGrid(ptrdiff_t elements, dim3& grid) {
  // One output point per thread
  return THC_getGridFromTiles(THCCeilDiv(elements,
                                         (ptrdiff_t) THC_NONCONTIG_REDUCE_BLOCK_SIZE), grid);
}

inline bool getContigReduceGrid(ptrdiff_t elements, dim3& grid) {
  // One output point per block
  return THC_getGridFromTiles(elements, grid);
}

// Performs a reduction out[..., 0, ...] = reduce_i(modify(in[..., i, ...])) for
// all in where i and the out's 0 are indexed at dimension `dim`
template <typename TensorType, typename ModifyOp, typename ReduceOp>
bool THC_reduceDim(THCState* state,
                   TensorType* out,
                   TensorType* in,
                   const ModifyOp& modifyOp,
                   const ReduceOp& reduceOp,
                   typename TensorUtils<TensorType>::DataType init,
                   int dim,
                   int keepdim) {
  ptrdiff_t inElements = TensorUtils<TensorType>::getNumElements(state, in);

  long reductionSize = TensorUtils<TensorType>::getSize(state, in, dim);
  long reductionStride = TensorUtils<TensorType>::getStride(state, in, dim);
  ptrdiff_t outElements = inElements / reductionSize;

  if (TensorUtils<TensorType>::getDims(state, out) > MAX_CUTORCH_DIMS ||
      TensorUtils<TensorType>::getDims(state, in) > MAX_CUTORCH_DIMS) {
    return false;
  }

  if (TensorUtils<TensorType>::getDims(state, in) == 0) {
    // Zero-dim tensor; do nothing
    return true;
  }

  // Is the reduction dimension contiguous? If so, then we can use a
  // shared memory reduction kernel to increase performance.
  bool contigReduction = (reductionStride == 1);

  dim3 block;
  dim3 grid;
  int smemSize = 0; // contiguous reduction uses smem
  if (contigReduction) {
    if (!getContigReduceGrid(outElements, grid)) {
      return false;
    }

    block = getContigReduceBlock(outElements, reductionSize);
    smemSize = sizeof(typename TensorUtils<TensorType>::DataType) * block.x;
  } else {
    if (!getNoncontigReduceGrid(outElements, grid)) {
      return false;
    }

    block = getNoncontigReduceBlock();

    if(outElements <= 4096){
        //x dim does different columns
        //y dim helps with the same reduction
        //If we only have 8 loops, don't bother sharing work across ydim
        unsigned long ydim = THCCeilDiv(reductionSize, 8L);

        //don't want y dim any bigger than 16, leaving min x dim to 32
        ydim = min((unsigned long) 16, ydim);

        block = dim3(THC_NONCONTIG_REDUCE_BLOCK_SIZE, 1, 1);
        while(ydim > 1){
          block.x /= 2;
          block.y *= 2;
          ydim /= 2;
        }
        THC_getGridFromTiles(THCCeilDiv(outElements, (long)block.x), grid);

    }
  }
  // Resize out to correspond to the reduced size
  THLongStorage* sizes = TensorUtils<TensorType>::newSizeOf(state, in);
  THLongStorage_set(sizes, dim, 1);
  TensorUtils<TensorType>::resize(state, out, sizes, NULL);
  THLongStorage_free(sizes);

  // It is possible that the tensor dimensions are able to be collapsed,
  // and thus we can reduce the actual code complexity of the copy by
  // exploiting this knowledge statically, since the div/mod is the
  // most expensive part of the operation, more so than memory accesses.
  // For instance, when copying a non-contiguous to a contiguous tensor
  // (or vice versa), the contiguous tensor can be collapsed to one
  // dimension, and the loop to translate the linear index to the array
  // index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, OUT, IN)                                      \
  if (contigReduction) {                                                \
    kernelReduceContigDim<ModifyOp, ReduceOp,                           \
                          typename TensorUtils<TensorType>::DataType,   \
                          TYPE, OUT, IN>                                \
      <<<grid, block, smemSize, THCState_getCurrentStream(state)>>>(    \
        outInfo, inInfo, reductionSize,                                 \
        (TYPE) outElements, init, modifyOp, reduceOp);                  \
  } else {                                                              \
    if(block.y == 1){                                                   \
        kernelReduceNoncontigDim<ModifyOp, ReduceOp,                    \
                           typename TensorUtils<TensorType>::DataType,  \
                           TYPE, OUT, IN>                               \
        <<<grid, block, 0, THCState_getCurrentStream(state)>>>(         \
                       outInfo, inInfo, reductionStride, reductionSize, \
        (TYPE) outElements, init, modifyOp, reduceOp);                  \
    }else{                                                              \
        kernelReduceNoncontigDim_shared<ModifyOp, ReduceOp,             \
                           typename TensorUtils<TensorType>::DataType,  \
                           TYPE, OUT, IN>                               \
        <<<grid, block, 0, THCState_getCurrentStream(state)>>>(         \
                       outInfo, inInfo, reductionStride, reductionSize, \
                       (TYPE) outElements, init, modifyOp, reduceOp);   \
    }                                                                   \
  }                                                                     \

#define HANDLE_IN_CASE(TYPE, OUT, IN)                     \
  {                                                       \
    if (inInfo.isContiguous()) {                          \
      HANDLE_CASE(TYPE, OUT, -2);                         \
    } else {                                              \
      switch (IN) {                                       \
        case 1:                                           \
          HANDLE_CASE(TYPE, OUT, 1);                      \
          break;                                          \
        case 2:                                           \
          HANDLE_CASE(TYPE, OUT, 2);                      \
          break;                                          \
        default:                                          \
          HANDLE_CASE(TYPE, OUT, -1);                     \
          break;                                          \
      }                                                   \
    }                                                     \
  }

#define HANDLE_OUT_CASE(TYPE, OUT, IN)                 \
  {                                                    \
    if (outInfo.isContiguous()) {                      \
      HANDLE_IN_CASE(TYPE, -2, IN);                    \
    } else {                                           \
      switch (OUT) {                                   \
        case 1:                                        \
          HANDLE_IN_CASE(TYPE, 1, IN);                 \
          break;                                       \
        case 2:                                        \
          HANDLE_IN_CASE(TYPE, 2, IN);                 \
          break;                                       \
        default:                                       \
          HANDLE_IN_CASE(TYPE, -1, IN);                \
          break;                                       \
      }                                                \
    }                                                  \
  }

  if (TensorUtils<TensorType>::canUse32BitIndexMath(state, out) &&
      TensorUtils<TensorType>::canUse32BitIndexMath(state, in)) {
    TensorInfo<typename TensorUtils<TensorType>::DataType,
               unsigned int> outInfo =
      getTensorInfo<TensorType, unsigned int>(state, out);
    outInfo.collapseDims();

    TensorInfo<typename TensorUtils<TensorType>::DataType,
               unsigned int> inInfo =
      getTensorInfo<TensorType, unsigned int>(state, in);
    inInfo.reduceDim(dim);
    inInfo.collapseDims();

    HANDLE_OUT_CASE(unsigned int, outInfo.dims, inInfo.dims);
  } else {
    TensorInfo<typename TensorUtils<TensorType>::DataType,
               unsigned long> outInfo =
      getTensorInfo<TensorType, unsigned long>(state, out);
    outInfo.collapseDims();

    TensorInfo<typename TensorUtils<TensorType>::DataType,
               unsigned long> inInfo =
      getTensorInfo<TensorType, unsigned long>(state, in);
    inInfo.reduceDim(dim);
    inInfo.collapseDims();

    // For large tensors, we only compile the completely contiguous
    // version and the completely generic version, to reduce
    // compilation time.
    if (outInfo.isContiguous() && inInfo.isContiguous()) {
      HANDLE_CASE(unsigned long, -2, -2);
    } else {
      HANDLE_CASE(unsigned long, -1, -1);
    }
  }
#undef HANDLE_CASE
#undef HANDLE_IN_CASE
#undef HANDLE_OUT_CASE


  if (!keepdim) {
    TensorUtils<TensorType>::squeeze1d(state, out, out, dim);
  }
  return true;
}

#undef THC_NONCONTIG_REDUCE_BLOCK_SIZE

#endif // THC_REDUCE_INC