Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Losch <mmlosch@kth.se>2015-03-03 14:48:53 +0300
committerMax Losch <mmlosch@kth.se>2015-03-03 14:48:53 +0300
commit5b3d27fa72c7b8731d41c87362ad98b7ebfea245 (patch)
tree0d6d3e4e57642a2e0ed668368bd099047237edc0 /generic
parent24a1715cd5095b3b92ec10b5f4764c13c7522ec1 (diff)
Add batch mode to VolumetricMaxPooling
Diffstat (limited to 'generic')
-rw-r--r--generic/VolumetricMaxPooling.c170
1 files changed, 132 insertions, 38 deletions
diff --git a/generic/VolumetricMaxPooling.c b/generic/VolumetricMaxPooling.c
index 20f9701..28fd5fe 100644
--- a/generic/VolumetricMaxPooling.c
+++ b/generic/VolumetricMaxPooling.c
@@ -52,6 +52,7 @@ static void nn_(VolumetricMaxPooling_updateOutput_frame)(real *input_p, real *ou
}
}
}
+
/* set output to local max */
*op = maxval;
@@ -86,15 +87,27 @@ static int nn_(VolumetricMaxPooling_updateOutput)(lua_State *L)
real *output_data;
real *indices_data;
+ luaL_argcheck(L, input->nDimension == 4 || input->nDimension == 5, 2, "4D or 5D (batch-mode) tensor expected");
+
+ int dimN = 0;
+ int dimt = 1;
+ int dimh = 2;
+ int dimw = 3;
+
+ if (input->nDimension == 5) {
+ dimN++;
+ dimt++;
+ dimh++;
+ dimw++;
+ }
- luaL_argcheck(L, input->nDimension == 4 , 2, "4D tensor expected");
- luaL_argcheck(L, input->size[3] >= kW && input->size[2] >= kH && input->size[1] >= kT, 2, "input image smaller than kernel size");
+ luaL_argcheck(L, input->size[dimw] >= kW && input->size[dimh] >= kH && input->size[dimt] >= kT, 2, "input image smaller than kernel size");
/* sizes */
- nslices = input->size[0];
- itime = input->size[1];
- iheight = input->size[2];
- iwidth = input->size[3];
+ nslices = input->size[dimN];
+ itime = input->size[dimt];
+ iheight = input->size[dimh];
+ iwidth = input->size[dimw];
otime = (itime - kT) / dT + 1;
oheight = (iheight - kH) / dH + 1;
owidth = (iwidth - kW) / dW + 1;
@@ -102,23 +115,66 @@ static int nn_(VolumetricMaxPooling_updateOutput)(lua_State *L)
/* get contiguous input */
input = THTensor_(newContiguous)(input);
- /* resize output */
- THTensor_(resize4d)(output, nslices, otime, oheight, owidth);
- /* indices will contain ti,i,j locations for each output point */
- THTensor_(resize5d)(indices, 3, nslices, otime, oheight, owidth);
-
- input_data = THTensor_(data)(input);
- output_data = THTensor_(data)(output);
- indices_data = THTensor_(data)(indices);
-
- nn_(VolumetricMaxPooling_updateOutput_frame)(input_data, output_data,
- indices_data+nslices*otime*owidth*oheight*2,
- indices_data+nslices*otime*owidth*oheight,
- indices_data,
- nslices,
- itime, iwidth, iheight,
- otime, owidth, oheight,
- kT, kW, kH, dT, dW, dH);
+ if (input->nDimension == 4) { /* non-batch mode */
+ /* resize output */
+ THTensor_(resize4d)(output, nslices, otime, oheight, owidth);
+ /* indices will contain ti,i,j locations for each output point */
+ THTensor_(resize5d)(indices, 3, nslices, otime, oheight, owidth);
+
+ input_data = THTensor_(data)(input);
+ output_data = THTensor_(data)(output);
+ indices_data = THTensor_(data)(indices);
+
+ nn_(VolumetricMaxPooling_updateOutput_frame)(input_data, output_data,
+ indices_data+nslices*otime*owidth*oheight*2,
+ indices_data+nslices*otime*owidth*oheight,
+ indices_data,
+ nslices,
+ itime, iwidth, iheight,
+ otime, owidth, oheight,
+ kT, kW, kH, dT, dW, dH);
+ }
+ else { /* batch mode */
+ long p;
+ long nBatch = input->size[0];
+
+ long istride = nslices*itime*iwidth*iheight;
+ long ostride = nslices*otime*owidth*oheight;
+
+ /* resize output */
+ THTensor_(resize5d)(output, nBatch, nslices, otime, oheight, owidth);
+ /* indices will contain ti,i,j locations for each output point */
+
+ THLongStorage* size = THLongStorage_newWithSize(6);
+ size->data[0] = 3; size->data[1] = nBatch;
+ size->data[2] = nslices; size->data[3] = otime;
+ size->data[4] = oheight; size->data[5] = owidth;
+ THTensor_(resize)(indices, size, NULL); /* resize6d not available */
+ //TODO: Replace with resize6d when available
+ //THTensor_(resize6d)(indices, 3, nBatch, nslices, otime, oheight, owidth);
+
+ input_data = THTensor_(data)(input);
+ output_data = THTensor_(data)(output);
+ indices_data = THTensor_(data)(indices);
+
+#pragma omp parallel for private(p)
+ for (p=0; p < nBatch; p++)
+ {
+ nn_(VolumetricMaxPooling_updateOutput_frame)(
+ input_data+p*istride,
+ output_data+p*ostride,
+ indices_data+(p+nBatch+nBatch)*ostride,
+ indices_data+(p+nBatch)*ostride,
+ indices_data+p*ostride,
+ nslices,
+ itime, iwidth, iheight,
+ otime, owidth, oheight,
+ kT, kW, kH, dT, dW, dH);
+ }
+
+ THLongStorage_free(size);
+ }
+
/* cleanup */
THTensor_(free)(input);
return 1;
@@ -182,6 +238,12 @@ static int nn_(VolumetricMaxPooling_updateGradInput)(lua_State *L)
real *gradOutput_data;
real *indices_data;
+ int dimN = 0;
+ int dimt = 1;
+ int dimh = 2;
+ int dimw = 3;
+
+
/* get contiguous gradOutput */
gradOutput = THTensor_(newContiguous)(gradOutput);
@@ -189,14 +251,21 @@ static int nn_(VolumetricMaxPooling_updateGradInput)(lua_State *L)
THTensor_(resizeAs)(gradInput, input);
THTensor_(zero)(gradInput);
+ if (input->nDimension == 5) {
+ dimN++;
+ dimt++;
+ dimh++;
+ dimw++;
+ }
+
/* sizes */
- nslices = input->size[0];
- itime = input->size[1];
- iheight = input->size[2];
- iwidth = input->size[3];
- otime = gradOutput->size[1];
- oheight = gradOutput->size[2];
- owidth = gradOutput->size[3];
+ nslices = input->size[dimN];
+ itime = input->size[dimt];
+ iheight = input->size[dimh];
+ iwidth = input->size[dimw];
+ otime = gradOutput->size[dimt];
+ oheight = gradOutput->size[dimh];
+ owidth = gradOutput->size[dimw];
/* get raw pointers */
gradInput_data = THTensor_(data)(gradInput);
@@ -204,14 +273,39 @@ static int nn_(VolumetricMaxPooling_updateGradInput)(lua_State *L)
indices_data = THTensor_(data)(indices);
/* backprop */
- nn_(VolumetricMaxPooling_updateGradInput_frame)(gradInput_data, gradOutput_data,
- indices_data+nslices*otime*owidth*oheight*2,
- indices_data+nslices*otime*owidth*oheight,
- indices_data,
- nslices,
- itime, iwidth, iheight,
- otime, owidth, oheight,
- dT, dW, dH);
+ if (input->nDimension == 4) { /* non-batch mode*/
+
+ nn_(VolumetricMaxPooling_updateGradInput_frame)(gradInput_data, gradOutput_data,
+ indices_data+nslices*otime*owidth*oheight*2,
+ indices_data+nslices*otime*owidth*oheight,
+ indices_data,
+ nslices,
+ itime, iwidth, iheight,
+ otime, owidth, oheight,
+ dT, dW, dH);
+ }
+ else { /* batch mode */
+ long p;
+ long nBatch = input->size[0];
+
+ long istride = nslices*itime*iwidth*iheight;
+ long ostride = nslices*otime*owidth*oheight;
+
+#pragma omp parallel for private(p)
+ for (p = 0; p < nBatch; p++)
+ {
+ nn_(VolumetricMaxPooling_updateGradInput_frame)(
+ gradInput_data+p*istride,
+ gradOutput_data+p*ostride,
+ indices_data+(p+nBatch+nBatch)*ostride,
+ indices_data+(p+nBatch)*ostride,
+ indices_data+p*ostride,
+ nslices,
+ itime, iwidth, iheight,
+ otime, owidth, oheight,
+ dT, dW, dH);
+ }
+ }
/* cleanup */
THTensor_(free)(gradOutput);