Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAngela Fan <angelafan@fb.com>2017-01-20 09:18:08 +0300
committerAngela Fan <angelafan@fb.com>2017-01-20 09:18:08 +0300
commitbc90075dcd8117e3ae5aec0ae51499232d24f98d (patch)
treed4b75369e85f56ab7e1a999bbb89062532632091
parent5989f82800a640ed0f5613c8ef3e417c4502661d (diff)
added c implementation of GatedLinearUnit
-rw-r--r--GatedLinearUnit.lua45
-rw-r--r--lib/THNN/generic/GatedLinearUnit.c71
-rw-r--r--lib/THNN/generic/THNN.h12
-rw-r--r--lib/THNN/init.c3
4 files changed, 101 insertions, 30 deletions
diff --git a/GatedLinearUnit.lua b/GatedLinearUnit.lua
index 5f215ca..5273abf 100644
--- a/GatedLinearUnit.lua
+++ b/GatedLinearUnit.lua
@@ -2,41 +2,26 @@ local GatedLinearUnit, parent = torch.class('nn.GatedLinearUnit', 'nn.Module')
function GatedLinearUnit:__init(dim)
parent.__init(self)
- self.sigmoid = nn.Sigmoid()
self.dim = dim
end
function GatedLinearUnit:updateOutput(input)
- local dim = self.dim or input:dim()
- local inputSize = input:size(dim)
-
- assert(inputSize % 2 == 0, "halving dimension needs to be even")
-
- self.fHalf = input:narrow(dim, 1, inputSize/2)
- self.sHalf = input:narrow(dim, inputSize/2 + 1, inputSize/2)
-
- self.sHalfOut = self.sigmoid:forward(self.sHalf)
- self.output:resizeAs(self.fHalf):copy(self.fHalf):cmul(self.sHalfOut)
-
- return self.output
+ local dim = self.dim or input:dim()
+ input.THNN.GatedLinear_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ dim
+ )
+ return self.output
end
function GatedLinearUnit:updateGradInput(input, gradOutput)
- local dim = self.dim or input:dim()
- local inputSize = input:size(dim)
-
- assert(inputSize % 2 == 0, "halving dimension needs to be even")
-
- local fGradInput = self.sHalfOut
- local sGradInput = self.sigmoid:backward(self.sHalf, gradOutput)
- :cmul(self.fHalf)
-
- self.gradInput:resizeAs(input)
- self.gradInput:narrow(dim, 1, inputSize/2)
- :copy(fGradInput)
- :cmul(gradOutput)
- self.gradInput:narrow(dim, inputSize/2+1, inputSize/2)
- :copy(sGradInput)
-
- return self.gradInput
+ local dim = self.dim or input:dim()
+ input.THNN.GatedLinear_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ dim
+ )
+ return self.gradInput
end
diff --git a/lib/THNN/generic/GatedLinearUnit.c b/lib/THNN/generic/GatedLinearUnit.c
new file mode 100644
index 0000000..d412a7b
--- /dev/null
+++ b/lib/THNN/generic/GatedLinearUnit.c
@@ -0,0 +1,71 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/GatedLinearUnit.c"
+#else
+
+void THNN_(GatedLinear_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ int dim)
+{
+ // size output to half of input
+ dim = dim - 1;
+ const long nIn = THTensor_(size)(input, dim);
+ THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn);
+
+ const long inputSize = THTensor_(size)(input, dim) / 2;
+ THLongStorage *newSizes = THTensor_(newSizeOf)(input);
+ THLongStorage_set(newSizes, dim, inputSize);
+ THTensor_(resize)(output, newSizes, NULL);
+
+ // halve tensor
+ THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize);
+ THTensor *secondHalf = THTensor_(newNarrow)(input, dim, inputSize, inputSize);
+
+ // x = x1:cmul( sigmoid(x2) )
+ THTensor_(sigmoid)(output, secondHalf);
+ THTensor_(cmul)(output, output, firstHalf);
+
+ THLongStorage_free(newSizes);
+ THTensor_(free)(firstHalf);
+ THTensor_(free)(secondHalf);
+}
+
+void THNN_(GatedLinear_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ int dim)
+{
+ // set up tensors
+ dim = dim - 1;
+ const long nIn = THTensor_(size)(input, dim);
+ THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn);
+
+ THTensor_(resizeAs)(gradInput, input);
+ const long inputSize = THTensor_(size)(input, dim) / 2;
+ THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize);
+ THTensor *secondHalf = THTensor_(newNarrow)(input, dim, inputSize, inputSize);
+ THTensor *gradInputfirstHalf = THTensor_(newNarrow)(gradInput, dim, 0, inputSize);
+ THTensor *gradInputsecondHalf = THTensor_(newNarrow)(gradInput, dim, inputSize, inputSize);
+
+ THTensor_(sigmoid)(gradInputfirstHalf, secondHalf);
+
+ TH_TENSOR_APPLY2(real, gradInputsecondHalf, real, gradInputfirstHalf,
+ real z = *gradInputfirstHalf_data;
+ *gradInputsecondHalf_data = (1. - z) * z;
+ );
+
+ THTensor_(cmul)(gradInputfirstHalf, gradInputfirstHalf, gradOutput);
+
+ THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, gradOutput);
+ THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, firstHalf);
+
+ THTensor_(free)(firstHalf);
+ THTensor_(free)(secondHalf);
+ THTensor_(free)(gradInputfirstHalf);
+ THTensor_(free)(gradInputsecondHalf);
+}
+
+#endif
diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h
index 8fd50f5..447289b 100644
--- a/lib/THNN/generic/THNN.h
+++ b/lib/THNN/generic/THNN.h
@@ -102,6 +102,18 @@ TH_API void THNN_(DistKLDivCriterion_updateGradInput)(
THTensor *gradInput, // [OUT] gradient w.r.t. input
bool sizeAverage); // if true, the loss will be normalized **by total number of elements**
+TH_API void THNN_(GatedLinear_updateOutput)(
+ THNNState *state, // library's state
+ THTensor *input, // input tensor
+ THTensor *output, // [OUT] output tensor, half size of input along dimension dim
+ int dim); // dimension for halving operation
+TH_API void THNN_(GatedLinear_updateGradInput)(
+ THNNState *state, // library's state
+ THTensor *input, // input tensor
+ THTensor *gradOutput, // gradient w.r.t module's output
+ THTensor *gradInput, // [OUT] gradient w.r.t input
+ int dim); // dimension for halving operation
+
// HardShink outputs 0 on interval of (-lambda; lambda) or original value otherwise.
TH_API void THNN_(HardShrink_updateOutput)(
THNNState *state, // library's state
diff --git a/lib/THNN/init.c b/lib/THNN/init.c
index 3a7806d..74adb8d 100644
--- a/lib/THNN/init.c
+++ b/lib/THNN/init.c
@@ -89,6 +89,9 @@
#include "generic/HardTanh.c"
#include "THGenerateFloatTypes.h"
+#include "generic/GatedLinearUnit.c"
+#include "THGenerateFloatTypes.h"
+
#include "generic/L1Cost.c"
#include "THGenerateFloatTypes.h"