Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THCUNN/generic/SpatialFullDilatedConvolution.cu')
-rw-r--r--lib/THCUNN/generic/SpatialFullDilatedConvolution.cu8
1 files changed, 4 insertions, 4 deletions
diff --git a/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu b/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu
index 322a213..aafd07e 100644
--- a/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu
+++ b/lib/THCUNN/generic/SpatialFullDilatedConvolution.cu
@@ -13,12 +13,12 @@ static inline void THNN_(SpatialFullDilatedConvolution_shapeCheck)(
"kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW);
THArgCheck(dW > 0 && dH > 0, 11,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
- THArgCheck(adjW < dW && adjH < dH, 15,
- "output adjustment must be smaller than stride, but got adjH: %d adjW: %d dH: %d dW: %d",
- adjH, adjW, dH, dW);
THArgCheck(dilationW > 0 && dilationH > 0, 15,
"dilation should be greater than zero, but got dilationH: %d, dilationW: %d",
dilationH, dilationW);
+ THArgCheck((adjW < dW || adjW < dilationW) && (adjH < dH || adjH < dilationH), 15,
+ "output padding must be smaller than either stride or dilation, but got adjH: %d adjW: %d dH: %d dW: %d dilationH: %d dilationW: %d",
+ adjH, adjW, dH, dW, dilationH, dilationW);
THArgCheck(THCTensor_(isContiguous)(state, weight), 4,
"weight tensor has to be contiguous");
THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5,
@@ -160,7 +160,7 @@ void THNN_(SpatialFullDilatedConvolution_updateOutput)(
col2im<real, accreal>(
THCState_getCurrentStream(state),
THCTensor_(data)(state, columns),
- nOutputPlane, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW,
+ nOutputPlane, outputHeight, outputWidth, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW, THCTensor_(data)(state, output_n)
);