Welcome to mirror list, hosted at ThFree Co, Russian Federation.

VolumetricUpSamplingTrilinear.cu « THCUNN « lib - github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0d861afded49a6d34401587eecb75b776f4a6dbd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// Adapted from interp.cpp from Caffe util by Pauline Luc
// Originally developed by George Papandreou
#include "THCUNN.h"
#include "common.h"
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"
#include "THCDeviceUtils.cuh"
#include "THCHalf.h"
#include "THCHalfAutoNumerics.cuh"
#include "THCAtomics.cuh"

template<typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
    const Acctype rdepth, const Acctype rheight, const Acctype rwidth,
    const THCDeviceTensor<Dtype, 5> data1, THCDeviceTensor<Dtype, 5> data2) {
  int index = threadIdx.x + blockIdx.x * blockDim.x;
  const int batchsize = data1.getSize(0);
  const int channels = data1.getSize(1);
  const int depth1 = data1.getSize(2);
  const int height1 = data1.getSize(3);
  const int width1 = data1.getSize(4);
  const int depth2 = data2.getSize(2);
  const int height2 = data2.getSize(3);
  const int width2 = data2.getSize(4);

  if (index < n) {
    const int w2 = (index % (height2*width2)) % width2; // 0:width2-1
    const int h2 = (index % (height2*width2)) / width2; // 0:height2-1
    const int t2 = index / (height2*width2);            // 0:depth2-1
    // special case: just copy
    if (depth1 == depth2 && height1 == height2 && width1 == width2) {
      const int t1 = t2;
      const int h1 = h2;
      const int w1 = w2;
      for (int n = 0; n < batchsize ; n++){
        for (int c = 0; c < channels; ++c) {
          const Dtype val = data1[n][c][t1][h1][w1];
          data2[n][c][t2][h2][w2] = val;
        }
      }
      return;
    }
    //
    const Acctype t1r = rdepth * t2;
    const int t1 = t1r;
    const int t1p = (t1 < depth1 - 1) ? 1 : 0;
    const Acctype t1lambda = t1r - t1;
    const Acctype t0lambda = Acctype(1) - t1lambda;
    //
    const Acctype h1r = rheight * h2;
    const int h1 = h1r;
    const int h1p = (h1 < height1 - 1) ? 1 : 0;
    const Acctype h1lambda = h1r - h1;
    const Acctype h0lambda = Acctype(1) - h1lambda;
    //
    const Acctype w1r = rwidth * w2;
    const int w1 = w1r;
    const int w1p = (w1 < width1 - 1) ? 1 : 0;
    const Acctype w1lambda = w1r - w1;
    const Acctype w0lambda = Acctype(1) - w1lambda;
    //
    for (int n = 0; n < batchsize ; n++){
        for (int c = 0; c < channels; ++c) {
        const Acctype val = t0lambda * (h0lambda * (w0lambda * data1[n][c][t1][h1][w1] 
                                                  + w1lambda * data1[n][c][t1][h1][w1+w1p])
                                      + h1lambda * (w0lambda * data1[n][c][t1][h1+h1p][w1]
                                                  + w1lambda * data1[n][c][t1][h1+h1p][w1+w1p]))
                          + t1lambda * (h0lambda * (w0lambda * data1[n][c][t1+t1p][h1][w1] 
                                                  + w1lambda * data1[n][c][t1+t1p][h1][w1+w1p])
                                      + h1lambda * (w0lambda * data1[n][c][t1+t1p][h1+h1p][w1]
                                                  + w1lambda * data1[n][c][t1+t1p][h1+h1p][w1+w1p]));
        data2[n][c][t2][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val);
      }
    }
  }
}

// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel_backward(const int n,
    const Acctype rdepth, const Acctype rheight, const Acctype rwidth,
    THCDeviceTensor<Dtype, 5> data1, const THCDeviceTensor<Dtype, 5> data2){
  int index = threadIdx.x + blockIdx.x * blockDim.x;
  const int batchsize = data1.getSize(0);
  const int channels = data1.getSize(1);
  const int depth1 = data1.getSize(2);
  const int height1 = data1.getSize(3);
  const int width1 = data1.getSize(4);
  const int depth2 = data2.getSize(2);
  const int height2 = data2.getSize(3);
  const int width2 = data2.getSize(4);
  if (index < n) {
    const int w2 = (index % (height2*width2)) % width2; // 0:width2-1
    const int h2 = (index % (height2*width2)) / width2; // 0:height2-1
    const int t2 = index / (height2*width2);            // 0:depth2-1
    // special case: just copy
    if (depth1 == depth2 && height1 == height2 && width1 == width2) {
      const int t1 = t2;
      const int h1 = h2;
      const int w1 = w2;
      for (int n = 0; n < batchsize ; n++){
        for (int c = 0; c < channels; ++c) {
          const Dtype val = data2[n][c][t1][h1][w1];
          data1[n][c][t2][h2][w2] += val;
        }
      }
      return;
    }
    //
    const Acctype t1r = rdepth * t2;
    const int t1 = t1r;
    const int t1p = (t1 < depth1 - 1) ? 1 : 0;
    const Acctype t1lambda = t1r - t1;
    const Acctype t0lambda = Acctype(1) - t1lambda;
    //
    const Acctype h1r = rheight * h2;
    const int h1 = h1r;
    const int h1p = (h1 < height1 - 1) ? 1 : 0;
    const Acctype h1lambda = h1r - h1;
    const Acctype h0lambda = Acctype(1) - h1lambda;
    //
    const Acctype w1r = rwidth * w2;
    const int w1 = w1r;
    const int w1p = (w1 < width1 - 1) ? 1 : 0;
    const Acctype w1lambda = w1r - w1;
    const Acctype w0lambda = Acctype(1) - w1lambda;
    //
    for (int n = 0; n < batchsize ; n++){
      for (int c = 0; c < channels; ++c) {
        const Dtype d2val = data2[n][c][t2][h2][w2];
        atomicAdd(data1[n][c][t1][h1][w1].data(),
                  ScalarConvert<Acctype, Dtype>::to(t0lambda * h0lambda * w0lambda * d2val));
        atomicAdd(data1[n][c][t1][h1][w1+w1p].data(),
                  ScalarConvert<Acctype, Dtype>::to(t0lambda * h0lambda * w1lambda * d2val));
        atomicAdd(data1[n][c][t1][h1+h1p][w1].data(),
                  ScalarConvert<Acctype, Dtype>::to(t0lambda * h1lambda * w0lambda * d2val));
        atomicAdd(data1[n][c][t1][h1+h1p][w1+w1p].data(),
                  ScalarConvert<Acctype, Dtype>::to(t0lambda * h1lambda * w1lambda * d2val));
        atomicAdd(data1[n][c][t1+t1p][h1][w1].data(),
                  ScalarConvert<Acctype, Dtype>::to(t1lambda * h0lambda * w0lambda * d2val));
        atomicAdd(data1[n][c][t1+t1p][h1][w1+w1p].data(),
                  ScalarConvert<Acctype, Dtype>::to(t1lambda * h0lambda * w1lambda * d2val));
        atomicAdd(data1[n][c][t1+t1p][h1+h1p][w1].data(),
                  ScalarConvert<Acctype, Dtype>::to(t1lambda * h1lambda * w0lambda * d2val));
        atomicAdd(data1[n][c][t1+t1p][h1+h1p][w1+w1p].data(),
                  ScalarConvert<Acctype, Dtype>::to(t1lambda * h1lambda * w1lambda * d2val));
      }
    }
  }
  /////////////////////////////////////////////////////////
}


#include "generic/VolumetricUpSamplingTrilinear.cu"
#include "THCGenerateFloatTypes.h"