diff options
Diffstat (limited to 'lib/THCUNN/generic/VolumetricAveragePooling.cu')
-rw-r--r-- | lib/THCUNN/generic/VolumetricAveragePooling.cu | 123 |
1 files changed, 93 insertions, 30 deletions
diff --git a/lib/THCUNN/generic/VolumetricAveragePooling.cu b/lib/THCUNN/generic/VolumetricAveragePooling.cu index 7a6c595..828a0f6 100644 --- a/lib/THCUNN/generic/VolumetricAveragePooling.cu +++ b/lib/THCUNN/generic/VolumetricAveragePooling.cu @@ -6,12 +6,11 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)( THCState *state, THCTensor *input, THCTensor *gradOutput, - int kT, - int kW, - int kH, - int dT, - int dW, - int dH) { + int kT, int kW, int kH, + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceil_mode) +{ int inputSlices; int inputTime; int inputHeight; @@ -66,11 +65,42 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)( THArgCheck(false, 2, "4D or 5D tensor expected, but got: %d", input->nDimension); } - int outputTime = (inputTime - kT) / dT + 1; - int outputHeight = (inputHeight - kH) / dH + 1; - int outputWidth = (inputWidth - kW) / dW + 1; + // The second argument is the index of padH. + THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 11, + "pad should not be greater than half of kernel size, but got " + "padT = %d, padW = %d, padH = %d, kT = %d, kW = %d, kH = %d", + padT, padW, padH, kT, kW, kH); + + int outputTime; + int outputHeight; + int outputWidth; + + if (ceil_mode) + { + outputTime = ceil(float(inputTime - kT + 2*padT) / float(dT)) + 1; + outputHeight = ceil(float(inputHeight - kH + 2*padH) / float(dH)) + 1; + outputWidth = ceil(float(inputWidth - kW + 2*padW) / float(dW)) + 1; + } + else + { + outputTime = floor(float(inputTime - kT + 2*padT) / float(dT)) + 1; + outputHeight = floor(float(inputHeight - kH + 2*padH) / float(dH)) + 1; + outputWidth = floor(float(inputWidth - kW + 2*padW) / float(dW)) + 1; + } + if (padT || padW || padH) + { + // ensure that the last pooling starts inside the image + // needed to avoid problems in ceil mode + if ((outputTime - 1)*dT >= inputTime + padT) + --outputTime; + if ((outputHeight - 1)*dH >= inputHeight + padH) + --outputHeight; + if ((outputWidth - 1)*dW >= inputWidth + padW) + --outputWidth; + } - if (gradOutput != NULL) { + if (gradOutput != NULL) + { THCUNN_check_dim_size(state, gradOutput, ndim, dimN, inputSlices); THCUNN_check_dim_size(state, gradOutput, ndim, dimt, outputTime); THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight); @@ -83,7 +113,10 @@ void THNN_(VolumetricAveragePooling_updateOutput)( THCTensor *input, THCTensor *output, int kT, int kW, int kH, - int dT, int dW, int dH) + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceil_mode, + bool count_include_pad) { int batchSize; int inputSlices; @@ -103,7 +136,8 @@ void THNN_(VolumetricAveragePooling_updateOutput)( } THNN_(VolumetricAveragePooling_shapeCheck) - (state, input, NULL, kT, kW, kH, dT, dW, dH); + (state, input, NULL, kT, kW, kH, dT, dW, dH, + padT, padW, padH, ceil_mode); if (THCTensor_(nDimension)(state, input) == 4) { @@ -124,9 +158,33 @@ void THNN_(VolumetricAveragePooling_updateOutput)( inputWidth = THCTensor_(size)(state, input, 4); } - int outputTime = (inputTime - kT) / dT + 1; - int outputHeight = (inputHeight - kH) / dH + 1; - int outputWidth = (inputWidth - kW) / dW + 1; + int outputTime; + int outputHeight; + int outputWidth; + + if (ceil_mode) + { + outputTime = ceil(float(inputTime - kT + 2*padT) / float(dT)) + 1; + outputHeight = ceil(float(inputHeight - kH + 2*padH) / float(dH)) + 1; + outputWidth = ceil(float(inputWidth - kW + 2*padW) / float(dW)) + 1; + } + else + { + outputTime = floor(float(inputTime - kT + 2*padT) / float(dT)) + 1; + outputHeight = floor(float(inputHeight - kH + 2*padH) / float(dH)) + 1; + outputWidth = floor(float(inputWidth - kW + 2*padW) / float(dW)) + 1; + } + if (padT || padH || padW) + { + // ensure that the last pooling starts inside the image + // needed to avoid problems in ceil mode + if ((outputTime - 1)*dT >= inputTime + padT) + --outputTime; + if ((outputHeight - 1)*dH >= inputHeight + padH) + --outputHeight; + if ((outputWidth - 1)*dW >= inputWidth + padW) + --outputWidth; + } if (input->nDimension == 4) /* 4D */ { @@ -164,7 +222,6 @@ void THNN_(VolumetricAveragePooling_updateOutput)( THCCeilDiv(outputHeight, static_cast<int>(block.y)), totalZ > 65535 ? 65535 : totalZ); - accreal normFactor = ScalarConvert<int, accreal>::to(1) / static_cast<accreal>(kT * kH * kW); switch (kW) { LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1); @@ -180,7 +237,8 @@ void THNN_(VolumetricAveragePooling_updateOutput)( cudaOutput, kT, kH, kW, dT, dH, dW, - normFactor, + padT, padH, padW, + count_include_pad, offsetZ ); break; @@ -198,11 +256,14 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( THCTensor *gradOutput, THCTensor *gradInput, int kT, int kW, int kH, - int dT, int dW, int dH) + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceil_mode, + bool count_include_pad) { - THNN_(VolumetricAveragePooling_shapeCheck) - (state, input, gradOutput, kT, kW, kH, dT, dW, dH); + (state, input, gradOutput, kT, kW, kH, dT, dW, dH, + padT, padW, padH, ceil_mode); bool kernelsOverlap = (dT < kT) || (dH < kH) || (dW < kW); // Resize and initialize result tensor. @@ -266,7 +327,8 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( // Optimizing for stride 1 is probably only of limited value, but this // specialization yields 3x speedup over the atomicAdd implementation. - if (dT == 1 && dH == 1 && dW == 1) + // Padding must be 0, otherwise, pool size may change. + if (dT == 1 && dH == 1 && dW == 1 && padT == 0 && padH == 0 && padW == 0) { int totalZ = inputTime * inputSlices * batchSize; int offsetZ = 0; @@ -286,20 +348,21 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( int totalZ = outputTime * inputSlices * batchSize; int offsetZ = 0; while (totalZ > 0) { - dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)), THCCeilDiv(outputHeight, static_cast<int>(block.y)), totalZ > 65535 ? 65535 : totalZ); if (kernelsOverlap) - { - cuda_VolumetricAveragePooling_updateGradInput_atomicAdd<real, accreal><<<grid, block>>>( - cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW, offsetZ); - } + { + cuda_VolumetricAveragePooling_updateGradInput_atomicAdd<real, accreal><<<grid, block>>>( + cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW, + padT, padH, padW, count_include_pad, offsetZ); + } else - { - cuda_VolumetricAveragePooling_updateGradInput<real, accreal><<<grid, block>>>( - cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW, offsetZ); - } + { + cuda_VolumetricAveragePooling_updateGradInput<real, accreal><<<grid, block>>>( + cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW, + padT, padH, padW, count_include_pad, offsetZ); + } THCudaCheck(cudaGetLastError()); totalZ -= 65535; offsetZ += 65535; |