Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-09-27 14:50:31 +0400
committerRonan Collobert <ronan@collobert.com>2012-09-27 14:50:31 +0400
commit13fc357d358339423c07cff36d9110cfde58329d (patch)
tree594ddd347cef12c81d3e3fcd5fa38309a6602345 /generic
parentada8bcb9c6251b54f2241ca0aa4b45742fb6768a (diff)
SpatialMaxPooling: batch parallelized over... the batch
Diffstat (limited to 'generic')
-rw-r--r--generic/SpatialMaxPooling.c201
1 files changed, 126 insertions, 75 deletions
diff --git a/generic/SpatialMaxPooling.c b/generic/SpatialMaxPooling.c
index ca21ce5..7faa0ee 100644
--- a/generic/SpatialMaxPooling.c
+++ b/generic/SpatialMaxPooling.c
@@ -2,6 +2,59 @@
#define TH_GENERIC_FILE "generic/SpatialMaxPooling.c"
#else
+static void nn_(SpatialMaxPooling_updateOutput_frame)(real *input_p, real *output_p,
+ real *indx_p, real *indy_p,
+ long nslices,
+ long iwidth, long iheight,
+ long owidth, long oheight,
+ int kW, int kH, int dW, int dH)
+{
+ long k;
+#pragma omp parallel for private(k)
+ for (k = 0; k < nslices; k++)
+ {
+ // loop over output
+ long i, j;
+ for(i = 0; i < oheight; i++)
+ {
+ for(j = 0; j < owidth; j++)
+ {
+ // local pointers
+ real *ip = input_p + k*iwidth*iheight + i*iwidth*dH + j*dW;
+ real *op = output_p + k*owidth*oheight + i*owidth + j;
+ real *indyp = indy_p + k*owidth*oheight + i*owidth + j;
+ real *indxp = indx_p + k*owidth*oheight + i*owidth + j;
+
+ // compute local max:
+ long maxindex = -1;
+ real maxval = -THInf;
+ long tcntr = 0;
+ int x,y;
+ for(y = 0; y < kH; y++)
+ {
+ for(x = 0; x < kW; x++)
+ {
+ real val = *(ip + y*iwidth + x);
+ if (val > maxval)
+ {
+ maxval = val;
+ maxindex = tcntr;
+ }
+ tcntr++;
+ }
+ }
+
+ // set output to local max
+ *op = maxval;
+
+ // store location of max (x,y)
+ *indyp = (int)(maxindex / kW)+1;
+ *indxp = (maxindex % kW) +1;
+ }
+ }
+ }
+}
+
static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L)
{
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
@@ -38,76 +91,79 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L)
if (input->nDimension == 3)
{
THTensor_(resize3d)(output, nslices, oheight, owidth);
- // indices will contain i,j locatyions for each output point
+ // indices will contain i,j locations for each output point
THTensor_(resize4d)(indices, 2, nslices, oheight, owidth);
+
+ real *input_data = THTensor_(data)(input);
+ real *output_data = THTensor_(data)(output);
+ real *indices_data = THTensor_(data)(indices);
+
+ nn_(SpatialMaxPooling_updateOutput_frame)(input_data, output_data,
+ indices_data+nslices*owidth*oheight, indices_data,
+ nslices,
+ iwidth, iheight,
+ owidth, oheight,
+ kW, kH, dW, dH);
}
else
{
THTensor_(resize4d)(output, nbatch, nslices, oheight, owidth);
- // indices will contain i,j locatyions for each output point
+ // indices will contain i,j locations for each output point
THTensor_(resize5d)(indices, 2, nbatch, nslices, oheight, owidth);
- }
+ real *input_data = THTensor_(data)(input);
+ real *output_data = THTensor_(data)(output);
+ real *indices_data = THTensor_(data)(indices);
- // get raw pointers
- real *input_data = THTensor_(data)(input);
- real *output_data = THTensor_(data)(output);
- real *indices_data = THTensor_(data)(indices);
+ long p;
+#pragma omp parallel for private(p)
+ for (p = 0; p < nbatch; p++)
+ {
+ nn_(SpatialMaxPooling_updateOutput_frame)(input_data+p*nslices*iwidth*iheight, output_data+p*nslices*owidth*oheight,
+ indices_data+(p+nbatch)*nslices*owidth*oheight, indices_data+p*nslices*owidth*oheight,
+ nslices,
+ iwidth, iheight,
+ owidth, oheight,
+ kW, kH, dW, dH);
+ }
+ }
- // compute max pooling for each input slice
+ // cleanup
+ THTensor_(free)(input);
+ return 1;
+}
+
+static void nn_(SpatialMaxPooling_updateGradInput_frame)(real *gradInput_p, real *gradOutput_p,
+ real *indx_p, real *indy_p,
+ long nslices,
+ long iwidth, long iheight,
+ long owidth, long oheight,
+ int dW, int dH)
+{
long k;
#pragma omp parallel for private(k)
for (k = 0; k < nslices; k++)
{
- long p;
- for (p = 0; p < nbatch; p++)
+ real *gradInput_p_k = gradInput_p + k*iwidth*iheight;
+ real *gradOutput_p_k = gradOutput_p + k*owidth*oheight;
+ real *indx_p_k = indx_p + k*owidth*oheight;
+ real *indy_p_k = indy_p + k*owidth*oheight;
+
+ // calculate max points
+ long i, j;
+ for(i = 0; i < oheight; i++)
{
- // pointers to slices
- real *input_p = input_data + p*nslices*iwidth*iheight + k*iwidth*iheight;
- real *output_p = output_data + p*nslices*owidth*oheight + k*owidth*oheight;
- real *indy_p = indices_data + p*nslices*owidth*oheight + k*owidth*oheight;
- real *indx_p = indices_data + (p+nbatch)*nslices*owidth*oheight + k*owidth*oheight;
-
- // loop over output
- int i,j;
- for(i = 0; i < oheight; i++) {
- for(j = 0; j < owidth; j++) {
- // local pointers
- real *ip = input_p + i*iwidth*dH + j*dW;
- real *op = output_p + i*owidth + j;
- real *indyp = indy_p + i*owidth + j;
- real *indxp = indx_p + i*owidth + j;
-
- // compute local max:
- long maxindex = -1;
- real maxval = -THInf;
- long tcntr = 0;
- int x,y;
- for(y = 0; y < kH; y++) {
- for(x = 0; x < kW; x++) {
- real val = *(ip + y*iwidth + x);
- if (val > maxval) {
- maxval = val;
- maxindex = tcntr;
- }
- tcntr++;
- }
- }
-
- // set output to local max
- *op = maxval;
-
- // store location of max (x,y)
- *indyp = (int)(maxindex / kW)+1;
- *indxp = (maxindex % kW) +1;
- }
+ for(j = 0; j < owidth; j++)
+ {
+ // retrieve position of max
+ long maxi = indy_p_k[i*owidth + j] - 1 + i*dH;
+ long maxj = indx_p_k[i*owidth + j] - 1 + j*dW;
+
+ // update gradient
+ gradInput_p_k[maxi*iwidth + maxj] += gradOutput_p_k[i*owidth + j];
}
}
}
- // cleanup
- THTensor_(free)(input);
-
- return 1;
}
static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L)
@@ -135,7 +191,6 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L)
dimh++;
}
-
// sizes
int nslices = input->size[dimh-1];
int iheight = input->size[dimh];
@@ -149,31 +204,27 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L)
real *indices_data = THTensor_(data)(indices);
// backprop
- long k;
-#pragma omp parallel for private(k)
- for (k = 0; k < nslices; k++)
+ if (input->nDimension == 3)
+ {
+ nn_(SpatialMaxPooling_updateGradInput_frame)(gradInput_data, gradOutput_data,
+ indices_data+nslices*owidth*oheight, indices_data,
+ nslices,
+ iwidth, iheight,
+ owidth, oheight,
+ dW, dH);
+ }
+ else
{
long p;
+#pragma omp parallel for private(p)
for (p = 0; p < nbatch; p++)
{
- // pointers to slices
- real *gradOutput_p = gradOutput_data + p*nslices*owidth*oheight + k*owidth*oheight;
- real *gradInput_p = gradInput_data + p*nslices*iwidth*iheight + k*iwidth*iheight;
- real *indy_p = indices_data + p*nslices*owidth*oheight + k*owidth*oheight;
- real *indx_p = indices_data + (p+nbatch)*nslices*owidth*oheight + k*owidth*oheight;
-
- // calculate max points
- int i,j;
- for(i = 0; i < oheight; i++) {
- for(j = 0; j < owidth; j++) {
- // retrieve position of max
- long maxi = *(indy_p + i*owidth + j) - 1 + i*dH;
- long maxj = *(indx_p + i*owidth + j) - 1 + j*dW;
-
- // update gradient
- *(gradInput_p + maxi*iwidth + maxj) += *(gradOutput_p + i*owidth + j);
- }
- }
+ nn_(SpatialMaxPooling_updateGradInput_frame)(gradInput_data+p*nslices*iwidth*iheight, gradOutput_data+p*nslices*owidth*oheight,
+ indices_data+(p+nbatch)*nslices*owidth*oheight, indices_data+p*nslices*owidth*oheight,
+ nslices,
+ iwidth, iheight,
+ owidth, oheight,
+ dW, dH);
}
}