diff options
author | Lu Fang <lufang@fb.com> | 2017-07-30 21:22:02 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-08-26 21:44:59 +0300 |
commit | 283539e161cf73955de64ff9a596b9eb8358dfd2 (patch) | |
tree | 48de3dcedf0aa0dfa94e81f961ccec44da29a197 | |
parent | f4869232d90ecc3d022d9a5da303571f2bd3a72c (diff) |
Adding implicit padding for 3d average pooling
-rw-r--r-- | lib/THCUNN/VolumetricAveragePooling.cu | 180 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 10 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricAveragePooling.cu | 123 |
3 files changed, 226 insertions, 87 deletions
diff --git a/lib/THCUNN/VolumetricAveragePooling.cu b/lib/THCUNN/VolumetricAveragePooling.cu index f584dcf..979d370 100644 --- a/lib/THCUNN/VolumetricAveragePooling.cu +++ b/lib/THCUNN/VolumetricAveragePooling.cu @@ -9,8 +9,12 @@ template <typename Dtype, typename Acctype> __global__ void cuda_VolumetricAveragePooling_updateOutput( - THCDeviceTensor<Dtype, 4> input, THCDeviceTensor<Dtype, 4> output, - int kT, int kH, int kW, int dT, int dH, int dW, Acctype normFactor, int offsetZ) + THCDeviceTensor<Dtype, 4> input, + THCDeviceTensor<Dtype, 4> output, + int kT, int kH, int kW, + int dT, int dH, int dW, + int padT, int padH, int padW, + bool count_include_pad, int offsetZ) { int oCol = blockIdx.x * blockDim.x + threadIdx.x; int oRow = blockIdx.y * blockDim.y + threadIdx.y; @@ -21,32 +25,40 @@ __global__ void cuda_VolumetricAveragePooling_updateOutput( { Acctype sum = 0.0; - int iColumn = oCol * dW; - int iRow = oRow * dH; - int iFrame = oFrame * dT; + int tstart = oFrame * dT - padT; + int hstart = oRow * dH - padH; + int wstart = oCol * dW - padW; + int tend = min(tstart + kT, input.getSize(1) + padT); + int hend = min(hstart + kH, input.getSize(2) + padH); + int wend = min(wstart + kW, input.getSize(3) + padW); + int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); + tstart = max(tstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + tend = min(tend, input.getSize(1)); + hend = min(hend, input.getSize(2)); + wend = min(wend, input.getSize(3)); + + Acctype divide_factor; + if (count_include_pad) + divide_factor = static_cast<Acctype>(pool_size); + else + divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart)); - for (int frame = 0; frame < kT; ++frame) + int ti, hi, wi; + for (ti = tstart; ti < tend; ++ti) { - if (iFrame + frame < input.getSize(1)) + for (hi = hstart; hi < hend; ++hi) { - for (int row = 0; row < kH; ++row) + for (wi = wstart; wi < wend; ++wi) { - if (iRow + row < input.getSize(2)) - { - for (int column = 0; column < kW; ++column) - { - if (iColumn + column < input.getSize(3)) - { - Dtype val = input[slice][iFrame + frame][iRow + row][iColumn + column]; - sum += val; - } - } - } + Dtype val = input[slice][ti][hi][wi]; + sum += val; } } } - output[slice][oFrame][oRow][oCol] = ScalarConvert<Acctype, Dtype>::to(sum * normFactor); + output[slice][oFrame][oRow][oCol] = ScalarConvert<Acctype, Dtype>::to(sum / divide_factor); } } @@ -54,9 +66,13 @@ __global__ void cuda_VolumetricAveragePooling_updateOutput( // performance reasons. // template<int KERNEL_WIDTH, typename Dtype, typename Acctype> -__global__ void cuda_VolumetricAveragePooling_updateOutput( - THCDeviceTensor<Dtype, 4> input, THCDeviceTensor<Dtype, 4> output, - int kT, int kH, int dT, int dH, int dW, Acctype normFactor, int offsetZ) +__global__ void cuda_VolumetricAveragePooling_updateOutput_fixedKW( + THCDeviceTensor<Dtype, 4> input, + THCDeviceTensor<Dtype, 4> output, + int kT, int kH, + int dT, int dH, int dW, + int padT, int padH, int padW, + bool count_include_pad, int offsetZ) { int oCol = blockIdx.x * blockDim.x + threadIdx.x; int oRow = blockIdx.y * blockDim.y + threadIdx.y; @@ -67,45 +83,54 @@ __global__ void cuda_VolumetricAveragePooling_updateOutput( { Acctype sum = 0.0; - int iColumn = oCol * dW; - int iRow = oRow * dH; - int iFrame = oFrame * dT; + int tstart = oFrame * dT - padT; + int hstart = oRow * dH - padH; + int wstart = oCol * dW - padW; + int tend = min(tstart + kT, input.getSize(1) + padT); + int hend = min(hstart + kH, input.getSize(2) + padH); + int wend = min(wstart + KERNEL_WIDTH, input.getSize(3) + padW); + int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); + tstart = max(tstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + tend = min(tend, input.getSize(1)); + hend = min(hend, input.getSize(2)); + wend = min(wend, input.getSize(3)); - for (int frame = 0; frame < kT; ++frame) + Acctype divide_factor; + if (count_include_pad) + divide_factor = static_cast<Acctype>(pool_size); + else + divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart)); + + int ti, hi, wi; + for (ti = tstart; ti < tend; ++ti) { - if (iFrame + frame < input.getSize(1)) + for (hi = hstart; hi < hend; ++hi) { - for (int row = 0; row < kH; ++row) + for (wi = wstart; wi < wend; ++wi) { - if (iRow + row < input.getSize(2)) - { - for (int column = 0; column < KERNEL_WIDTH; ++column) - { - if (iColumn + column < input.getSize(3)) - { - Dtype val = input[slice][iFrame + frame][iRow + row][iColumn + column]; - sum += val; - } - } - } + Dtype val = input[slice][ti][hi][wi]; + sum += val; } } } - output[slice][oFrame][oRow][oCol] = ScalarConvert<Acctype, Dtype>::to(sum * normFactor); + output[slice][oFrame][oRow][oCol] = ScalarConvert<Acctype, Dtype>::to(sum / divide_factor); } } -#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \ - cuda_VolumetricAveragePooling_updateOutput<KW><<<grid, block>>>( \ - cudaInput, cudaOutput, kT, kH, dT, dH, dW, normFactor, offsetZ); \ +#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \ + cuda_VolumetricAveragePooling_updateOutput_fixedKW<KW, real, accreal><<<grid, block>>>( \ + cudaInput, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW, count_include_pad, offsetZ); \ break template <typename Dtype, typename Acctype> __global__ void cuda_VolumetricAveragePooling_updateGradInput_Stride1( THCDeviceTensor<Dtype, 4> gradOutput, THCDeviceTensor<Dtype, 4> gradInput, - int kT, int kH, int kW, Acctype normFactor, int offsetZ) + int kT, int kH, int kW, + Acctype normFactor, int offsetZ) { int iCol = blockIdx.x * blockDim.x + threadIdx.x; int iRow = blockIdx.y * blockDim.y + threadIdx.y; @@ -148,7 +173,10 @@ template <typename Dtype, typename Acctype> __global__ void cuda_VolumetricAveragePooling_updateGradInput_atomicAdd( THCDeviceTensor<Dtype, 4> gradOutput, THCDeviceTensor<Dtype, 4> gradInput, - int kT, int kH, int kW, int dT, int dH, int dW, int offsetZ) + int kT, int kH, int kW, + int dT, int dH, int dW, + int padT, int padH, int padW, + bool count_include_pad, int offsetZ) { int oCol = blockIdx.x * blockDim.x + threadIdx.x; int oRow = blockIdx.y * blockDim.y + threadIdx.y; @@ -158,13 +186,33 @@ __global__ void cuda_VolumetricAveragePooling_updateGradInput_atomicAdd( // guard against over-tiled threads if (oRow < gradOutput.getSize(2) && oCol < gradOutput.getSize(3)) { + int tstart = oFrame * dT - padT; + int hstart = oRow * dH - padH; + int wstart = oCol * dW - padW; + int tend = min(tstart + kT, gradInput.getSize(1) + padT); + int hend = min(hstart + kH, gradInput.getSize(2) + padH); + int wend = min(wstart + kW, gradInput.getSize(3) + padW); + int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); + tstart = max(tstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + tend = min(tend, gradInput.getSize(1)); + hend = min(hend, gradInput.getSize(2)); + wend = min(wend, gradInput.getSize(3)); + + Acctype divide_factor; + if (count_include_pad) + divide_factor = static_cast<Acctype>(pool_size); + else + divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart)); + Dtype val = ScalarConvert<Acctype, Dtype>::to( - ScalarConvert<Dtype, Acctype>::to(gradOutput[slice][oFrame][oRow][oCol]) / (kT * kH * kW)); - for (int iFrame = oFrame * dT; iFrame < oFrame * dT + kT; ++iFrame) + ScalarConvert<Dtype, Acctype>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor); + for (int iFrame = tstart; iFrame < tend; ++iFrame) { - for (int iRow = oRow * dH; iRow < oRow * dH + kH; ++iRow) + for (int iRow = hstart; iRow < hend; ++iRow) { - for (int iCol = oCol * dW; iCol < oCol * dW + kW; ++iCol) + for (int iCol = wstart; iCol < wend; ++iCol) { atomicAdd(&gradInput[slice][iFrame][iRow][iCol], val); } @@ -178,7 +226,9 @@ __global__ void cuda_VolumetricAveragePooling_updateGradInput( THCDeviceTensor<Dtype, 4> gradOutput, THCDeviceTensor<Dtype, 4> gradInput, int kT, int kH, int kW, - int dT, int dH, int dW, int offsetZ) + int dT, int dH, int dW, + int padT, int padH, int padW, + bool count_include_pad, int offsetZ) { int oCol = blockIdx.x * blockDim.x + threadIdx.x; int oRow = blockIdx.y * blockDim.y + threadIdx.y; @@ -188,13 +238,33 @@ __global__ void cuda_VolumetricAveragePooling_updateGradInput( // guard against over-tiled threads if (oRow < gradOutput.getSize(2) && oCol < gradOutput.getSize(3)) { + int tstart = oFrame * dT - padT; + int hstart = oRow * dH - padH; + int wstart = oCol * dW - padW; + int tend = min(tstart + kT, gradInput.getSize(1) + padT); + int hend = min(hstart + kH, gradInput.getSize(2) + padH); + int wend = min(wstart + kW, gradInput.getSize(3) + padW); + int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); + tstart = max(tstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + tend = min(tend, gradInput.getSize(1)); + hend = min(hend, gradInput.getSize(2)); + wend = min(wend, gradInput.getSize(3)); + + Acctype divide_factor; + if (count_include_pad) + divide_factor = static_cast<Acctype>(pool_size); + else + divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart)); + Dtype val = ScalarConvert<Acctype, Dtype>::to( - ScalarConvert<Dtype, Acctype>::to(gradOutput[slice][oFrame][oRow][oCol]) / (kT * kH * kW)); - for (int iFrame = oFrame * dT; iFrame < oFrame * dT + kT; ++iFrame) + ScalarConvert<Dtype, Acctype>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor); + for (int iFrame = tstart; iFrame < tend; ++iFrame) { - for (int iRow = oRow * dH; iRow < oRow * dH + kH; ++iRow) + for (int iRow = hstart; iRow < hend; ++iRow) { - for (int iCol = oCol * dW; iCol < oCol * dW + kW; ++iCol) + for (int iCol = wstart; iCol < wend; ++iCol) { gradInput[slice][iFrame][iRow][iCol] = val; } diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index 1a4464f..df186b1 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -1239,7 +1239,10 @@ TH_API void THNN_(VolumetricAveragePooling_updateOutput)( THCTensor *input, THCTensor *output, int kT, int kW, int kH, - int dT, int dW, int dH); + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceil_mode, + bool count_include_pad); TH_API void THNN_(VolumetricAveragePooling_updateGradInput)( THCState *state, @@ -1247,7 +1250,10 @@ TH_API void THNN_(VolumetricAveragePooling_updateGradInput)( THCTensor *gradOutput, THCTensor *gradInput, int kT, int kW, int kH, - int dT, int dW, int dH); + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceil_mode, + bool count_include_pad); TH_API void THNN_(VolumetricConvolution_updateOutput)( THCState *state, diff --git a/lib/THCUNN/generic/VolumetricAveragePooling.cu b/lib/THCUNN/generic/VolumetricAveragePooling.cu index 7a6c595..828a0f6 100644 --- a/lib/THCUNN/generic/VolumetricAveragePooling.cu +++ b/lib/THCUNN/generic/VolumetricAveragePooling.cu @@ -6,12 +6,11 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)( THCState *state, THCTensor *input, THCTensor *gradOutput, - int kT, - int kW, - int kH, - int dT, - int dW, - int dH) { + int kT, int kW, int kH, + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceil_mode) +{ int inputSlices; int inputTime; int inputHeight; @@ -66,11 +65,42 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)( THArgCheck(false, 2, "4D or 5D tensor expected, but got: %d", input->nDimension); } - int outputTime = (inputTime - kT) / dT + 1; - int outputHeight = (inputHeight - kH) / dH + 1; - int outputWidth = (inputWidth - kW) / dW + 1; + // The second argument is the index of padH. + THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 11, + "pad should not be greater than half of kernel size, but got " + "padT = %d, padW = %d, padH = %d, kT = %d, kW = %d, kH = %d", + padT, padW, padH, kT, kW, kH); + + int outputTime; + int outputHeight; + int outputWidth; + + if (ceil_mode) + { + outputTime = ceil(float(inputTime - kT + 2*padT) / float(dT)) + 1; + outputHeight = ceil(float(inputHeight - kH + 2*padH) / float(dH)) + 1; + outputWidth = ceil(float(inputWidth - kW + 2*padW) / float(dW)) + 1; + } + else + { + outputTime = floor(float(inputTime - kT + 2*padT) / float(dT)) + 1; + outputHeight = floor(float(inputHeight - kH + 2*padH) / float(dH)) + 1; + outputWidth = floor(float(inputWidth - kW + 2*padW) / float(dW)) + 1; + } + if (padT || padW || padH) + { + // ensure that the last pooling starts inside the image + // needed to avoid problems in ceil mode + if ((outputTime - 1)*dT >= inputTime + padT) + --outputTime; + if ((outputHeight - 1)*dH >= inputHeight + padH) + --outputHeight; + if ((outputWidth - 1)*dW >= inputWidth + padW) + --outputWidth; + } - if (gradOutput != NULL) { + if (gradOutput != NULL) + { THCUNN_check_dim_size(state, gradOutput, ndim, dimN, inputSlices); THCUNN_check_dim_size(state, gradOutput, ndim, dimt, outputTime); THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight); @@ -83,7 +113,10 @@ void THNN_(VolumetricAveragePooling_updateOutput)( THCTensor *input, THCTensor *output, int kT, int kW, int kH, - int dT, int dW, int dH) + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceil_mode, + bool count_include_pad) { int batchSize; int inputSlices; @@ -103,7 +136,8 @@ void THNN_(VolumetricAveragePooling_updateOutput)( } THNN_(VolumetricAveragePooling_shapeCheck) - (state, input, NULL, kT, kW, kH, dT, dW, dH); + (state, input, NULL, kT, kW, kH, dT, dW, dH, + padT, padW, padH, ceil_mode); if (THCTensor_(nDimension)(state, input) == 4) { @@ -124,9 +158,33 @@ void THNN_(VolumetricAveragePooling_updateOutput)( inputWidth = THCTensor_(size)(state, input, 4); } - int outputTime = (inputTime - kT) / dT + 1; - int outputHeight = (inputHeight - kH) / dH + 1; - int outputWidth = (inputWidth - kW) / dW + 1; + int outputTime; + int outputHeight; + int outputWidth; + + if (ceil_mode) + { + outputTime = ceil(float(inputTime - kT + 2*padT) / float(dT)) + 1; + outputHeight = ceil(float(inputHeight - kH + 2*padH) / float(dH)) + 1; + outputWidth = ceil(float(inputWidth - kW + 2*padW) / float(dW)) + 1; + } + else + { + outputTime = floor(float(inputTime - kT + 2*padT) / float(dT)) + 1; + outputHeight = floor(float(inputHeight - kH + 2*padH) / float(dH)) + 1; + outputWidth = floor(float(inputWidth - kW + 2*padW) / float(dW)) + 1; + } + if (padT || padH || padW) + { + // ensure that the last pooling starts inside the image + // needed to avoid problems in ceil mode + if ((outputTime - 1)*dT >= inputTime + padT) + --outputTime; + if ((outputHeight - 1)*dH >= inputHeight + padH) + --outputHeight; + if ((outputWidth - 1)*dW >= inputWidth + padW) + --outputWidth; + } if (input->nDimension == 4) /* 4D */ { @@ -164,7 +222,6 @@ void THNN_(VolumetricAveragePooling_updateOutput)( THCCeilDiv(outputHeight, static_cast<int>(block.y)), totalZ > 65535 ? 65535 : totalZ); - accreal normFactor = ScalarConvert<int, accreal>::to(1) / static_cast<accreal>(kT * kH * kW); switch (kW) { LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1); @@ -180,7 +237,8 @@ void THNN_(VolumetricAveragePooling_updateOutput)( cudaOutput, kT, kH, kW, dT, dH, dW, - normFactor, + padT, padH, padW, + count_include_pad, offsetZ ); break; @@ -198,11 +256,14 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( THCTensor *gradOutput, THCTensor *gradInput, int kT, int kW, int kH, - int dT, int dW, int dH) + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceil_mode, + bool count_include_pad) { - THNN_(VolumetricAveragePooling_shapeCheck) - (state, input, gradOutput, kT, kW, kH, dT, dW, dH); + (state, input, gradOutput, kT, kW, kH, dT, dW, dH, + padT, padW, padH, ceil_mode); bool kernelsOverlap = (dT < kT) || (dH < kH) || (dW < kW); // Resize and initialize result tensor. @@ -266,7 +327,8 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( // Optimizing for stride 1 is probably only of limited value, but this // specialization yields 3x speedup over the atomicAdd implementation. - if (dT == 1 && dH == 1 && dW == 1) + // Padding must be 0, otherwise, pool size may change. + if (dT == 1 && dH == 1 && dW == 1 && padT == 0 && padH == 0 && padW == 0) { int totalZ = inputTime * inputSlices * batchSize; int offsetZ = 0; @@ -286,20 +348,21 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( int totalZ = outputTime * inputSlices * batchSize; int offsetZ = 0; while (totalZ > 0) { - dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)), THCCeilDiv(outputHeight, static_cast<int>(block.y)), totalZ > 65535 ? 65535 : totalZ); if (kernelsOverlap) - { - cuda_VolumetricAveragePooling_updateGradInput_atomicAdd<real, accreal><<<grid, block>>>( - cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW, offsetZ); - } + { + cuda_VolumetricAveragePooling_updateGradInput_atomicAdd<real, accreal><<<grid, block>>>( + cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW, + padT, padH, padW, count_include_pad, offsetZ); + } else - { - cuda_VolumetricAveragePooling_updateGradInput<real, accreal><<<grid, block>>>( - cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW, offsetZ); - } + { + cuda_VolumetricAveragePooling_updateGradInput<real, accreal><<<grid, block>>>( + cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW, + padT, padH, padW, count_include_pad, offsetZ); + } THCudaCheck(cudaGetLastError()); totalZ -= 65535; offsetZ += 65535; |