diff options
-rw-r--r-- | SpatialMaxPooling.lua | 30 | ||||
-rwxr-xr-x | doc/convolution.md | 13 | ||||
-rw-r--r-- | generic/SpatialMaxPooling.c | 93 | ||||
-rw-r--r-- | test.lua | 60 |
4 files changed, 132 insertions, 64 deletions
diff --git a/SpatialMaxPooling.lua b/SpatialMaxPooling.lua index 7be85d5..995c213 100644 --- a/SpatialMaxPooling.lua +++ b/SpatialMaxPooling.lua @@ -1,6 +1,6 @@ local SpatialMaxPooling, parent = torch.class('nn.SpatialMaxPooling', 'nn.Module') -function SpatialMaxPooling:__init(kW, kH, dW, dH) +function SpatialMaxPooling:__init(kW, kH, dW, dH, padW, padH) parent.__init(self) dW = dW or kW @@ -11,10 +11,28 @@ function SpatialMaxPooling:__init(kW, kH, dW, dH) self.dW = dW self.dH = dH + self.padW = padW or 0 + self.padH = padH or 0 + + self.ceil_mode = false self.indices = torch.Tensor() end +function SpatialMaxPooling:ceil() + self.ceil_mode = true + return self +end + +function SpatialMaxPooling:floor() + self.ceil_mode = false + return self +end + function SpatialMaxPooling:updateOutput(input) + -- backward compatibility + self.ceil_mode = self.ceil_mode or false + self.padW = self.padW or 0 + self.padH = self.padH or 0 input.nn.SpatialMaxPooling_updateOutput(self, input) return self.output end @@ -34,6 +52,12 @@ function SpatialMaxPooling:empty() end function SpatialMaxPooling:__tostring__() - return string.format('%s(%d,%d,%d,%d)', torch.type(self), - self.kW, self.kH, self.dW, self.dH) + local s = string.format('%s(%d,%d,%d,%d', torch.type(self), + self.kW, self.kH, self.dW, self.dH) + if (self.padW or self.padH) and (self.padW ~= 0 or self.padH ~= 0) then + s = s .. ',' .. self.padW .. ','.. self.padH + end + s = s .. ')' + + return s end diff --git a/doc/convolution.md b/doc/convolution.md index 18da026..8d9e77b 100755 --- a/doc/convolution.md +++ b/doc/convolution.md @@ -361,13 +361,24 @@ Computes the `p` norm in a convolutional manner on a set of 2D input planes. ### SpatialMaxPooling ### ```lua -module = nn.SpatialMaxPooling(kW, kH [, dW, dH]) +module = nn.SpatialMaxPooling(kW, kH [, dW, dH, padW, padH]) ``` Applies 2D max-pooling operation in `kWxkH` regions by step size `dWxdH` steps. The number of output features is equal to the number of input planes. +If the input image is a 3D tensor `nInputPlane x height x width`, the output +image size will be `nOutputPlane x oheight x owidth` where + +```lua +owidth = op((width + 2*padW - kW) / dW + 1) +oheight = op((height + 2*padH - kH) / dH + 1) +``` + +`op` is a rounding operator. By default, it is `floor`. It can be changed +by calling `:ceil()` or `:floor()` methods. + <a name="nn.SpatialAveragePooling"/> ### SpatialAveragePooling ### diff --git a/generic/SpatialMaxPooling.c b/generic/SpatialMaxPooling.c index 8dd04c9..ef6f554 100644 --- a/generic/SpatialMaxPooling.c +++ b/generic/SpatialMaxPooling.c @@ -3,11 +3,12 @@ #else static void nn_(SpatialMaxPooling_updateOutput_frame)(real *input_p, real *output_p, - real *indx_p, real *indy_p, + real *ind_p, long nslices, long iwidth, long iheight, long owidth, long oheight, - int kW, int kH, int dW, int dH) + int kW, int kH, int dW, int dH, + int padW, int padH) { long k; #pragma omp parallel for private(k) @@ -15,41 +16,46 @@ static void nn_(SpatialMaxPooling_updateOutput_frame)(real *input_p, real *outpu { /* loop over output */ long i, j; + real *ip = input_p + k*iwidth*iheight; for(i = 0; i < oheight; i++) { for(j = 0; j < owidth; j++) { + long hstart = i * dH - padH; + long wstart = j * dW - padW; + long hend = fminf(hstart + kH, iheight); + long wend = fminf(wstart + kW, iwidth); + hstart = fmaxf(hstart, 0); + wstart = fmaxf(wstart, 0); + /* local pointers */ - real *ip = input_p + k*iwidth*iheight + i*iwidth*dH + j*dW; real *op = output_p + k*owidth*oheight + i*owidth + j; - real *indyp = indy_p + k*owidth*oheight + i*owidth + j; - real *indxp = indx_p + k*owidth*oheight + i*owidth + j; + real *indp = ind_p + k*owidth*oheight + i*owidth + j; /* compute local max: */ long maxindex = -1; real maxval = -THInf; long tcntr = 0; - int x,y; - for(y = 0; y < kH; y++) + long x,y; + for(y = hstart; y < hend; y++) { - for(x = 0; x < kW; x++) + for(x = wstart; x < wend; x++) { - real val = *(ip + y*iwidth + x); + tcntr = y*iwidth + x; + real val = *(ip + tcntr); if (val > maxval) { maxval = val; maxindex = tcntr; } - tcntr++; } } /* set output to local max */ *op = maxval; - /* store location of max (x,y) */ - *indyp = (int)(maxindex / kW)+1; - *indxp = (maxindex % kW) +1; + /* store location of max */ + *indp = maxindex + 1; } } } @@ -62,6 +68,9 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) int kH = luaT_getfieldcheckint(L, 1, "kH"); int dW = luaT_getfieldcheckint(L, 1, "dW"); int dH = luaT_getfieldcheckint(L, 1, "dH"); + int padW = luaT_getfieldcheckint(L, 1, "padW"); + int padH = luaT_getfieldcheckint(L, 1, "padH"); + int ceil_mode = luaT_getfieldcheckboolean(L,1,"ceil_mode"); THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_Tensor); THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor); int dimw = 2; @@ -85,14 +94,33 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) dimw++; dimh++; } - luaL_argcheck(L, input->size[dimw] >= kW && input->size[dimh] >= kH, 2, "input image smaller than kernel size"); + luaL_argcheck(L, input->size[dimw] >= kW - padW && input->size[dimh] >= kH - padH, 2, "input image smaller than kernel size"); + + luaL_argcheck(L, kW/2 >= padW && kH/2 >= padH, 2, "pad should be smaller than half of kernel size"); /* sizes */ nslices = input->size[dimh-1]; iheight = input->size[dimh]; iwidth = input->size[dimw]; - oheight = (iheight - kH) / dH + 1; - owidth = (iwidth - kW) / dW + 1; + if (ceil_mode) + { + oheight = (long)(ceil((float)(iheight - kH + 2*padH) / dH)) + 1; + owidth = (long)(ceil((float)(iwidth - kW + 2*padW) / dW)) + 1; + } + else + { + oheight = (long)(floor((float)(iheight - kH + 2*padH) / dH)) + 1; + owidth = (long)(floor((float)(iwidth - kW + 2*padW) / dW)) + 1; + } + + if (padW || padH) + { + // ensure that the last pooling starts inside the image + if ((oheight - 1)*dH >= iheight + padH) + --oheight; + if ((owidth - 1)*dW >= iwidth + padW) + --owidth; + } /* get contiguous input */ input = THTensor_(newContiguous)(input); @@ -101,27 +129,28 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) if (input->nDimension == 3) { THTensor_(resize3d)(output, nslices, oheight, owidth); - /* indices will contain i,j locations for each output point */ - THTensor_(resize4d)(indices, 2, nslices, oheight, owidth); + /* indices will contain the locations for each output point */ + THTensor_(resize3d)(indices, nslices, oheight, owidth); input_data = THTensor_(data)(input); output_data = THTensor_(data)(output); indices_data = THTensor_(data)(indices); nn_(SpatialMaxPooling_updateOutput_frame)(input_data, output_data, - indices_data+nslices*owidth*oheight, indices_data, + indices_data, nslices, iwidth, iheight, owidth, oheight, - kW, kH, dW, dH); + kW, kH, dW, dH, + padW, padH); } else { long p; THTensor_(resize4d)(output, nbatch, nslices, oheight, owidth); - /* indices will contain i,j locations for each output point */ - THTensor_(resize5d)(indices, 2, nbatch, nslices, oheight, owidth); + /* indices will contain the locations for each output point */ + THTensor_(resize4d)(indices, nbatch, nslices, oheight, owidth); input_data = THTensor_(data)(input); output_data = THTensor_(data)(output); @@ -131,11 +160,12 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) for (p = 0; p < nbatch; p++) { nn_(SpatialMaxPooling_updateOutput_frame)(input_data+p*nslices*iwidth*iheight, output_data+p*nslices*owidth*oheight, - indices_data+(p+nbatch)*nslices*owidth*oheight, indices_data+p*nslices*owidth*oheight, + indices_data+p*nslices*owidth*oheight, nslices, iwidth, iheight, owidth, oheight, - kW, kH, dW, dH); + kW, kH, dW, dH, + padW, padH); } } @@ -145,7 +175,7 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) } static void nn_(SpatialMaxPooling_updateGradInput_frame)(real *gradInput_p, real *gradOutput_p, - real *indx_p, real *indy_p, + real *ind_p, long nslices, long iwidth, long iheight, long owidth, long oheight, @@ -157,8 +187,7 @@ static void nn_(SpatialMaxPooling_updateGradInput_frame)(real *gradInput_p, real { real *gradInput_p_k = gradInput_p + k*iwidth*iheight; real *gradOutput_p_k = gradOutput_p + k*owidth*oheight; - real *indx_p_k = indx_p + k*owidth*oheight; - real *indy_p_k = indy_p + k*owidth*oheight; + real *ind_p_k = ind_p + k*owidth*oheight; /* calculate max points */ long i, j; @@ -167,11 +196,9 @@ static void nn_(SpatialMaxPooling_updateGradInput_frame)(real *gradInput_p, real for(j = 0; j < owidth; j++) { /* retrieve position of max */ - long maxi = indy_p_k[i*owidth + j] - 1 + i*dH; - long maxj = indx_p_k[i*owidth + j] - 1 + j*dW; - + long maxp = ind_p_k[i*owidth + j] - 1; /* update gradient */ - gradInput_p_k[maxi*iwidth + maxj] += gradOutput_p_k[i*owidth + j]; + gradInput_p_k[maxp] += gradOutput_p_k[i*owidth + j]; } } } @@ -226,7 +253,7 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L) if (input->nDimension == 3) { nn_(SpatialMaxPooling_updateGradInput_frame)(gradInput_data, gradOutput_data, - indices_data+nslices*owidth*oheight, indices_data, + indices_data, nslices, iwidth, iheight, owidth, oheight, @@ -239,7 +266,7 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L) for (p = 0; p < nbatch; p++) { nn_(SpatialMaxPooling_updateGradInput_frame)(gradInput_data+p*nslices*iwidth*iheight, gradOutput_data+p*nslices*owidth*oheight, - indices_data+(p+nbatch)*nslices*owidth*oheight, indices_data+p*nslices*owidth*oheight, + indices_data+p*nslices*owidth*oheight, nslices, iwidth, iheight, owidth, oheight, @@ -1926,38 +1926,44 @@ function nntest.SpatialSubSampling() end function nntest.SpatialMaxPooling() - local from = math.random(1,5) - local ki = math.random(1,4) - local kj = math.random(1,4) - local si = math.random(1,3) - local sj = math.random(1,3) - local outi = math.random(4,5) - local outj = math.random(4,5) - local ini = (outi-1)*si+ki - local inj = (outj-1)*sj+kj - - local module = nn.SpatialMaxPooling(ki,kj,si,sj) - local input = torch.rand(from,ini,inj) + for _,ceil_mode in pairs({true,false}) do + local from = math.random(1,5) + local ki = math.random(1,4) + local kj = math.random(1,4) + local si = math.random(1,3) + local sj = math.random(1,3) + local outi = math.random(4,5) + local outj = math.random(4,5) + local padW = math.min(math.random(0,1),math.floor(ki/2)) + local padH = math.min(math.random(0,1),math.floor(kj/2)) + local ini = (outi-1)*si+ki-2*padW + local inj = (outj-1)*sj+kj-2*padH + + local ceil_string = ceil_mode and 'ceil' or 'floor' + local module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH) + if ceil_mode then module:ceil() else module:floor() end + local input = torch.rand(from,inj,ini) - local err = jac.testJacobian(module, input) - mytester:assertlt(err, precision, 'error on state ') - - local ferr, berr = jac.testIO(module, input) - mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') - mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state ') - -- batch - local nbatch = math.random(2,5) - input = torch.rand(nbatch,from,ini,inj) - module = nn.SpatialMaxPooling(ki,kj,si,sj) + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') - local err = jac.testJacobian(module, input) - mytester:assertlt(err, precision, 'error on state (Batch) ') + -- batch + local nbatch = math.random(2,5) + input = torch.rand(nbatch,from,inj,ini) + module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH) + if ceil_mode then module:ceil() else module:floor() end - local ferr, berr = jac.testIO(module, input) - mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') - mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state (Batch)') + local ferr, berr = jac.testIO(module, input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') + end end function nntest.SpatialAveragePooling() |