Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'generic/Max.c')
-rw-r--r--generic/Max.c100
1 files changed, 100 insertions, 0 deletions
diff --git a/generic/Max.c b/generic/Max.c
new file mode 100644
index 0000000..87f52f1
--- /dev/null
+++ b/generic/Max.c
@@ -0,0 +1,100 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/Max.c"
+#else
+
+static int nn_(Max_updateOutput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ int dimension = luaT_getfieldcheckint(L, 1, "dimension")-1;
+ THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_(Tensor_id));
+ THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_(Tensor_id));
+
+ THLongStorage *dim;
+ long i;
+
+ luaL_argcheck(L, dimension >= 0 && dimension < input->nDimension, 2, "dimension out of range");
+
+ dim = THLongStorage_newWithSize(input->nDimension);
+ for(i = 0; i < input->nDimension; i++)
+ dim->data[i] = input->size[i];
+ dim->data[dimension] = 1;
+ THTensor_(resize)(output, dim, NULL);
+ THTensor_(resize)(indices, dim, NULL);
+ THLongStorage_free(dim);
+
+ TH_TENSOR_DIM_APPLY3(real, output, real, input, real, indices, dimension,
+ long theIndex = 0;
+ real theMax = input_data[0];
+ for(i = 1; i < input_size; i++)
+ {
+ if(input_data[i*input_stride] > theMax)
+ {
+ theIndex = i;
+ theMax = input_data[i*input_stride];
+ }
+ }
+ *indices_data = theIndex+1;
+ *output_data = theMax;)
+
+ THTensor_(select)(output, NULL, dimension, 0);
+
+ return 1;
+}
+
+static int nn_(Max_updateGradInput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ THTensor *gradOutput = luaT_checkudata(L, 3, torch_(Tensor_id));
+ THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_(Tensor_id));
+ int dimension = luaT_getfieldcheckint(L, 1, "dimension")-1;
+ THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id));
+
+ THTensor *gradOutputPlusOneDim;
+ THLongStorage *dim, *str;
+ int i, j;
+
+ THTensor_(resizeAs)(gradInput, input);
+ THTensor_(zero)(gradInput);
+
+ dim = THLongStorage_newWithSize(gradOutput->nDimension+1);
+ str = THLongStorage_newWithSize(gradOutput->nDimension+1);
+ for(i = 0, j = 0; j < gradOutput->nDimension+1; j++)
+ {
+ if(j == dimension)
+ {
+ dim->data[j] = input->size[dimension];
+ str->data[j] = 0;
+ continue;
+ }
+
+ dim->data[j] = gradOutput->size[i];
+ str->data[j] = gradOutput->stride[i];
+ i++;
+ }
+
+ gradOutputPlusOneDim = THTensor_(newWithStorage)(gradOutput->storage, gradOutput->storageOffset, dim, str);
+ THLongStorage_free(dim);
+ THLongStorage_free(str);
+
+ TH_TENSOR_DIM_APPLY3(real, gradInput, real, gradOutputPlusOneDim, real, indices, dimension,
+ gradInput_data[ ((long)(*indices_data)-1)*gradInput_stride ] = *gradOutputPlusOneDim_data;)
+
+ THTensor_(free)(gradOutputPlusOneDim);
+
+ return 1;
+}
+
+static const struct luaL_Reg nn_(Max__) [] = {
+ {"Max_updateOutput", nn_(Max_updateOutput)},
+ {"Max_updateGradInput", nn_(Max_updateGradInput)},
+ {NULL, NULL}
+};
+
+static void nn_(Max_init)(lua_State *L)
+{
+ luaT_pushmetaclass(L, torch_(Tensor_id));
+ luaT_registeratname(L, nn_(Max__), "nn");
+ lua_pop(L,1);
+}
+
+#endif