diff options
author | Lu Fang <lufang@fb.com> | 2017-07-30 21:22:02 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-08-26 21:44:49 +0300 |
commit | 6cfd1dd69bc95813439596724f60513e269d6537 (patch) | |
tree | 0fb0f2fc8ae47a2a53793dc770e290e533b935de | |
parent | 5dea92d257258f5437de5544960463666ca51bae (diff) |
Adding implicit padding for 3d average pooling
-rw-r--r-- | lib/THNN/generic/THNN.h | 8 | ||||
-rw-r--r-- | lib/THNN/generic/VolumetricAveragePooling.c | 221 |
2 files changed, 180 insertions, 49 deletions
diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h index 37b094b..2c4aabf 100644 --- a/lib/THNN/generic/THNN.h +++ b/lib/THNN/generic/THNN.h @@ -1249,14 +1249,18 @@ TH_API void THNN_(VolumetricAveragePooling_updateOutput)( THTensor *input, THTensor *output, int kT, int kW, int kH, - int dT, int dW, int dH); + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceil_mode, bool count_include_pad); TH_API void THNN_(VolumetricAveragePooling_updateGradInput)( THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, int kT, int kW, int kH, - int dT, int dW, int dH); + int dT, int dW, int dH, + int padT, int padW, int padH, + bool ceil_mode, bool count_include_pad); TH_API void THNN_(VolumetricConvolution_updateOutput)( THNNState *state, diff --git a/lib/THNN/generic/VolumetricAveragePooling.c b/lib/THNN/generic/VolumetricAveragePooling.c index 91c870e..2c69305 100644 --- a/lib/THNN/generic/VolumetricAveragePooling.c +++ b/lib/THNN/generic/VolumetricAveragePooling.c @@ -11,7 +11,12 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)( int kH, int dT, int dW, - int dH) { + int dH, + int padT, + int padW, + int padH, + bool ceil_mode) +{ long nslices; long itime; long iheight; @@ -49,14 +54,46 @@ static inline void THNN_(VolumetricAveragePooling_shapeCheck)( input->size[dimt], input->size[dimh], input->size[dimw], kT, kH, kW); + // The second argument is argNumber... here is the index of padH. + THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 11, + "pad should not be greater than half of kernel size, but got " + "padT = %d, padW = %d, padH = %d, kT = %d, kW = %d, kH = %d", + padT, padW, padH, kT, kW, kH); + /* sizes */ nslices = input->size[dimN]; itime = input->size[dimt]; iheight = input->size[dimh]; iwidth = input->size[dimw]; - otime = (itime - kT) / dT + 1; - oheight = (iheight - kH) / dH + 1; - owidth = (iwidth - kW) / dW + 1; + + if (ceil_mode) { + otime = (long)(ceil((float)(itime - kT + 2*padT) / dT)) + 1; + oheight = (long)(ceil((float)(iheight - kH + 2*padH) / dH)) + 1; + owidth = (long)(ceil((float)(iwidth - kW + 2*padW) / dW)) + 1; + } + else + { + otime = (long)(floor((float)(itime - kT + 2*padT) / dT)) + 1; + oheight = (long)(floor((float)(iheight - kH + 2*padH) / dH)) + 1; + owidth = (long)(floor((float)(iwidth - kW + 2*padW) / dW)) + 1; + } + + if (padT || padW || padH) + { + // ensure that the last pooling starts inside the image + // needed to avoid problems in ceil mode + if ((otime - 1)*dT >= itime + padT) + --otime; + if ((oheight - 1)*dH >= iheight + padH) + --oheight; + if ((owidth - 1)*dW >= iwidth + padW) + --owidth; + } + + if (otime < 1 || owidth < 1 || oheight < 1) + THError("Given input size: (%dx%dx%dx%d). " + "Calculated output size: (%dx%dx%dx%d). Output size is too small", + nslices,itime,iheight,iwidth,nslices,otime,oheight,owidth); if (gradOutput != NULL) { THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimN, nslices); @@ -81,35 +118,61 @@ static void THNN_(VolumetricAveragePooling_updateOutput_frame)( int kH, int dT, int dW, - int dH) + int dH, + int padT, + int padW, + int padH, + bool count_include_pad) { long k; #pragma omp parallel for private(k) for (k = 0; k < nslices; k++) { - /* loop over output */ long i, j, ti; + + /* local pointers. */ + real *ip = input_p + k * itime * iwidth * iheight; + real *op = output_p + k * otime * owidth * oheight; + for (i = 0; i < otime * oheight * owidth; ++i) + *(op + i) = 0; + + /* loop over output */ for (ti = 0; ti < otime; ti++) { for (i = 0; i < oheight; i++) { for (j = 0; j < owidth; j++) { - /* local pointers */ - real *ip = input_p + k * itime * iwidth * iheight - + ti * iwidth * iheight * dT + i * iwidth * dH + j * dW; - real *op = output_p + k * otime * owidth * oheight - + ti * owidth * oheight + i * owidth + j; + /* compute pool range. */ + long tstart = ti * dT - padT; + long hstart = i * dH - padH; + long wstart = j * dW - padW; + long tend = fminf(tstart + kT, itime + padT); + long hend = fminf(hstart + kH, iheight + padH); + long wend = fminf(wstart + kW, iwidth + padW); + long pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); + tstart = fmaxf(tstart, 0); + hstart = fmaxf(hstart, 0); + wstart = fmaxf(wstart, 0); + tend = fmin(tend, itime); + hend = fmin(hend, iheight); + wend = fmin(wend, iwidth); + + int divide_factor; + if (count_include_pad) + divide_factor = pool_size; + else + divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart); /* compute local sum: */ real sum = 0.0; - int x, y, z; + long x, y, z; - for (z=0; z < kT; z++) + for (z = tstart; z < tend; z++) { - for (y = 0; y < kH; y++) + for (y = hstart; y < hend; y++) { - for (x = 0; x < kW; x++) + for (x = wstart; x < wend; x++) { sum += *(ip + z * iwidth * iheight + y * iwidth + x); } @@ -117,7 +180,7 @@ static void THNN_(VolumetricAveragePooling_updateOutput_frame)( } /* set output to local max */ - *op = sum / (kT * kW * kH); + *op++ += sum / divide_factor; } } } @@ -133,7 +196,12 @@ void THNN_(VolumetricAveragePooling_updateOutput)( int kH, int dT, int dW, - int dH) + int dH, + int padT, + int padW, + int padH, + bool ceil_mode, + bool count_include_pad) { long nslices; long itime; @@ -147,7 +215,7 @@ void THNN_(VolumetricAveragePooling_updateOutput)( THNN_(VolumetricAveragePooling_shapeCheck)( state, input, NULL, kT, kW, kH, - dT, dW, dH); + dT, dW, dH, padT, padW, padH, ceil_mode); int dimN = 0; int dimt = 1; @@ -167,9 +235,29 @@ void THNN_(VolumetricAveragePooling_updateOutput)( itime = input->size[dimt]; iheight = input->size[dimh]; iwidth = input->size[dimw]; - otime = (itime - kT) / dT + 1; - oheight = (iheight - kH) / dH + 1; - owidth = (iwidth - kW) / dW + 1; + if (ceil_mode) + { + otime = (long)(ceil((float)(itime - kT + 2*padT) / dT)) + 1; + oheight = (long)(ceil((float)(iheight - kH + 2*padH) / dH)) + 1; + owidth = (long)(ceil((float)(iwidth - kW + 2*padW) / dW)) + 1; + } + else + { + otime = (long)(floor((float)(itime - kT + 2*padT) / dT)) + 1; + oheight = (long)(floor((float)(iheight - kH + 2*padH) / dH)) + 1; + owidth = (long)(floor((float)(iwidth - kW + 2*padW) / dW)) + 1; + } + if (padT || padH || padW) + { + // ensure that the last pooling starts inside the image + // needed to avoid problems in ceil mode + if ((otime - 1)*dT >= itime + padT) + --otime; + if ((oheight - 1)*dH >= iheight + padH) + --oheight; + if ((owidth - 1)*dW >= iwidth + padW) + --owidth; + } /* get contiguous input */ input = THTensor_(newContiguous)(input); @@ -187,7 +275,9 @@ void THNN_(VolumetricAveragePooling_updateOutput)( itime, iwidth, iheight, otime, owidth, oheight, kT, kW, kH, - dT, dW, dH + dT, dW, dH, + padT, padW, padH, + count_include_pad ); } else /* batch mode */ @@ -212,7 +302,9 @@ void THNN_(VolumetricAveragePooling_updateOutput)( itime, iwidth, iheight, otime, owidth, oheight, kT, kW, kH, - dT, dW, dH + dT, dW, dH, + padT, padW, padH, + count_include_pad ); } } @@ -236,36 +328,62 @@ static void THNN_(VolumetricAveragePooling_updateGradInput_frame)( int kH, int dT, int dW, - int dH) + int dH, + int padT, + int padW, + int padH, + bool count_include_pad) { long k; #pragma omp parallel for private(k) for (k = 0; k < nslices; k++) { - /* loop over output */ long i, j, ti; + + /* local pointers */ + real *ip = gradInput_p + k * itime * iwidth * iheight; + real *op = gradOutput_p + k * otime * owidth * oheight; + for (i = 0; i < itime*iwidth*iheight; i++) + *(ip + i) = 0; + + /* loop over output */ for (ti = 0; ti < otime; ti++) { for (i = 0; i < oheight; i++) { for (j = 0; j < owidth; j++) { - /* local pointers */ - real *ip = gradInput_p + k * itime * iwidth * iheight - + ti * iwidth * iheight * dT + i * iwidth * dH + j * dW; - real *op = gradOutput_p + k * otime * owidth * oheight - + ti * owidth * oheight + i * owidth + j; + long tstart = ti * dT - padT; + long hstart = i * dH - padH; + long wstart = j * dW - padW; + long tend = fminf(tstart + kT, itime + padT); + long hend = fminf(hstart + kH, iheight + padH); + long wend = fminf(wstart + kW, iwidth + padW); + long pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart); + tstart = fmaxf(tstart, 0); + hstart = fmaxf(hstart, 0); + wstart = fmaxf(wstart, 0); + tend = fminf(tend, itime); + hend = fminf(hend, iheight); + wend = fminf(wend, iwidth); + + long divide_factor; + if (count_include_pad) + divide_factor = pool_size; + else + divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart); /* scatter gradients out to footprint: */ - real val = *op / (kT * kW * kH); - int x,y,z; - for (z=0; z < kT; z++) + real val = *op++; + + long x,y,z; + for (z = tstart; z < tend; z++) { - for (y = 0; y < kH; y++) + for (y = hstart; y < hend; y++) { - for (x = 0; x < kW; x++) + for (x = wstart; x < wend; x++) { - *(ip + z * iwidth * iheight + y * iwidth + x) += val; + *(ip + z * iheight * iwidth + y * iwidth + x) += val / divide_factor; } } } @@ -285,15 +403,20 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( int kH, int dT, int dW, - int dH) + int dH, + int padT, + int padW, + int padH, + bool ceil_mode, + bool count_include_pad) { - int nslices; - int itime; - int iheight; - int iwidth; - int otime; - int oheight; - int owidth; + long nslices; + long itime; + long iheight; + long iwidth; + long otime; + long oheight; + long owidth; real *gradInput_data; real *gradOutput_data; @@ -304,7 +427,7 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( THNN_(VolumetricAveragePooling_shapeCheck)( state, input, gradOutput, kT, kW, kH, - dT, dW, dH); + dT, dW, dH, padT, padW, padH, ceil_mode); /* get contiguous gradOutput */ gradOutput = THTensor_(newContiguous)(gradOutput); @@ -342,7 +465,9 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( itime, iwidth, iheight, otime, owidth, oheight, kT, kW, kH, - dT, dW, dH + dT, dW, dH, + padT, padW, padH, + count_include_pad ); } else /* batch mode */ @@ -361,7 +486,9 @@ void THNN_(VolumetricAveragePooling_updateGradInput)( itime, iwidth, iheight, otime, owidth, oheight, kT, kW, kH, - dT, dW, dH + dT, dW, dH, + padT, padW, padH, + count_include_pad ); } } |