diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-08-03 05:03:45 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-08-03 05:46:04 +0300 |
commit | 19c849d07fd4a4047496f86ac2cfa6bc3219b7ca (patch) | |
tree | 2d20cd81f99335e6597a458a15429c6731aba8cb | |
parent | d87d7c7619a008a27ad0a1d03dfaf978ccfb1719 (diff) |
remove limitations on output_padding in Conv* routines
-rw-r--r-- | lib/THNN/generic/SpatialDilatedConvolution.c | 3 | ||||
-rw-r--r-- | lib/THNN/generic/SpatialFullDilatedConvolution.c | 18 | ||||
-rw-r--r-- | lib/THNN/generic/VolumetricDilatedConvolution.c | 1 | ||||
-rw-r--r-- | lib/THNN/generic/VolumetricFullDilatedConvolution.c | 19 |
4 files changed, 24 insertions, 17 deletions
diff --git a/lib/THNN/generic/SpatialDilatedConvolution.c b/lib/THNN/generic/SpatialDilatedConvolution.c index 897cc0d..1fcc742 100644 --- a/lib/THNN/generic/SpatialDilatedConvolution.c +++ b/lib/THNN/generic/SpatialDilatedConvolution.c @@ -260,7 +260,8 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)( // Unpack columns back into input: THNN_(col2im)( THTensor_(data)(gradColumns), - nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, + kH, kW, padH, padW, dH, dW, dilationH, dilationW, THTensor_(data)(gradInput_n) ); diff --git a/lib/THNN/generic/SpatialFullDilatedConvolution.c b/lib/THNN/generic/SpatialFullDilatedConvolution.c index 4d5a3fc..71e01a7 100644 --- a/lib/THNN/generic/SpatialFullDilatedConvolution.c +++ b/lib/THNN/generic/SpatialFullDilatedConvolution.c @@ -30,16 +30,16 @@ static void THNN_(im2col)(const real* data_im, const int channels, } static void THNN_(col2im)(const real* data_col, const int channels, - const int height, const int width, const int kernel_h, const int kernel_w, + const int height, const int width, + const int output_height, const int output_width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, real* data_im) { memset(data_im, 0, sizeof(real) * height * width * channels); - const int height_col = (height + 2 * pad_h - - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_col = (width + 2 * pad_w - - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + const int height_col = output_height; + const int width_col = output_width; const int channels_col = channels * kernel_h * kernel_w; for (int c_col = 0; c_col < channels_col; ++c_col) { int w_offset = c_col % kernel_w; @@ -67,12 +67,12 @@ static inline void THNN_(SpatialFullDilatedConvolution_shapeCheck)( "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); THArgCheck(dW > 0 && dH > 0, 11, "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); - THArgCheck(adjW < dW && adjH < dH, 15, - "output adjustment must be smaller than stride, but got adjH: %d adjW: %d dH: %d dW: %d", - adjH, adjW, dH, dW); THArgCheck(dilationW > 0 && dilationH > 0, 15, "dilation should be greater than zero, but got dilationH: %d, dilationW: %d", dilationH, dilationW); + THArgCheck((adjW < dW || adjW < dilationW) && (adjH < dH || adjH < dilationH), 15, + "output padding must be smaller than either stride or dilation, but got adjH: %d adjW: %d dH: %d dW: %d dilationH: %d dilationW: %d", + adjH, adjW, dH, dW, dilationH, dilationW); THNN_ARGCHECK(weight->nDimension == 2 || weight->nDimension == 4, 5, weight, "2D or 4D weight tensor expected, but got: %s"); @@ -201,7 +201,7 @@ void THNN_(SpatialFullDilatedConvolution_updateOutput)( // Unpack columns back into input: THNN_(col2im)( THTensor_(data)(columns), - nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW, + nOutputPlane, outputHeight, outputWidth, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, THTensor_(data)(output_n) ); diff --git a/lib/THNN/generic/VolumetricDilatedConvolution.c b/lib/THNN/generic/VolumetricDilatedConvolution.c index 5627e6e..53dfcb9 100644 --- a/lib/THNN/generic/VolumetricDilatedConvolution.c +++ b/lib/THNN/generic/VolumetricDilatedConvolution.c @@ -272,6 +272,7 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)( THNN_(col2vol)( THTensor_(data)(gradColumns), nInputPlane, inputDepth, inputHeight, inputWidth, + outputDepth, outputHeight, outputWidth, kT, kH, kW, padT, padH, padW, dT, dH, dW, dilationT, dilationH, dilationW, THTensor_(data)(gradInput_n) diff --git a/lib/THNN/generic/VolumetricFullDilatedConvolution.c b/lib/THNN/generic/VolumetricFullDilatedConvolution.c index 4e22d38..2e40b62 100644 --- a/lib/THNN/generic/VolumetricFullDilatedConvolution.c +++ b/lib/THNN/generic/VolumetricFullDilatedConvolution.c @@ -47,6 +47,7 @@ static void THNN_(vol2col)( static void THNN_(col2vol)( const real* data_col, const int channels, const int depth, const int height, const int width, + const int out_depth, const int out_height, const int out_width, const int kT, const int kH, const int kW, const int pT, const int pH, const int pW, const int dT, const int dH, const int dW, @@ -55,9 +56,9 @@ static void THNN_(col2vol)( { int c, t, h, w; memset(data_vol, 0, sizeof(real) * depth * height * width * channels); - int depth_col = (depth + 2 * pT - (dilationT * (kT - 1) + 1)) / dT + 1; - int height_col = (height + 2 * pH - (dilationH * (kH - 1) + 1)) / dH + 1; - int width_col = (width + 2 * pW - (dilationW * (kW - 1) + 1)) / dW + 1; + int depth_col = out_depth; + int height_col = out_height; + int width_col = out_width; int channels_col = channels * kT * kH * kW; for (c = 0; c < channels_col; ++c) { @@ -99,13 +100,16 @@ static inline void THNN_(VolumetricFullDilatedConvolution_shapeCheck)( "expected for weight, but got: %s"); THArgCheck(dT > 0 && dW > 0 && dH > 0, 11, "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW); - THArgCheck(aT < dT && aW < dW && aH < dH, 15, - "output adjustment must be smaller than stride, but got " - "adjT: %d adjH: %d adjW: %d dT: %d dH: %d dW: %d", - aT, aH, aW, dT, dH, dW); THArgCheck(dilationT > 0 && dilationW > 0 && dilationH > 0, 15, "dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d", dilationT, dilationH, dilationW); + THArgCheck((aT < dT || aT < dilationT) + && (aW < dW || aW < dilationW) + && (aH < dH || aH < dilationH), 15, + "output padding must be smaller than either stride or dilation," + " but got aT: %d aH: %d aW: %d dT: %d dH: %d dW: %d " + "dilationT: %d dilationH: %d dilationW: %d", + aT, aH, aW, dT, dH, dW, dilationT, dilationH, dilationW); int ndim = input->nDimension; const int nInputPlane = (int)weight->size[0]; @@ -247,6 +251,7 @@ void THNN_(VolumetricFullDilatedConvolution_updateOutput)( THNN_(col2vol)( THTensor_(data)(columns), nOutputPlane, outputDepth, outputHeight, outputWidth, + inputDepth, inputHeight, inputWidth, kT, kH, kW, pT, pH, pW, dT, dH, dW, |