Welcome to mirror list, hosted at ThFree Co, Russian Federation.

THCReduceAll.cuh « THC « lib - github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: cfe40fda242d1b68df2b0126139f8a745c083769 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#ifndef THC_REDUCEALL_INC
#define THC_REDUCEALL_INC

//
// This file contains dimension reduction operation functions and
// kernels that work on both contiguous and non-contiguous tensor
// arguments of arbitrary (up to MAX_CUTORCH_DIMS) dimensioned
// arguments without copying or temporary storage, for reducing an
// entire tensor to one value.
//

#include "THCReduceApplyUtils.cuh"

// Size per each reduction block
#define THC_REDUCE_ALL_BLOCK_SIZE 1024L

// Cutoff size for two-pass reduction
#define THC_TWO_PASS_REDUCTION_SIZE 2048L

// Kernel that handles an entire reduction of a tensor in one pass
template <typename ModifyOp, typename ReduceOp, typename IndexType, int ADims>
__global__ void
THCudaTensor_reduceAll(TensorInfo<IndexType> in,
                       IndexType totalElements,
                       float init,
                       ModifyOp modifyOp,
                       ReduceOp reduceOp,
                       float* out) {
  // With a block-wide stride, have each thread perform its own reduction.
  float r = init;
  for (IndexType i = threadIdx.x; i < totalElements; i += blockDim.x) {
    const IndexType inOffset = IndexToOffset<IndexType, ADims>::get(i, in);
    r = reduceOp(r, modifyOp(in.data[inOffset]));
  }

  // Reduce within the block
  extern __shared__ float smem[];
  r = reduceBlock<float, ReduceOp>(smem, blockDim.x, r, reduceOp, init);

  if (threadIdx.x == 0) {
    // Write out reduced value
    *out = r;
  }
}

template <typename IndexType>
__device__ __forceinline__ IndexType getStartIndex(IndexType totalSize) {
  IndexType sizePerBlock = THCCeilDiv(totalSize, (IndexType) gridDim.x);
  return blockIdx.x * sizePerBlock;
}

template <typename IndexType>
__device__ __forceinline__ IndexType getEndIndex(IndexType totalSize) {
  IndexType sizePerBlock = THCCeilDiv(totalSize, (IndexType) gridDim.x);
  return min((IndexType) ((blockIdx.x + 1) * sizePerBlock), totalSize);
}

// Kernel that handles an entire reduction of a tensor in two passes
template <typename ModifyOp, typename ReduceOp, typename IndexType, int ADims>
__global__ void
THCudaTensor_reduceAllPass1(TensorInfo<IndexType> in,
                            IndexType totalElements,
                            float init,
                            ModifyOp modifyOp,
                            ReduceOp reduceOp,
                            float* scratchSpace) {
  const IndexType startIndex = getStartIndex<IndexType>(totalElements);
  const IndexType endIndex = getEndIndex<IndexType>(totalElements);

  // With a block-wide stride, have each thread perform its own reduction.
  float r = init;
  for (IndexType i = startIndex + threadIdx.x; i < endIndex; i += blockDim.x) {
    const IndexType inOffset = IndexToOffset<IndexType, ADims>::get(i, in);
    r = reduceOp(r, modifyOp(in.data[inOffset]));
  }

  // Reduce within the block
  extern __shared__ float smem[];
  r = reduceBlock<float, ReduceOp>(smem, blockDim.x, r, reduceOp, init);

  if (threadIdx.x == 0) {
    // Write out block-wide reduced value
    scratchSpace[blockIdx.x] = r;
  }
}

template <typename ReduceOp, typename IndexType>
__global__ void
THCudaTensor_reduceAllPass2(int numPass1Blocks,
                            float init,
                            ReduceOp reduceOp,
                            float* scratchSpace,
                            float* out) {
  float r = init;
  if (threadIdx.x < numPass1Blocks) {
    r = scratchSpace[threadIdx.x];
  }

  // Reduce within the block
  extern __shared__ float smem[];
  r = reduceBlock<float, ReduceOp>(smem, numPass1Blocks, r, reduceOp, init);

  if (threadIdx.x == 0) {
    *out = r;
  }
}

// Perform a two-pass reduction if the tensor is large enough to
// warrant it.
inline bool isTwoPassReductionSize(long elements) {
  return (elements > THC_TWO_PASS_REDUCTION_SIZE);
}

inline long getTwoPassBlocks(THCState* state, long elements) {
  long numBlocks = THCCeilDiv(elements, THC_REDUCE_ALL_BLOCK_SIZE);

  // We can only have as many blocks as there is scratch space
  long scratchSpace =
    THCState_getCurrentDeviceScratchSpaceSize(state) / sizeof(float);
  THAssert(scratchSpace > 0);

  if (numBlocks > scratchSpace) {
    numBlocks = scratchSpace;
  }

  return numBlocks;
}

// Get the block/grid size that we want
inline void getPass1ReduceBlockGrid(THCState* state, long elements,
                                    dim3& grid, dim3& block) {
  grid = dim3(getTwoPassBlocks(state, elements));
  block = dim3(THC_REDUCE_ALL_BLOCK_SIZE);
}

inline void getPass2ReduceBlockGrid(THCState* state, long elements,
                                    dim3& grid, dim3& block) {
  grid = dim3(1);
  // We only need as many threads as there were blocks originally
  block = dim3(getTwoPassBlocks(state, elements));
}

