diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-10-01 00:21:12 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-10-01 00:21:12 +0400 |
commit | f50fcafb4736c3106cfcdb3c1d87f8c04fde3be1 (patch) | |
tree | ad1aff8597273a2b0f026cb248a27cd397c93239 /generic | |
parent | 1ed6ecfee94b48fd6a92bae3ef982a41a55c0fde (diff) |
Added SpatialMaxSampling module, for flexible competitive resampling.
Diffstat (limited to 'generic')
-rw-r--r-- | generic/SpatialMaxSampling.c | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/generic/SpatialMaxSampling.c b/generic/SpatialMaxSampling.c new file mode 100644 index 0000000..ea9d135 --- /dev/null +++ b/generic/SpatialMaxSampling.c @@ -0,0 +1,124 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/SpatialMaxSampling.c" +#else + +#ifndef MAX +#define MAX(a,b) ( ((a)>(b)) ? (a) : (b) ) +#endif +#ifndef MIN +#define MIN(a,b) ( ((a)<(b)) ? (a) : (b) ) +#endif + +static int nn_(SpatialMaxSampling_forward)(lua_State *L) +{ + // get all params + THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id)); + int owidth = luaT_getfieldcheckint(L, 1, "owidth"); + int oheight = luaT_getfieldcheckint(L, 1, "oheight"); + THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_(Tensor_id)); + THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_(Tensor_id)); + + // check dims + luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected"); + + // dims + int ichannels = input->size[0]; + int iheight = input->size[1]; + int iwidth = input->size[2]; + int ochannels = ichannels; + float dW = (float)iwidth/owidth; + float dH = (float)iheight/oheight; + + // get contiguous input + input = THTensor_(newContiguous)(input); + + // resize output + THTensor_(resize3d)(output, ochannels, oheight, owidth); + + // indices will contain i,j locations for each output point + THTensor_(resize4d)(indices, 2, ochannels, oheight, owidth); + + // get raw pointers + real *input_data = THTensor_(data)(input); + real *output_data = THTensor_(data)(output); + real *indices_data = THTensor_(data)(indices); + + // compute max pooling for each input slice + long k; + for (k = 0; k < ochannels; k++) { + // pointers to slices + real *input_p = input_data + k*iwidth*iheight; + real *output_p = output_data + k*owidth*oheight; + real *indx_p = indices_data + k*owidth*oheight; + real *indy_p = indices_data + (k+ochannels)*owidth*oheight; + + // loop over output + int i,j; + for(i = 0; i < oheight; i++) { + for(j = 0; j < owidth; j++) { + // compute nearest offsets + long ixs = (long)(j*dW+0.5); + long iys = (long)(i*dH+0.5); + long ixe = (long)((j+1)*dW+0.5); + long iye = (long)((i+1)*dH+0.5); + + // local pointers + real *op = output_p + i*owidth + j; + real *indxp = indx_p + i*owidth + j; + real *indyp = indy_p + i*owidth + j; + + // compute local max: + long maxindex = -1; + real maxval = -THInf; + long tcntr = 0; + int x,y; + for(y = iys; y < iye; y++) { + for(x = ixs; x < ixe; x++) { + real val = *(input_p + y*iwidth + x); + if (val > maxval) { + maxval = val; + maxindex = tcntr; + } + tcntr++; + } + } + + // set output to local max + *op = maxval; + + // store location of max (x,y) + long kW = ixe-ixs; + *indxp = (int)(maxindex / kW)+1; + *indyp = (maxindex % kW) +1; + } + } + } + + // cleanup + THTensor_(free)(input); + return 1; +} + +static int nn_(SpatialMaxSampling_backward)(lua_State *L) +{ + // get all params + THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id)); + THTensor *gradOutput = luaT_checkudata(L, 3, torch_(Tensor_id)); + THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id)); + return 1; +} + +static const struct luaL_Reg nn_(SpatialMaxSampling__) [] = { + {"SpatialMaxSampling_forward", nn_(SpatialMaxSampling_forward)}, + {"SpatialMaxSampling_backward", nn_(SpatialMaxSampling_backward)}, + {NULL, NULL} +}; + +static void nn_(SpatialMaxSampling_init)(lua_State *L) +{ + luaT_pushmetaclass(L, torch_(Tensor_id)); + luaT_registeratname(L, nn_(SpatialMaxSampling__), "nn"); + lua_pop(L,1); +} + +#endif |