diff options
author | Max Losch <mmlosch@kth.se> | 2015-03-03 14:48:53 +0300 |
---|---|---|
committer | Max Losch <mmlosch@kth.se> | 2015-03-03 14:48:53 +0300 |
commit | 5b3d27fa72c7b8731d41c87362ad98b7ebfea245 (patch) | |
tree | 0d6d3e4e57642a2e0ed668368bd099047237edc0 /generic | |
parent | 24a1715cd5095b3b92ec10b5f4764c13c7522ec1 (diff) |
Add batch mode to VolumetricMaxPooling
Diffstat (limited to 'generic')
-rw-r--r-- | generic/VolumetricMaxPooling.c | 170 |
1 files changed, 132 insertions, 38 deletions
diff --git a/generic/VolumetricMaxPooling.c b/generic/VolumetricMaxPooling.c index 20f9701..28fd5fe 100644 --- a/generic/VolumetricMaxPooling.c +++ b/generic/VolumetricMaxPooling.c @@ -52,6 +52,7 @@ static void nn_(VolumetricMaxPooling_updateOutput_frame)(real *input_p, real *ou } } } + /* set output to local max */ *op = maxval; @@ -86,15 +87,27 @@ static int nn_(VolumetricMaxPooling_updateOutput)(lua_State *L) real *output_data; real *indices_data; + luaL_argcheck(L, input->nDimension == 4 || input->nDimension == 5, 2, "4D or 5D (batch-mode) tensor expected"); + + int dimN = 0; + int dimt = 1; + int dimh = 2; + int dimw = 3; + + if (input->nDimension == 5) { + dimN++; + dimt++; + dimh++; + dimw++; + } - luaL_argcheck(L, input->nDimension == 4 , 2, "4D tensor expected"); - luaL_argcheck(L, input->size[3] >= kW && input->size[2] >= kH && input->size[1] >= kT, 2, "input image smaller than kernel size"); + luaL_argcheck(L, input->size[dimw] >= kW && input->size[dimh] >= kH && input->size[dimt] >= kT, 2, "input image smaller than kernel size"); /* sizes */ - nslices = input->size[0]; - itime = input->size[1]; - iheight = input->size[2]; - iwidth = input->size[3]; + nslices = input->size[dimN]; + itime = input->size[dimt]; + iheight = input->size[dimh]; + iwidth = input->size[dimw]; otime = (itime - kT) / dT + 1; oheight = (iheight - kH) / dH + 1; owidth = (iwidth - kW) / dW + 1; @@ -102,23 +115,66 @@ static int nn_(VolumetricMaxPooling_updateOutput)(lua_State *L) /* get contiguous input */ input = THTensor_(newContiguous)(input); - /* resize output */ - THTensor_(resize4d)(output, nslices, otime, oheight, owidth); - /* indices will contain ti,i,j locations for each output point */ - THTensor_(resize5d)(indices, 3, nslices, otime, oheight, owidth); - - input_data = THTensor_(data)(input); - output_data = THTensor_(data)(output); - indices_data = THTensor_(data)(indices); - - nn_(VolumetricMaxPooling_updateOutput_frame)(input_data, output_data, - indices_data+nslices*otime*owidth*oheight*2, - indices_data+nslices*otime*owidth*oheight, - indices_data, - nslices, - itime, iwidth, iheight, - otime, owidth, oheight, - kT, kW, kH, dT, dW, dH); + if (input->nDimension == 4) { /* non-batch mode */ + /* resize output */ + THTensor_(resize4d)(output, nslices, otime, oheight, owidth); + /* indices will contain ti,i,j locations for each output point */ + THTensor_(resize5d)(indices, 3, nslices, otime, oheight, owidth); + + input_data = THTensor_(data)(input); + output_data = THTensor_(data)(output); + indices_data = THTensor_(data)(indices); + + nn_(VolumetricMaxPooling_updateOutput_frame)(input_data, output_data, + indices_data+nslices*otime*owidth*oheight*2, + indices_data+nslices*otime*owidth*oheight, + indices_data, + nslices, + itime, iwidth, iheight, + otime, owidth, oheight, + kT, kW, kH, dT, dW, dH); + } + else { /* batch mode */ + long p; + long nBatch = input->size[0]; + + long istride = nslices*itime*iwidth*iheight; + long ostride = nslices*otime*owidth*oheight; + + /* resize output */ + THTensor_(resize5d)(output, nBatch, nslices, otime, oheight, owidth); + /* indices will contain ti,i,j locations for each output point */ + + THLongStorage* size = THLongStorage_newWithSize(6); + size->data[0] = 3; size->data[1] = nBatch; + size->data[2] = nslices; size->data[3] = otime; + size->data[4] = oheight; size->data[5] = owidth; + THTensor_(resize)(indices, size, NULL); /* resize6d not available */ + //TODO: Replace with resize6d when available + //THTensor_(resize6d)(indices, 3, nBatch, nslices, otime, oheight, owidth); + + input_data = THTensor_(data)(input); + output_data = THTensor_(data)(output); + indices_data = THTensor_(data)(indices); + +#pragma omp parallel for private(p) + for (p=0; p < nBatch; p++) + { + nn_(VolumetricMaxPooling_updateOutput_frame)( + input_data+p*istride, + output_data+p*ostride, + indices_data+(p+nBatch+nBatch)*ostride, + indices_data+(p+nBatch)*ostride, + indices_data+p*ostride, + nslices, + itime, iwidth, iheight, + otime, owidth, oheight, + kT, kW, kH, dT, dW, dH); + } + + THLongStorage_free(size); + } + /* cleanup */ THTensor_(free)(input); return 1; @@ -182,6 +238,12 @@ static int nn_(VolumetricMaxPooling_updateGradInput)(lua_State *L) real *gradOutput_data; real *indices_data; + int dimN = 0; + int dimt = 1; + int dimh = 2; + int dimw = 3; + + /* get contiguous gradOutput */ gradOutput = THTensor_(newContiguous)(gradOutput); @@ -189,14 +251,21 @@ static int nn_(VolumetricMaxPooling_updateGradInput)(lua_State *L) THTensor_(resizeAs)(gradInput, input); THTensor_(zero)(gradInput); + if (input->nDimension == 5) { + dimN++; + dimt++; + dimh++; + dimw++; + } + /* sizes */ - nslices = input->size[0]; - itime = input->size[1]; - iheight = input->size[2]; - iwidth = input->size[3]; - otime = gradOutput->size[1]; - oheight = gradOutput->size[2]; - owidth = gradOutput->size[3]; + nslices = input->size[dimN]; + itime = input->size[dimt]; + iheight = input->size[dimh]; + iwidth = input->size[dimw]; + otime = gradOutput->size[dimt]; + oheight = gradOutput->size[dimh]; + owidth = gradOutput->size[dimw]; /* get raw pointers */ gradInput_data = THTensor_(data)(gradInput); @@ -204,14 +273,39 @@ static int nn_(VolumetricMaxPooling_updateGradInput)(lua_State *L) indices_data = THTensor_(data)(indices); /* backprop */ - nn_(VolumetricMaxPooling_updateGradInput_frame)(gradInput_data, gradOutput_data, - indices_data+nslices*otime*owidth*oheight*2, - indices_data+nslices*otime*owidth*oheight, - indices_data, - nslices, - itime, iwidth, iheight, - otime, owidth, oheight, - dT, dW, dH); + if (input->nDimension == 4) { /* non-batch mode*/ + + nn_(VolumetricMaxPooling_updateGradInput_frame)(gradInput_data, gradOutput_data, + indices_data+nslices*otime*owidth*oheight*2, + indices_data+nslices*otime*owidth*oheight, + indices_data, + nslices, + itime, iwidth, iheight, + otime, owidth, oheight, + dT, dW, dH); + } + else { /* batch mode */ + long p; + long nBatch = input->size[0]; + + long istride = nslices*itime*iwidth*iheight; + long ostride = nslices*otime*owidth*oheight; + +#pragma omp parallel for private(p) + for (p = 0; p < nBatch; p++) + { + nn_(VolumetricMaxPooling_updateGradInput_frame)( + gradInput_data+p*istride, + gradOutput_data+p*ostride, + indices_data+(p+nBatch+nBatch)*ostride, + indices_data+(p+nBatch)*ostride, + indices_data+p*ostride, + nslices, + itime, iwidth, iheight, + otime, owidth, oheight, + dT, dW, dH); + } + } /* cleanup */ THTensor_(free)(gradOutput); |