Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuca Antiga <luca.antiga@orobix.com>2017-05-29 20:02:05 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-07 18:24:41 +0300
commit3d484ecc002a1876e577ba90d326d1b417f54c8d (patch)
treeb8e3ddfc4fb99077669f051b0b50ea64e238f573
parenta9c4d64850f2aecb1e29e0b42a3e801c13d192cd (diff)
Add 3D upsampling (nearest and trilinear) with tests
-rw-r--r--lib/THCUNN/VolumetricUpSamplingNearest.cu95
-rw-r--r--lib/THCUNN/VolumetricUpSamplingTrilinear.cu155
-rw-r--r--lib/THCUNN/generic/THCUNN.h34
-rw-r--r--lib/THCUNN/generic/VolumetricUpSamplingNearest.cu185
-rw-r--r--lib/THCUNN/generic/VolumetricUpSamplingTrilinear.cu118
5 files changed, 587 insertions, 0 deletions
diff --git a/lib/THCUNN/VolumetricUpSamplingNearest.cu b/lib/THCUNN/VolumetricUpSamplingNearest.cu
new file mode 100644
index 0000000..3aacf56
--- /dev/null
+++ b/lib/THCUNN/VolumetricUpSamplingNearest.cu
@@ -0,0 +1,95 @@
+#include "THCUNN.h"
+#include "common.h"
+
+#include <thrust/transform.h>
+#include <thrust/reduce.h>
+#include <thrust/transform_reduce.h>
+#include <thrust/functional.h>
+
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
+
+/*
+ * Description:
+ */
+
+__device__ int translate_idx(int ii, int d1, int d2, int d3, int d4, int scale_factor)
+{
+ int x, y, z, w, v;
+ v = ii % d4;
+ ii = ii/d4;
+ w = ii % d3;
+ ii = ii/d3;
+ z = ii % d2;
+ ii = ii/d2;
+ y = ii % d1;
+ ii = ii/d1;
+ x = ii;
+ v = v/scale_factor;
+ w = w/scale_factor;
+ z = z/scale_factor;
+ d2 /= scale_factor;
+ d3 /= scale_factor;
+ d4 /= scale_factor;
+ return ((((x*d1+y)*d2)+z)*d3+w)*d4+v;
+
+}
+__device__ int translate_idx_inv(int ii, int d1, int d2, int d3, int d4, int scale_factor, int off_x, int off_y, int off_z)
+{
+ int x, y, z, w, v;
+ v = ii % d4;
+ ii = ii/d4;
+ w = ii % d3;
+ ii = ii/d3;
+ z = ii % d2;
+ ii = ii/d2;
+ y = ii % d1;
+ ii = ii/d1;
+ x = ii;
+ v = v*scale_factor+off_x;
+ w = w*scale_factor+off_y;
+ z = z*scale_factor+off_z;
+ d2 *= scale_factor;
+ d3 *= scale_factor;
+ d4 *= scale_factor;
+ return ((((x*d1+y)*d2)+z)*d3+w)*d4+v;
+
+}
+
+template <typename Dtype>
+__global__ void vupscale(Dtype *input, Dtype *output, long no_elements,
+ int scale_factor, int d1, int d2, int d3, int d4)
+{
+ // output offset:
+ long ii = threadIdx.x + blockDim.x * blockIdx.x;
+ ii += threadIdx.y + blockDim.y * (blockDim.x * gridDim.x) * blockIdx.y;
+ if (ii >= no_elements) return;
+ int ipidx = translate_idx(ii, d1, d2, d3, d4, scale_factor);
+ output[ii]=input[ipidx];
+}
+
+/*
+ * Description:
+ */
+template <typename Dtype, typename Acctype>
+__global__ void vdownscale(Dtype *gradInput_data, Dtype *gradOutput_data, long no_elements,
+ int scale_factor, int d1, int d2, int d3, int d4)
+{
+ // output offset:
+ long ii = threadIdx.x + blockDim.x * blockIdx.x;
+ ii += threadIdx.y + blockDim.y * (blockDim.x * gridDim.x) * blockIdx.y;
+ if (ii >= no_elements) return;
+ Acctype sum = Acctype(0);
+ for (int i=0; i < scale_factor; i++){
+ for(int j=0; j < scale_factor; j++){
+ for(int k=0; k < scale_factor; k++){
+ int ipidx = translate_idx_inv(ii, d1, d2, d3, d4, scale_factor, i, j, k);
+ sum += gradOutput_data[ipidx];
+ }
+ }
+ }
+ gradInput_data[ii] += ScalarConvert<Acctype, Dtype>::to(sum);
+}
+
+#include "generic/VolumetricUpSamplingNearest.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/lib/THCUNN/VolumetricUpSamplingTrilinear.cu b/lib/THCUNN/VolumetricUpSamplingTrilinear.cu
new file mode 100644
index 0000000..0d861af
--- /dev/null
+++ b/lib/THCUNN/VolumetricUpSamplingTrilinear.cu
@@ -0,0 +1,155 @@
+// Adapted from interp.cpp from Caffe util by Pauline Luc
+// Originally developed by George Papandreou
+#include "THCUNN.h"
+#include "common.h"
+#include "THCDeviceTensor.cuh"
+#include "THCDeviceTensorUtils.cuh"
+#include "THCDeviceUtils.cuh"
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
+#include "THCAtomics.cuh"
+
+template<typename Dtype, typename Acctype>
+__global__ void caffe_gpu_interp2_kernel(const int n,
+ const Acctype rdepth, const Acctype rheight, const Acctype rwidth,
+ const THCDeviceTensor<Dtype, 5> data1, THCDeviceTensor<Dtype, 5> data2) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ const int batchsize = data1.getSize(0);
+ const int channels = data1.getSize(1);
+ const int depth1 = data1.getSize(2);
+ const int height1 = data1.getSize(3);
+ const int width1 = data1.getSize(4);
+ const int depth2 = data2.getSize(2);
+ const int height2 = data2.getSize(3);
+ const int width2 = data2.getSize(4);
+
+ if (index < n) {
+ const int w2 = (index % (height2*width2)) % width2; // 0:width2-1
+ const int h2 = (index % (height2*width2)) / width2; // 0:height2-1
+ const int t2 = index / (height2*width2); // 0:depth2-1
+ // special case: just copy
+ if (depth1 == depth2 && height1 == height2 && width1 == width2) {
+ const int t1 = t2;
+ const int h1 = h2;
+ const int w1 = w2;
+ for (int n = 0; n < batchsize ; n++){
+ for (int c = 0; c < channels; ++c) {
+ const Dtype val = data1[n][c][t1][h1][w1];
+ data2[n][c][t2][h2][w2] = val;
+ }
+ }
+ return;
+ }
+ //
+ const Acctype t1r = rdepth * t2;
+ const int t1 = t1r;
+ const int t1p = (t1 < depth1 - 1) ? 1 : 0;
+ const Acctype t1lambda = t1r - t1;
+ const Acctype t0lambda = Acctype(1) - t1lambda;
+ //
+ const Acctype h1r = rheight * h2;
+ const int h1 = h1r;
+ const int h1p = (h1 < height1 - 1) ? 1 : 0;
+ const Acctype h1lambda = h1r - h1;
+ const Acctype h0lambda = Acctype(1) - h1lambda;
+ //
+ const Acctype w1r = rwidth * w2;
+ const int w1 = w1r;
+ const int w1p = (w1 < width1 - 1) ? 1 : 0;
+ const Acctype w1lambda = w1r - w1;
+ const Acctype w0lambda = Acctype(1) - w1lambda;
+ //
+ for (int n = 0; n < batchsize ; n++){
+ for (int c = 0; c < channels; ++c) {
+ const Acctype val = t0lambda * (h0lambda * (w0lambda * data1[n][c][t1][h1][w1]
+ + w1lambda * data1[n][c][t1][h1][w1+w1p])
+ + h1lambda * (w0lambda * data1[n][c][t1][h1+h1p][w1]
+ + w1lambda * data1[n][c][t1][h1+h1p][w1+w1p]))
+ + t1lambda * (h0lambda * (w0lambda * data1[n][c][t1+t1p][h1][w1]
+ + w1lambda * data1[n][c][t1+t1p][h1][w1+w1p])
+ + h1lambda * (w0lambda * data1[n][c][t1+t1p][h1+h1p][w1]
+ + w1lambda * data1[n][c][t1+t1p][h1+h1p][w1+w1p]));
+ data2[n][c][t2][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val);
+ }
+ }
+ }
+}
+
+// Backward (adjoint) operation 1 <- 2 (accumulates)
+template <typename Dtype, typename Acctype>
+__global__ void caffe_gpu_interp2_kernel_backward(const int n,
+ const Acctype rdepth, const Acctype rheight, const Acctype rwidth,
+ THCDeviceTensor<Dtype, 5> data1, const THCDeviceTensor<Dtype, 5> data2){
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ const int batchsize = data1.getSize(0);
+ const int channels = data1.getSize(1);
+ const int depth1 = data1.getSize(2);
+ const int height1 = data1.getSize(3);
+ const int width1 = data1.getSize(4);
+ const int depth2 = data2.getSize(2);
+ const int height2 = data2.getSize(3);
+ const int width2 = data2.getSize(4);
+ if (index < n) {
+ const int w2 = (index % (height2*width2)) % width2; // 0:width2-1
+ const int h2 = (index % (height2*width2)) / width2; // 0:height2-1
+ const int t2 = index / (height2*width2); // 0:depth2-1
+ // special case: just copy
+ if (depth1 == depth2 && height1 == height2 && width1 == width2) {
+ const int t1 = t2;
+ const int h1 = h2;
+ const int w1 = w2;
+ for (int n = 0; n < batchsize ; n++){
+ for (int c = 0; c < channels; ++c) {
+ const Dtype val = data2[n][c][t1][h1][w1];
+ data1[n][c][t2][h2][w2] += val;
+ }
+ }
+ return;
+ }
+ //
+ const Acctype t1r = rdepth * t2;
+ const int t1 = t1r;
+ const int t1p = (t1 < depth1 - 1) ? 1 : 0;
+ const Acctype t1lambda = t1r - t1;
+ const Acctype t0lambda = Acctype(1) - t1lambda;
+ //
+ const Acctype h1r = rheight * h2;
+ const int h1 = h1r;
+ const int h1p = (h1 < height1 - 1) ? 1 : 0;
+ const Acctype h1lambda = h1r - h1;
+ const Acctype h0lambda = Acctype(1) - h1lambda;
+ //
+ const Acctype w1r = rwidth * w2;
+ const int w1 = w1r;
+ const int w1p = (w1 < width1 - 1) ? 1 : 0;
+ const Acctype w1lambda = w1r - w1;
+ const Acctype w0lambda = Acctype(1) - w1lambda;
+ //
+ for (int n = 0; n < batchsize ; n++){
+ for (int c = 0; c < channels; ++c) {
+ const Dtype d2val = data2[n][c][t2][h2][w2];
+ atomicAdd(data1[n][c][t1][h1][w1].data(),
+ ScalarConvert<Acctype, Dtype>::to(t0lambda * h0lambda * w0lambda * d2val));
+ atomicAdd(data1[n][c][t1][h1][w1+w1p].data(),
+ ScalarConvert<Acctype, Dtype>::to(t0lambda * h0lambda * w1lambda * d2val));
+ atomicAdd(data1[n][c][t1][h1+h1p][w1].data(),
+ ScalarConvert<Acctype, Dtype>::to(t0lambda * h1lambda * w0lambda * d2val));
+ atomicAdd(data1[n][c][t1][h1+h1p][w1+w1p].data(),
+ ScalarConvert<Acctype, Dtype>::to(t0lambda * h1lambda * w1lambda * d2val));
+ atomicAdd(data1[n][c][t1+t1p][h1][w1].data(),
+ ScalarConvert<Acctype, Dtype>::to(t1lambda * h0lambda * w0lambda * d2val));
+ atomicAdd(data1[n][c][t1+t1p][h1][w1+w1p].data(),
+ ScalarConvert<Acctype, Dtype>::to(t1lambda * h0lambda * w1lambda * d2val));
+ atomicAdd(data1[n][c][t1+t1p][h1+h1p][w1].data(),
+ ScalarConvert<Acctype, Dtype>::to(t1lambda * h1lambda * w0lambda * d2val));
+ atomicAdd(data1[n][c][t1+t1p][h1+h1p][w1+w1p].data(),
+ ScalarConvert<Acctype, Dtype>::to(t1lambda * h1lambda * w1lambda * d2val));
+ }
+ }
+ }
+ /////////////////////////////////////////////////////////
+}
+
+
+#include "generic/VolumetricUpSamplingTrilinear.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index 4b7b64a..f51759f 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -1370,4 +1370,38 @@ TH_API void THNN_(VolumetricReplicationPadding_updateGradInput)(
int ptop, int pbottom,
int pfront, int pback);
+TH_API void THNN_(VolumetricUpSamplingNearest_updateGradInput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ int scale_factor);
+
+TH_API void THNN_(VolumetricUpSamplingNearest_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ int scale_factor);
+
+TH_API void THNN_(VolumetricUpSamplingTrilinear_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ int outputDepth,
+ int outputHeight,
+ int outputWidth);
+
+TH_API void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)(
+ THCState *state,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ int nbatch,
+ int nchannels,
+ int inputDepth,
+ int inputHeight,
+ int inputWidth,
+ int outputDepth,
+ int outputHeight,
+ int outputWidth);
+
#endif
diff --git a/lib/THCUNN/generic/VolumetricUpSamplingNearest.cu b/lib/THCUNN/generic/VolumetricUpSamplingNearest.cu
new file mode 100644
index 0000000..db21bbb
--- /dev/null
+++ b/lib/THCUNN/generic/VolumetricUpSamplingNearest.cu
@@ -0,0 +1,185 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/VolumetricUpSamplingNearest.cu"
+#else
+
+#include "../common.h"
+
+static inline void THNN_(VolumetricUpSamplingNearest_shapeCheck)
+ (THCState *state,THCTensor *input, THCTensor *gradOutput,
+ int scale_factor) {
+ THArgCheck(input != NULL, 2, "4D input tensor expected but got NULL");
+ THArgCheck(scale_factor > 1, 4,
+ "scale_factor must be greater than 1, but got: %d", scale_factor);
+ THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input,
+ "4D or 5D input tensor expected but got: %s");
+ if (input->nDimension == 4) {
+ int nChannels = THCTensor_(size)(state, input, 0);
+ int inputDepth = THCTensor_(size)(state, input, 1);
+ int inputHeight = THCTensor_(size)(state, input, 2);
+ int inputWidth = THCTensor_(size)(state, input, 3);
+ int outputDepth = inputDepth * scale_factor;
+ int outputHeight = inputHeight * scale_factor;
+ int outputWidth = inputWidth * scale_factor;
+ if (gradOutput != NULL) {
+ THCUNN_check_dim_size(state, gradOutput, 4, 0, nChannels);
+ THCUNN_check_dim_size(state, gradOutput, 4, 1, outputDepth);
+ THCUNN_check_dim_size(state, gradOutput, 4, 2, outputHeight);
+ THCUNN_check_dim_size(state, gradOutput, 4, 3, outputWidth);
+ }
+ } else {
+ int nBatch = THCTensor_(size)(state, input, 0);
+ int nChannels = THCTensor_(size)(state, input, 1);
+ int inputDepth = THCTensor_(size)(state, input, 2);
+ int inputHeight = THCTensor_(size)(state, input, 3);
+ int inputWidth = THCTensor_(size)(state, input, 4);
+ int outputDepth = inputDepth * scale_factor;
+ int outputHeight = inputHeight * scale_factor;
+ int outputWidth = inputWidth * scale_factor;
+ if (gradOutput != NULL) {
+ THCUNN_check_dim_size(state, gradOutput, 5, 0, nBatch);
+ THCUNN_check_dim_size(state, gradOutput, 5, 1, nChannels);
+ THCUNN_check_dim_size(state, gradOutput, 5, 2, outputDepth);
+ THCUNN_check_dim_size(state, gradOutput, 5, 3, outputHeight);
+ THCUNN_check_dim_size(state, gradOutput, 5, 4, outputWidth);
+ }
+ }
+}
+
+void THNN_(VolumetricUpSamplingNearest_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ int scale_factor)
+{
+ THCTensor_(zero)(state, output);
+
+ THCUNN_assertSameGPU(state, 2, input, output);
+ THNN_(VolumetricUpSamplingNearest_shapeCheck)(state, input, NULL, scale_factor);
+ int inputDepth = THCTensor_(size)(state, input, input->nDimension-3);
+ int inputHeight = THCTensor_(size)(state, input, input->nDimension-2);
+ int inputWidth = THCTensor_(size)(state, input, input->nDimension-1);
+ int outputDepth = inputDepth * scale_factor;
+ int outputHeight = inputHeight * scale_factor;
+ int outputWidth = inputWidth * scale_factor;
+
+ if (input->nDimension == 4) {
+ THCTensor_(resize4d)(state, output,
+ THCTensor_(size)(state, input, 0),
+ outputDepth, outputHeight, outputWidth);
+ } else {
+ THCTensor_(resize5d)(state, output,
+ THCTensor_(size)(state, input, 0),
+ THCTensor_(size)(state, input, 1),
+ outputDepth, outputHeight, outputWidth);
+ }
+
+ input = THCTensor_(newContiguous)(state, input);
+ // This is for allocating output Tensor
+ long no_elements = 1;
+ for(int i = 0; i < input->nDimension; i++){
+ no_elements *= input->size[i];
+ }
+ no_elements *= scale_factor * scale_factor * scale_factor;
+
+ int d1;
+ int d2;
+ int d3;
+ int d4;
+
+ if (input->nDimension == 4) {
+ d1 = output->size[0];
+ d2 = output->size[1];
+ d3 = output->size[2];
+ d4 = output->size[3];
+ } else {
+ d1 = output->size[1];
+ d2 = output->size[2];
+ d3 = output->size[3];
+ d4 = output->size[4];
+ }
+
+ real *input_data = THCTensor_(data)(state, input);
+ real *output_data = THCTensor_(data)(state, output);
+
+ // cuda blocks & threads:
+ long nthreads = 256;
+ // Max number of blocks: http://en.wikipedia.org/wiki/CUDA
+ // 65535 for SM 2.x, 2^32 -1 for >= 3.0
+ // TODO: When we move to SM 3.5 we should update this
+ long n_xblocks = min(max((int)ceil((float)no_elements / nthreads), 1), 65535);
+ long n_yblocks = (long)ceil((float)no_elements / (float)(n_xblocks * nthreads));
+ if (n_yblocks > 65535) {
+ THError("Input size is too large! aborting");
+ }
+ dim3 blocks(n_xblocks, n_yblocks);
+ dim3 threads(nthreads);
+
+ // kernel:
+ vupscale<<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data, no_elements, scale_factor, d1, d2, d3, d4);
+ THCudaCheck(cudaGetLastError());
+
+ // final cut:
+ THCTensor_(free)(state, input);
+}
+
+void THNN_(VolumetricUpSamplingNearest_updateGradInput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ int scale_factor)
+{
+
+ THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
+ THNN_(VolumetricUpSamplingNearest_shapeCheck)(state, input, gradOutput, scale_factor);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ THCTensor_(resizeAs)(state, gradInput, input);
+
+ THCTensor_(zero)(state, gradInput);
+
+ real *gradInput_data = THCTensor_(data)(state, gradInput);
+ real *gradOutput_data = THCTensor_(data)(state, gradOutput);
+
+ long no_elements = 1;
+ for(int i = 0; i < gradInput->nDimension; i++){
+ no_elements *= gradInput->size[i];
+ }
+
+ int d1;
+ int d2;
+ int d3;
+ int d4;
+
+ if (gradInput->nDimension == 4) {
+ d1 = gradInput->size[0];
+ d2 = gradInput->size[1];
+ d3 = gradInput->size[2];
+ d4 = gradInput->size[3];
+ } else {
+ d1 = gradInput->size[1];
+ d2 = gradInput->size[2];
+ d3 = gradInput->size[3];
+ d4 = gradInput->size[4];
+ }
+
+ // cuda blocks & threads:
+ long nthreads = 256;
+ // Max number of blocks: http://en.wikipedia.org/wiki/CUDA
+ // 65535 for SM 2.x, 2^32 -1 for >= 3.0
+ // TODO: When we move to SM 3.5 we should update this
+ long n_xblocks = min(max((int)ceil((float)no_elements / nthreads), 1), 65535);
+ long n_yblocks = (long)ceil((float)no_elements / (float)(n_xblocks * nthreads));
+ if (n_yblocks > 65535) {
+ THError("Input size is too large! aborting");
+ }
+ dim3 blocks(n_xblocks, n_yblocks);
+ dim3 threads(nthreads);
+
+ // kernel:
+ vdownscale<real ,accreal> <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data, no_elements,
+ scale_factor, d1, d2, d3, d4);
+ THCudaCheck(cudaGetLastError());
+ THCTensor_(free)(state, gradOutput);
+}
+
+#endif
diff --git a/lib/THCUNN/generic/VolumetricUpSamplingTrilinear.cu b/lib/THCUNN/generic/VolumetricUpSamplingTrilinear.cu
new file mode 100644
index 0000000..58be310
--- /dev/null
+++ b/lib/THCUNN/generic/VolumetricUpSamplingTrilinear.cu
@@ -0,0 +1,118 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/VolumetricUpSamplingTrilinear.cu"
+#else
+
+static inline void THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
+ (THCState *state,
+ THCTensor *input, THCTensor *gradOutput,
+ int nBatch, int nChannels,
+ int inputDepth, int inputHeight, int inputWidth,
+ int outputDepth, int outputHeight, int outputWidth) {
+ THArgCheck(inputDepth > 0 && inputHeight > 0 && inputWidth > 0
+ && outputDepth && outputHeight > 0 && outputWidth > 0, 2,
+ "input and output sizes should be greater than 0,"
+ " but got input (D: %d, H: %d, W: %d) output (D: %d, H: %d, W: %d)",
+ inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth);
+ if (input != NULL) {
+ THCUNN_argCheck(state, input->nDimension == 5, 2, input,
+ "5D input tensor expected but got: %s");
+ }
+
+ if (gradOutput != NULL) {
+ THCUNN_check_dim_size(state, gradOutput, 5, 0, nBatch);
+ THCUNN_check_dim_size(state, gradOutput, 5, 1, nChannels);
+ THCUNN_check_dim_size(state, gradOutput, 5, 2, outputDepth);
+ THCUNN_check_dim_size(state, gradOutput, 5, 3, outputHeight);
+ THCUNN_check_dim_size(state, gradOutput, 5, 4, outputWidth);
+ }
+}
+
+void THNN_(VolumetricUpSamplingTrilinear_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ int outputDepth,
+ int outputHeight,
+ int outputWidth)
+{
+ int nbatch = THCTensor_(size)(state, input, 0);
+ int channels = THCTensor_(size)(state, input, 1);
+ int inputDepth = THCTensor_(size)(state, input, 2);
+ int inputHeight = THCTensor_(size)(state, input, 3);
+ int inputWidth = THCTensor_(size)(state, input, 4);
+ THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
+ (state, input, NULL,
+ nbatch, channels,
+ inputDepth, inputHeight, inputWidth,
+ outputDepth, outputHeight, outputWidth);
+ input = THCTensor_(newContiguous)(state, input);
+ THCUNN_assertSameGPU(state, 2, input, output);
+ THCTensor_(resize5d)(state, output,
+ THCTensor_(size)(state, input, 0),
+ THCTensor_(size)(state, input, 1),
+ outputDepth, outputHeight, outputWidth);
+ THCTensor_(zero)(state, output);
+ THCDeviceTensor<real, 5> idata = toDeviceTensor<real, 5>(state, input);
+ THCDeviceTensor<real, 5> odata = toDeviceTensor<real, 5>(state, output);
+ THAssert(inputDepth > 0 && inputHeight > 0 && inputWidth > 0 && outputDepth > 0 && outputHeight > 0 && outputWidth > 0);
+ const accreal rdepth= (outputDepth > 1) ? (accreal)(inputDepth - 1)/(outputDepth - 1) : accreal(0);
+ const accreal rheight= (outputHeight > 1) ? (accreal)(inputHeight - 1)/(outputHeight - 1) : accreal(0);
+ const accreal rwidth = (outputWidth > 1) ? (accreal)(inputWidth - 1)/(outputWidth - 1) : accreal(0);
+ const int num_kernels = outputDepth * outputHeight * outputWidth;
+ const int num_threads =
+ THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock;
+ cudaStream_t stream = THCState_getCurrentStream(state);
+ caffe_gpu_interp2_kernel<real, accreal> <<<THCCeilDiv(num_kernels, num_threads), num_threads ,
+ 0 , stream>>>(num_kernels, rdepth, rheight, rwidth, idata, odata);
+ THCudaCheck(cudaGetLastError());
+ THCTensor_(free)(state, input);
+}
+
+
+void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)(
+ THCState *state,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ int nbatch,
+ int nchannels,
+ int inputDepth,
+ int inputHeight,
+ int inputWidth,
+ int outputDepth,
+ int outputHeight,
+ int outputWidth)
+{
+ THNN_(VolumetricUpSamplingTrilinear_shapeCheck)
+ (state, NULL, gradOutput,
+ nbatch, nchannels,
+ inputDepth, inputHeight, inputWidth,
+ outputDepth, outputHeight, outputWidth);
+ gradInput = THCTensor_(newContiguous)(state, gradInput);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
+ THCTensor_(resize5d)(state, gradInput, nbatch, nchannels, inputDepth, inputHeight, inputWidth);
+ THCTensor_(zero)(state, gradInput);
+ THCDeviceTensor<real, 5> data1 = toDeviceTensor<real, 5>(state, gradInput);
+ THCDeviceTensor<real, 5> data2 = toDeviceTensor<real, 5>(state, gradOutput);
+ int depth1 = data1.getSize(2);
+ int height1 = data1.getSize(3);
+ int width1 = data1.getSize(4);
+ int depth2 = data2.getSize(2);
+ int height2 = data2.getSize(3);
+ int width2 = data2.getSize(4);
+ assert(depth1 > 0 && height1 > 0 && width1 > 0 && depth2 > 0 && height2 > 0 && width2 > 0);
+ const accreal rdepth= (depth2 > 1) ? (accreal)(depth1 - 1)/(depth2 - 1) : accreal(0);
+ const accreal rheight= (height2 > 1) ? (accreal)(height1 - 1)/(height2 - 1) : accreal(0);
+ const accreal rwidth = (width2 > 1) ? (accreal)(width1 - 1) / (width2 - 1) : accreal(0);
+ const int num_kernels = depth2 * height2 * width2;
+ const int num_threads =
+ THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock;
+ cudaStream_t stream = THCState_getCurrentStream(state);
+ caffe_gpu_interp2_kernel_backward<real ,accreal> <<<THCCeilDiv(num_kernels, num_threads),
+ num_threads, 0, stream>>>(num_kernels, rdepth, rheight, rwidth, data1, data2);
+ THCudaCheck(cudaGetLastError());
+ THCTensor_(free)(state, gradInput);
+ THCTensor_(free)(state, gradOutput);
+}
+
+#endif