Welcome to mirror list, hosted at ThFree Co, Russian Federation.

VolumetricUpSamplingTrilinear.cu « generic « THCUNN « lib - github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 58be3109a39e1ee4193b652ebb21264e23bb1305 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/VolumetricUpSamplingTrilinear.cu"
#else

static inline void THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
                        (THCState *state,
                         THCTensor *input, THCTensor *gradOutput,
                         int nBatch, int nChannels,
                         int inputDepth, int inputHeight, int inputWidth,
                         int outputDepth, int outputHeight, int outputWidth) {
  THArgCheck(inputDepth > 0 && inputHeight > 0 && inputWidth > 0
             && outputDepth && outputHeight > 0 && outputWidth > 0, 2,
             "input and output sizes should be greater than 0,"
             " but got input (D: %d, H: %d, W: %d) output (D: %d, H: %d, W: %d)",
             inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth);
  if (input != NULL) {
     THCUNN_argCheck(state, input->nDimension == 5, 2, input,
                     "5D input tensor expected but got: %s");
  }

  if (gradOutput != NULL) {
    THCUNN_check_dim_size(state, gradOutput, 5, 0, nBatch);
    THCUNN_check_dim_size(state, gradOutput, 5, 1, nChannels);
    THCUNN_check_dim_size(state, gradOutput, 5, 2, outputDepth);
    THCUNN_check_dim_size(state, gradOutput, 5, 3, outputHeight);
    THCUNN_check_dim_size(state, gradOutput, 5, 4, outputWidth);
  }
}

void THNN_(VolumetricUpSamplingTrilinear_updateOutput)(
           THCState *state,
           THCTensor *input,
           THCTensor *output,
           int outputDepth,
           int outputHeight,
           int outputWidth)
{
  int nbatch = THCTensor_(size)(state, input, 0);
  int channels = THCTensor_(size)(state, input, 1);
  int inputDepth = THCTensor_(size)(state, input, 2);
  int inputHeight = THCTensor_(size)(state, input, 3);
  int inputWidth = THCTensor_(size)(state, input, 4);
  THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
       (state, input, NULL,
        nbatch, channels,
        inputDepth, inputHeight, inputWidth,
        outputDepth, outputHeight, outputWidth);
  input = THCTensor_(newContiguous)(state, input);
  THCUNN_assertSameGPU(state, 2, input, output);
  THCTensor_(resize5d)(state, output,
                       THCTensor_(size)(state, input, 0),
                       THCTensor_(size)(state, input, 1),
                       outputDepth, outputHeight, outputWidth);
  THCTensor_(zero)(state, output);
  THCDeviceTensor<real, 5> idata = toDeviceTensor<real, 5>(state, input);
  THCDeviceTensor<real, 5> odata = toDeviceTensor<real, 5>(state, output);
  THAssert(inputDepth > 0 && inputHeight > 0 && inputWidth > 0 && outputDepth > 0 && outputHeight > 0 && outputWidth > 0);
  const accreal rdepth= (outputDepth > 1) ? (accreal)(inputDepth - 1)/(outputDepth - 1) : accreal(0);
  const accreal rheight= (outputHeight > 1) ? (accreal)(inputHeight - 1)/(outputHeight - 1) : accreal(0);
  const accreal rwidth = (outputWidth > 1) ? (accreal)(inputWidth - 1)/(outputWidth - 1) : accreal(0);
  const int num_kernels = outputDepth * outputHeight * outputWidth;
  const int num_threads =
    THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock;
  cudaStream_t stream = THCState_getCurrentStream(state);
  caffe_gpu_interp2_kernel<real, accreal> <<<THCCeilDiv(num_kernels, num_threads), num_threads ,
   0 , stream>>>(num_kernels, rdepth, rheight, rwidth, idata, odata);
  THCudaCheck(cudaGetLastError());
  THCTensor_(free)(state, input);
}


void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)(
           THCState *state,
           THCTensor *gradOutput,
           THCTensor *gradInput,
           int nbatch,
           int nchannels,
           int inputDepth,
           int inputHeight,
           int inputWidth,
           int outputDepth,
           int outputHeight,
           int outputWidth)
{
  THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
       (state, NULL, gradOutput,
        nbatch, nchannels,
        inputDepth, inputHeight, inputWidth,
        outputDepth, outputHeight, outputWidth);
  gradInput = THCTensor_(newContiguous)(state, gradInput);
  gradOutput = THCTensor_(newContiguous)(state, gradOutput);
  THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
  THCTensor_(resize5d)(state, gradInput, nbatch, nchannels, inputDepth, inputHeight, inputWidth);
  THCTensor_(zero)(state, gradInput);
  THCDeviceTensor<real, 5> data1 = toDeviceTensor<real, 5>(state, gradInput);
  THCDeviceTensor<real, 5> data2 = toDeviceTensor<real, 5>(state, gradOutput);
  int depth1 = data1.getSize(2);
  int height1 = data1.getSize(3);
  int width1 = data1.getSize(4);
  int depth2 = data2.getSize(2);
  int height2 = data2.getSize(3);
  int width2 = data2.getSize(4);
  assert(depth1 > 0 && height1 > 0 && width1 > 0 && depth2 > 0 && height2 > 0 && width2 > 0);
  const accreal rdepth= (depth2 > 1) ? (accreal)(depth1 - 1)/(depth2 - 1) : accreal(0);
  const accreal rheight= (height2 > 1) ? (accreal)(height1 - 1)/(height2 - 1) : accreal(0);
  const accreal rwidth = (width2 > 1) ? (accreal)(width1 - 1) / (width2 - 1) : accreal(0);
  const int num_kernels = depth2 * height2 * width2;
  const int num_threads =
    THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock;
  cudaStream_t stream = THCState_getCurrentStream(state);
  caffe_gpu_interp2_kernel_backward<real ,accreal> <<<THCCeilDiv(num_kernels, num_threads),
  num_threads, 0, stream>>>(num_kernels, rdepth, rheight, rwidth, data1, data2);
  THCudaCheck(cudaGetLastError());
  THCTensor_(free)(state, gradInput);
  THCTensor_(free)(state, gradOutput);
}

#endif