Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--SpatialLogSoftMax.lua17
-rw-r--r--generic/SpatialLogSoftMax.c116
-rw-r--r--init.c5
-rw-r--r--init.lua1
-rw-r--r--nnx-1.0-1.rockspec1
5 files changed, 140 insertions, 0 deletions
diff --git a/SpatialLogSoftMax.lua b/SpatialLogSoftMax.lua
new file mode 100644
index 0000000..d145caa
--- /dev/null
+++ b/SpatialLogSoftMax.lua
@@ -0,0 +1,17 @@
+local SpatialLogSoftMax, parent = torch.class('nn.SpatialLogSoftMax', 'nn.Module')
+
+function SpatialLogSoftMax:__init()
+ parent.__init(self)
+end
+
+function SpatialLogSoftMax:forward(input)
+ self.output:resizeAs(input)
+ input.nn.SpatialLogSoftMax_forward(self, input)
+ return self.output
+end
+
+function SpatialLogSoftMax:backward(input, gradOutput)
+ self.gradInput:resizeAs(input)
+ input.nn.SpatialLogSoftMax_backward(self, input, gradOutput)
+ return self.gradInput
+end
diff --git a/generic/SpatialLogSoftMax.c b/generic/SpatialLogSoftMax.c
new file mode 100644
index 0000000..0fc3d9d
--- /dev/null
+++ b/generic/SpatialLogSoftMax.c
@@ -0,0 +1,116 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/SpatialLogSoftMax.c"
+#else
+
+static int nn_(SpatialLogSoftMax_forward)(lua_State *L)
+{
+ // get all params
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_(Tensor_id));
+
+ // dims
+ int width = input->size[0];
+ int height = input->size[1];
+
+ // select planes
+ THTensor *input_row = THTensor_(new)();
+ THTensor *input_point = THTensor_(new)();
+ THTensor *output_row = THTensor_(new)();
+ THTensor *output_point = THTensor_(new)();
+
+ // process the whole plane
+ int x,y;
+ for (y=0; y<height; y++) {
+ THTensor_(select)(input_row, input, 1, y);
+ THTensor_(select)(output_row, output, 1, y);
+ for (x=0; x<width; x++) {
+ THTensor_(select)(input_point, input_row, 1, x);
+ THTensor_(select)(output_point, output_row, 1, x);
+
+ real sum = THLogZero;
+
+ TH_TENSOR_APPLY2(real, output_point, real, input_point, \
+ real z = *input_point_data; \
+ *output_point_data = z; \
+ sum = THLogAdd(sum, z);)
+
+ THTensor_(add)(output_point, -sum);
+ }
+ }
+
+ // cleanup
+ THTensor_(free)(input_row);
+ THTensor_(free)(input_point);
+ THTensor_(free)(output_row);
+ THTensor_(free)(output_point);
+ return 1;
+}
+
+static int nn_(SpatialLogSoftMax_backward)(lua_State *L)
+{
+ // get all params
+ THTensor *input = luaT_checkudata(L, 2, torch_(Tensor_id));
+ THTensor *gradOutput = luaT_checkudata(L, 3, torch_(Tensor_id));
+ THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_(Tensor_id));
+ THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_(Tensor_id));
+
+ // dims
+ int width = input->size[0];
+ int height = input->size[1];
+
+ // zero gradInput
+ THTensor_(zero)(gradInput);
+
+ // select planes
+ THTensor *gradOutput_row = THTensor_(new)();
+ THTensor *gradOutput_point = THTensor_(new)();
+ THTensor *gradInput_row = THTensor_(new)();
+ THTensor *gradInput_point = THTensor_(new)();
+ THTensor *output_row = THTensor_(new)();
+ THTensor *output_point = THTensor_(new)();
+
+ // compute gradients for each point
+ int x,y;
+ for (y=0; y<height; y++) {
+ THTensor_(select)(gradInput_row, gradInput, 1, y);
+ THTensor_(select)(gradOutput_row, gradOutput, 1, y);
+ THTensor_(select)(output_row, output, 1, y);
+ for (x=0; x<width; x++) {
+ THTensor_(select)(gradInput_point, gradInput_row, 1, x);
+ THTensor_(select)(gradOutput_point, gradOutput_row, 1, x);
+ THTensor_(select)(output_point, output_row, 1, x);
+
+ real sum = THTensor_(sum)(gradOutput_point);
+
+ TH_TENSOR_APPLY3(real, gradInput_point, \
+ real, gradOutput_point, \
+ real, output_point, \
+ *gradInput_point_data = *gradOutput_point_data - exp(*output_point_data)*sum;);
+ }
+ }
+
+ // cleanup
+ THTensor_(free)(gradInput_row);
+ THTensor_(free)(gradInput_point);
+ THTensor_(free)(gradOutput_row);
+ THTensor_(free)(gradOutput_point);
+ THTensor_(free)(output_row);
+ THTensor_(free)(output_point);
+
+ return 1;
+}
+
+static const struct luaL_Reg nn_(SpatialLogSoftMax__) [] = {
+ {"SpatialLogSoftMax_forward", nn_(SpatialLogSoftMax_forward)},
+ {"SpatialLogSoftMax_backward", nn_(SpatialLogSoftMax_backward)},
+ {NULL, NULL}
+};
+
+static void nn_(SpatialLogSoftMax_init)(lua_State *L)
+{
+ luaT_pushmetaclass(L, torch_(Tensor_id));
+ luaT_registeratname(L, nn_(SpatialLogSoftMax__), "nn");
+ lua_pop(L,1);
+}
+
+#endif
diff --git a/init.c b/init.c
index 2d09c91..5f25a71 100644
--- a/init.c
+++ b/init.c
@@ -20,6 +20,9 @@ static const void* torch_DoubleTensor_id = NULL;
#include "generic/SpatialConvolutionTable.c"
#include "THGenerateFloatTypes.h"
+#include "generic/SpatialLogSoftMax.c"
+#include "THGenerateFloatTypes.h"
+
#include "generic/Threshold.c"
#include "THGenerateFloatTypes.h"
@@ -33,12 +36,14 @@ DLL_EXPORT int luaopen_libnnx(lua_State *L)
nn_FloatAbs_init(L);
nn_FloatThreshold_init(L);
nn_FloatSpatialConvolutionTable_init(L);
+ nn_FloatSpatialLogSoftMax_init(L);
nn_DoubleSpatialLinear_init(L);
nn_DoubleHardShrink_init(L);
nn_DoubleAbs_init(L);
nn_DoubleThreshold_init(L);
nn_DoubleSpatialConvolutionTable_init(L);
+ nn_DoubleSpatialLogSoftMax_init(L);
return 1;
}
diff --git a/init.lua b/init.lua
index 096642b..ec814d3 100644
--- a/init.lua
+++ b/init.lua
@@ -60,6 +60,7 @@ torch.include('nnx', 'Narrow.lua')
-- spatial (images) operators:
torch.include('nnx', 'SpatialLinear.lua')
+torch.include('nnx', 'SpatialLogSoftMax.lua')
torch.include('nnx', 'SpatialConvolutionTable.lua')
-- criterions:
diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec
index 49284a8..44ae451 100644
--- a/nnx-1.0-1.rockspec
+++ b/nnx-1.0-1.rockspec
@@ -57,6 +57,7 @@ build = {
install_files(/lua/nnx Sqrt.lua)
install_files(/lua/nnx Threshold.lua)
install_files(/lua/nnx SpatialConvolutionTable.lua)
+ install_files(/lua/nnx SpatialLogSoftMax.lua)
install_files(/lua/nnx SpatialLinear.lua)
install_files(/lua/nnx SuperCriterion.lua)
add_subdirectory (test)