diff options
Diffstat (limited to 'generic/SpatialMaxPooling.c')
-rw-r--r-- | generic/SpatialMaxPooling.c | 109 |
1 files changed, 64 insertions, 45 deletions
diff --git a/generic/SpatialMaxPooling.c b/generic/SpatialMaxPooling.c index 7faa0ee..8dd04c9 100644 --- a/generic/SpatialMaxPooling.c +++ b/generic/SpatialMaxPooling.c @@ -13,19 +13,19 @@ static void nn_(SpatialMaxPooling_updateOutput_frame)(real *input_p, real *outpu #pragma omp parallel for private(k) for (k = 0; k < nslices; k++) { - // loop over output + /* loop over output */ long i, j; for(i = 0; i < oheight; i++) { for(j = 0; j < owidth; j++) { - // local pointers + /* local pointers */ real *ip = input_p + k*iwidth*iheight + i*iwidth*dH + j*dW; real *op = output_p + k*owidth*oheight + i*owidth + j; real *indyp = indy_p + k*owidth*oheight + i*owidth + j; real *indxp = indx_p + k*owidth*oheight + i*owidth + j; - // compute local max: + /* compute local max: */ long maxindex = -1; real maxval = -THInf; long tcntr = 0; @@ -44,10 +44,10 @@ static void nn_(SpatialMaxPooling_updateOutput_frame)(real *input_p, real *outpu } } - // set output to local max + /* set output to local max */ *op = maxval; - // store location of max (x,y) + /* store location of max (x,y) */ *indyp = (int)(maxindex / kW)+1; *indxp = (maxindex % kW) +1; } @@ -64,11 +64,21 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) int dH = luaT_getfieldcheckint(L, 1, "dH"); THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_Tensor); THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor); - - luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4 , 2, "3D or 4D (batch mode) tensor expected"); int dimw = 2; int dimh = 1; long nbatch = 1; + long nslices; + long iheight; + long iwidth; + long oheight; + long owidth; + real *input_data; + real *output_data; + real *indices_data; + + + luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4 , 2, "3D or 4D (batch mode) tensor expected"); + if (input->nDimension == 4) { nbatch = input->size[0]; @@ -77,26 +87,26 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) } luaL_argcheck(L, input->size[dimw] >= kW && input->size[dimh] >= kH, 2, "input image smaller than kernel size"); - // sizes - long nslices = input->size[dimh-1]; - long iheight = input->size[dimh]; - long iwidth = input->size[dimw]; - long oheight = (iheight - kH) / dH + 1; - long owidth = (iwidth - kW) / dW + 1; + /* sizes */ + nslices = input->size[dimh-1]; + iheight = input->size[dimh]; + iwidth = input->size[dimw]; + oheight = (iheight - kH) / dH + 1; + owidth = (iwidth - kW) / dW + 1; - // get contiguous input + /* get contiguous input */ input = THTensor_(newContiguous)(input); - // resize output + /* resize output */ if (input->nDimension == 3) { THTensor_(resize3d)(output, nslices, oheight, owidth); - // indices will contain i,j locations for each output point + /* indices will contain i,j locations for each output point */ THTensor_(resize4d)(indices, 2, nslices, oheight, owidth); - real *input_data = THTensor_(data)(input); - real *output_data = THTensor_(data)(output); - real *indices_data = THTensor_(data)(indices); + input_data = THTensor_(data)(input); + output_data = THTensor_(data)(output); + indices_data = THTensor_(data)(indices); nn_(SpatialMaxPooling_updateOutput_frame)(input_data, output_data, indices_data+nslices*owidth*oheight, indices_data, @@ -107,15 +117,16 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) } else { + long p; + THTensor_(resize4d)(output, nbatch, nslices, oheight, owidth); - // indices will contain i,j locations for each output point + /* indices will contain i,j locations for each output point */ THTensor_(resize5d)(indices, 2, nbatch, nslices, oheight, owidth); - real *input_data = THTensor_(data)(input); - real *output_data = THTensor_(data)(output); - real *indices_data = THTensor_(data)(indices); + input_data = THTensor_(data)(input); + output_data = THTensor_(data)(output); + indices_data = THTensor_(data)(indices); - long p; #pragma omp parallel for private(p) for (p = 0; p < nbatch; p++) { @@ -128,7 +139,7 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) } } - // cleanup + /* cleanup */ THTensor_(free)(input); return 1; } @@ -149,17 +160,17 @@ static void nn_(SpatialMaxPooling_updateGradInput_frame)(real *gradInput_p, real real *indx_p_k = indx_p + k*owidth*oheight; real *indy_p_k = indy_p + k*owidth*oheight; - // calculate max points + /* calculate max points */ long i, j; for(i = 0; i < oheight; i++) { for(j = 0; j < owidth; j++) { - // retrieve position of max + /* retrieve position of max */ long maxi = indy_p_k[i*owidth + j] - 1 + i*dH; long maxj = indx_p_k[i*owidth + j] - 1 + j*dW; - // update gradient + /* update gradient */ gradInput_p_k[maxi*iwidth + maxj] += gradOutput_p_k[i*owidth + j]; } } @@ -174,36 +185,44 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L) int dH = luaT_getfieldcheckint(L, 1, "dH"); THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_Tensor); THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor); + int dimw = 2; + int dimh = 1; + long nbatch = 1; + int nslices; + int iheight; + int iwidth; + int oheight; + int owidth; + real *gradInput_data; + real *gradOutput_data; + real *indices_data; - // get contiguous gradOutput + /* get contiguous gradOutput */ gradOutput = THTensor_(newContiguous)(gradOutput); - // resize + /* resize */ THTensor_(resizeAs)(gradInput, input); THTensor_(zero)(gradInput); - int dimw = 2; - int dimh = 1; - long nbatch = 1; if (input->nDimension == 4) { nbatch = input->size[0]; dimw++; dimh++; } - // sizes - int nslices = input->size[dimh-1]; - int iheight = input->size[dimh]; - int iwidth = input->size[dimw]; - int oheight = gradOutput->size[dimh]; - int owidth = gradOutput->size[dimw]; + /* sizes */ + nslices = input->size[dimh-1]; + iheight = input->size[dimh]; + iwidth = input->size[dimw]; + oheight = gradOutput->size[dimh]; + owidth = gradOutput->size[dimw]; - // get raw pointers - real *gradInput_data = THTensor_(data)(gradInput); - real *gradOutput_data = THTensor_(data)(gradOutput); - real *indices_data = THTensor_(data)(indices); + /* get raw pointers */ + gradInput_data = THTensor_(data)(gradInput); + gradOutput_data = THTensor_(data)(gradOutput); + indices_data = THTensor_(data)(indices); - // backprop + /* backprop */ if (input->nDimension == 3) { nn_(SpatialMaxPooling_updateGradInput_frame)(gradInput_data, gradOutput_data, @@ -228,7 +247,7 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L) } } - // cleanup + /* cleanup */ THTensor_(free)(gradOutput); return 1; |