diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-10-01 01:53:46 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-10-01 01:53:46 +0400 |
commit | 2c6beea7dbfa04e17be5f5b7286a49c4cd137dbd (patch) | |
tree | fe5ddcb763b3017b59a0b7bd8479be791cf5cdf0 /generic | |
parent | 429f06d8e1ec2f53ed31653e7efa16e66c5bfd99 (diff) |
Bprop for SpatialMaxSampling all checked.
Diffstat (limited to 'generic')
-rw-r--r-- | generic/SpatialMaxSampling.c | 65 |
1 files changed, 61 insertions, 4 deletions
diff --git a/generic/SpatialMaxSampling.c b/generic/SpatialMaxSampling.c index ea9d135..d53b364 100644 --- a/generic/SpatialMaxSampling.c +++ b/generic/SpatialMaxSampling.c @@ -20,6 +20,7 @@ static int nn_(SpatialMaxSampling_forward)(lua_State *L) // check dims luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected"); + luaL_argcheck(L, (input->size[1] >= oheight) && (input->size[2] >= owidth), 2, "upsampling not supported"); // dims int ichannels = input->size[0]; @@ -49,8 +50,8 @@ static int nn_(SpatialMaxSampling_forward)(lua_State *L) // pointers to slices real *input_p = input_data + k*iwidth*iheight; real *output_p = output_data + k*owidth*oheight; - real *indx_p = indices_data + k*owidth*oheight; - real *indy_p = indices_data + (k+ochannels)*owidth*oheight; + real *indy_p = indices_data + k*owidth*oheight; + real *indx_p = indices_data + (k+ochannels)*owidth*oheight; // loop over output int i,j; @@ -88,8 +89,8 @@ static int nn_(SpatialMaxSampling_forward)(lua_State *L) // store location of max (x,y) long kW = ixe-ixs; - *indxp = (int)(maxindex / kW)+1; - *indyp = (maxindex % kW) +1; + *indyp = (int)(maxindex / kW)+1; + *indxp = (maxindex % kW) +1; } } } @@ -105,6 +106,62 @@ static int nn_(SpatialMaxSampling_backward)(lua_State *L) THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id)); THTensor *gradOutput = luaT_checkudata(L, 3, torch_(Tensor_id)); THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id)); + THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_(Tensor_id)); + int owidth = luaT_getfieldcheckint(L, 1, "owidth"); + int oheight = luaT_getfieldcheckint(L, 1, "oheight"); + + // sizes + int ichannels = input->size[0]; + int iheight = input->size[1]; + int iwidth = input->size[2]; + int ochannels = ichannels; + float dW = (float)iwidth/owidth; + float dH = (float)iheight/oheight; + + // get contiguous gradOutput + gradOutput = THTensor_(newContiguous)(gradOutput); + + // resize input + THTensor_(resizeAs)(gradInput, input); + THTensor_(zero)(gradInput); + + // get raw pointers + real *gradInput_data = THTensor_(data)(gradInput); + real *gradOutput_data = THTensor_(data)(gradOutput); + real *indices_data = THTensor_(data)(indices); + + // backprop all + long k; + for (k = 0; k < ichannels; k++) { + // pointers to slices + real *gradOutput_p = gradOutput_data + k*owidth*oheight; + real *gradInput_p = gradInput_data + k*iwidth*iheight; + real *indy_p = indices_data + k*owidth*oheight; + real *indx_p = indices_data + (k+ochannels)*owidth*oheight; + + // calculate max points + int i,j; + for(i = 0; i < oheight; i++) { + for(j = 0; j < owidth; j++) { + // compute nearest offsets + long iys = (long)(i*dH+0.5); + long ixs = (long)(j*dW+0.5); + + // retrieve position of max + real *indyp = indy_p + i*owidth + j; + real *indxp = indx_p + i*owidth + j; + long maxi = (*indyp) - 1 + iys; + long maxj = (*indxp) - 1 + ixs; + + // update gradient + *(gradInput_p + maxi*iwidth + maxj) += *(gradOutput_p + i*owidth + j); + } + } + } + + // cleanup + THTensor_(free)(gradOutput); + return 1; } |