Welcome to mirror list, hosted at ThFree Co, Russian Federation.

THCTensorSort.cu « generic « THC « lib - github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 01165223702bbc22b1e782b264892dcc5d12d0d6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCTensorSort.cu"
#else

// In alignment with default sort on a c++ map, this function
// will permute key and value tensors identically, and
// in such a way that the 'key' tensor is ordered numerically
THC_API void THCTensor_(sortKeyValueInplace)(THCState* state,
                                           THCTensor* key,
                                           THCudaLongTensor* value,
                                           int dim, bool dir) {
  THLongStorage *valueSize = THCudaLongTensor_newSizeOf(state, value);
  THArgCheck(THCTensor_(isSize)(state, key, valueSize), 2,
             "Key tensor must have same size as value tensor");
  THLongStorage_free(valueSize);
  long dims = THCudaLongTensor_nDimension(state, value);
  THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);
  dims = THCTensor_(nDimension)(state, key);
  THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);

  long inElements = THCTensor_(nElement)(state, key);
  long keySliceSize = THCTensor_(size)(state, key, dim);
  long keySlices = inElements / keySliceSize;

  if (THCTensor_(nDimension)(state, key) == 0) {
    // Zero-dim tensor; do nothing
    return;
  }

  // The amount of shared memory and block size is based on
  // 2^ceil(lg(n)); we choose that sorting implementation for a given
  // size.
  long ceilPowerOf2 = nextHighestPowerOf2(keySliceSize);

  // FIXME: We'd have to find some other trick with Thrust to perform a
  // vectorized (key, value) sort by slice segment
  if (ceilPowerOf2 > 2048) {
    THError("sortKeyValueInplace only works for sizes <= 2048 at present");
  }

  // The grid is based on the number of independent slices that we
  // have to sort; one block per slice
  dim3 grid;
  if (!THC_getGridFromTiles(keySlices, grid)) {
    THError("Slice to sort is too large");
  }

#define HANDLE_CASE(TYPE, A, SIZE)                                      \
  do {                                                                  \
    int blockSize = SIZE / 2;                                           \
    if (blockSize < 1) {                                                \
      blockSize = 1;                                                    \
    }                                                                   \
                                                                        \
    dim3 block(blockSize);                                              \
                                                                        \
    if (dir) {                                                          \
      bitonicSortKVInPlace<real, long, A, -1, GTComp<real>, TYPE, SIZE> \
        <<<grid, block, 0, THCState_getCurrentStream(state)>>>(         \
          keyInfo,                                                      \
          keySlices,                                                    \
          (TYPE) keySliceSize,                                          \
          (TYPE) keyInfo.strides[collapseKeyDim],                       \
          valueInfo,                                                    \
          (TYPE) valueInfo.strides[collapseValueDim],                   \
          GTComp<real>());                                              \
    } else {                                                            \
      bitonicSortKVInPlace<real, long, A, -1, LTComp<real>, TYPE, SIZE> \
        <<<grid, block, 0, THCState_getCurrentStream(state)>>>(         \
          keyInfo,                                                      \
          keySlices,                                                    \
          (TYPE) keySliceSize,                                          \
          (TYPE) keyInfo.strides[collapseKeyDim],                       \
          valueInfo,                                                    \
          (TYPE) valueInfo.strides[collapseValueDim],                   \
          LTComp<real>());                                              \
    }                                                                   \
  } while (0)

