diff options
author | Luca Antiga <luca.antiga@orobix.com> | 2017-05-29 20:02:05 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-07 18:24:41 +0300 |
commit | 3d484ecc002a1876e577ba90d326d1b417f54c8d (patch) | |
tree | b8e3ddfc4fb99077669f051b0b50ea64e238f573 | |
parent | a9c4d64850f2aecb1e29e0b42a3e801c13d192cd (diff) |
Add 3D upsampling (nearest and trilinear) with tests
-rw-r--r-- | lib/THCUNN/VolumetricUpSamplingNearest.cu | 95 | ||||
-rw-r--r-- | lib/THCUNN/VolumetricUpSamplingTrilinear.cu | 155 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 34 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricUpSamplingNearest.cu | 185 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricUpSamplingTrilinear.cu | 118 |
5 files changed, 587 insertions, 0 deletions
diff --git a/lib/THCUNN/VolumetricUpSamplingNearest.cu b/lib/THCUNN/VolumetricUpSamplingNearest.cu new file mode 100644 index 0000000..3aacf56 --- /dev/null +++ b/lib/THCUNN/VolumetricUpSamplingNearest.cu @@ -0,0 +1,95 @@ +#include "THCUNN.h" +#include "common.h" + +#include <thrust/transform.h> +#include <thrust/reduce.h> +#include <thrust/transform_reduce.h> +#include <thrust/functional.h> + +#include "THCHalf.h" +#include "THCHalfAutoNumerics.cuh" + +/* + * Description: + */ + +__device__ int translate_idx(int ii, int d1, int d2, int d3, int d4, int scale_factor) +{ + int x, y, z, w, v; + v = ii % d4; + ii = ii/d4; + w = ii % d3; + ii = ii/d3; + z = ii % d2; + ii = ii/d2; + y = ii % d1; + ii = ii/d1; + x = ii; + v = v/scale_factor; + w = w/scale_factor; + z = z/scale_factor; + d2 /= scale_factor; + d3 /= scale_factor; + d4 /= scale_factor; + return ((((x*d1+y)*d2)+z)*d3+w)*d4+v; + +} +__device__ int translate_idx_inv(int ii, int d1, int d2, int d3, int d4, int scale_factor, int off_x, int off_y, int off_z) +{ + int x, y, z, w, v; + v = ii % d4; + ii = ii/d4; + w = ii % d3; + ii = ii/d3; + z = ii % d2; + ii = ii/d2; + y = ii % d1; + ii = ii/d1; + x = ii; + v = v*scale_factor+off_x; + w = w*scale_factor+off_y; + z = z*scale_factor+off_z; + d2 *= scale_factor; + d3 *= scale_factor; + d4 *= scale_factor; + return ((((x*d1+y)*d2)+z)*d3+w)*d4+v; + +} + +template <typename Dtype> +__global__ void vupscale(Dtype *input, Dtype *output, long no_elements, + int scale_factor, int d1, int d2, int d3, int d4) +{ + // output offset: + long ii = threadIdx.x + blockDim.x * blockIdx.x; + ii += threadIdx.y + blockDim.y * (blockDim.x * gridDim.x) * blockIdx.y; + if (ii >= no_elements) return; + int ipidx = translate_idx(ii, d1, d2, d3, d4, scale_factor); + output[ii]=input[ipidx]; +} + +/* + * Description: + */ +template <typename Dtype, typename Acctype> +__global__ void vdownscale(Dtype *gradInput_data, Dtype *gradOutput_data, long no_elements, + int scale_factor, int d1, int d2, int d3, int d4) +{ + // output offset: + long ii = threadIdx.x + blockDim.x * blockIdx.x; + ii += threadIdx.y + blockDim.y * (blockDim.x * gridDim.x) * blockIdx.y; + if (ii >= no_elements) return; + Acctype sum = Acctype(0); + for (int i=0; i < scale_factor; i++){ + for(int j=0; j < scale_factor; j++){ + for(int k=0; k < scale_factor; k++){ + int ipidx = translate_idx_inv(ii, d1, d2, d3, d4, scale_factor, i, j, k); + sum += gradOutput_data[ipidx]; + } + } + } + gradInput_data[ii] += ScalarConvert<Acctype, Dtype>::to(sum); +} + +#include "generic/VolumetricUpSamplingNearest.cu" +#include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/VolumetricUpSamplingTrilinear.cu b/lib/THCUNN/VolumetricUpSamplingTrilinear.cu new file mode 100644 index 0000000..0d861af --- /dev/null +++ b/lib/THCUNN/VolumetricUpSamplingTrilinear.cu @@ -0,0 +1,155 @@ +// Adapted from interp.cpp from Caffe util by Pauline Luc +// Originally developed by George Papandreou +#include "THCUNN.h" +#include "common.h" +#include "THCDeviceTensor.cuh" +#include "THCDeviceTensorUtils.cuh" +#include "THCDeviceUtils.cuh" +#include "THCHalf.h" +#include "THCHalfAutoNumerics.cuh" +#include "THCAtomics.cuh" + +template<typename Dtype, typename Acctype> +__global__ void caffe_gpu_interp2_kernel(const int n, + const Acctype rdepth, const Acctype rheight, const Acctype rwidth, + const THCDeviceTensor<Dtype, 5> data1, THCDeviceTensor<Dtype, 5> data2) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + const int batchsize = data1.getSize(0); + const int channels = data1.getSize(1); + const int depth1 = data1.getSize(2); + const int height1 = data1.getSize(3); + const int width1 = data1.getSize(4); + const int depth2 = data2.getSize(2); + const int height2 = data2.getSize(3); + const int width2 = data2.getSize(4); + + if (index < n) { + const int w2 = (index % (height2*width2)) % width2; // 0:width2-1 + const int h2 = (index % (height2*width2)) / width2; // 0:height2-1 + const int t2 = index / (height2*width2); // 0:depth2-1 + // special case: just copy + if (depth1 == depth2 && height1 == height2 && width1 == width2) { + const int t1 = t2; + const int h1 = h2; + const int w1 = w2; + for (int n = 0; n < batchsize ; n++){ + for (int c = 0; c < channels; ++c) { + const Dtype val = data1[n][c][t1][h1][w1]; + data2[n][c][t2][h2][w2] = val; + } + } + return; + } + // + const Acctype t1r = rdepth * t2; + const int t1 = t1r; + const int t1p = (t1 < depth1 - 1) ? 1 : 0; + const Acctype t1lambda = t1r - t1; + const Acctype t0lambda = Acctype(1) - t1lambda; + // + const Acctype h1r = rheight * h2; + const int h1 = h1r; + const int h1p = (h1 < height1 - 1) ? 1 : 0; + const Acctype h1lambda = h1r - h1; + const Acctype h0lambda = Acctype(1) - h1lambda; + // + const Acctype w1r = rwidth * w2; + const int w1 = w1r; + const int w1p = (w1 < width1 - 1) ? 1 : 0; + const Acctype w1lambda = w1r - w1; + const Acctype w0lambda = Acctype(1) - w1lambda; + // + for (int n = 0; n < batchsize ; n++){ + for (int c = 0; c < channels; ++c) { + const Acctype val = t0lambda * (h0lambda * (w0lambda * data1[n][c][t1][h1][w1] + + w1lambda * data1[n][c][t1][h1][w1+w1p]) + + h1lambda * (w0lambda * data1[n][c][t1][h1+h1p][w1] + + w1lambda * data1[n][c][t1][h1+h1p][w1+w1p])) + + t1lambda * (h0lambda * (w0lambda * data1[n][c][t1+t1p][h1][w1] + + w1lambda * data1[n][c][t1+t1p][h1][w1+w1p]) + + h1lambda * (w0lambda * data1[n][c][t1+t1p][h1+h1p][w1] + + w1lambda * data1[n][c][t1+t1p][h1+h1p][w1+w1p])); + data2[n][c][t2][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val); + } + } + } +} + +// Backward (adjoint) operation 1 <- 2 (accumulates) +template <typename Dtype, typename Acctype> +__global__ void caffe_gpu_interp2_kernel_backward(const int n, + const Acctype rdepth, const Acctype rheight, const Acctype rwidth, + THCDeviceTensor<Dtype, 5> data1, const THCDeviceTensor<Dtype, 5> data2){ + int index = threadIdx.x + blockIdx.x * blockDim.x; + const int batchsize = data1.getSize(0); + const int channels = data1.getSize(1); + const int depth1 = data1.getSize(2); + const int height1 = data1.getSize(3); + const int width1 = data1.getSize(4); + const int depth2 = data2.getSize(2); + const int height2 = data2.getSize(3); + const int width2 = data2.getSize(4); + if (index < n) { + const int w2 = (index % (height2*width2)) % width2; // 0:width2-1 + const int h2 = (index % (height2*width2)) / width2; // 0:height2-1 + const int t2 = index / (height2*width2); // 0:depth2-1 + // special case: just copy + if (depth1 == depth2 && height1 == height2 && width1 == width2) { + const int t1 = t2; + const int h1 = h2; + const int w1 = w2; + for (int n = 0; n < batchsize ; n++){ + for (int c = 0; c < channels; ++c) { + const Dtype val = data2[n][c][t1][h1][w1]; + data1[n][c][t2][h2][w2] += val; + } + } + return; + } + // + const Acctype t1r = rdepth * t2; + const int t1 = t1r; + const int t1p = (t1 < depth1 - 1) ? 1 : 0; + const Acctype t1lambda = t1r - t1; + const Acctype t0lambda = Acctype(1) - t1lambda; + // + const Acctype h1r = rheight * h2; + const int h1 = h1r; + const int h1p = (h1 < height1 - 1) ? 1 : 0; + const Acctype h1lambda = h1r - h1; + const Acctype h0lambda = Acctype(1) - h1lambda; + // + const Acctype w1r = rwidth * w2; + const int w1 = w1r; + const int w1p = (w1 < width1 - 1) ? 1 : 0; + const Acctype w1lambda = w1r - w1; + const Acctype w0lambda = Acctype(1) - w1lambda; + // + for (int n = 0; n < batchsize ; n++){ + for (int c = 0; c < channels; ++c) { + const Dtype d2val = data2[n][c][t2][h2][w2]; + atomicAdd(data1[n][c][t1][h1][w1].data(), + ScalarConvert<Acctype, Dtype>::to(t0lambda * h0lambda * w0lambda * d2val)); + atomicAdd(data1[n][c][t1][h1][w1+w1p].data(), + ScalarConvert<Acctype, Dtype>::to(t0lambda * h0lambda * w1lambda * d2val)); + atomicAdd(data1[n][c][t1][h1+h1p][w1].data(), + ScalarConvert<Acctype, Dtype>::to(t0lambda * h1lambda * w0lambda * d2val)); + atomicAdd(data1[n][c][t1][h1+h1p][w1+w1p].data(), + ScalarConvert<Acctype, Dtype>::to(t0lambda * h1lambda * w1lambda * d2val)); + atomicAdd(data1[n][c][t1+t1p][h1][w1].data(), + ScalarConvert<Acctype, Dtype>::to(t1lambda * h0lambda * w0lambda * d2val)); + atomicAdd(data1[n][c][t1+t1p][h1][w1+w1p].data(), + ScalarConvert<Acctype, Dtype>::to(t1lambda * h0lambda * w1lambda * d2val)); + atomicAdd(data1[n][c][t1+t1p][h1+h1p][w1].data(), + ScalarConvert<Acctype, Dtype>::to(t1lambda * h1lambda * w0lambda * d2val)); + atomicAdd(data1[n][c][t1+t1p][h1+h1p][w1+w1p].data(), + ScalarConvert<Acctype, Dtype>::to(t1lambda * h1lambda * w1lambda * d2val)); + } + } + } + ///////////////////////////////////////////////////////// +} + + +#include "generic/VolumetricUpSamplingTrilinear.cu" +#include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index 4b7b64a..f51759f 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -1370,4 +1370,38 @@ TH_API void THNN_(VolumetricReplicationPadding_updateGradInput)( int ptop, int pbottom, int pfront, int pback); +TH_API void THNN_(VolumetricUpSamplingNearest_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + int scale_factor); + +TH_API void THNN_(VolumetricUpSamplingNearest_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + int scale_factor); + +TH_API void THNN_(VolumetricUpSamplingTrilinear_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + int outputDepth, + int outputHeight, + int outputWidth); + +TH_API void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)( + THCState *state, + THCTensor *gradOutput, + THCTensor *gradInput, + int nbatch, + int nchannels, + int inputDepth, + int inputHeight, + int inputWidth, + int outputDepth, + int outputHeight, + int outputWidth); + #endif diff --git a/lib/THCUNN/generic/VolumetricUpSamplingNearest.cu b/lib/THCUNN/generic/VolumetricUpSamplingNearest.cu new file mode 100644 index 0000000..db21bbb --- /dev/null +++ b/lib/THCUNN/generic/VolumetricUpSamplingNearest.cu @@ -0,0 +1,185 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/VolumetricUpSamplingNearest.cu" +#else + +#include "../common.h" + +static inline void THNN_(VolumetricUpSamplingNearest_shapeCheck) + (THCState *state,THCTensor *input, THCTensor *gradOutput, + int scale_factor) { + THArgCheck(input != NULL, 2, "4D input tensor expected but got NULL"); + THArgCheck(scale_factor > 1, 4, + "scale_factor must be greater than 1, but got: %d", scale_factor); + THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, + "4D or 5D input tensor expected but got: %s"); + if (input->nDimension == 4) { + int nChannels = THCTensor_(size)(state, input, 0); + int inputDepth = THCTensor_(size)(state, input, 1); + int inputHeight = THCTensor_(size)(state, input, 2); + int inputWidth = THCTensor_(size)(state, input, 3); + int outputDepth = inputDepth * scale_factor; + int outputHeight = inputHeight * scale_factor; + int outputWidth = inputWidth * scale_factor; + if (gradOutput != NULL) { + THCUNN_check_dim_size(state, gradOutput, 4, 0, nChannels); + THCUNN_check_dim_size(state, gradOutput, 4, 1, outputDepth); + THCUNN_check_dim_size(state, gradOutput, 4, 2, outputHeight); + THCUNN_check_dim_size(state, gradOutput, 4, 3, outputWidth); + } + } else { + int nBatch = THCTensor_(size)(state, input, 0); + int nChannels = THCTensor_(size)(state, input, 1); + int inputDepth = THCTensor_(size)(state, input, 2); + int inputHeight = THCTensor_(size)(state, input, 3); + int inputWidth = THCTensor_(size)(state, input, 4); + int outputDepth = inputDepth * scale_factor; + int outputHeight = inputHeight * scale_factor; + int outputWidth = inputWidth * scale_factor; + if (gradOutput != NULL) { + THCUNN_check_dim_size(state, gradOutput, 5, 0, nBatch); + THCUNN_check_dim_size(state, gradOutput, 5, 1, nChannels); + THCUNN_check_dim_size(state, gradOutput, 5, 2, outputDepth); + THCUNN_check_dim_size(state, gradOutput, 5, 3, outputHeight); + THCUNN_check_dim_size(state, gradOutput, 5, 4, outputWidth); + } + } +} + +void THNN_(VolumetricUpSamplingNearest_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + int scale_factor) +{ + THCTensor_(zero)(state, output); + + THCUNN_assertSameGPU(state, 2, input, output); + THNN_(VolumetricUpSamplingNearest_shapeCheck)(state, input, NULL, scale_factor); + int inputDepth = THCTensor_(size)(state, input, input->nDimension-3); + int inputHeight = THCTensor_(size)(state, input, input->nDimension-2); + int inputWidth = THCTensor_(size)(state, input, input->nDimension-1); + int outputDepth = inputDepth * scale_factor; + int outputHeight = inputHeight * scale_factor; + int outputWidth = inputWidth * scale_factor; + + if (input->nDimension == 4) { + THCTensor_(resize4d)(state, output, + THCTensor_(size)(state, input, 0), + outputDepth, outputHeight, outputWidth); + } else { + THCTensor_(resize5d)(state, output, + THCTensor_(size)(state, input, 0), + THCTensor_(size)(state, input, 1), + outputDepth, outputHeight, outputWidth); + } + + input = THCTensor_(newContiguous)(state, input); + // This is for allocating output Tensor + long no_elements = 1; + for(int i = 0; i < input->nDimension; i++){ + no_elements *= input->size[i]; + } + no_elements *= scale_factor * scale_factor * scale_factor; + + int d1; + int d2; + int d3; + int d4; + + if (input->nDimension == 4) { + d1 = output->size[0]; + d2 = output->size[1]; + d3 = output->size[2]; + d4 = output->size[3]; + } else { + d1 = output->size[1]; + d2 = output->size[2]; + d3 = output->size[3]; + d4 = output->size[4]; + } + + real *input_data = THCTensor_(data)(state, input); + real *output_data = THCTensor_(data)(state, output); + + // cuda blocks & threads: + long nthreads = 256; + // Max number of blocks: http://en.wikipedia.org/wiki/CUDA + // 65535 for SM 2.x, 2^32 -1 for >= 3.0 + // TODO: When we move to SM 3.5 we should update this + long n_xblocks = min(max((int)ceil((float)no_elements / nthreads), 1), 65535); + long n_yblocks = (long)ceil((float)no_elements / (float)(n_xblocks * nthreads)); + if (n_yblocks > 65535) { + THError("Input size is too large! aborting"); + } + dim3 blocks(n_xblocks, n_yblocks); + dim3 threads(nthreads); + + // kernel: + vupscale<<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data, no_elements, scale_factor, d1, d2, d3, d4); + THCudaCheck(cudaGetLastError()); + + // final cut: + THCTensor_(free)(state, input); +} + +void THNN_(VolumetricUpSamplingNearest_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + int scale_factor) +{ + + THCUNN_assertSameGPU(state, 2, gradOutput, gradInput); + THNN_(VolumetricUpSamplingNearest_shapeCheck)(state, input, gradOutput, scale_factor); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); + THCTensor_(resizeAs)(state, gradInput, input); + + THCTensor_(zero)(state, gradInput); + + real *gradInput_data = THCTensor_(data)(state, gradInput); + real *gradOutput_data = THCTensor_(data)(state, gradOutput); + + long no_elements = 1; + for(int i = 0; i < gradInput->nDimension; i++){ + no_elements *= gradInput->size[i]; + } + + int d1; + int d2; + int d3; + int d4; + + if (gradInput->nDimension == 4) { + d1 = gradInput->size[0]; + d2 = gradInput->size[1]; + d3 = gradInput->size[2]; + d4 = gradInput->size[3]; + } else { + d1 = gradInput->size[1]; + d2 = gradInput->size[2]; + d3 = gradInput->size[3]; + d4 = gradInput->size[4]; + } + + // cuda blocks & threads: + long nthreads = 256; + // Max number of blocks: http://en.wikipedia.org/wiki/CUDA + // 65535 for SM 2.x, 2^32 -1 for >= 3.0 + // TODO: When we move to SM 3.5 we should update this + long n_xblocks = min(max((int)ceil((float)no_elements / nthreads), 1), 65535); + long n_yblocks = (long)ceil((float)no_elements / (float)(n_xblocks * nthreads)); + if (n_yblocks > 65535) { + THError("Input size is too large! aborting"); + } + dim3 blocks(n_xblocks, n_yblocks); + dim3 threads(nthreads); + + // kernel: + vdownscale<real ,accreal> <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data, no_elements, + scale_factor, d1, d2, d3, d4); + THCudaCheck(cudaGetLastError()); + THCTensor_(free)(state, gradOutput); +} + +#endif diff --git a/lib/THCUNN/generic/VolumetricUpSamplingTrilinear.cu b/lib/THCUNN/generic/VolumetricUpSamplingTrilinear.cu new file mode 100644 index 0000000..58be310 --- /dev/null +++ b/lib/THCUNN/generic/VolumetricUpSamplingTrilinear.cu @@ -0,0 +1,118 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/VolumetricUpSamplingTrilinear.cu" +#else + +static inline void THNN_(VolumetricUpSamplingTrilinear_shapeCheck) + (THCState *state, + THCTensor *input, THCTensor *gradOutput, + int nBatch, int nChannels, + int inputDepth, int inputHeight, int inputWidth, + int outputDepth, int outputHeight, int outputWidth) { + THArgCheck(inputDepth > 0 && inputHeight > 0 && inputWidth > 0 + && outputDepth && outputHeight > 0 && outputWidth > 0, 2, + "input and output sizes should be greater than 0," + " but got input (D: %d, H: %d, W: %d) output (D: %d, H: %d, W: %d)", + inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth); + if (input != NULL) { + THCUNN_argCheck(state, input->nDimension == 5, 2, input, + "5D input tensor expected but got: %s"); + } + + if (gradOutput != NULL) { + THCUNN_check_dim_size(state, gradOutput, 5, 0, nBatch); + THCUNN_check_dim_size(state, gradOutput, 5, 1, nChannels); + THCUNN_check_dim_size(state, gradOutput, 5, 2, outputDepth); + THCUNN_check_dim_size(state, gradOutput, 5, 3, outputHeight); + THCUNN_check_dim_size(state, gradOutput, 5, 4, outputWidth); + } +} + +void THNN_(VolumetricUpSamplingTrilinear_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + int outputDepth, + int outputHeight, + int outputWidth) +{ + int nbatch = THCTensor_(size)(state, input, 0); + int channels = THCTensor_(size)(state, input, 1); + int inputDepth = THCTensor_(size)(state, input, 2); + int inputHeight = THCTensor_(size)(state, input, 3); + int inputWidth = THCTensor_(size)(state, input, 4); + THNN_(VolumetricUpSamplingTrilinear_shapeCheck) + (state, input, NULL, + nbatch, channels, + inputDepth, inputHeight, inputWidth, + outputDepth, outputHeight, outputWidth); + input = THCTensor_(newContiguous)(state, input); + THCUNN_assertSameGPU(state, 2, input, output); + THCTensor_(resize5d)(state, output, + THCTensor_(size)(state, input, 0), + THCTensor_(size)(state, input, 1), + outputDepth, outputHeight, outputWidth); + THCTensor_(zero)(state, output); + THCDeviceTensor<real, 5> idata = toDeviceTensor<real, 5>(state, input); + THCDeviceTensor<real, 5> odata = toDeviceTensor<real, 5>(state, output); + THAssert(inputDepth > 0 && inputHeight > 0 && inputWidth > 0 && outputDepth > 0 && outputHeight > 0 && outputWidth > 0); + const accreal rdepth= (outputDepth > 1) ? (accreal)(inputDepth - 1)/(outputDepth - 1) : accreal(0); + const accreal rheight= (outputHeight > 1) ? (accreal)(inputHeight - 1)/(outputHeight - 1) : accreal(0); + const accreal rwidth = (outputWidth > 1) ? (accreal)(inputWidth - 1)/(outputWidth - 1) : accreal(0); + const int num_kernels = outputDepth * outputHeight * outputWidth; + const int num_threads = + THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock; + cudaStream_t stream = THCState_getCurrentStream(state); + caffe_gpu_interp2_kernel<real, accreal> <<<THCCeilDiv(num_kernels, num_threads), num_threads , + 0 , stream>>>(num_kernels, rdepth, rheight, rwidth, idata, odata); + THCudaCheck(cudaGetLastError()); + THCTensor_(free)(state, input); +} + + +void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)( + THCState *state, + THCTensor *gradOutput, + THCTensor *gradInput, + int nbatch, + int nchannels, + int inputDepth, + int inputHeight, + int inputWidth, + int outputDepth, + int outputHeight, + int outputWidth) +{ + THNN_(VolumetricUpSamplingTrilinear_shapeCheck) + (state, NULL, gradOutput, + nbatch, nchannels, + inputDepth, inputHeight, inputWidth, + outputDepth, outputHeight, outputWidth); + gradInput = THCTensor_(newContiguous)(state, gradInput); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); + THCUNN_assertSameGPU(state, 2, gradOutput, gradInput); + THCTensor_(resize5d)(state, gradInput, nbatch, nchannels, inputDepth, inputHeight, inputWidth); + THCTensor_(zero)(state, gradInput); + THCDeviceTensor<real, 5> data1 = toDeviceTensor<real, 5>(state, gradInput); + THCDeviceTensor<real, 5> data2 = toDeviceTensor<real, 5>(state, gradOutput); + int depth1 = data1.getSize(2); + int height1 = data1.getSize(3); + int width1 = data1.getSize(4); + int depth2 = data2.getSize(2); + int height2 = data2.getSize(3); + int width2 = data2.getSize(4); + assert(depth1 > 0 && height1 > 0 && width1 > 0 && depth2 > 0 && height2 > 0 && width2 > 0); + const accreal rdepth= (depth2 > 1) ? (accreal)(depth1 - 1)/(depth2 - 1) : accreal(0); + const accreal rheight= (height2 > 1) ? (accreal)(height1 - 1)/(height2 - 1) : accreal(0); + const accreal rwidth = (width2 > 1) ? (accreal)(width1 - 1) / (width2 - 1) : accreal(0); + const int num_kernels = depth2 * height2 * width2; + const int num_threads = + THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock; + cudaStream_t stream = THCState_getCurrentStream(state); + caffe_gpu_interp2_kernel_backward<real ,accreal> <<<THCCeilDiv(num_kernels, num_threads), + num_threads, 0, stream>>>(num_kernels, rdepth, rheight, rwidth, data1, data2); + THCudaCheck(cudaGetLastError()); + THCTensor_(free)(state, gradInput); + THCTensor_(free)(state, gradOutput); +} + +#endif |