From 13fc357d358339423c07cff36d9110cfde58329d Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Thu, 27 Sep 2012 12:50:31 +0200 Subject: SpatialMaxPooling: batch parallelized over... the batch --- generic/SpatialMaxPooling.c | 201 +++++++++++++++++++++++++++----------------- 1 file changed, 126 insertions(+), 75 deletions(-) (limited to 'generic') diff --git a/generic/SpatialMaxPooling.c b/generic/SpatialMaxPooling.c index ca21ce5..7faa0ee 100644 --- a/generic/SpatialMaxPooling.c +++ b/generic/SpatialMaxPooling.c @@ -2,6 +2,59 @@ #define TH_GENERIC_FILE "generic/SpatialMaxPooling.c" #else +static void nn_(SpatialMaxPooling_updateOutput_frame)(real *input_p, real *output_p, + real *indx_p, real *indy_p, + long nslices, + long iwidth, long iheight, + long owidth, long oheight, + int kW, int kH, int dW, int dH) +{ + long k; +#pragma omp parallel for private(k) + for (k = 0; k < nslices; k++) + { + // loop over output + long i, j; + for(i = 0; i < oheight; i++) + { + for(j = 0; j < owidth; j++) + { + // local pointers + real *ip = input_p + k*iwidth*iheight + i*iwidth*dH + j*dW; + real *op = output_p + k*owidth*oheight + i*owidth + j; + real *indyp = indy_p + k*owidth*oheight + i*owidth + j; + real *indxp = indx_p + k*owidth*oheight + i*owidth + j; + + // compute local max: + long maxindex = -1; + real maxval = -THInf; + long tcntr = 0; + int x,y; + for(y = 0; y < kH; y++) + { + for(x = 0; x < kW; x++) + { + real val = *(ip + y*iwidth + x); + if (val > maxval) + { + maxval = val; + maxindex = tcntr; + } + tcntr++; + } + } + + // set output to local max + *op = maxval; + + // store location of max (x,y) + *indyp = (int)(maxindex / kW)+1; + *indxp = (maxindex % kW) +1; + } + } + } +} + static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) { THTensor *input = luaT_checkudata(L, 2, torch_Tensor); @@ -38,76 +91,79 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) if (input->nDimension == 3) { THTensor_(resize3d)(output, nslices, oheight, owidth); - // indices will contain i,j locatyions for each output point + // indices will contain i,j locations for each output point THTensor_(resize4d)(indices, 2, nslices, oheight, owidth); + + real *input_data = THTensor_(data)(input); + real *output_data = THTensor_(data)(output); + real *indices_data = THTensor_(data)(indices); + + nn_(SpatialMaxPooling_updateOutput_frame)(input_data, output_data, + indices_data+nslices*owidth*oheight, indices_data, + nslices, + iwidth, iheight, + owidth, oheight, + kW, kH, dW, dH); } else { THTensor_(resize4d)(output, nbatch, nslices, oheight, owidth); - // indices will contain i,j locatyions for each output point + // indices will contain i,j locations for each output point THTensor_(resize5d)(indices, 2, nbatch, nslices, oheight, owidth); - } + real *input_data = THTensor_(data)(input); + real *output_data = THTensor_(data)(output); + real *indices_data = THTensor_(data)(indices); - // get raw pointers - real *input_data = THTensor_(data)(input); - real *output_data = THTensor_(data)(output); - real *indices_data = THTensor_(data)(indices); + long p; +#pragma omp parallel for private(p) + for (p = 0; p < nbatch; p++) + { + nn_(SpatialMaxPooling_updateOutput_frame)(input_data+p*nslices*iwidth*iheight, output_data+p*nslices*owidth*oheight, + indices_data+(p+nbatch)*nslices*owidth*oheight, indices_data+p*nslices*owidth*oheight, + nslices, + iwidth, iheight, + owidth, oheight, + kW, kH, dW, dH); + } + } - // compute max pooling for each input slice + // cleanup + THTensor_(free)(input); + return 1; +} + +static void nn_(SpatialMaxPooling_updateGradInput_frame)(real *gradInput_p, real *gradOutput_p, + real *indx_p, real *indy_p, + long nslices, + long iwidth, long iheight, + long owidth, long oheight, + int dW, int dH) +{ long k; #pragma omp parallel for private(k) for (k = 0; k < nslices; k++) { - long p; - for (p = 0; p < nbatch; p++) + real *gradInput_p_k = gradInput_p + k*iwidth*iheight; + real *gradOutput_p_k = gradOutput_p + k*owidth*oheight; + real *indx_p_k = indx_p + k*owidth*oheight; + real *indy_p_k = indy_p + k*owidth*oheight; + + // calculate max points + long i, j; + for(i = 0; i < oheight; i++) { - // pointers to slices - real *input_p = input_data + p*nslices*iwidth*iheight + k*iwidth*iheight; - real *output_p = output_data + p*nslices*owidth*oheight + k*owidth*oheight; - real *indy_p = indices_data + p*nslices*owidth*oheight + k*owidth*oheight; - real *indx_p = indices_data + (p+nbatch)*nslices*owidth*oheight + k*owidth*oheight; - - // loop over output - int i,j; - for(i = 0; i < oheight; i++) { - for(j = 0; j < owidth; j++) { - // local pointers - real *ip = input_p + i*iwidth*dH + j*dW; - real *op = output_p + i*owidth + j; - real *indyp = indy_p + i*owidth + j; - real *indxp = indx_p + i*owidth + j; - - // compute local max: - long maxindex = -1; - real maxval = -THInf; - long tcntr = 0; - int x,y; - for(y = 0; y < kH; y++) { - for(x = 0; x < kW; x++) { - real val = *(ip + y*iwidth + x); - if (val > maxval) { - maxval = val; - maxindex = tcntr; - } - tcntr++; - } - } - - // set output to local max - *op = maxval; - - // store location of max (x,y) - *indyp = (int)(maxindex / kW)+1; - *indxp = (maxindex % kW) +1; - } + for(j = 0; j < owidth; j++) + { + // retrieve position of max + long maxi = indy_p_k[i*owidth + j] - 1 + i*dH; + long maxj = indx_p_k[i*owidth + j] - 1 + j*dW; + + // update gradient + gradInput_p_k[maxi*iwidth + maxj] += gradOutput_p_k[i*owidth + j]; } } } - // cleanup - THTensor_(free)(input); - - return 1; } static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L) @@ -135,7 +191,6 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L) dimh++; } - // sizes int nslices = input->size[dimh-1]; int iheight = input->size[dimh]; @@ -149,31 +204,27 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L) real *indices_data = THTensor_(data)(indices); // backprop - long k; -#pragma omp parallel for private(k) - for (k = 0; k < nslices; k++) + if (input->nDimension == 3) + { + nn_(SpatialMaxPooling_updateGradInput_frame)(gradInput_data, gradOutput_data, + indices_data+nslices*owidth*oheight, indices_data, + nslices, + iwidth, iheight, + owidth, oheight, + dW, dH); + } + else { long p; +#pragma omp parallel for private(p) for (p = 0; p < nbatch; p++) { - // pointers to slices - real *gradOutput_p = gradOutput_data + p*nslices*owidth*oheight + k*owidth*oheight; - real *gradInput_p = gradInput_data + p*nslices*iwidth*iheight + k*iwidth*iheight; - real *indy_p = indices_data + p*nslices*owidth*oheight + k*owidth*oheight; - real *indx_p = indices_data + (p+nbatch)*nslices*owidth*oheight + k*owidth*oheight; - - // calculate max points - int i,j; - for(i = 0; i < oheight; i++) { - for(j = 0; j < owidth; j++) { - // retrieve position of max - long maxi = *(indy_p + i*owidth + j) - 1 + i*dH; - long maxj = *(indx_p + i*owidth + j) - 1 + j*dW; - - // update gradient - *(gradInput_p + maxi*iwidth + maxj) += *(gradOutput_p + i*owidth + j); - } - } + nn_(SpatialMaxPooling_updateGradInput_frame)(gradInput_data+p*nslices*iwidth*iheight, gradOutput_data+p*nslices*owidth*oheight, + indices_data+(p+nbatch)*nslices*owidth*oheight, indices_data+p*nslices*owidth*oheight, + nslices, + iwidth, iheight, + owidth, oheight, + dW, dH); } } -- cgit v1.2.3