Welcome to mirror list, hosted at ThFree Co, Russian Federation.

VolumetricDilatedMaxPooling.cu « generic « THCUNN « lib - github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: bd43bc066fcd452f166c38acefa7efbc2a8907b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/VolumetricDilatedMaxPooling.cu"
#else

#define UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW:                         \
  cuda_VolumetricDilatedMaxPooling_updateOutput<KW><<<grid, block,             \
    0, THCState_getCurrentStream(state)>>>(                             \
    cudaInput, cudaIndices, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW,\
    dilationT, dilationH, dilationW, offsetZ); \
    break

void THNN_(VolumetricDilatedMaxPooling_updateOutput)(
           THCState *state,
           THCTensor *input,
           THCTensor *output,
           THCIndexTensor *indices,
           int kT, int kW, int kH,
           int dT, int dW, int dH,
           int padT, int padW, int padH,
           int dilationT, int dilationW, int dilationH,
           bool ceilMode)
{
  int batchSize;
  int inputSlices;
  int inputTime;
  int inputHeight;
  int inputWidth;
  int outputTime;
  int outputHeight;
  int outputWidth;

  THCUNN_assertSameGPU_generic(state, 3, input, indices, output);

  if (THCTensor_(nDimension)(state, input) == 4)
  {
    THArgCheck(
      THCTensor_(size)(state, input, 1) >= kT &&
      THCTensor_(size)(state, input, 2) >= kH &&
      THCTensor_(size)(state, input, 3) >= kW, 2,
      "input image smaller than kernel size"
    );

    /* sizes */
    batchSize   = 1;
    inputSlices = THCTensor_(size)(state, input, 0);
    inputTime   = THCTensor_(size)(state, input, 1);
    inputHeight = THCTensor_(size)(state, input, 2);
    inputWidth  = THCTensor_(size)(state, input, 3);
  }
  else if (THCTensor_(nDimension)(state, input) == 5)
  {
    THArgCheck(
      THCTensor_(size)(state, input, 4) >= kW &&
      THCTensor_(size)(state, input, 3) >= kH &&
      THCTensor_(size)(state, input, 2) >= kT, 2,
      "input image smaller than kernel size"
    );

    /* sizes */
    batchSize   = THCTensor_(size)(state, input, 0);
    inputSlices = THCTensor_(size)(state, input, 1);
    inputTime   = THCTensor_(size)(state, input, 2);
    inputHeight = THCTensor_(size)(state, input, 3);
    inputWidth  = THCTensor_(size)(state, input, 4);
  }
  else
  {
    THArgCheck(false, 2, "4D or 5D tensor expected");
  }

  THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 2,
    "pad should be smaller than half of kernel size"
  );

  if (ceilMode)
  {
    outputTime   = (int)(ceil((float)(inputTime - (dilationT * (kT - 1) + 1) + 2*padT) / dT)) + 1;
    outputHeight = (int)(ceil((float)(inputHeight - (dilationH * (kH - 1) + 1) + 2*padH) / dH)) + 1;
    outputWidth  = (int)(ceil((float)(inputWidth  - (dilationW * (kW - 1) + 1) + 2*padW) / dW)) + 1;
  }
  else
  {
    outputTime   = (int)(floor((float)(inputTime - (dilationT * (kT - 1) + 1) + 2*padT) / dT)) + 1;
    outputHeight = (int)(floor((float)(inputHeight - (dilationH * (kH - 1) + 1) + 2*padH) / dH)) + 1;
    outputWidth  = (int)(floor((float)(inputWidth  - (dilationW * (kW - 1) + 1) + 2*padW) / dW)) + 1;
  }

  if (outputTime < 1 || outputHeight < 1 || outputWidth < 1)
    THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small",
            inputSlices,inputTime,inputHeight,inputWidth,inputSlices,outputTime,outputHeight,outputWidth);

  if (padT || padW || padH)
  {
    if ((outputTime - 1)*dT >= inputTime + padT)
      --outputTime;
    if ((outputHeight - 1)*dH >= inputHeight + padH)
      --outputHeight;
    if ((outputWidth  - 1)*dW >= inputWidth  + padW)
      --outputWidth;
  }

  if (input->nDimension == 4) /* 4D */
  {
    /* resize output */
    THCTensor_(resize4d)(state, output, inputSlices,
                          outputTime, outputHeight, outputWidth);
    /* indices pack ti,i,j locations for each output point as uchar into
     each float of the tensor */
    THCIndexTensor_(resize4d)(state, indices, inputSlices,
                          outputTime, outputHeight, outputWidth);
  }
  else
  { /* 5D */
    THCTensor_(resize5d)(state, output, batchSize, inputSlices,
                          outputTime, outputHeight, outputWidth);
    // Index tensor packs index offsets as uchars into floats
    THCIndexTensor_(resize5d)(state, indices, batchSize, inputSlices,
                          outputTime, outputHeight, outputWidth);
  }

  input = THCTensor_(newContiguous)(state, input);

  // Collapse batch and feature dimensions
  THCDeviceTensor<real, 4> cudaInput;
  THCDeviceTensor<real, 4> cudaOutput;
  if (THCTensor_(nDimension)(state, input) == 4)
  {
    cudaInput  = toDeviceTensor<real, 4>(state, input);
    cudaOutput = toDeviceTensor<real, 4>(state, output);
  }
  else
  {
    cudaInput  = toDeviceTensor<real, 5>(state, input).downcastOuter<4>();
    cudaOutput = toDeviceTensor<real, 5>(state, output).downcastOuter<4>();
  }

  THLongStorage *indicesSize = THLongStorage_newWithSize(4);
  long indicesSizeRaw[4] = { batchSize * inputSlices,
                            outputTime, outputHeight, outputWidth };
  THLongStorage_rawCopy(indicesSize, indicesSizeRaw);

  THCIndexTensor *indices1 = THCIndexTensor_(newWithStorage)(
    state, THCIndexTensor_(storage)(state, indices),
    THCIndexTensor_(storageOffset)(state, indices),
    indicesSize, NULL);

  THLongStorage_free(indicesSize);

  THCDeviceTensor<THCIndex_t, 4> cudaIndices =
    toDeviceTensor<THCIndex_t, 4>(state, indices1);

  int totalZ = outputTime * inputSlices * batchSize;
  int offsetZ = 0;
  dim3 block(32, 8);

  while (totalZ > 0) {
    dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
              THCCeilDiv(outputHeight, static_cast<int>(block.y)),
              totalZ > 65535 ? 65535 : totalZ);

    switch (kW)
      {
        UPDATE_OUTPUT_KERNEL_WIDTH(1);
        UPDATE_OUTPUT_KERNEL_WIDTH(2);
        UPDATE_OUTPUT_KERNEL_WIDTH(3);
        UPDATE_OUTPUT_KERNEL_WIDTH(4);
        UPDATE_OUTPUT_KERNEL_WIDTH(5);
        UPDATE_OUTPUT_KERNEL_WIDTH(6);
        UPDATE_OUTPUT_KERNEL_WIDTH(7);
      default:
        cuda_VolumetricDilatedMaxPooling_updateOutput<<<grid, block,
          0, THCState_getCurrentStream(state)>>>(
                             cudaInput, cudaIndices, cudaOutput,
                             kT, kH, kW, dT, dH, dW,
                             padT, padH, padW, dilationT, dilationH, dilationW, offsetZ);
      }
    THCudaCheck(cudaGetLastError());
    totalZ -= 65535;
    offsetZ += 65535;
  }

  THCTensor_(free)(state, input);
  THCIndexTensor_(free)(state, indices1);
}

