From 200ae7d55a3381a232256223c0694498f8f51df0 Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Fri, 7 Jul 2017 11:37:53 -0400 Subject: Adding Spatial Transformers w/CuDNN support --- lib/THNN/generic/SpatialGridSamplerBilinear.c | 204 ++++++++++++++++++++++++++ lib/THNN/generic/THNN.h | 12 ++ lib/THNN/init.c | 3 + 3 files changed, 219 insertions(+) create mode 100644 lib/THNN/generic/SpatialGridSamplerBilinear.c diff --git a/lib/THNN/generic/SpatialGridSamplerBilinear.c b/lib/THNN/generic/SpatialGridSamplerBilinear.c new file mode 100644 index 0000000..37da51c --- /dev/null +++ b/lib/THNN/generic/SpatialGridSamplerBilinear.c @@ -0,0 +1,204 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/SpatialGridSamplerBilinear.c" +#else + +#undef MIN +#define MIN(a,b) ( ((a)<(b)) ? (a) : (b) ) + +static inline void THNN_(SpatialGridSamplerBilinear_shapeCheck) + (THTensor *input, THTensor *grid, THTensor *gradOutput) { + THNN_ARGCHECK(input->nDimension == 4, 2, input, + "4D input tensor expected but got: %s"); + THNN_ARGCHECK(grid->nDimension == 4, 2, grid, + "4D grid tensor expected but got: %s"); + + int nbatch = THTensor_(size)(input, 0); + int channels = THTensor_(size)(input, 1); + int iheight = THTensor_(size)(input, 2); + int iwidth = THTensor_(size)(input, 3); + int oheight = THTensor_(size)(grid, 1); + int owidth = THTensor_(size)(grid, 2); + + THNN_CHECK_DIM_SIZE(grid, 4, 0, nbatch); + THNN_CHECK_DIM_SIZE(grid, 4, 3, 2); + + if (gradOutput != NULL) { + THNN_CHECK_DIM_SIZE(gradOutput, 4, 0, nbatch); + THNN_CHECK_DIM_SIZE(gradOutput, 4, 1, channels); + THNN_CHECK_DIM_SIZE(gradOutput, 4, 2, oheight); + THNN_CHECK_DIM_SIZE(gradOutput, 4, 3, owidth); + } +} + +#define SAFE_GET(input, x, y, n, c, H, W) x >= 0 && x < W && y >=0 \ + && y < H ? THTensor_fastGet4d(input, n, c, y, x) : 0 + +TH_API void THNN_(SpatialGridSamplerBilinear_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *grid, + THTensor *output) { + + THNN_(SpatialGridSamplerBilinear_shapeCheck)(input, grid, NULL); + int N = THTensor_(size)(input, 0); + int C = THTensor_(size)(input, 1); + int IH = THTensor_(size)(input, 2); + int IW = THTensor_(size)(input, 3); + int H = THTensor_(size)(grid, 1); + int W = THTensor_(size)(grid, 2); + + // resize output to the same shape as input + THTensor_(resize4d)(output, N, C, H, W); + + // loop over each output pixel + int n, h, w, c; +#pragma omp parallel for private(n, h, w, c) + for (n = 0; n < N; ++n) { + for (h = 0; h < H; ++h) { + for (w = 0; w < W; ++w) { + // get the corresponding input x, y co-ordinates from grid + real ix = THTensor_fastGet4d(grid, n, h, w, 0); + real iy = THTensor_fastGet4d(grid, n, h, w, 1); + + // normalize ix, iy from [-1, 1] to [0, IH-1] & [0, IW-1] + ix = ((ix + 1) / 2) * (IW-1); + iy = ((iy + 1) / 2) * (IH-1); + + // get NE, NW, SE, SW pixel values from (x, y) + int ix_nw = floor(ix); + int iy_nw = floor(iy); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + real nw = (ix_se - ix) * (iy_se - iy); + real ne = (ix - ix_sw) * (iy_sw - iy); + real sw = (ix_ne - ix) * (iy - iy_ne); + real se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + for (c = 0; c < C; ++c) { + // (c, iy_nw, ix_nw) * nw + (c, iy_ne, ix_ne) * ne + // + (c, iy_sw, ix_sw) * sw + (c, iy_se, ix_se) * se + real nw_val = SAFE_GET(input, ix_nw, iy_nw, n, c, IH, IW); + real ne_val = SAFE_GET(input, ix_ne, iy_ne, n, c, IH, IW); + real sw_val = SAFE_GET(input, ix_sw, iy_sw, n, c, IH, IW); + real se_val = SAFE_GET(input, ix_se, iy_se, n, c, IH, IW); + real out_val = nw_val * nw + ne_val * ne + sw_val * sw + se_val * se; + THTensor_fastSet4d(output, n, c, h, w, out_val); + } + } + } + } +} + +#define SAFE_ADD(input, x, y, n, c, H, W, value) \ + do { \ + if (x >= 0 && x < W && y >=0 && y < H) { \ + real old_value = THTensor_fastGet4d(input, n, c, y, x); \ + THTensor_fastSet4d(input, n, c, y, x, value + old_value); \ + } \ + } while(0) + +TH_API void THNN_(SpatialGridSamplerBilinear_updateGradInput)( + THNNState *state, + THTensor *input, THTensor *gradInput, + THTensor *grid, THTensor *gradGrid, + THTensor *gradOutput) { + + THNN_(SpatialGridSamplerBilinear_shapeCheck)(input, grid, gradOutput); + int N = THTensor_(size)(input, 0); + int C = THTensor_(size)(input, 1); + int IH = THTensor_(size)(input, 2); + int IW = THTensor_(size)(input, 3); + int H = THTensor_(size)(grid, 1); + int W = THTensor_(size)(grid, 2); + + THTensor_(resize4d)(gradInput, N, C, IH, IW); + THTensor_(resize4d)(gradGrid, N, H, W, 2); + THTensor_(zero)(gradInput); + THTensor_(zero)(gradGrid); + + // loop over each output pixel + int n, h, w; +#pragma omp parallel for private(n, h, w) + for (n = 0; n < N; ++n) { + for (h = 0; h < H; ++h) { + for (w = 0; w < W; ++w) { + // get the corresponding input x, y co-ordinates from grid + real ix = THTensor_fastGet4d(grid, n, h, w, 0); + real iy = THTensor_fastGet4d(grid, n, h, w, 1); + + real gix = 0; + real giy = 0; + + // normalize ix, iy from [-1, 1] to [0, H-1] & [0, W-1] + ix = ((ix + 1) / 2) * (IW-1); + iy = ((iy + 1) / 2) * (IH-1); + + // get NE, NW, SE, SW pixel values from (x, y) + int ix_nw = floor(ix); + int iy_nw = floor(iy); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + real nw = (ix_se - ix) * (iy_se - iy); + real ne = (ix - ix_sw) * (iy_sw - iy); + real sw = (ix_ne - ix) * (iy - iy_ne); + real se = (ix - ix_nw) * (iy - iy_nw); + + for (int c = 0; c < C; ++c) { + real gradout = THTensor_fastGet4d(gradOutput, n, c, h, w); + + // calculate and set gradInput + SAFE_ADD(gradInput, ix_nw, iy_nw, n, c, IH, IW, nw * gradout); + SAFE_ADD(gradInput, ix_ne, iy_ne, n, c, IH, IW, ne * gradout); + SAFE_ADD(gradInput, ix_sw, iy_sw, n, c, IH, IW, sw * gradout); + SAFE_ADD(gradInput, ix_se, iy_se, n, c, IH, IW, se * gradout); + + // calculate gradGrid + real nw_val = SAFE_GET(input, ix_nw, iy_nw, n, c, IH, IW); + real ne_val = SAFE_GET(input, ix_ne, iy_ne, n, c, IH, IW); + real sw_val = SAFE_GET(input, ix_sw, iy_sw, n, c, IH, IW); + real se_val = SAFE_GET(input, ix_se, iy_se, n, c, IH, IW); + + gix -= nw_val * (iy_se - iy) * gradout; + gix += ne_val * (iy_sw - iy) * gradout; + gix -= sw_val * (iy - iy_ne) * gradout; + gix += se_val * (iy - iy_nw) * gradout; + + giy -= nw_val * (ix_se - ix) * gradout; + giy -= ne_val * (ix - ix_sw) * gradout; + giy += sw_val * (ix_ne - ix) * gradout; + giy += se_val * (ix - ix_nw) * gradout; + } + + // un-normalize gradGrid values back to [-1, 1] constraints + gix = gix * (IW - 1) / 2; + giy = giy * (IH - 1) / 2; + + real gix_old = THTensor_fastGet4d(gradGrid, n, h, w, 0); + real giy_old = THTensor_fastGet4d(gradGrid, n, h, w, 1); + + THTensor_fastSet4d(gradGrid, n, h, w, 0, gix_old + gix); + THTensor_fastSet4d(gradGrid, n, h, w, 1, giy_old + giy); + + } + } + } +} + +#undef MIN +#undef SAFE_GET +#undef SAFE_ADD + +#endif diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h index 76a28eb..31b0795 100644 --- a/lib/THNN/generic/THNN.h +++ b/lib/THNN/generic/THNN.h @@ -1171,6 +1171,18 @@ TH_API void THNN_(SpatialUpSamplingBilinear_updateGradInput)( int outputHeight, int outputWidth); +TH_API void THNN_(SpatialGridSamplerBilinear_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *grid, + THTensor *output); + +TH_API void THNN_(SpatialGridSamplerBilinear_updateGradInput)( + THNNState *state, + THTensor *input, THTensor *gradInput, + THTensor *grid, THTensor *gradGrid, + THTensor *gradOutput); + TH_API void THNN_(unfolded_acc)( THTensor *finput, THTensor *input, diff --git a/lib/THNN/init.c b/lib/THNN/init.c index 5c8c023..12f3222 100644 --- a/lib/THNN/init.c +++ b/lib/THNN/init.c @@ -236,6 +236,9 @@ #include "generic/SpatialUpSamplingBilinear.c" #include "THGenerateFloatTypes.h" +#include "generic/SpatialGridSamplerBilinear.c" +#include "THGenerateFloatTypes.h" + #include "generic/VolumetricAveragePooling.c" #include "THGenerateFloatTypes.h" -- cgit v1.2.3