#define HANDLE_SORT_CASE(TYPE, A)                       \
  {                                                     \
    switch (ceilPowerOf2) {                             \
      case 2048:                                        \
      HANDLE_CASE(TYPE, A, 2048);                       \
      break;                                            \
      case 1024:                                        \
      case 512:                                         \
      case 256:                                         \
      HANDLE_CASE(TYPE, A, 1024);                       \
      break;                                            \
      case 128:                                         \
      case 64:                                          \
      HANDLE_CASE(TYPE, A, 128);                        \
      break;                                            \
      case 32:                                          \
      case 16:                                          \
      case 8:                                           \
      case 4:                                           \
      case 2:                                           \
      HANDLE_CASE(TYPE, A, 32);                         \
      break;                                            \
      case 1:                                           \
      /* Nothing to do, data already sorted */          \
      break;                                            \
      default:                                          \
      assert(false);                                    \
    }                                                   \
  }

  // The constructed key/value tensor info is used to select the slice
  // we are sorting on a per-block basis
  if (TensorUtils<THCTensor>::canUse32BitIndexMath(state, key)) {
    TensorInfo<real, unsigned int> keyInfo =
      getTensorInfo<THCTensor, unsigned int>(state, key);
    keyInfo.reduceDim(dim);
    int collapseKeyDim = keyInfo.collapseDims(dim);

    TensorInfo<long, unsigned int> valueInfo =
      getTensorInfo<THCudaLongTensor, unsigned int>(state, value);
    valueInfo.reduceDim(dim);
    int collapseValueDim = valueInfo.collapseDims(dim);

    if (keyInfo.isContiguous()) {
      HANDLE_SORT_CASE(unsigned int, -2);
    } else {
      switch (keyInfo.dims) {
        case 2:
          HANDLE_SORT_CASE(unsigned int, 2);
          break;
        default:
          HANDLE_SORT_CASE(unsigned int, -1);
          break;
      }
    }
  } else {
    TensorInfo<real, unsigned long> keyInfo =
      getTensorInfo<THCTensor, unsigned long>(state, key);
    keyInfo.reduceDim(dim);
    int collapseKeyDim = keyInfo.collapseDims(dim);

    TensorInfo<long, unsigned long> valueInfo =
      getTensorInfo<THCudaLongTensor, unsigned long>(state, value);
    valueInfo.reduceDim(dim);
    int collapseValueDim = valueInfo.collapseDims(dim);

    // long case is rare, just instantiate the generic version
    HANDLE_SORT_CASE(unsigned long, -1);
  }
#undef HANDLE_CASE
#undef HANDLE_SORT_CASE
#undef HANDLE_A_CASE

  THCudaCheck(cudaGetLastError());
}

