From 9614cd41480f7d2c1382f33924ad168c32b03828 Mon Sep 17 00:00:00 2001 From: Koray Kavukcuoglu Date: Wed, 26 Sep 2012 11:01:45 -0400 Subject: add batch mode to SpatialMaxPooling and openmpize. --- generic/SpatialMaxPooling.c | 179 ++++++++++++++++++++++++++------------------ 1 file changed, 108 insertions(+), 71 deletions(-) (limited to 'generic') diff --git a/generic/SpatialMaxPooling.c b/generic/SpatialMaxPooling.c index 234e843..ca21ce5 100644 --- a/generic/SpatialMaxPooling.c +++ b/generic/SpatialMaxPooling.c @@ -12,13 +12,22 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_Tensor); THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor); - luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected"); - luaL_argcheck(L, input->size[2] >= kW && input->size[1] >= kH, 2, "input image smaller than kernel size"); + luaL_argcheck(L, input->nDimension == 3 || input->nDimension == 4 , 2, "3D or 4D (batch mode) tensor expected"); + int dimw = 2; + int dimh = 1; + long nbatch = 1; + if (input->nDimension == 4) + { + nbatch = input->size[0]; + dimw++; + dimh++; + } + luaL_argcheck(L, input->size[dimw] >= kW && input->size[dimh] >= kH, 2, "input image smaller than kernel size"); // sizes - long nslices = input->size[0]; - long iheight = input->size[1]; - long iwidth = input->size[2]; + long nslices = input->size[dimh-1]; + long iheight = input->size[dimh]; + long iwidth = input->size[dimw]; long oheight = (iheight - kH) / dH + 1; long owidth = (iwidth - kW) / dW + 1; @@ -26,10 +35,19 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) input = THTensor_(newContiguous)(input); // resize output - THTensor_(resize3d)(output, nslices, oheight, owidth); + if (input->nDimension == 3) + { + THTensor_(resize3d)(output, nslices, oheight, owidth); + // indices will contain i,j locatyions for each output point + THTensor_(resize4d)(indices, 2, nslices, oheight, owidth); + } + else + { + THTensor_(resize4d)(output, nbatch, nslices, oheight, owidth); + // indices will contain i,j locatyions for each output point + THTensor_(resize5d)(indices, 2, nbatch, nslices, oheight, owidth); + } - // indices will contain i,j locatyions for each output point - THTensor_(resize4d)(indices, 2, nslices, oheight, owidth); // get raw pointers real *input_data = THTensor_(data)(input); @@ -39,49 +57,53 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) // compute max pooling for each input slice long k; #pragma omp parallel for private(k) - for (k = 0; k < nslices; k++) { - // pointers to slices - real *input_p = input_data + k*iwidth*iheight; - real *output_p = output_data + k*owidth*oheight; - real *indy_p = indices_data + k*owidth*oheight; - real *indx_p = indices_data + (k+nslices)*owidth*oheight; - - // loop over output - int i,j; - for(i = 0; i < oheight; i++) { - for(j = 0; j < owidth; j++) { - // local pointers - real *ip = input_p + i*iwidth*dH + j*dW; - real *op = output_p + i*owidth + j; - real *indyp = indy_p + i*owidth + j; - real *indxp = indx_p + i*owidth + j; - - // compute local max: - long maxindex = -1; - real maxval = -THInf; - long tcntr = 0; - int x,y; - for(y = 0; y < kH; y++) { - for(x = 0; x < kW; x++) { - real val = *(ip + y*iwidth + x); - if (val > maxval) { - maxval = val; - maxindex = tcntr; - } - tcntr++; - } - } - - // set output to local max - *op = maxval; - - // store location of max (x,y) - *indyp = (int)(maxindex / kW)+1; - *indxp = (maxindex % kW) +1; + for (k = 0; k < nslices; k++) + { + long p; + for (p = 0; p < nbatch; p++) + { + // pointers to slices + real *input_p = input_data + p*nslices*iwidth*iheight + k*iwidth*iheight; + real *output_p = output_data + p*nslices*owidth*oheight + k*owidth*oheight; + real *indy_p = indices_data + p*nslices*owidth*oheight + k*owidth*oheight; + real *indx_p = indices_data + (p+nbatch)*nslices*owidth*oheight + k*owidth*oheight; + + // loop over output + int i,j; + for(i = 0; i < oheight; i++) { + for(j = 0; j < owidth; j++) { + // local pointers + real *ip = input_p + i*iwidth*dH + j*dW; + real *op = output_p + i*owidth + j; + real *indyp = indy_p + i*owidth + j; + real *indxp = indx_p + i*owidth + j; + + // compute local max: + long maxindex = -1; + real maxval = -THInf; + long tcntr = 0; + int x,y; + for(y = 0; y < kH; y++) { + for(x = 0; x < kW; x++) { + real val = *(ip + y*iwidth + x); + if (val > maxval) { + maxval = val; + maxindex = tcntr; + } + tcntr++; + } + } + + // set output to local max + *op = maxval; + + // store location of max (x,y) + *indyp = (int)(maxindex / kW)+1; + *indxp = (maxindex % kW) +1; + } } } } - // cleanup THTensor_(free)(input); @@ -104,13 +126,22 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L) THTensor_(resizeAs)(gradInput, input); THTensor_(zero)(gradInput); + int dimw = 2; + int dimh = 1; + long nbatch = 1; + if (input->nDimension == 4) { + nbatch = input->size[0]; + dimw++; + dimh++; + } + + // sizes - int ichannels = input->size[0]; - int iheight = input->size[1]; - int iwidth = input->size[2]; - int ochannels = ichannels; - int oheight = gradOutput->size[1]; - int owidth = gradOutput->size[2]; + int nslices = input->size[dimh-1]; + int iheight = input->size[dimh]; + int iwidth = input->size[dimw]; + int oheight = gradOutput->size[dimh]; + int owidth = gradOutput->size[dimw]; // get raw pointers real *gradInput_data = THTensor_(data)(gradInput); @@ -119,23 +150,29 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L) // backprop long k; - for (k = 0; k < input->size[0]; k++) { - // pointers to slices - real *gradOutput_p = gradOutput_data + k*owidth*oheight; - real *gradInput_p = gradInput_data + k*iwidth*iheight; - real *indy_p = indices_data + k*owidth*oheight; - real *indx_p = indices_data + (k+ochannels)*owidth*oheight; - - // calculate max points - int i,j; - for(i = 0; i < oheight; i++) { - for(j = 0; j < owidth; j++) { - // retrieve position of max - long maxi = *(indy_p + i*owidth + j) - 1 + i*dH; - long maxj = *(indx_p + i*owidth + j) - 1 + j*dW; - - // update gradient - *(gradInput_p + maxi*iwidth + maxj) += *(gradOutput_p + i*owidth + j); +#pragma omp parallel for private(k) + for (k = 0; k < nslices; k++) + { + long p; + for (p = 0; p < nbatch; p++) + { + // pointers to slices + real *gradOutput_p = gradOutput_data + p*nslices*owidth*oheight + k*owidth*oheight; + real *gradInput_p = gradInput_data + p*nslices*iwidth*iheight + k*iwidth*iheight; + real *indy_p = indices_data + p*nslices*owidth*oheight + k*owidth*oheight; + real *indx_p = indices_data + (p+nbatch)*nslices*owidth*oheight + k*owidth*oheight; + + // calculate max points + int i,j; + for(i = 0; i < oheight; i++) { + for(j = 0; j < owidth; j++) { + // retrieve position of max + long maxi = *(indy_p + i*owidth + j) - 1 + i*dH; + long maxj = *(indx_p + i*owidth + j) - 1 + j*dW; + + // update gradient + *(gradInput_p + maxi*iwidth + maxj) += *(gradOutput_p + i*owidth + j); + } } } } -- cgit v1.2.3