Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-10-01 01:53:46 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-10-01 01:53:46 +0400
commit2c6beea7dbfa04e17be5f5b7286a49c4cd137dbd (patch)
treefe5ddcb763b3017b59a0b7bd8479be791cf5cdf0 /generic
parent429f06d8e1ec2f53ed31653e7efa16e66c5bfd99 (diff)
Bprop for SpatialMaxSampling all checked.
Diffstat (limited to 'generic')
-rw-r--r--generic/SpatialMaxSampling.c65
1 files changed, 61 insertions, 4 deletions
diff --git a/generic/SpatialMaxSampling.c b/generic/SpatialMaxSampling.c
index ea9d135..d53b364 100644
--- a/generic/SpatialMaxSampling.c
+++ b/generic/SpatialMaxSampling.c
@@ -20,6 +20,7 @@ static int nn_(SpatialMaxSampling_forward)(lua_State *L)
// check dims
luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected");
+ luaL_argcheck(L, (input->size[1] >= oheight) && (input->size[2] >= owidth), 2, "upsampling not supported");
// dims
int ichannels = input->size[0];
@@ -49,8 +50,8 @@ static int nn_(SpatialMaxSampling_forward)(lua_State *L)
// pointers to slices
real *input_p = input_data + k*iwidth*iheight;
real *output_p = output_data + k*owidth*oheight;
- real *indx_p = indices_data + k*owidth*oheight;
- real *indy_p = indices_data + (k+ochannels)*owidth*oheight;
+ real *indy_p = indices_data + k*owidth*oheight;
+ real *indx_p = indices_data + (k+ochannels)*owidth*oheight;
// loop over output
int i,j;
@@ -88,8 +89,8 @@ static int nn_(SpatialMaxSampling_forward)(lua_State *L)
// store location of max (x,y)
long kW = ixe-ixs;
- *indxp = (int)(maxindex / kW)+1;
- *indyp = (maxindex % kW) +1;
+ *indyp = (int)(maxindex / kW)+1;
+ *indxp = (maxindex % kW) +1;
}
}
}
@@ -105,6 +106,62 @@ static int nn_(SpatialMaxSampling_backward)(lua_State *L)
THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
THTensor *gradOutput = luaT_checkudata(L, 3, torch_(Tensor_id));
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id));
+ THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_(Tensor_id));
+ int owidth = luaT_getfieldcheckint(L, 1, "owidth");
+ int oheight = luaT_getfieldcheckint(L, 1, "oheight");
+
+ // sizes
+ int ichannels = input->size[0];
+ int iheight = input->size[1];
+ int iwidth = input->size[2];
+ int ochannels = ichannels;
+ float dW = (float)iwidth/owidth;
+ float dH = (float)iheight/oheight;
+
+ // get contiguous gradOutput
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+
+ // resize input
+ THTensor_(resizeAs)(gradInput, input);
+ THTensor_(zero)(gradInput);
+
+ // get raw pointers
+ real *gradInput_data = THTensor_(data)(gradInput);
+ real *gradOutput_data = THTensor_(data)(gradOutput);
+ real *indices_data = THTensor_(data)(indices);
+
+ // backprop all
+ long k;
+ for (k = 0; k < ichannels; k++) {
+ // pointers to slices
+ real *gradOutput_p = gradOutput_data + k*owidth*oheight;
+ real *gradInput_p = gradInput_data + k*iwidth*iheight;
+ real *indy_p = indices_data + k*owidth*oheight;
+ real *indx_p = indices_data + (k+ochannels)*owidth*oheight;
+
+ // calculate max points
+ int i,j;
+ for(i = 0; i < oheight; i++) {
+ for(j = 0; j < owidth; j++) {
+ // compute nearest offsets
+ long iys = (long)(i*dH+0.5);
+ long ixs = (long)(j*dW+0.5);
+
+ // retrieve position of max
+ real *indyp = indy_p + i*owidth + j;
+ real *indxp = indx_p + i*owidth + j;
+ long maxi = (*indyp) - 1 + iys;
+ long maxj = (*indxp) - 1 + ixs;
+
+ // update gradient
+ *(gradInput_p + maxi*iwidth + maxj) += *(gradOutput_p + i*owidth + j);
+ }
+ }
+ }
+
+ // cleanup
+ THTensor_(free)(gradOutput);
+
return 1;
}