void sortViaThrust(THCState* state,
                   THCTensor* sorted,
                   THCudaLongTensor* indices,
                   THCTensor* input,
                   int dim, bool dir) {
  long nDims = THCTensor_(nDimension)(state, input);

  long totalElements = THCTensor_(nElement)(state, input);
  long sliceSize = THCTensor_(size)(state, input, dim);
  long sliceStride = THCTensor_(stride)(state, input, dim);

  // We perform a vectorized segmented sort in Thrust.
  // Say we are sorting a (2, 3) tensor. We have in flattened form:
  // values 0.4 1.2 5.3 6.2 1.3 2.3
  // indices  0   1   2   3   4   5
  // where indices is a global index (across all slices)

  // First we sort by values, globally:
  // values 6.2 5.3 2.3 1.2 1.3 0.4
  // indices  3   2   5   1   4   0

  // Then we stable sort by segment, which is index / 3:
  // values 5.3 1.2 0.4 6.2 2.3 1.3
  // indices  2   1   0   3   5   4

  // Then we translate the global index to a per-slice Lua index
  // (index % 3) + 1:
  // values 5.3 1.2 0.4 6.2 2.3 1.3
  // indices  3   2   1   1   3   2

  // This method can only work if the slice we are sorting (`dim`) is
  // innermost, and both values and indices are contiguous. We do this
  // by re-arranging the input into this form as needed, which will
  // unfortunately allocate memory if the request is not in this form.
  // Vectorized sort is slower than iterated sort if the number of
  // slices is small (since we're sorting twice, instead of invoking a
  // smaller sort `numSlices` times), but the Thrust sort
  // implementation here is a catch-all, so we're not looking for
  // efficiency, but instead correctness.
  THCTensor_(copy)(state, sorted, input);
  THCTensor* trKeys = THCTensor_(newWithTensor)(state, sorted);
  THCudaLongTensor* trIndices = THCudaLongTensor_newWithTensor(state, indices);

  // Transpose dim to innermost
  if (dim != nDims - 1) {
    THCTensor_(transpose)(state, trKeys, NULL, dim, nDims - 1);
    THCudaLongTensor_transpose(state, trIndices, NULL, dim, nDims - 1);
  }

  // Thrust must operate on a contiguous layout
  THCTensor* trContigKey = THCTensor_(newContiguous)(state, trKeys);
  THCudaLongTensor* trContigIndices = THCudaLongTensor_newContiguous(state, trIndices);

  THCTensor_(free)(state, trKeys);
  THCudaLongTensor_free(state, trIndices);

  thrust::device_ptr<real> keyIter(THCTensor_(data)(state, trContigKey));

  // Since we are composing a global index across all segments rather
  // than a per-segment index, we treat the memory as int so we don't
  // have problems sorting slices < 2^24 but where the entire tensor
  // has more than 2^24 elements
  thrust::device_ptr<long>
    indexIter((long*) THCudaLongTensor_data(state, trContigIndices));

  // Fill the indices with a global index across all slices
  thrust::counting_iterator<long> countIter(0);

  thrust::copy(
#if CUDA_VERSION >= 7000
    thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
    countIter, countIter + totalElements, indexIter);

  // First, we sort globally (across all slices) according to key
  // (the values we're sorting)
  if (dir) {
    thrust::stable_sort_by_key(
#if CUDA_VERSION >= 7000
      thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
      keyIter, keyIter + totalElements, indexIter, ThrustGTOp<real>());
  } else {
    thrust::stable_sort_by_key(
#if CUDA_VERSION >= 7000
      thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
      keyIter, keyIter + totalElements, indexIter, ThrustLTOp<real>());
  }

  // Then, re-sort according to slice that each index is
  // in. This completes the segment sort in Thrust, since we're
  // stably sorting here, preserving the relative order of values
  // per each slice
  thrust::stable_sort_by_key(
#if CUDA_VERSION >= 7000
    thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
    indexIter, indexIter + totalElements, keyIter,
    SliceComp(sliceSize));

  // Translate the global integer 0-based index to a per-slice real
  // Lua index
  thrust::for_each(
#if CUDA_VERSION >= 7000
    thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
    indexIter, indexIter + totalElements,
    GlobalIndexToPerSliceIndex(sliceSize));

  // Reverse the transposition as needed
  if (dim != nDims - 1) {
    THCTensor_(transpose)(state, trContigKey, NULL, dim, nDims - 1);
    THCudaLongTensor_transpose(state, trContigIndices, NULL, dim, nDims - 1);
  }

  // Then copy back to the expected output
  THCTensor_(freeCopyTo)(state, trContigKey, sorted);
  THCudaLongTensor_freeCopyTo(state, trContigIndices, indices);
}

THC_API void THCTensor_(sort)(THCState* state,
                               THCTensor *sorted,
                               THCudaLongTensor *indices,
                               THCTensor *input,
                               int dim, int order) {
  THAssert(THCTensor_(checkGPU)(state, 2, sorted, input));
  THAssert(THCudaLongTensor_checkGPU(state, 1, indices));
  long dims = THCTensor_(nDimension)(state, sorted);
  THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING);
  dims = THCTensor_(nDimension)(state, input);
  THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING);
  dims = THCudaLongTensor_nDimension(state, indices);
  THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING);

  // Make sure sufficient output space is allocated
  THCTensor_(resizeAs)(state, sorted, input);
  THLongStorage *inputSize = THCTensor_(newSizeOf)(state, input);
  THCudaLongTensor_resize(state, indices, inputSize, NULL);
  THLongStorage_free(inputSize);

  // How large are the slices that we are sorting?
  long sliceSize = THCTensor_(size)(state, input, dim);

  if (sliceSize <= 2048) {
    // Fill `indices` (the values) with the
    // slice-relative index.
    THCudaLongTensor_fillSliceWithIndex(state, indices, dim);

    // We sort k/v pairs in-place; copy unsorted input to output
    THCTensor_(copy)(state, sorted, input);

    // Sort using our in-place k/v kernel that supports arbitrary
    // layout
    THCTensor_(sortKeyValueInplace)(state, sorted, indices, dim, order);
  } else {
    // Otherwise, fall back upon Thrust, which handles all other cases
    // (potentially slowly, with extra copies/memory allocations)
    sortViaThrust(state, sorted, indices, input, dim, (bool) order);
  }

  THCudaCheck(cudaGetLastError());
}

#endif