Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-07-25 15:50:37 +0400
committerRonan Collobert <ronan@collobert.com>2012-07-25 15:50:37 +0400
commit0d9a4a50c9253b032a41ad114c41aa82d370d102 (patch)
tree220fffbdeac73edbf7bb85bf39666a4d33cd2ee5 /generic
parent37b738909113cd43b2d9a15ff279d737134a9462 (diff)
added temporal max pooling
Diffstat (limited to 'generic')
-rw-r--r--generic/TemporalMaxPooling.c127
1 files changed, 127 insertions, 0 deletions
diff --git a/generic/TemporalMaxPooling.c b/generic/TemporalMaxPooling.c
new file mode 100644
index 0000000..56d0ef6
--- /dev/null
+++ b/generic/TemporalMaxPooling.c
@@ -0,0 +1,127 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/TemporalMaxPooling.c"
+#else
+
+static int nn_(TemporalMaxPooling_updateOutput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ int kW = luaT_getfieldcheckint(L, 1, "kW");
+ int dW = luaT_getfieldcheckint(L, 1, "dW");
+ THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_(Tensor_id));
+ THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_(Tensor_id));
+
+ luaL_argcheck(L, input->nDimension == 2, 2, "2D tensor expected");
+ luaL_argcheck(L, input->size[0] >= kW, 2, "input sequence smaller than kernel size");
+
+ // sizes
+ long niframe = input->size[0];
+ long framesize = input->size[1];
+ long noframe = (niframe - kW) / dW + 1;
+
+ // get contiguous input
+ input = THTensor_(newContiguous)(input);
+
+ // resize output
+ THTensor_(resize2d)(output, noframe, framesize);
+
+ // indices will contain index locations for each output point
+ THTensor_(resize2d)(indices, noframe, framesize);
+
+ // get raw pointers
+ real *input_data = THTensor_(data)(input);
+ real *output_data = THTensor_(data)(output);
+ real *indices_data = THTensor_(data)(indices);
+
+ long t, x, y;
+ for(t = 0; t < noframe; t++)
+ {
+ real *ip = input_data + t*framesize*dW;
+ real *op = output_data + t*framesize;
+ real *xp = indices_data + t*framesize;
+#pragma omp parallel for private(y)
+ for(y = 0; y < framesize; y++)
+ {
+ // compute local max:
+ long maxindex = -1;
+ real maxval = -THInf;
+ for(x = 0; x < kW; x++)
+ {
+ real val = ip[x*framesize+y];
+ if (val > maxval)
+ {
+ maxval = val;
+ maxindex = x;
+ }
+ }
+
+ // set output to local max
+ op[y] = maxval;
+ xp[y] = (real)maxindex;
+ }
+ }
+
+ // cleanup
+ THTensor_(free)(input);
+
+ return 1;
+}
+
+static int nn_(TemporalMaxPooling_updateGradInput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ THTensor *gradOutput = luaT_checkudata(L, 3, torch_(Tensor_id));
+ int dW = luaT_getfieldcheckint(L, 1, "dW");
+ THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_(Tensor_id));
+ THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id));
+
+ // get contiguous gradOutput
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+
+ // resize and zero
+ THTensor_(resizeAs)(gradInput, input);
+ THTensor_(zero)(gradInput);
+
+ // sizes
+ int noframe = gradOutput->size[0];
+ long framesize = gradOutput->size[1];
+
+ // get raw pointers
+ real *gradInput_data = THTensor_(data)(gradInput);
+ real *gradOutput_data = THTensor_(data)(gradOutput);
+ real *indices_data = THTensor_(data)(indices);
+
+ long t, y;
+ for(t = 0; t < noframe; t++)
+ {
+ real *gip = gradInput_data + t*framesize*dW;
+ real *gop = gradOutput_data + t*framesize;
+ real *xp = indices_data + t*framesize;
+#pragma omp parallel for private(y)
+ for(y = 0; y < framesize; y++)
+ {
+ // compute local max:
+ long maxindex = (long)xp[y];
+ gip[maxindex*framesize+y] += gop[y];
+ }
+ }
+
+ // cleanup
+ THTensor_(free)(gradOutput);
+
+ return 1;
+}
+
+static const struct luaL_Reg nn_(TemporalMaxPooling__) [] = {
+ {"TemporalMaxPooling_updateOutput", nn_(TemporalMaxPooling_updateOutput)},
+ {"TemporalMaxPooling_updateGradInput", nn_(TemporalMaxPooling_updateGradInput)},
+ {NULL, NULL}
+};
+
+static void nn_(TemporalMaxPooling_init)(lua_State *L)
+{
+ luaT_pushmetaclass(L, torch_(Tensor_id));
+ luaT_registeratname(L, nn_(TemporalMaxPooling__), "nn");
+ lua_pop(L,1);
+}
+
+#endif