Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-09 00:11:12 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-09 00:11:12 +0400
commitd437475ce27135c35493db46f807370e44139deb (patch)
tree0a9f0d1aa3055db8a9ff37fe98c4cabaca3949dc
parenta44e8ba60e43aaccc019e1d7e3cf1e8bd22795c4 (diff)
Added Sparse (L1) criterion
-rw-r--r--SparseCriterion.lua (renamed from todo/SparseCriterion.lua)10
-rw-r--r--generic/SparseCriterion.c49
-rw-r--r--init.c5
-rw-r--r--init.lua1
-rw-r--r--nnx-1.0-1.rockspec1
5 files changed, 66 insertions, 0 deletions
diff --git a/todo/SparseCriterion.lua b/SparseCriterion.lua
index 79404a0..ddaa75c 100644
--- a/todo/SparseCriterion.lua
+++ b/SparseCriterion.lua
@@ -5,6 +5,16 @@ function SparseCriterion:__init()
self.sizeAverage = true
end
+function SparseCriterion:forward(input)
+ input.nn.SparseCriterion_forward(self, input)
+ return self.output
+end
+
+function SparseCriterion:backward(input)
+ input.nn.SparseCriterion_backward(self, input)
+ return self.gradInput
+end
+
function SparseCriterion:write(file)
parent.write(self, file)
file:writeBool(self.sizeAverage)
diff --git a/generic/SparseCriterion.c b/generic/SparseCriterion.c
new file mode 100644
index 0000000..76da569
--- /dev/null
+++ b/generic/SparseCriterion.c
@@ -0,0 +1,49 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/SparseCriterion.c"
+#else
+
+static int nn_(SparseCriterion_forward)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
+ real sum = 0;
+
+ TH_TENSOR_APPLY(real, input, sum += fabs(*input_data);)
+
+ if(sizeAverage) sum /= THTensor_(nElement)(input);
+
+ lua_pushnumber(L, sum);
+ lua_setfield(L, 1, "output");
+
+ lua_pushnumber(L, sum);
+ return 1;
+}
+
+static int nn_(SparseCriterion_backward)(lua_State *L)
+{
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage");
+ THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id));
+ real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.);
+
+ THTensor_(resizeAs)(gradInput, input);
+ TH_TENSOR_APPLY2(real, gradInput, real, input,
+ *gradInput_data = ( *input_data >= 0 ? norm : -norm);)
+
+ return 1;
+}
+
+static const struct luaL_Reg nn_(SparseCriterion__) [] = {
+ {"SparseCriterion_forward", nn_(SparseCriterion_forward)},
+ {"SparseCriterion_backward", nn_(SparseCriterion_backward)},
+ {NULL, NULL}
+};
+
+static void nn_(SparseCriterion_init)(lua_State *L)
+{
+ luaT_pushmetaclass(L, torch_(Tensor_id));
+ luaT_registeratname(L, nn_(SparseCriterion__), "nn");
+ lua_pop(L,1);
+}
+
+#endif
diff --git a/init.c b/init.c
index 50c5482..b8238c4 100644
--- a/init.c
+++ b/init.c
@@ -33,6 +33,9 @@ static const void* torch_DoubleTensor_id = NULL;
#include "generic/SpatialReSampling.c"
#include "THGenerateFloatTypes.h"
+#include "generic/SparseCriterion.c"
+#include "THGenerateFloatTypes.h"
+
#include "generic/Threshold.c"
#include "THGenerateFloatTypes.h"
@@ -50,6 +53,7 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L)
nn_FloatSpatialMaxPooling_init(L);
nn_FloatSpatialUpSampling_init(L);
nn_FloatSpatialReSampling_init(L);
+ nn_FloatSparseCriterion_init(L);
nn_DoubleSpatialLinear_init(L);
nn_DoubleHardShrink_init(L);
@@ -60,6 +64,7 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L)
nn_DoubleSpatialMaxPooling_init(L);
nn_DoubleSpatialUpSampling_init(L);
nn_DoubleSpatialReSampling_init(L);
+ nn_DoubleSparseCriterion_init(L);
return 1;
}
diff --git a/init.lua b/init.lua
index cc53e7e..7d92558 100644
--- a/init.lua
+++ b/init.lua
@@ -82,6 +82,7 @@ torch.include('nnx', 'SpatialFovea.lua')
-- criterions:
torch.include('nnx', 'SuperCriterion.lua')
+torch.include('nnx', 'SparseCriterion.lua')
torch.include('nnx', 'SpatialMSECriterion.lua')
torch.include('nnx', 'SpatialClassNLLCriterion.lua')
diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec
index 1073da2..7026d4d 100644
--- a/nnx-1.0-1.rockspec
+++ b/nnx-1.0-1.rockspec
@@ -86,6 +86,7 @@ build = {
install_files(/lua/nnx SpatialFovea.lua)
install_files(/lua/nnx SpatialMSECriterion.lua)
install_files(/lua/nnx SpatialClassNLLCriterion.lua)
+ install_files(/lua/nnx SparseCriterion.lua)
add_subdirectory (test)
install_targets(/lib nnx)
]],