inline void getSinglePassReduceBlockGrid(long elements,
                                         dim3& grid, dim3& block) {
  grid = dim3(1);
  block = dim3(THC_REDUCE_ALL_BLOCK_SIZE);
}

template <typename ModifyOp, typename ReduceOp, typename IndexType, int ADims>
void callReduceAll(THCState* state,
                   const TensorInfo<IndexType>& in,
                   long totalElements,
                   float init,
                   const ModifyOp& modifyOp,
                   const ReduceOp& reduceOp,
                   float* devOut) {
  dim3 grid;
  dim3 block;

  if (isTwoPassReductionSize(totalElements)) {
    getPass1ReduceBlockGrid(state, totalElements, grid, block);
    size_t smemSize = block.x * sizeof(float);

    THCudaTensor_reduceAllPass1<ModifyOp, ReduceOp, IndexType, ADims>
      <<<grid, block, smemSize, THCState_getCurrentStream(state)>>>(
        in, (IndexType) totalElements, init, modifyOp, reduceOp,
        (float*) THCState_getCurrentDeviceScratchSpace(state));

    int numPass1Blocks = grid.x;
    getPass2ReduceBlockGrid(state, totalElements, grid, block);
    smemSize = block.x * sizeof(float);

    THCudaTensor_reduceAllPass2<ReduceOp, IndexType>
      <<<grid, block, smemSize, THCState_getCurrentStream(state)>>>(
        numPass1Blocks, init, reduceOp,
        (float*) THCState_getCurrentDeviceScratchSpace(state),
        devOut);

  } else {
    getSinglePassReduceBlockGrid(totalElements, grid, block);
    size_t smemSize = block.x * sizeof(float);

    THCudaTensor_reduceAll<ModifyOp, ReduceOp, IndexType, ADims>
      <<<grid, block, smemSize, THCState_getCurrentStream(state)>>>(
        in, (IndexType) totalElements, init, modifyOp, reduceOp, devOut);
  }
}

// Reduces the entire tensor to one floating-point value. `out` points
// to host-resident memory.
template <typename ModifyOp, typename ReduceOp>
bool THCudaTensor_reduceAll(THCState* state,
                            THCudaTensor* in,
                            const ModifyOp& modifyOp,
                            const ReduceOp& reduceOp,
                            float init,
                            float* out,
                            int outOnDevice) {
  long inElements = THCudaTensor_nElement(state, in);

  if (THCudaTensor_nDimension(state, in) > MAX_CUTORCH_DIMS) {
    return false;
  }

  if (THCudaTensor_nDimension(state, in) == 0) {
    // Zero-dim tensor; do nothing
    *out = init;
    return true;
  }

  float* devOut = out;
  if (!outOnDevice) {
    // Use the stream-specific scratch space for the reduction kernel
    // to write out its value
    devOut = (float*) THCState_getCurrentDeviceScratchSpace(state);
  }

  // It is possible that the tensor dimensions are able to be collapsed,
  // and thus we can reduce the actual code complexity of the copy by
  // exploiting this knowledge statically, since the div/mod is the
  // most expensive part of the operation, more so than memory accesses.
  // For instance, when copying a non-contiguous to a contiguous tensor
  // (or vice versa), the contiguous tensor can be collapsed to one
  // dimension, and the loop to translate the linear index to the array
  // index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, IN)                                           \
  callReduceAll<ModifyOp, ReduceOp, TYPE, IN>(                          \
    state, inInfo, inElements, init, modifyOp, reduceOp, devOut);

#define HANDLE_IN_CASE(TYPE, IN)                    \
  {                                                 \
    if (inInfo.isContiguous()) {                    \
      HANDLE_CASE(TYPE, -2);                        \
    } else {                                        \
      switch (IN) {                                 \
        case 1:                                     \
          HANDLE_CASE(TYPE, 1);                     \
          break;                                    \
        case 2:                                     \
          HANDLE_CASE(TYPE, 2);                     \
          break;                                    \
        case 3:                                     \
          HANDLE_CASE(TYPE, 3);                     \
          break;                                    \
        default:                                    \
          HANDLE_CASE(TYPE, -1);                    \
          break;                                    \
      }                                             \
    }                                               \
  }

  if (THC_canUse32BitIndexMath(state, in)) {
    TensorInfo<unsigned int> inInfo(state, in);
    inInfo.collapseDims();

    HANDLE_IN_CASE(unsigned int, inInfo.dims);
  } else {
    TensorInfo<unsigned long long> inInfo(state, in);
    inInfo.collapseDims();

    // For large tensors, we only compile the completely contiguous
    // version and the completely generic version, to reduce
    // compilation time.
    if (inInfo.isContiguous()) {
      HANDLE_IN_CASE(unsigned long long, -2);
    } else {
      HANDLE_IN_CASE(unsigned long long, -1);
    }
  }
#undef HANDLE_CASE
#undef HANDLE_IN_CASE

  // If our destination is not on the device, copy the value back to
  // the host (synchronous!)
  if (!outOnDevice) {
    cudaMemcpy(out, devOut, sizeof(float), cudaMemcpyDeviceToHost);
  }

  return true;
}

#undef THC_REDUCE_ALL_BLOCK_SIZE
#undef THC_TWO_PASS_REDUCTION_SIZE

#endif // THC_REDUCEALL_INC