diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-12-07 04:07:22 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-12-07 04:07:22 +0300 |
commit | 8d35db45bbb2ad35d3a045d7ebe185b1f9efc505 (patch) | |
tree | 568a42f5d82ff59248e2a1b7db914c79b05c706d | |
parent | eef562a68f44f86dece3967930a7d609b482f7e6 (diff) | |
parent | b544850efadb6314aa189c8193dc359eec519557 (diff) |
Merge pull request #394 from gchanan/volumShapeChecks
Improve Volumetric shape checking.
-rw-r--r-- | lib/THCUNN/generic/SpatialFullConvolution.cu | 3 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 8 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricAveragePooling.cu | 85 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricConvolution.cu | 77 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricDilatedConvolution.cu | 111 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricDilatedMaxPooling.cu | 149 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricFullConvolution.cu | 96 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricMaxPooling.cu | 14 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricMaxUnpooling.cu | 74 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricReplicationPadding.cu | 98 |
10 files changed, 565 insertions, 150 deletions
diff --git a/lib/THCUNN/generic/SpatialFullConvolution.cu b/lib/THCUNN/generic/SpatialFullConvolution.cu index 395e3c6..12995d2 100644 --- a/lib/THCUNN/generic/SpatialFullConvolution.cu +++ b/lib/THCUNN/generic/SpatialFullConvolution.cu @@ -12,6 +12,9 @@ static inline void THNN_(SpatialFullConvolution_shapeCheck)( "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); THArgCheck(dW > 0 && dH > 0, 11, "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + THArgCheck(adjW < dW && adjH < dH, 15, + "output adjustment must be smaller than stride, but got adjH: %d adjW: %d dH: %d dW: %d", + adjH, adjW, dH, dW); THCUNN_argCheck(state, weight->nDimension == 2 || weight->nDimension == 4, 5, weight, "2D or 4D weight tensor expected, but got: %s"); diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index 0941d75..d5b7996 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -1054,9 +1054,11 @@ TH_API void THNN_(VolumetricDilatedMaxPooling_updateGradInput)( THCTensor *gradOutput, THCTensor *gradInput, THCIndexTensor *indices, + int kT, int kW, int kH, int dT, int dW, int dH, int padT, int padW, int padH, - int dilationT, int dilationW, int dilationH); + int dilationT, int dilationW, int dilationH, + bool ceilMode); TH_API void THNN_(VolumetricFullConvolution_updateOutput)( THCState *state, @@ -1111,8 +1113,10 @@ TH_API void THNN_(VolumetricMaxPooling_updateGradInput)( THCTensor *gradOutput, THCTensor *gradInput, THCIndexTensor *indices, + int kT, int kW, int kH, int dT, int dW, int dH, - int padT, int padW, int padH); + int padT, int padW, int padH, + bool ceilMode); TH_API void THNN_(VolumetricMaxUnpooling_updateOutput)( THCState *state, diff --git a/lib/THCUNN/generic/VolumetricAveragePooling.cu b/lib/THCUNN/generic/VolumetricAveragePooling.cu index b0eaea4..7a6c595 100644 --- a/lib/THCUNN/generic/VolumetricAveragePooling.cu +++ b/lib/THCUNN/generic/VolumetricAveragePooling.cu @@ -2,25 +2,30 @@ #define THC_GENERIC_FILE "generic/VolumetricAveragePooling.cu" #else -void THNN_(VolumetricAveragePooling_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int kT, int kW, int kH, - int dT, int dW, int dH) -{ - int batchSize; +static inline void THNN_(VolumetricAveragePooling_shapeCheck)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + int kT, + int kW, + int kH, + int dT, + int dW, + int dH) { int inputSlices; int inputTime; int inputHeight; int inputWidth; + int ndim = input->nDimension; + int dimN = 0; int dimt = 1; int dimh = 2; int dimw = 3; if (input->nDimension == 5) { + dimN++; dimt++; dimh++; dimw++; @@ -36,7 +41,6 @@ void THNN_(VolumetricAveragePooling_updateOutput)( kT, kH, kW); /* sizes */ - batchSize = 1; inputSlices = THCTensor_(size)(state, input, 0); inputTime = THCTensor_(size)(state, input, 1); inputHeight = THCTensor_(size)(state, input, 2); @@ -52,7 +56,6 @@ void THNN_(VolumetricAveragePooling_updateOutput)( kT, kH, kW); /* sizes */ - batchSize = THCTensor_(size)(state, input, 0); inputSlices = THCTensor_(size)(state, input, 1); inputTime = THCTensor_(size)(state, input, 2); inputHeight = THCTensor_(size)(state, input, 3); @@ -67,6 +70,64 @@ void THNN_(VolumetricAveragePooling_updateOutput)( int outputHeight = (inputHeight - kH) / dH + 1; int outputWidth = (inputWidth - kW) / dW + 1; + if (gradOutput != NULL) { + THCUNN_check_dim_size(state, gradOutput, ndim, dimN, inputSlices); + THCUNN_check_dim_size(state, gradOutput, ndim, dimt, outputTime); + THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight); + THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth); + } +} + +void THNN_(VolumetricAveragePooling_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + int kT, int kW, int kH, + int dT, int dW, int dH) +{ + int batchSize; + int inputSlices; + int inputTime; + int inputHeight; + int inputWidth; + + int dimt = 1; + int dimh = 2; + int dimw = 3; + + if (input->nDimension == 5) + { + dimt++; + dimh++; + dimw++; + } + + THNN_(VolumetricAveragePooling_shapeCheck) + (state, input, NULL, kT, kW, kH, dT, dW, dH); + + if (THCTensor_(nDimension)(state, input) == 4) + { + /* sizes */ + batchSize = 1; + inputSlices = THCTensor_(size)(state, input, 0); + inputTime = THCTensor_(size)(state, input, 1); + inputHeight = THCTensor_(size)(state, input, 2); + inputWidth = THCTensor_(size)(state, input, 3); + } + else if (THCTensor_(nDimension)(state, input) == 5) + { + /* sizes */ + batchSize = THCTensor_(size)(state, input, 0); + inputSlices = THCTensor_(size)(state, input, 1); + inputTime = THCTensor_(size)(state, input, 2); + inputHeight = THCTensor_(size)(state, input, 3); + inputWidth = THCTensor_(size)(state, input, 4); + } + + int outputTime = (inputTime - kT) / dT + 1; + int outputHeight = (inputHeight - kH) / dH + 1; + int outputWidth = (inputWidth - kW) / dW + 1; + if (input->nDimension == 4) /* 4D */ { /* resize output */ @@ -139,7 +200,9 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( int kT, int kW, int kH, int dT, int dW, int dH) { - // TODO: gradOutput shape check + + THNN_(VolumetricAveragePooling_shapeCheck) + (state, input, gradOutput, kT, kW, kH, dT, dW, dH); bool kernelsOverlap = (dT < kT) || (dH < kH) || (dW < kW); // Resize and initialize result tensor. diff --git a/lib/THCUNN/generic/VolumetricConvolution.cu b/lib/THCUNN/generic/VolumetricConvolution.cu index 02b0c1e..a371ac8 100644 --- a/lib/THCUNN/generic/VolumetricConvolution.cu +++ b/lib/THCUNN/generic/VolumetricConvolution.cu @@ -7,9 +7,18 @@ static inline void THNN_(VolumetricConvolution_shapeCheck) THCTensor *input, THCTensor *gradOutput, THCTensor *weight, - THCTensor *gradWeight) { + THCTensor *gradWeight, + THCTensor *bias, + int dT, + int dW, + int dH, + int padT, + int padW, + int padH) { THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, "4D or 5D (batch mode) tensor expected for input, but got: %s"); + THArgCheck(dT > 0 && dW > 0 && dH > 0, 10, + "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW); if (gradOutput != NULL) { THCUNN_argCheck(state, gradOutput->nDimension == 4 || gradOutput->nDimension == 5, 3, @@ -28,6 +37,59 @@ static inline void THNN_(VolumetricConvolution_shapeCheck) "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " "expected for gradWeight, but got: %s"); } + + if (weight == NULL) { + weight = gradWeight; + } + int nOutputPlane = (int)weight->size[0]; + int nInputPlane = (int)weight->size[1]; + int kT = (int)weight->size[2]; + int kH = (int)weight->size[3]; + int kW = (int)weight->size[4]; + + THArgCheck(kT > 0 && kW > 0 && kH > 0, 4, + "kernel size should be greater than zero, but got kT: %d kH: %d kW: %d", kT, kH, kW); + int ndim = input->nDimension; + int dimf = 0; + int dimh = 1; + int dimw = 2; + int dimd = 3; + + if (ndim == 5) + { + dimf++; + dimh++; + dimw++; + dimd++; + } + + long inputWidth = input->size[dimw]; + long inputHeight = input->size[dimh]; + long inputDepth = input->size[dimd]; + long outputWidth = (inputWidth + 2*padH - kH) / dH + 1; + long outputHeight = (inputHeight + 2*padT - kT) / dT + 1; + long outputDepth = (inputDepth + 2*padW - kW) / dW + 1; + + if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1) + { + THError( + "Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small", + nInputPlane, inputDepth, inputHeight, inputWidth, + nOutputPlane, outputDepth, outputHeight, outputWidth + ); + } + + if (bias != NULL) { + THCUNN_check_dim_size(state, bias, 1, 0, weight->size[0]); + } + THCUNN_check_dim_size(state, input, ndim, dimf, nInputPlane); + + if (gradOutput != NULL) { + THCUNN_check_dim_size(state, gradOutput, ndim, dimf, nOutputPlane); + THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight); + THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth); + THCUNN_check_dim_size(state, gradOutput, ndim, dimd, outputDepth); + } } void THNN_(VolumetricConvolution_updateOutput)( @@ -44,7 +106,9 @@ void THNN_(VolumetricConvolution_updateOutput)( THCTensor *columns = finput; THCTensor *ones = fgradInput; THCUNN_assertSameGPU(state, 6, input, output, weight, bias, columns, ones); - THNN_(VolumetricConvolution_shapeCheck)(state, input, NULL, weight, NULL); + THNN_(VolumetricConvolution_shapeCheck)( + state, input, NULL, weight, NULL, + bias, dT, dW, dH, padT, padW, padH); input = THCTensor_(newContiguous)(state, input); int nOutputPlane = (int)weight->size[0]; @@ -191,7 +255,9 @@ void THNN_(VolumetricConvolution_updateGradInput)( THCTensor *gradColumns = finput; THCUNN_assertSameGPU(state, 5, input, gradOutput, weight, gradColumns, gradInput); - THNN_(VolumetricConvolution_shapeCheck)(state, input, gradOutput, weight, NULL); + THNN_(VolumetricConvolution_shapeCheck)( + state, input, gradOutput, weight, NULL, + NULL, dT, dW, dH, padT, padW, padH); gradOutput = THCTensor_(newContiguous)(state, gradOutput); int batch = 1; @@ -295,6 +361,9 @@ void THNN_(VolumetricConvolution_accGradParameters)( THCTensor *columns = finput; THCTensor *ones = fgradInput; THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight, gradBias, columns, ones); + THNN_(VolumetricConvolution_shapeCheck)( + state, input, gradOutput, NULL, gradWeight, + gradBias, dT, dW, dH, padT, padW, padH); int nOutputPlane = (int)gradWeight->size[0]; int nInputPlane = (int)gradWeight->size[1]; @@ -302,8 +371,6 @@ void THNN_(VolumetricConvolution_accGradParameters)( int kH = (int)gradWeight->size[3]; int kW = (int)gradWeight->size[4]; - THNN_(VolumetricConvolution_shapeCheck)(state, input, gradOutput, NULL, gradWeight); - input = THCTensor_(newContiguous)(state, input); gradOutput = THCTensor_(newContiguous)(state, gradOutput); diff --git a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu index 268d690..422cdc7 100644 --- a/lib/THCUNN/generic/VolumetricDilatedConvolution.cu +++ b/lib/THCUNN/generic/VolumetricDilatedConvolution.cu @@ -2,6 +2,69 @@ #define THC_GENERIC_FILE "generic/VolumetricDilatedConvolution.cu" #else +static inline void THNN_(VolumetricDilatedConvolution_shapeCheck)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *weight, + THCTensor *bias, + int kT, int kH, int kW, + int dT, int dH, int dW, + int padT, int padH, int padW, + int dilationT, int dilationH, int dilationW) { + THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, + "4D or 5D (batch mode) tensor expected for input, but got: %s"); + THCUNN_argCheck(state, weight->nDimension == 5, 4, weight, + "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " + "expected for weight, but got: %s"); + THArgCheck(kT > 0 && kW > 0 && kH > 0, 8, + "kernel size should be greater than zero, but got kT: %d kH: %d kW: %d", kT, kH, kW); + THArgCheck(dT > 0 && dW > 0 && dH > 0, 11, + "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW); + THArgCheck(dilationT > 0 && dilationW > 0 && dilationH > 0, 15, + "dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d", + dilationT, dilationH, dilationW); + + if (bias != NULL) { + THCUNN_check_dim_size(state, bias, 1, 0, weight->size[0]); + } + + int ndim = input->nDimension; + int dimf = 0; + int dimd = 1; + int dimh = 2; + int dimw = 3; + + if (ndim == 5) { + dimf++; + dimd++; + dimh++; + dimw++; + } + + int nInputPlane = weight->size[1]; + int nOutputPlane = weight->size[0]; + long inputDepth = input->size[dimd]; + long inputHeight = input->size[dimh]; + long inputWidth = input->size[dimw]; + long outputDepth = (inputDepth + 2*padT - (dilationT * (kT - 1) + 1)) / dT + 1; + long outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1) + THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small", + nInputPlane,inputDepth,inputHeight,inputWidth,nOutputPlane,outputDepth,outputHeight,outputWidth); + + THCUNN_check_dim_size(state, input, ndim, dimf, nInputPlane); + + if (gradOutput != NULL) { + THCUNN_check_dim_size(state, gradOutput, ndim, dimf, nOutputPlane); + THCUNN_check_dim_size(state, gradOutput, ndim, dimd, outputDepth); + THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight); + THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth); + } +} + void THNN_(VolumetricDilatedConvolution_updateOutput)( THCState *state, THCTensor *input, @@ -19,15 +82,10 @@ void THNN_(VolumetricDilatedConvolution_updateOutput)( if (bias) { THCUNN_assertSameGPU(state, 2, weight, bias); } - THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, - "4D or 5D (batch mode) tensor expected for input, but got: %s"); - THCUNN_argCheck(state, weight->nDimension == 5, 4, weight, - "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " - "expected for weight, but got: %s"); - THArgCheck(!bias || weight->size[0] == bias->size[0], 4, "nOutputPlane mismatch in weight and bias"); - THArgCheck(kT > 0 && kW > 0 && kH > 0, 8, "kernel size should be greater than zero"); - THArgCheck(dT > 0 && dW > 0 && dH > 0, 10, "stride should be greater than zero"); - THArgCheck(dilationT > 0 && dilationW > 0 && dilationH > 0, 16, "dilation should be greater than 0"); + THNN_(VolumetricDilatedConvolution_shapeCheck)( + state, input, NULL, weight, bias, + kT, kH, kW, dT, dH, dW, padT, padH, padW, + dilationT, dilationH, dilationW); // Params: int nInputPlane = weight->size[1]; @@ -36,12 +94,9 @@ void THNN_(VolumetricDilatedConvolution_updateOutput)( input = THCTensor_(newContiguous)(state, input); int batch = 1; if (input->nDimension == 4) { - THArgCheck(input->size[0] == nInputPlane, 2, "input channels and nInputPlane dont match"); // Force batch batch = 0; THCTensor_(resize5d)(state, input, 1, input->size[0], input->size[1], input->size[2], input->size[3]); - } else { - THArgCheck(input->size[1] == nInputPlane, 2, "input channels and nInputPlane dont match"); } long inputDepth = input->size[2]; @@ -51,10 +106,6 @@ void THNN_(VolumetricDilatedConvolution_updateOutput)( long outputHeight = (inputHeight + 2*padH - (dilationH * (kH - 1) + 1)) / dH + 1; long outputWidth = (inputWidth + 2*padW - (dilationW * (kW - 1) + 1)) / dW + 1; - if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1) - THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small", - nInputPlane,inputDepth,inputHeight,inputWidth,nOutputPlane,outputDepth,outputHeight,outputWidth); - // Batch size + input planes long batchSize = input->size[0]; @@ -174,16 +225,10 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)( THCUNN_assertSameGPU(state, 5, input, gradOutput, weight, gradColumns, gradInput); - THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, - "4D or 5D (batch mode) tensor expected for input, but got: %s"); - THCUNN_argCheck(state, gradOutput->nDimension == 4 || gradOutput->nDimension == 5, 3, - gradOutput, - "4D or 5D (batch mode) tensor expected for gradOutput, but got: %s"); - THCUNN_argCheck(state, weight->nDimension == 5, 4, weight, - "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " - "expected for weight, but got: %s"); - THArgCheck(kT > 0 && kW > 0 && kH > 0, 8, "kernel size should be greater than zero"); - THArgCheck(dT > 0 && dW > 0 && dH > 0, 10, "stride should be greater than zero"); + THNN_(VolumetricDilatedConvolution_shapeCheck)( + state, input, gradOutput, weight, NULL, + kT, kH, kW, dT, dH, dW, padT, padH, padW, + dilationT, dilationH, dilationW); // Params int nInputPlane = weight->size[1]; @@ -293,16 +338,10 @@ void THNN_(VolumetricDilatedConvolution_accGradParameters)( if (gradBias) { THCUNN_assertSameGPU(state, 2, gradWeight, gradBias); } - THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, - "4D or 5D (batch mode) tensor expected for input, but got: %s"); - THCUNN_argCheck(state, gradOutput->nDimension == 4 || gradOutput->nDimension == 5, 3, - gradOutput, - "4D or 5D (batch mode) tensor expected for gradOutput, but got: %s"); - THCUNN_argCheck(state, gradWeight->nDimension == 5, 4, gradWeight, - "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " - "expected for gradWeight, but got: %s"); - THArgCheck(kT > 0 && kW > 0 && kH > 0, 8, "kernel size should be greater than zero"); - THArgCheck(dT > 0 && dW > 0 && dH > 0, 10, "stride should be greater than zero"); + THNN_(VolumetricDilatedConvolution_shapeCheck)( + state, input, gradOutput, gradWeight, gradBias, + kT, kH, kW, dT, dH, dW, padT, padH, padW, + dilationT, dilationH, dilationW); // Params int nInputPlane = gradWeight->size[1]; diff --git a/lib/THCUNN/generic/VolumetricDilatedMaxPooling.cu b/lib/THCUNN/generic/VolumetricDilatedMaxPooling.cu index 68879bb..0a1bb81 100644 --- a/lib/THCUNN/generic/VolumetricDilatedMaxPooling.cu +++ b/lib/THCUNN/generic/VolumetricDilatedMaxPooling.cu @@ -9,6 +9,114 @@ dilationT, dilationH, dilationW, offsetZ); \ break +static inline void THNN_(VolumetricDilatedMaxPooling_shapeCheck)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCIndexTensor *indices, + int kT, int kW, int kH, + int dT, int dW, int dH, + int padT, int padW, int padH, + int dilationT, int dilationW, int dilationH, + bool ceilMode) { + int ndim = input->nDimension; + int inputSlices; + int inputTime; + int inputHeight; + int inputWidth; + int outputTime; + int outputHeight; + int outputWidth; + int dimf = 0; + int dimt = 1; + int dimh = 2; + int dimw = 3; + + THArgCheck(kT > 0 && kW > 0 && kH > 0, 7, + "kernel size should be greater than zero, but got kT: %d kH: %d kW: %d", + kT, kH, kW); + THArgCheck(dT > 0 && dW > 0 && dH > 0, 10, + "stride should be greater than zero, but got dT: %d dH: %d dW: %d", + dT, dH, dW); + THArgCheck(dilationT > 0 && dilationW > 0 && dilationH > 0, 16, + "dilation should be greater than 0, but got dilationT: %d dilationH: %d dilationW: %d", + dilationT, dilationH, dilationW); + + if (input->nDimension == 5) + { + dimf++; + dimt++; + dimh++; + dimw++; + } + + if (THCTensor_(nDimension)(state, input) == 4) + { + /* sizes */ + inputSlices = THCTensor_(size)(state, input, 0); + inputTime = THCTensor_(size)(state, input, 1); + inputHeight = THCTensor_(size)(state, input, 2); + inputWidth = THCTensor_(size)(state, input, 3); + } + else if (THCTensor_(nDimension)(state, input) == 5) + { + /* sizes */ + inputSlices = THCTensor_(size)(state, input, 1); + inputTime = THCTensor_(size)(state, input, 2); + inputHeight = THCTensor_(size)(state, input, 3); + inputWidth = THCTensor_(size)(state, input, 4); + } + else + { + THArgCheck(false, 2, "4D or 5D tensor expected, got %d", THCTensor_(nDimension)(state, input)); + } + + THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 13, + "pad should be smaller than half of kernel size, but got " + "kT: %d kW: %d, kH: %d, padT: %d, padW: %d, padH: %d", + kT, kW, kH, padT, padW, padH); + + if (ceilMode) + { + outputTime = (int)(ceil((float)(inputTime - (dilationT * (kT - 1) + 1) + 2*padT) / dT)) + 1; + outputHeight = (int)(ceil((float)(inputHeight - (dilationH * (kH - 1) + 1) + 2*padH) / dH)) + 1; + outputWidth = (int)(ceil((float)(inputWidth - (dilationW * (kW - 1) + 1) + 2*padW) / dW)) + 1; + } + else + { + outputTime = (int)(floor((float)(inputTime - (dilationT * (kT - 1) + 1) + 2*padT) / dT)) + 1; + outputHeight = (int)(floor((float)(inputHeight - (dilationH * (kH - 1) + 1) + 2*padH) / dH)) + 1; + outputWidth = (int)(floor((float)(inputWidth - (dilationW * (kW - 1) + 1) + 2*padW) / dW)) + 1; + } + + if (padT || padW || padH) + { + if ((outputTime - 1)*dT >= inputTime + padT) + --outputTime; + if ((outputHeight - 1)*dH >= inputHeight + padH) + --outputHeight; + if ((outputWidth - 1)*dW >= inputWidth + padW) + --outputWidth; + } + + if (outputTime < 1 || outputHeight < 1 || outputWidth < 1) + THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small", + inputSlices,inputTime,inputHeight,inputWidth,inputSlices,outputTime,outputHeight,outputWidth); + + if (gradOutput != NULL) { + THCUNN_check_dim_size(state, gradOutput, ndim, dimf, inputSlices); + THCUNN_check_dim_size(state, gradOutput, ndim, dimt, outputTime); + THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight); + THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth); + } + if (indices != NULL) { + THCUNN_check_dim_size_indices(state, indices, ndim, dimf, inputSlices); + THCUNN_check_dim_size_indices(state, indices, ndim, dimt, outputTime); + THCUNN_check_dim_size_indices(state, indices, ndim, dimh, outputHeight); + THCUNN_check_dim_size_indices(state, indices, ndim, dimw, outputWidth); + } +} + void THNN_(VolumetricDilatedMaxPooling_updateOutput)( THCState *state, THCTensor *input, @@ -41,16 +149,13 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)( } THCUNN_assertSameGPU(state, 3, input, indices, output); + THNN_(VolumetricDilatedMaxPooling_shapeCheck)( + state, input, NULL, NULL, kT, kW, kH, + dT, dW, dH, padT, padW, padH, + dilationT, dilationW, dilationH, ceilMode); if (THCTensor_(nDimension)(state, input) == 4) { - THArgCheck(input->size[dimw] >= kW && input->size[dimh] >= kH - && input->size[dimt] >= kT, 2, - "input image (T: %d H: %d W: %d) smaller than " - "kernel size (kT: %d kH: %d kW: %d)", - input->size[dimt], input->size[dimh], input->size[dimw], - kT, kH, kW); - /* sizes */ batchSize = 1; inputSlices = THCTensor_(size)(state, input, 0); @@ -60,13 +165,6 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)( } else if (THCTensor_(nDimension)(state, input) == 5) { - THArgCheck(input->size[dimw] >= kW && input->size[dimh] >= kH - && input->size[dimt] >= kT, 2, - "input image (T: %d H: %d W: %d) smaller than " - "kernel size (kT: %d kH: %d kW: %d)", - input->size[dimt], input->size[dimh], input->size[dimw], - kT, kH, kW); - /* sizes */ batchSize = THCTensor_(size)(state, input, 0); inputSlices = THCTensor_(size)(state, input, 1); @@ -74,17 +172,6 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)( inputHeight = THCTensor_(size)(state, input, 3); inputWidth = THCTensor_(size)(state, input, 4); } - else - { - THArgCheck(false, 2, "4D or 5D tensor expected, got %d", THCTensor_(nDimension)(state, input)); - } - - THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 2, - "pad should be smaller than half of kernel size" - ); - THArgCheck(dilationT > 0 && dilationW > 0 && dilationH > 0, 14, - "dilation should be greater than 0" - ); if (ceilMode) { @@ -99,10 +186,6 @@ void THNN_(VolumetricDilatedMaxPooling_updateOutput)( outputWidth = (int)(floor((float)(inputWidth - (dilationW * (kW - 1) + 1) + 2*padW) / dW)) + 1; } - if (outputTime < 1 || outputHeight < 1 || outputWidth < 1) - THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small", - inputSlices,inputTime,inputHeight,inputWidth,inputSlices,outputTime,outputHeight,outputWidth); - if (padT || padW || padH) { if ((outputTime - 1)*dT >= inputTime + padT) @@ -205,9 +288,11 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)( THCTensor *gradOutput, THCTensor *gradInput, THCIndexTensor *indices, + int kT, int kW, int kH, int dT, int dW, int dH, int padT, int padW, int padH, - int dilationT, int dilationW, int dilationH) + int dilationT, int dilationW, int dilationH, + bool ceilMode) { // TODO: gradOutput shape check // Resize and initialize result tensor. @@ -222,6 +307,10 @@ void THNN_(VolumetricDilatedMaxPooling_updateGradInput)( int outputWidth; THCUNN_assertSameGPU(state, 4, input, indices, gradOutput, gradInput); + THNN_(VolumetricDilatedMaxPooling_shapeCheck)( + state, input, gradOutput, indices, kT, kW, kH, + dT, dW, dH, padT, padW, padH, + dilationT, dilationW, dilationH, ceilMode); if (THCTensor_(nDimension)(state, input) == 4) /* 4D */ { diff --git a/lib/THCUNN/generic/VolumetricFullConvolution.cu b/lib/THCUNN/generic/VolumetricFullConvolution.cu index f48566c..47f4943 100644 --- a/lib/THCUNN/generic/VolumetricFullConvolution.cu +++ b/lib/THCUNN/generic/VolumetricFullConvolution.cu @@ -2,6 +2,71 @@ #define THC_GENERIC_FILE "generic/VolumetricFullConvolution.cu" #else +static inline void THNN_(VolumetricFullConvolution_shapeCheck)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *weight, + THCTensor *bias, + int dT, int dW, int dH, + int padT, int padW, int padH, + int adjT, int adjW, int adjH) { + THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, + "4D or 5D (batch mode) tensor expected for input, but got: %s"); + // number of input & output planes and kernel size is indirectly defined by the weight tensor + THCUNN_argCheck(state, weight->nDimension == 5, 4, weight, + "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " + "expected for weight, but got: %s"); + THArgCheck(dT > 0 && dW > 0 && dH > 0, 8, + "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW); + THArgCheck(adjT < dT && adjW < dW && adjH < dH, 14, + "output adjustment must be smaller than stride, but got " + "adjT: %d adjH: %d adjW: %d dT: %d dH: %d dW: %d", + adjT, adjH, adjW, dT, dH, dW); + + int ndim = input->nDimension; + int nInputPlane = THCTensor_(size)(state, weight, 0); + int nOutputPlane = THCTensor_(size)(state, weight, 1); + const int kT = (int)weight->size[2]; + const int kH = (int)weight->size[3]; + const int kW = (int)weight->size[4]; + + if (bias != NULL) { + THCUNN_check_dim_size(state, bias, 1, 0, weight->size[1]); + } + + int dimf = 0; + int dimd = 1; + int dimh = 2; + int dimw = 3; + + if (ndim == 5) { + dimf++; + dimd++; + dimh++; + dimw++; + } + + long inputWidth = input->size[dimw]; + long inputHeight = input->size[dimh]; + long inputDepth = input->size[dimd]; + long outputWidth = (inputWidth - 1) * dW - 2*padW + kW + adjW; + long outputHeight = (inputHeight - 1) * dH - 2*padH + kH + adjH; + long outputDepth = (inputDepth - 1) * dT - 2*padT + kT + adjT; + + if (outputDepth < 1 || outputWidth < 1 || outputHeight < 1) + THError("Given input size: (%dx%dx%dx%d). Calculated output size: (%dx%dx%dx%d). Output size is too small", + nInputPlane,inputDepth,inputHeight,inputWidth,nOutputPlane,outputDepth,outputHeight,outputWidth); + + THCUNN_check_dim_size(state, input, ndim, dimf, nInputPlane); + if (gradOutput != NULL) { + THCUNN_check_dim_size(state, gradOutput, ndim, dimf, nOutputPlane); + THCUNN_check_dim_size(state, gradOutput, ndim, dimd, outputDepth); + THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight); + THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth); + } +} + void THNN_(VolumetricFullConvolution_updateOutput)( THCState *state, THCTensor *input, @@ -26,22 +91,18 @@ void THNN_(VolumetricFullConvolution_updateOutput)( THCUNN_assertSameGPU(state, 6, input, output, weight, bias, columns, ones); - THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, - "4D or 5D (batch mode) tensor expected for input, but got: %s"); - THCUNN_argCheck(state, weight->nDimension == 5, 4, weight, - "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " - "expected for weight, but got: %s"); + THNN_(VolumetricFullConvolution_shapeCheck)( + state, input, NULL, weight, bias, + dT, dW, dH, padT, padW, padH, + adjT, adjW, adjH); input = THCTensor_(newContiguous)(state, input); int batch = 1; if (input->nDimension == 4) { - THArgCheck(input->size[0] == nInputPlane, 2, "input channels and nInputPlane dont match"); // Force batch batch = 0; THCTensor_(resize5d)(state, input, 1, input->size[0], input->size[1], input->size[2], input->size[3]); - } else { - THArgCheck(input->size[1] == nInputPlane, 2, "input channels and nInputPlane dont match"); } long inputWidth = input->size[4]; @@ -174,11 +235,10 @@ void THNN_(VolumetricFullConvolution_updateGradInput)( THCUNN_assertSameGPU(state, 5, input, gradOutput, weight, gradColumns, gradInput); - THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, - "4D or 5D (batch mode) tensor expected for input, but got: %s"); - THCUNN_argCheck(state, weight->nDimension == 5, 4, weight, - "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " - "expected for weight, but got: %s"); + THNN_(VolumetricFullConvolution_shapeCheck)( + state, input, gradOutput, weight, NULL, + dT, dW, dH, padT, padW, padH, + adjT, adjW, adjH); input = THCTensor_(newContiguous)(state, input); gradOutput = THCTensor_(newContiguous)(state, gradOutput); @@ -293,12 +353,10 @@ void THNN_(VolumetricFullConvolution_accGradParameters)( THCUNN_assertSameGPU(state, 6, input, gradOutput, gradWeight, gradBias, columns, ones); - THCUNN_argCheck(state, input->nDimension == 4 || input->nDimension == 5, 2, input, - "4D or 5D (batch mode) tensor expected for input, but got: %s"); - THCUNN_argCheck(state, gradWeight->nDimension == 5, 4, gradWeight, - "5D (nOutputPlane x nInputPlane x kT x kH x kW) tensor " - "expected for gradWeight, but got: %s"); - + THNN_(VolumetricFullConvolution_shapeCheck)( + state, input, gradOutput, gradWeight, + gradBias, dT, dW, dH, padT, padW, padH, + adjT, adjW, adjH); input = THCTensor_(newContiguous)(state, input); gradOutput = THCTensor_(newContiguous)(state, gradOutput); diff --git a/lib/THCUNN/generic/VolumetricMaxPooling.cu b/lib/THCUNN/generic/VolumetricMaxPooling.cu index 4a55a45..c86be82 100644 --- a/lib/THCUNN/generic/VolumetricMaxPooling.cu +++ b/lib/THCUNN/generic/VolumetricMaxPooling.cu @@ -13,8 +13,9 @@ void THNN_(VolumetricMaxPooling_updateOutput)( bool ceilMode) { THNN_(VolumetricDilatedMaxPooling_updateOutput)( - state, input, output, indices, - kT, kW, kH, dT, dW, dH, padT, padW, padH, 1, 1, 1, ceilMode); + state, input, output, indices, + kT, kW, kH, dT, dW, dH, padT, padW, padH, + 1, 1, 1, ceilMode); } @@ -24,12 +25,15 @@ void THNN_(VolumetricMaxPooling_updateGradInput)( THCTensor *gradOutput, THCTensor *gradInput, THCIndexTensor *indices, + int kT, int kW, int kH, int dT, int dW, int dH, - int padT, int padW, int padH) + int padT, int padW, int padH, + bool ceilMode) { THNN_(VolumetricDilatedMaxPooling_updateGradInput)( - state, input, gradOutput, gradInput, indices, - dT, dW, dH, padT, padW, padH, 1, 1, 1); + state, input, gradOutput, gradInput, indices, + kT, kW, kH, dT, dW, dH, padT, padW, padH, + 1, 1, 1, ceilMode); } diff --git a/lib/THCUNN/generic/VolumetricMaxUnpooling.cu b/lib/THCUNN/generic/VolumetricMaxUnpooling.cu index e314298..7e0dca9 100644 --- a/lib/THCUNN/generic/VolumetricMaxUnpooling.cu +++ b/lib/THCUNN/generic/VolumetricMaxUnpooling.cu @@ -2,6 +2,66 @@ #define THC_GENERIC_FILE "generic/VolumetricMaxUnpooling.cu" #else +static inline void THNN_(VolumetricMaxUnpooling_shapeCheck)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCIndexTensor *indices, + int oT, + int oW, + int oH, + int dT, + int dW, + int dH, + int pT, + int pW, + int pH) { + int inputSlices; + + THCUNN_check_shape_indices(state, indices, input); + + THArgCheck(dT > 0 && dW > 0 && dH > 0, 10, + "stride should be greater than zero, but got dT: %d dH: %d dW: %d", + dT, dH, dW); + + if (THCTensor_(nDimension)(state, input) == 4) + { + inputSlices = THCTensor_(size)(state, input, 0); + } + else if (THCTensor_(nDimension)(state, input) == 5) + { + inputSlices = THCTensor_(size)(state, input, 1); + } + else + { + THArgCheck(false, 2, "4D or 5D tensor expected, got %d", + THCTensor_(nDimension)(state, input)); + } + + int dimw = 3; + int dimh = 2; + int dimt = 1; + int dimn = 0; + if (input->nDimension == 5) + { + dimt++; + dimw++; + dimh++; + dimn++; + } + + if (gradOutput != NULL) { + if (oT != gradOutput->size[dimt] || oW != gradOutput->size[dimw] || oH != gradOutput->size[dimh]) + { + THError( + "Inconsistent gradOutput size. oT= %d, oH= %d, oW= %d, gradOutput: %dx%dx%d", + oT, oH, oW, gradOutput->size[dimt], gradOutput->size[dimh], gradOutput->size[dimw]); + } + + THCUNN_check_dim_size(state, gradOutput, input->nDimension, dimn, inputSlices); + } +} + void THNN_(VolumetricMaxUnpooling_updateOutput)( THCState *state, THCTensor *input, @@ -17,6 +77,10 @@ void THNN_(VolumetricMaxUnpooling_updateOutput)( int inputHeight; int inputWidth; + THNN_(VolumetricMaxUnpooling_shapeCheck)( + state, input, NULL, indices, + outputTime, outputWidth, outputHeight, + dT, dW, dH, padT, padW, padH); THCUNN_assertSameGPU(state, 3, input, indices, output); if (THCTensor_(nDimension)(state, input) == 4) @@ -37,11 +101,6 @@ void THNN_(VolumetricMaxUnpooling_updateOutput)( inputHeight = THCTensor_(size)(state, input, 3); inputWidth = THCTensor_(size)(state, input, 4); } - else - { - THArgCheck(false, 2, "4D or 5D tensor expected, got %d", - THCTensor_(nDimension)(state, input)); - } if (input->nDimension == 4) /* 4D */ { @@ -117,7 +176,10 @@ void THNN_(VolumetricMaxUnpooling_updateGradInput)( int inputHeight; int inputWidth; - // TODO: check gradOutput shape + THNN_(VolumetricMaxUnpooling_shapeCheck)( + state, input, gradOutput, indices, + outputTime, outputWidth, outputHeight, + dT, dW, dH, padT, padW, padH); THCUNN_assertSameGPU(state, 4, input, indices, gradOutput, gradInput); if (THCTensor_(nDimension)(state, input) == 4) /* 4D */ diff --git a/lib/THCUNN/generic/VolumetricReplicationPadding.cu b/lib/THCUNN/generic/VolumetricReplicationPadding.cu index 35a1c76..aecbd19 100644 --- a/lib/THCUNN/generic/VolumetricReplicationPadding.cu +++ b/lib/THCUNN/generic/VolumetricReplicationPadding.cu @@ -2,6 +2,62 @@ #define THC_GENERIC_FILE "generic/VolumetricReplicationPadding.cu" #else +static inline void THNN_(VolumetricReplicationPadding_shapeCheck)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + int pleft, int pright, + int ptop, int pbottom, + int pfront, int pback) { + THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, input), 2, + "input tensor must fit into 32-bit index math"); + int numInputDims = THCTensor_(nDimension)(state, input); + + THCUNN_argCheck(state, numInputDims == 4 || numInputDims == 5, 2, input, + "4D or 5D (batch mode) tensor expected for input, but got: %s"); + + int planeDim = 0; + int dimd = 1; + int dimh = 2; + int dimw = 3; + if (numInputDims == 5) { + planeDim++; + dimd++; + dimh++; + dimw++; + } + + int numPlanes = THCTensor_(size)(state, input, planeDim); + int idepth = input->size[dimd]; + int iheight = input->size[dimh]; + int iwidth = input->size[dimw]; + int odepth = idepth + pfront + pback; + int oheight = iheight + ptop + pbottom; + int owidth = iwidth + pleft + pright; + THArgCheck(owidth >= 1 || oheight >= 1 || odepth >= 1, 2, + "input (D: %d H: %d, W: %d) is too small." + " Calculated output D: %d H: %d W: %d", + idepth, iheight, iwidth, odepth, oheight, owidth); + + if (gradOutput != NULL) { + THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, gradOutput), + 3, "output gradient tensor must fit into 32-bit index math"); + + THArgCheck(numPlanes == THCTensor_(size)(state, gradOutput, planeDim), 3, + "gradOutput width unexpected. Expected: %d, Got: %d", + numPlanes, THCTensor_(size)(state, gradOutput, planeDim)); + THArgCheck(owidth == THCTensor_(size)(state, gradOutput, dimw), 3, + "gradOutput width unexpected. Expected: %d, Got: %d", + owidth, THCTensor_(size)(state, gradOutput, dimw)); + THArgCheck(oheight == THCTensor_(size)(state, gradOutput, dimh), 3, + "gradOutput height unexpected. Expected: %d, Got: %d", + oheight, THCTensor_(size)(state, gradOutput, dimh)); + THArgCheck(odepth == THCTensor_(size)(state, gradOutput, dimd), 3, + "gradOutput depth unexpected. Expected: %d, Got: %d", + odepth, THCTensor_(size)(state, gradOutput, dimd)); + } +} + void THNN_(VolumetricReplicationPadding_updateOutput)( THCState *state, THCTensor *input, @@ -9,8 +65,9 @@ void THNN_(VolumetricReplicationPadding_updateOutput)( int pleft, int pright, int ptop, int pbottom, int pfront, int pback) { - THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, input), 2, - "input tensor must fit into 32-bit index math"); + THNN_(VolumetricReplicationPadding_shapeCheck)( + state, input, NULL, pleft, pright, ptop, + pbottom, pfront, pback); int planeDim = 0; int dimd = 1; @@ -19,8 +76,6 @@ void THNN_(VolumetricReplicationPadding_updateOutput)( int numBatch = 1; int numInputDims = THCTensor_(nDimension)(state, input); - THCUNN_argCheck(state, numInputDims == 4 || numInputDims == 5, 2, input, - "4D or 5D (batch mode) tensor expected for input, but got: %s"); if (numInputDims == 5) { numBatch = THCTensor_(size)(state, input, 0); @@ -30,18 +85,6 @@ void THNN_(VolumetricReplicationPadding_updateOutput)( dimw++; } - int idepth = input->size[dimd]; - int iheight = input->size[dimh]; - int iwidth = input->size[dimw]; - int odepth = idepth + pfront + pback; - int oheight = iheight + ptop + pbottom; - int owidth = iwidth + pleft + pright; - - THArgCheck(owidth >= 1 || oheight >= 1 || odepth >= 1, 2, - "input (D: %d H: %d, W: %d)is too small." - " Calculated output D: %d H: %d W: %d", - idepth, iheight, iwidth, odepth, oheight, owidth); - int numPlanes = THCTensor_(size)(state, input, planeDim); int inputD = THCTensor_(size)(state, input, dimd); int inputH = THCTensor_(size)(state, input, dimh); @@ -85,10 +128,9 @@ void THNN_(VolumetricReplicationPadding_updateGradInput)( int pleft, int pright, int ptop, int pbottom, int pfront, int pback) { - THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, input), 2, - "input tensor must fit into 32-bit index math"); - THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, gradOutput), - 3, "output gradient tensor must fit into 32-bit index math"); + THNN_(VolumetricReplicationPadding_shapeCheck)( + state, input, gradOutput, pleft, pright, ptop, + pbottom, pfront, pback); int planeDim = 0; int dimd = 1; @@ -103,22 +145,6 @@ void THNN_(VolumetricReplicationPadding_updateGradInput)( dimw++; } - int idepth = input->size[dimd]; - int iheight = input->size[dimh]; - int iwidth = input->size[dimw]; - int odepth = idepth + pfront + pback; - int oheight = iheight + ptop + pbottom; - int owidth = iwidth + pleft + pright; - - THArgCheck(owidth == THCTensor_(size)(state, gradOutput, dimw), 3, - "gradOutput width unexpected. Expected: %d, Got: %d", - owidth, THCTensor_(size)(state, gradOutput, dimw)); - THArgCheck(oheight == THCTensor_(size)(state, gradOutput, dimh), 3, - "gradOutput height unexpected. Expected: %d, Got: %d", - oheight, THCTensor_(size)(state, gradOutput, dimh)); - THArgCheck(odepth == THCTensor_(size)(state, gradOutput, dimd), 3, - "gradOutput depth unexpected. Expected: %d, Got: %d", - odepth, THCTensor_(size)(state, gradOutput, dimd)); THCTensor_(resizeAs)(state, gradInput, input); THCTensor_(zero)(state, gradInput); |