Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'generic/L1Cost.c')
-rw-r--r--generic/L1Cost.c49
1 files changed, 49 insertions, 0 deletions
diff --git a/generic/L1Cost.c b/generic/L1Cost.c
new file mode 100644
index 0000000..a450e06
--- /dev/null
+++ b/generic/L1Cost.c
@@ -0,0 +1,49 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/L1Cost.c"
+#else
+
+static int nn_(L1Cost_updateOutput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ accreal sum;
+
+ sum = 0;
+ TH_TENSOR_APPLY(real, input, sum += fabs(*input_data););
+
+ lua_pushnumber(L, sum);
+ lua_setfield(L, 1, "output");
+
+ lua_pushnumber(L, sum);
+ return 1;
+}
+
+static int nn_(L1Cost_updateGradInput)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
+ THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
+
+ THTensor_(resizeAs)(gradInput, input);
+ TH_TENSOR_APPLY2(real, gradInput, real, input,
+ if (*input_data > 0)
+ *gradInput_data = 1;
+ else if (*input_data < 0)
+ *gradInput_data = -1;
+ else
+ *gradInput_data = 0;);
+ return 1;
+}
+
+static const struct luaL_Reg nn_(L1Cost__) [] = {
+ {"L1Cost_updateOutput", nn_(L1Cost_updateOutput)},
+ {"L1Cost_updateGradInput", nn_(L1Cost_updateGradInput)},
+ {NULL, NULL}
+};
+
+static void nn_(L1Cost_init)(lua_State *L)
+{
+ luaT_pushmetatable(L, torch_Tensor);
+ luaT_registeratname(L, nn_(L1Cost__), "nn");
+ lua_pop(L,1);
+}
+
+#endif