#undef UPDATE_OUTPUT_KERNEL_WIDTH

void THNN_(VolumetricDilatedMaxPooling_updateGradInput)(
           THCState *state,
           THCTensor *input,
           THCTensor *gradOutput,
           THCTensor *gradInput,
           THCIndexTensor *indices,
           int dT, int dW, int dH,
           int padT, int padW, int padH,
           int dilationT, int dilationW, int dilationH)
{
  // Resize and initialize result tensor.
  THCTensor_(resizeAs)(state, gradInput, input);
  THCTensor_(zero)(state, gradInput);

  int batchSize;
  int inputSlices;

  int outputTime;
  int outputHeight;
  int outputWidth;

  THCUNN_assertSameGPU_generic(state, 4, input, indices, gradOutput, gradInput);

  if (THCTensor_(nDimension)(state, input) == 4) /* 4D */
  {
    batchSize = 1;
    inputSlices  = THCTensor_(size)(state, input, 0);

    outputTime   = THCTensor_(size)(state, gradOutput, 1);
    outputHeight = THCTensor_(size)(state, gradOutput, 2);
    outputWidth  = THCTensor_(size)(state, gradOutput, 3);
  }
  else
  {
    batchSize    = THCTensor_(size)(state, input, 0);
    inputSlices  = THCTensor_(size)(state, input, 1);

    outputTime   = THCTensor_(size)(state, gradOutput, 2);
    outputHeight = THCTensor_(size)(state, gradOutput, 3);
    outputWidth  = THCTensor_(size)(state, gradOutput, 4);
  }

  gradOutput = THCTensor_(newContiguous)(state, gradOutput);

  // Collapse batch and feature dimensions
  THCDeviceTensor<real, 4> cudaGradInput;
  THCDeviceTensor<real, 4> cudaGradOutput;
  if (THCTensor_(nDimension)(state, input) == 4)
  {
    cudaGradInput  = toDeviceTensor<real, 4>(state, gradInput);
    cudaGradOutput = toDeviceTensor<real, 4>(state, gradOutput);
  }
  else
  {
    cudaGradInput =
      toDeviceTensor<real, 5>(state, gradInput).downcastOuter<4>();
    cudaGradOutput =
      toDeviceTensor<real, 5>(state, gradOutput).downcastOuter<4>();
  }

  THLongStorage *indicesSize = THLongStorage_newWithSize(4);
  long indicesSizeRaw[4] = { batchSize * inputSlices,
                           outputTime, outputHeight, outputWidth };
  THLongStorage_rawCopy(indicesSize, indicesSizeRaw);
  THCIndexTensor *indices1 = THCIndexTensor_(newWithStorage)(
    state, THCIndexTensor_(storage)(state, indices),
    THCIndexTensor_(storageOffset)(state, indices), indicesSize, NULL);
  THLongStorage_free(indicesSize);

  THCDeviceTensor<THCIndex_t, 4> cudaIndices =
    toDeviceTensor<THCIndex_t, 4>(state, indices1);

  int totalZ = outputTime * inputSlices * batchSize;
  int offsetZ = 0;
  dim3 block(32, 8);

  while (totalZ > 0) {
    dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
              THCCeilDiv(outputHeight, static_cast<int>(block.y)),
              totalZ > 65535 ? 65535 : totalZ);

    cuda_VolumetricDilatedMaxPooling_updateGradInput<<<grid, block,
      0, THCState_getCurrentStream(state)>>>(
                                             cudaGradOutput,
                                             cudaIndices,
                                             cudaGradInput,
                                             dT, dH, dW,
                                             padT, padH, padW,
                                             dilationT, dilationH, dilationW, offsetZ);
    THCudaCheck(cudaGetLastError());
    totalZ -= 65535;
    offsetZ += 65535;
  }

  // cleanup
  THCTensor_(free)(state, gradOutput);
  THCIndexTensor_(free)(state, indices1);
}

#endif