diff options
author | Angela Fan <angelafan@fb.com> | 2017-01-20 09:18:08 +0300 |
---|---|---|
committer | Angela Fan <angelafan@fb.com> | 2017-01-20 09:18:08 +0300 |
commit | bc90075dcd8117e3ae5aec0ae51499232d24f98d (patch) | |
tree | d4b75369e85f56ab7e1a999bbb89062532632091 | |
parent | 5989f82800a640ed0f5613c8ef3e417c4502661d (diff) |
added c implementation of GatedLinearUnit
-rw-r--r-- | GatedLinearUnit.lua | 45 | ||||
-rw-r--r-- | lib/THNN/generic/GatedLinearUnit.c | 71 | ||||
-rw-r--r-- | lib/THNN/generic/THNN.h | 12 | ||||
-rw-r--r-- | lib/THNN/init.c | 3 |
4 files changed, 101 insertions, 30 deletions
diff --git a/GatedLinearUnit.lua b/GatedLinearUnit.lua index 5f215ca..5273abf 100644 --- a/GatedLinearUnit.lua +++ b/GatedLinearUnit.lua @@ -2,41 +2,26 @@ local GatedLinearUnit, parent = torch.class('nn.GatedLinearUnit', 'nn.Module') function GatedLinearUnit:__init(dim) parent.__init(self) - self.sigmoid = nn.Sigmoid() self.dim = dim end function GatedLinearUnit:updateOutput(input) - local dim = self.dim or input:dim() - local inputSize = input:size(dim) - - assert(inputSize % 2 == 0, "halving dimension needs to be even") - - self.fHalf = input:narrow(dim, 1, inputSize/2) - self.sHalf = input:narrow(dim, inputSize/2 + 1, inputSize/2) - - self.sHalfOut = self.sigmoid:forward(self.sHalf) - self.output:resizeAs(self.fHalf):copy(self.fHalf):cmul(self.sHalfOut) - - return self.output + local dim = self.dim or input:dim() + input.THNN.GatedLinear_updateOutput( + input:cdata(), + self.output:cdata(), + dim + ) + return self.output end function GatedLinearUnit:updateGradInput(input, gradOutput) - local dim = self.dim or input:dim() - local inputSize = input:size(dim) - - assert(inputSize % 2 == 0, "halving dimension needs to be even") - - local fGradInput = self.sHalfOut - local sGradInput = self.sigmoid:backward(self.sHalf, gradOutput) - :cmul(self.fHalf) - - self.gradInput:resizeAs(input) - self.gradInput:narrow(dim, 1, inputSize/2) - :copy(fGradInput) - :cmul(gradOutput) - self.gradInput:narrow(dim, inputSize/2+1, inputSize/2) - :copy(sGradInput) - - return self.gradInput + local dim = self.dim or input:dim() + input.THNN.GatedLinear_updateGradInput( + input:cdata(), + gradOutput:cdata(), + self.gradInput:cdata(), + dim + ) + return self.gradInput end diff --git a/lib/THNN/generic/GatedLinearUnit.c b/lib/THNN/generic/GatedLinearUnit.c new file mode 100644 index 0000000..d412a7b --- /dev/null +++ b/lib/THNN/generic/GatedLinearUnit.c @@ -0,0 +1,71 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/GatedLinearUnit.c" +#else + +void THNN_(GatedLinear_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + int dim) +{ + // size output to half of input + dim = dim - 1; + const long nIn = THTensor_(size)(input, dim); + THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn); + + const long inputSize = THTensor_(size)(input, dim) / 2; + THLongStorage *newSizes = THTensor_(newSizeOf)(input); + THLongStorage_set(newSizes, dim, inputSize); + THTensor_(resize)(output, newSizes, NULL); + + // halve tensor + THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize); + THTensor *secondHalf = THTensor_(newNarrow)(input, dim, inputSize, inputSize); + + // x = x1:cmul( sigmoid(x2) ) + THTensor_(sigmoid)(output, secondHalf); + THTensor_(cmul)(output, output, firstHalf); + + THLongStorage_free(newSizes); + THTensor_(free)(firstHalf); + THTensor_(free)(secondHalf); +} + +void THNN_(GatedLinear_updateGradInput)( + THNNState *state, + THTensor *input, + THTensor *gradOutput, + THTensor *gradInput, + int dim) +{ + // set up tensors + dim = dim - 1; + const long nIn = THTensor_(size)(input, dim); + THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn); + + THTensor_(resizeAs)(gradInput, input); + const long inputSize = THTensor_(size)(input, dim) / 2; + THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize); + THTensor *secondHalf = THTensor_(newNarrow)(input, dim, inputSize, inputSize); + THTensor *gradInputfirstHalf = THTensor_(newNarrow)(gradInput, dim, 0, inputSize); + THTensor *gradInputsecondHalf = THTensor_(newNarrow)(gradInput, dim, inputSize, inputSize); + + THTensor_(sigmoid)(gradInputfirstHalf, secondHalf); + + TH_TENSOR_APPLY2(real, gradInputsecondHalf, real, gradInputfirstHalf, + real z = *gradInputfirstHalf_data; + *gradInputsecondHalf_data = (1. - z) * z; + ); + + THTensor_(cmul)(gradInputfirstHalf, gradInputfirstHalf, gradOutput); + + THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, gradOutput); + THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, firstHalf); + + THTensor_(free)(firstHalf); + THTensor_(free)(secondHalf); + THTensor_(free)(gradInputfirstHalf); + THTensor_(free)(gradInputsecondHalf); +} + +#endif diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h index 8fd50f5..447289b 100644 --- a/lib/THNN/generic/THNN.h +++ b/lib/THNN/generic/THNN.h @@ -102,6 +102,18 @@ TH_API void THNN_(DistKLDivCriterion_updateGradInput)( THTensor *gradInput, // [OUT] gradient w.r.t. input bool sizeAverage); // if true, the loss will be normalized **by total number of elements** +TH_API void THNN_(GatedLinear_updateOutput)( + THNNState *state, // library's state + THTensor *input, // input tensor + THTensor *output, // [OUT] output tensor, half size of input along dimension dim + int dim); // dimension for halving operation +TH_API void THNN_(GatedLinear_updateGradInput)( + THNNState *state, // library's state + THTensor *input, // input tensor + THTensor *gradOutput, // gradient w.r.t module's output + THTensor *gradInput, // [OUT] gradient w.r.t input + int dim); // dimension for halving operation + // HardShink outputs 0 on interval of (-lambda; lambda) or original value otherwise. TH_API void THNN_(HardShrink_updateOutput)( THNNState *state, // library's state diff --git a/lib/THNN/init.c b/lib/THNN/init.c index 3a7806d..74adb8d 100644 --- a/lib/THNN/init.c +++ b/lib/THNN/init.c @@ -89,6 +89,9 @@ #include "generic/HardTanh.c" #include "THGenerateFloatTypes.h" +#include "generic/GatedLinearUnit.c" +#include "THGenerateFloatTypes.h" + #include "generic/L1Cost.c" #include "THGenerateFloatTypes.h" |