diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-07-09 00:11:12 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-07-09 00:11:12 +0400 |
commit | d437475ce27135c35493db46f807370e44139deb (patch) | |
tree | 0a9f0d1aa3055db8a9ff37fe98c4cabaca3949dc | |
parent | a44e8ba60e43aaccc019e1d7e3cf1e8bd22795c4 (diff) |
Added Sparse (L1) criterion
-rw-r--r-- | SparseCriterion.lua (renamed from todo/SparseCriterion.lua) | 10 | ||||
-rw-r--r-- | generic/SparseCriterion.c | 49 | ||||
-rw-r--r-- | init.c | 5 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | nnx-1.0-1.rockspec | 1 |
5 files changed, 66 insertions, 0 deletions
diff --git a/todo/SparseCriterion.lua b/SparseCriterion.lua index 79404a0..ddaa75c 100644 --- a/todo/SparseCriterion.lua +++ b/SparseCriterion.lua @@ -5,6 +5,16 @@ function SparseCriterion:__init() self.sizeAverage = true end +function SparseCriterion:forward(input) + input.nn.SparseCriterion_forward(self, input) + return self.output +end + +function SparseCriterion:backward(input) + input.nn.SparseCriterion_backward(self, input) + return self.gradInput +end + function SparseCriterion:write(file) parent.write(self, file) file:writeBool(self.sizeAverage) diff --git a/generic/SparseCriterion.c b/generic/SparseCriterion.c new file mode 100644 index 0000000..76da569 --- /dev/null +++ b/generic/SparseCriterion.c @@ -0,0 +1,49 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/SparseCriterion.c" +#else + +static int nn_(SparseCriterion_forward)(lua_State *L) +{ + THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id)); + int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage"); + real sum = 0; + + TH_TENSOR_APPLY(real, input, sum += fabs(*input_data);) + + if(sizeAverage) sum /= THTensor_(nElement)(input); + + lua_pushnumber(L, sum); + lua_setfield(L, 1, "output"); + + lua_pushnumber(L, sum); + return 1; +} + +static int nn_(SparseCriterion_backward)(lua_State *L) +{ + THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id)); + int sizeAverage = luaT_getfieldcheckboolean(L, 1, "sizeAverage"); + THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id)); + real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.); + + THTensor_(resizeAs)(gradInput, input); + TH_TENSOR_APPLY2(real, gradInput, real, input, + *gradInput_data = ( *input_data >= 0 ? norm : -norm);) + + return 1; +} + +static const struct luaL_Reg nn_(SparseCriterion__) [] = { + {"SparseCriterion_forward", nn_(SparseCriterion_forward)}, + {"SparseCriterion_backward", nn_(SparseCriterion_backward)}, + {NULL, NULL} +}; + +static void nn_(SparseCriterion_init)(lua_State *L) +{ + luaT_pushmetaclass(L, torch_(Tensor_id)); + luaT_registeratname(L, nn_(SparseCriterion__), "nn"); + lua_pop(L,1); +} + +#endif @@ -33,6 +33,9 @@ static const void* torch_DoubleTensor_id = NULL; #include "generic/SpatialReSampling.c" #include "THGenerateFloatTypes.h" +#include "generic/SparseCriterion.c" +#include "THGenerateFloatTypes.h" + #include "generic/Threshold.c" #include "THGenerateFloatTypes.h" @@ -50,6 +53,7 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L) nn_FloatSpatialMaxPooling_init(L); nn_FloatSpatialUpSampling_init(L); nn_FloatSpatialReSampling_init(L); + nn_FloatSparseCriterion_init(L); nn_DoubleSpatialLinear_init(L); nn_DoubleHardShrink_init(L); @@ -60,6 +64,7 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L) nn_DoubleSpatialMaxPooling_init(L); nn_DoubleSpatialUpSampling_init(L); nn_DoubleSpatialReSampling_init(L); + nn_DoubleSparseCriterion_init(L); return 1; } @@ -82,6 +82,7 @@ torch.include('nnx', 'SpatialFovea.lua') -- criterions: torch.include('nnx', 'SuperCriterion.lua') +torch.include('nnx', 'SparseCriterion.lua') torch.include('nnx', 'SpatialMSECriterion.lua') torch.include('nnx', 'SpatialClassNLLCriterion.lua') diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec index 1073da2..7026d4d 100644 --- a/nnx-1.0-1.rockspec +++ b/nnx-1.0-1.rockspec @@ -86,6 +86,7 @@ build = { install_files(/lua/nnx SpatialFovea.lua) install_files(/lua/nnx SpatialMSECriterion.lua) install_files(/lua/nnx SpatialClassNLLCriterion.lua) + install_files(/lua/nnx SparseCriterion.lua) add_subdirectory (test) install_targets(/lib nnx) ]], |