Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-10-01 00:21:12 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-10-01 00:21:12 +0400
commitf50fcafb4736c3106cfcdb3c1d87f8c04fde3be1 (patch)
treead1aff8597273a2b0f026cb248a27cd397c93239 /generic
parent1ed6ecfee94b48fd6a92bae3ef982a41a55c0fde (diff)
Added SpatialMaxSampling module, for flexible competitive resampling.
Diffstat (limited to 'generic')
-rw-r--r--generic/SpatialMaxSampling.c124
1 files changed, 124 insertions, 0 deletions
diff --git a/generic/SpatialMaxSampling.c b/generic/SpatialMaxSampling.c
new file mode 100644
index 0000000..ea9d135
--- /dev/null
+++ b/generic/SpatialMaxSampling.c
@@ -0,0 +1,124 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/SpatialMaxSampling.c"
+#else
+
+#ifndef MAX
+#define MAX(a,b) ( ((a)>(b)) ? (a) : (b) )
+#endif
+#ifndef MIN
+#define MIN(a,b) ( ((a)<(b)) ? (a) : (b) )
+#endif
+
+static int nn_(SpatialMaxSampling_forward)(lua_State *L)
+{
+ // get all params
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ int owidth = luaT_getfieldcheckint(L, 1, "owidth");
+ int oheight = luaT_getfieldcheckint(L, 1, "oheight");
+ THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_(Tensor_id));
+ THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_(Tensor_id));
+
+ // check dims
+ luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected");
+
+ // dims
+ int ichannels = input->size[0];
+ int iheight = input->size[1];
+ int iwidth = input->size[2];
+ int ochannels = ichannels;
+ float dW = (float)iwidth/owidth;
+ float dH = (float)iheight/oheight;
+
+ // get contiguous input
+ input = THTensor_(newContiguous)(input);
+
+ // resize output
+ THTensor_(resize3d)(output, ochannels, oheight, owidth);
+
+ // indices will contain i,j locations for each output point
+ THTensor_(resize4d)(indices, 2, ochannels, oheight, owidth);
+
+ // get raw pointers
+ real *input_data = THTensor_(data)(input);
+ real *output_data = THTensor_(data)(output);
+ real *indices_data = THTensor_(data)(indices);
+
+ // compute max pooling for each input slice
+ long k;
+ for (k = 0; k < ochannels; k++) {
+ // pointers to slices
+ real *input_p = input_data + k*iwidth*iheight;
+ real *output_p = output_data + k*owidth*oheight;
+ real *indx_p = indices_data + k*owidth*oheight;
+ real *indy_p = indices_data + (k+ochannels)*owidth*oheight;
+
+ // loop over output
+ int i,j;
+ for(i = 0; i < oheight; i++) {
+ for(j = 0; j < owidth; j++) {
+ // compute nearest offsets
+ long ixs = (long)(j*dW+0.5);
+ long iys = (long)(i*dH+0.5);
+ long ixe = (long)((j+1)*dW+0.5);
+ long iye = (long)((i+1)*dH+0.5);
+
+ // local pointers
+ real *op = output_p + i*owidth + j;
+ real *indxp = indx_p + i*owidth + j;
+ real *indyp = indy_p + i*owidth + j;
+
+ // compute local max:
+ long maxindex = -1;
+ real maxval = -THInf;
+ long tcntr = 0;
+ int x,y;
+ for(y = iys; y < iye; y++) {
+ for(x = ixs; x < ixe; x++) {
+ real val = *(input_p + y*iwidth + x);
+ if (val > maxval) {
+ maxval = val;
+ maxindex = tcntr;
+ }
+ tcntr++;
+ }
+ }
+
+ // set output to local max
+ *op = maxval;
+
+ // store location of max (x,y)
+ long kW = ixe-ixs;
+ *indxp = (int)(maxindex / kW)+1;
+ *indyp = (maxindex % kW) +1;
+ }
+ }
+ }
+
+ // cleanup
+ THTensor_(free)(input);
+ return 1;
+}
+
+static int nn_(SpatialMaxSampling_backward)(lua_State *L)
+{
+ // get all params
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ THTensor *gradOutput = luaT_checkudata(L, 3, torch_(Tensor_id));
+ THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id));
+ return 1;
+}
+
+static const struct luaL_Reg nn_(SpatialMaxSampling__) [] = {
+ {"SpatialMaxSampling_forward", nn_(SpatialMaxSampling_forward)},
+ {"SpatialMaxSampling_backward", nn_(SpatialMaxSampling_backward)},
+ {NULL, NULL}
+};
+
+static void nn_(SpatialMaxSampling_init)(lua_State *L)
+{
+ luaT_pushmetaclass(L, torch_(Tensor_id));
+ luaT_registeratname(L, nn_(SpatialMaxSampling__), "nn");
+ lua_pop(L,1);
+}
+
+#endif