Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-08-03 05:03:45 +0300
committerSoumith Chintala <soumith@gmail.com>2017-08-03 05:46:24 +0300
commite9ef2d5281dec554724b816b520413c437fb1772 (patch)
tree461c2bb3f0a76822b8ea276a5d005fe1e7081ec6
parentbbebfdc88c6cd0e533d10a08fc48565c7452e2e6 (diff)
remove limitations on output_padding in Conv* routines
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionLocal.cu2
-rw-r--r--lib/THCUNN/generic/SpatialConvolutionMM.cu2
-rw-r--r--lib/THCUNN/generic/SpatialDepthWiseConvolution.cu2
-rw-r--r--lib/THCUNN/generic/SpatialDilatedConvolution.cu2
-rw-r--r--lib/THCUNN/generic/SpatialFullDilatedConvolution.cu8
-rw-r--r--lib/THCUNN/generic/VolumetricDilatedConvolution.cu1
-rw-r--r--lib/THCUNN/generic/VolumetricFullDilatedConvolution.cu15
-rw-r--r--lib/THCUNN/im2col.h7
-rw-r--r--lib/THCUNN/vol2col.h6
9 files changed, 23 insertions, 22 deletions
diff --git a/lib/THCUNN/generic/SpatialConvolutionLocal.cu b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
index 4a0563d..06bf1f8 100644
--- a/lib/THCUNN/generic/SpatialConvolutionLocal.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionLocal.cu
@@ -266,7 +266,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, fgradInput_n),
- nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
1, 1, THCTensor_(data)(state, gradInput_n)
);
diff --git a/lib/THCUNN/generic/SpatialConvolutionMM.cu b/lib/THCUNN/generic/SpatialConvolutionMM.cu
index b4ae8e5..4db0406 100644
--- a/lib/THCUNN/generic/SpatialConvolutionMM.cu
+++ b/lib/THCUNN/generic/SpatialConvolutionMM.cu
@@ -302,7 +302,7 @@ void THNN_(SpatialConvolutionMM_updateGradInput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, gradColumns),
- nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
1, 1, THCTensor_(data)(state, gradInput_n)
);
}
diff --git a/lib/THCUNN/generic/SpatialDepthWiseConvolution.cu b/lib/THCUNN/generic/SpatialDepthWiseConvolution.cu
index 1fc365f..68077ed 100644
--- a/lib/THCUNN/generic/SpatialDepthWiseConvolution.cu
+++ b/lib/THCUNN/generic/SpatialDepthWiseConvolution.cu
@@ -388,7 +388,7 @@ void THNN_(SpatialDepthWiseConvolution_updateGradInput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, gradColumns),
- 1, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ 1, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
1, 1, THCTensor_(data)(state, gradInput_i)
);
}
diff --git a/lib/THCUNN/generic/SpatialDilatedConvolution.cu b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
index 01c97c9..ff764f6 100644
--- a/lib/THCUNN/generic/SpatialDilatedConvolution.cu
+++ b/lib/THCUNN/generic/SpatialDilatedConvolution.cu
@@ -296,7 +296,7 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, gradColumns),
- nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW,
THCTensor_(data)(state, gradInput_n)
);
diff --git a/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu b/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu
index 322a213..aafd07e 100644
--- a/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu
+++ b/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu
@@ -13,12 +13,12 @@ static inline void THNN_(SpatialFullDilatedConvolution_shapeCheck)(
"kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
THArgCheck(dW > 0 && dH > 0, 11,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
- THArgCheck(adjW < dW && adjH < dH, 15,
- "output adjustment must be smaller than stride, but got adjH: %d adjW: %d dH: %d dW: %d",
- adjH, adjW, dH, dW);
THArgCheck(dilationW > 0 && dilationH > 0, 15,
"dilation should be greater than zero, but got dilationH: %d, dilationW: %d",
dilationH, dilationW);
+ THArgCheck((adjW < dW || adjW < dilationW) && (adjH < dH || adjH < dilationH), 15,
+ "output padding must be smaller than either stride or dilation, but got adjH: %d adjW: %d dH: %d dW: %d dilationH: %d dilationW: %d",
+ adjH, adjW, dH, dW, dilationH, dilationW);
THArgCheck(THCTensor_(isContiguous)(state, weight), 4,
"weight tensor has to be contiguous");
THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5,
@@ -160,7 +160,7 @@ void THNN_(SpatialFullDilatedConvolution_updateOutput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, columns),
- nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
+ nOutputPlane, outputHeight, outputWidth, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW, THCTensor_(data)(state, output_n)
);
diff --git a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
index 45bb0f6..1203733 100644
--- a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu
@@ -310,6 +310,7 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)(
THCState_getCurrentStream(state),
THCTensor_(data)(state, gradColumns),
nInputPlane, inputDepth, inputHeight, inputWidth,
+ outputDepth, outputHeight, outputWidth,
kT, kH, kW, padT, padH, padW, dT, dH, dW,
dilationT, dilationH, dilationW,
THCTensor_(data)(state, gradInput_n)
diff --git a/lib/THCUNN/generic/VolumetricFullDilatedConvolution.cu b/lib/THCUNN/generic/VolumetricFullDilatedConvolution.cu
index bda0b59..94fcc6c 100644
--- a/lib/THCUNN/generic/VolumetricFullDilatedConvolution.cu
+++ b/lib/THCUNN/generic/VolumetricFullDilatedConvolution.cu
@@ -24,13 +24,16 @@ static inline void THNN_(VolumetricFullDilatedConvolution_shapeCheck)(
"bias tensor has to be contiguous");
THArgCheck(dT > 0 && dW > 0 && dH > 0, 8,
"stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW);
- THArgCheck(adjT < dT && adjW < dW && adjH < dH, 14,
- "output adjustment must be smaller than stride, but got "
- "adjT: %d adjH: %d adjW: %d dT: %d dH: %d dW: %d",
- adjT, adjH, adjW, dT, dH, dW);
THArgCheck(dilationT > 0 && dilationW > 0 && dilationH > 0, 15,
"dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d",
dilationT, dilationH, dilationW);
+ THArgCheck((adjT < dT || adjT < dilationT)
+ && (adjW < dW || adjW < dilationW)
+ && (adjH < dH || adjH < dilationH), 15,
+ "output padding must be smaller than either stride or dilation,"
+ " but got adjT: %d adjH: %d adjW: %d dT: %d dH: %d dW: %d "
+ "dilationT: %d dilationH: %d dilationW: %d",
+ adjT, adjH, adjW, dT, dH, dW, dilationT, dilationH, dilationW);
int ndim = input->nDimension;
int nInputPlane = THCTensor_(size)(state, weight, 0);
@@ -178,7 +181,9 @@ void THNN_(VolumetricFullDilatedConvolution_updateOutput)(
col2vol<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, columns),
- nOutputPlane, outputDepth, outputHeight, outputWidth, kT, kH, kW, padT, padH, padW, dT, dH, dW,
+ nOutputPlane, outputDepth, outputHeight, outputWidth,
+ inputDepth, inputHeight, inputWidth,
+ kT, kH, kW, padT, padH, padW, dT, dH, dW,
dilationT, dilationH, dilationW,
THCTensor_(data)(state, output_n)
);
diff --git a/lib/THCUNN/im2col.h b/lib/THCUNN/im2col.h
index ba57263..060525f 100644
--- a/lib/THCUNN/im2col.h
+++ b/lib/THCUNN/im2col.h
@@ -104,13 +104,10 @@ __global__ void col2im_kernel(const int n, const Dtype* data_col,
template <typename Dtype, typename Acctype>
void col2im(cudaStream_t stream, const Dtype* data_col, const int channels,
const int height, const int width,
+ const int output_height, const int output_width,
const int patch_h, const int patch_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, Dtype* data_im) {
- int height_col = (height + 2 * pad_h - (dilation_h * (patch_h - 1) + 1))
- / stride_h + 1;
- int width_col = (width + 2 * pad_w - (dilation_w * (patch_w - 1) + 1))
- / stride_w + 1;
int num_kernels = channels * height * width;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
@@ -118,7 +115,7 @@ void col2im(cudaStream_t stream, const Dtype* data_col, const int channels,
num_kernels, data_col, height, width, channels,
patch_h, patch_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w,
- height_col, width_col, data_im
+ output_height, output_width, data_im
);
THCudaCheck(cudaGetLastError());
}
diff --git a/lib/THCUNN/vol2col.h b/lib/THCUNN/vol2col.h
index 15b110e..12c4838 100644
--- a/lib/THCUNN/vol2col.h
+++ b/lib/THCUNN/vol2col.h
@@ -120,14 +120,12 @@ __global__ void vol2im_kernel(const int n, const Dtype* data_col,
template <typename Dtype, typename Acctype>
void col2vol(cudaStream_t stream, const Dtype* data_col, const int channels,
const int depth, const int height, const int width,
+ const int output_depth, const int output_height, const int output_width,
const int patch_t, const int patch_h, const int patch_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
Dtype* data_vol) {
- int depth_col = (depth + 2 * pad_t - (dilation_t * (patch_t - 1) + 1)) / stride_t + 1;
- int height_col = (height + 2 * pad_h - (dilation_h * (patch_h - 1) + 1)) / stride_h + 1;
- int width_col = (width + 2 * pad_w - (dilation_w * (patch_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * depth * height * width;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
@@ -135,7 +133,7 @@ void col2vol(cudaStream_t stream, const Dtype* data_col, const int channels,
num_kernels, data_col, depth, height, width, channels,
patch_t, patch_h, patch_w, pad_t, pad_h, pad_w, stride_t, stride_h, stride_w,
dilation_t, dilation_h, dilation_w,
- depth_col, height_col, width_col, data_vol
+ output_depth, output_height, output_width, data_vol
);
THCudaCheck(cudaGetLastError());
}