diff options
author | Angela Fan <angelafan@fb.com> | 2017-01-27 09:56:33 +0300 |
---|---|---|
committer | Angela Fan <angelafan@fb.com> | 2017-01-27 09:56:33 +0300 |
commit | d253af054affc75d42efc6c33bae5e8673f8a563 (patch) | |
tree | 5ff6a345bcf25041b1850142fd0adb8e835bfe45 | |
parent | 5fa193a84ca8fe112bf6e75487ac96eb1b1239d2 (diff) |
cuda implementation of Gated Linear Unit, fixed issues with genericization
-rw-r--r-- | lib/THCUNN/GatedLinearUnit.cu | 29 | ||||
-rw-r--r-- | lib/THCUNN/generic/GatedLinearUnit.cu | 64 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 13 | ||||
-rw-r--r-- | test.lua | 11 |
4 files changed, 117 insertions, 0 deletions
diff --git a/lib/THCUNN/GatedLinearUnit.cu b/lib/THCUNN/GatedLinearUnit.cu new file mode 100644 index 0000000..d9f3471 --- /dev/null +++ b/lib/THCUNN/GatedLinearUnit.cu @@ -0,0 +1,29 @@ +#include "THCUNN.h" +#include "THCHalf.h" +#include "THCHalfAutoNumerics.cuh" +#include "common.h" + +template <typename Dtype, typename Acctype> +struct gatedLinearCSigMul_functor +{ + __device__ void operator()(Dtype *target, const Dtype *sigTensor, const Dtype *mulTensor) const + { + const Acctype sigNum = Acctype(1)/(Acctype(1)+ exp(ScalarConvert<Dtype, Acctype>::to(-*sigTensor))); + const Dtype mulNum = *mulTensor; + *target = ScalarConvert<Acctype, Dtype>::to(sigNum * mulNum); + } +}; + +template <typename Dtype, typename Acctype> +struct gatedLinearDerivativeSecondHalf_functor +{ + __device__ void operator()(Dtype *target, const Dtype *sigTensor, const Dtype *mulTensor) const + { + const Acctype sigNum = Acctype(1)/(Acctype(1)+ exp(ScalarConvert<Dtype, Acctype>::to(-*sigTensor))); + const Dtype mulNum = *mulTensor; + *target *= ScalarConvert<Acctype, Dtype>::to((Acctype(1) - sigNum) * sigNum * mulNum); + } +}; + +#include "generic/GatedLinearUnit.cu" +#include "THCGenerateFloatTypes.h"
\ No newline at end of file diff --git a/lib/THCUNN/generic/GatedLinearUnit.cu b/lib/THCUNN/generic/GatedLinearUnit.cu new file mode 100644 index 0000000..0684878 --- /dev/null +++ b/lib/THCUNN/generic/GatedLinearUnit.cu @@ -0,0 +1,64 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/GatedLinearUnit.cu" +#else + +void THNN_(GatedLinear_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + int dim) +{ + THCUNN_assertSameGPU(state, 2, input, output); + + // size output to half of input + dim = dim - 1; + const long nIn = THCTensor_(size)(state, input, dim); + THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn); + const long inputSize = THCTensor_(size)(state, input, dim) / 2; + THLongStorage *newSizes = THCTensor_(newSizeOf)(state, input); + THLongStorage_set(newSizes, dim, inputSize); + THCTensor_(resize)(state, output, newSizes, NULL); + + // halve tensor + THCTensor *firstHalf = THCTensor_(newNarrow)(state, input, dim, 0, inputSize); + THCTensor *secondHalf = THCTensor_(newNarrow)(state, input, dim, inputSize, inputSize); + + // x = x1:cmul( sigmoid(x2) ) + THC_pointwiseApply3(state, output, secondHalf, firstHalf, gatedLinearCSigMul_functor<real, accreal>()); + + THLongStorage_free(newSizes); + THCTensor_(free)(state, firstHalf); + THCTensor_(free)(state, secondHalf); +} + +void THNN_(GatedLinear_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + int dim) +{ + THCUNN_assertSameGPU(state, 2, gradOutput, gradInput); + dim = dim - 1; + const long nIn = THCTensor_(size)(state, input, dim); + THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", dim+1, nIn); + + THCTensor_(resizeAs)(state, gradInput, input); + const long inputSize = THCTensor_(size)(state, input, dim) / 2; + THCTensor *firstHalf = THCTensor_(newNarrow)(state, input, dim, 0, inputSize); + THCTensor *secondHalf = THCTensor_(newNarrow)(state, input, dim, inputSize, inputSize); + THCTensor *gradInputfirstHalf = THCTensor_(newNarrow)(state, gradInput, dim, 0, inputSize); + THCTensor *gradInputsecondHalf = THCTensor_(newNarrow)(state, gradInput, dim, inputSize, inputSize); + // first half of derivative + THC_pointwiseApply3(state, gradInputfirstHalf, secondHalf, gradOutput, gatedLinearCSigMul_functor<real, accreal>()); + // second half of derivative + THCTensor_(copy)(state, gradInputsecondHalf, firstHalf); + THC_pointwiseApply3(state, gradInputsecondHalf, secondHalf, gradOutput, gatedLinearDerivativeSecondHalf_functor<real, accreal>()); + + THCTensor_(free)(state, firstHalf); + THCTensor_(free)(state, secondHalf); + THCTensor_(free)(state, gradInputfirstHalf); + THCTensor_(free)(state, gradInputsecondHalf); +} + +#endif
\ No newline at end of file diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index bf903b9..8346e59 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -138,6 +138,19 @@ TH_API void THNN_(HardTanh_updateGradInput)( real max_val, bool inplace); +TH_API void THNN_(GatedLinear_updateOutput)( + THCState *state, + THCudaTensor *input, + THCudaTensor *output, + int dim); + +TH_API void THNN_(GatedLinear_updateGradInput)( + THCState *state, + THCudaTensor *input, + THCudaTensor *gradOutput, + THCudaTensor *gradInput, + int dim); + TH_API void THNN_(LeakyReLU_updateOutput)( THCState *state, THCTensor *input, @@ -84,6 +84,7 @@ end local function pointwise_forward(proto_module, name, max_error) local size = math.random(1,100) + if name == 'GatedLinearUnit' then size = size*2 end for k, typename in ipairs(typenames) do local input = torch.randn(size):type(typename) @@ -105,10 +106,12 @@ end local function pointwise_backward(proto_module, name, max_error) local size = math.random(1,100) + if name == 'GatedLinearUnit' then size = size*2 end for k, typename in ipairs(typenames) do local input = torch.randn(size):type(typename) local gradOutput = torch.randn(size):type(typename) + if name == 'GatedLinearUnit' then gradOutput = torch.randn(size/2) end local ctype = t2cpu[typename] input = makeNonContiguous(input:type(ctype)) @@ -267,6 +270,14 @@ function cunntest.LogSigmoid_transposed() pointwise_transposed(nn.LogSigmoid(), 'LogSigmoid', 1e-6) end +function cunntest.GatedLinearUnit_forward() + pointwise_forward(nn.GatedLinearUnit(), 'GatedLinearUnit', precision_forward) +end + +function cunntest.GatedLinearUnit_backward() + pointwise_backward(nn.GatedLinearUnit(), 'GatedLinearUnit', precision_backward) +end + function cunntest.Threshold_forward() pointwise_forward(nn.Threshold(), 'Threshold', precision_forward) pointwise_forward(nn.Threshold(nil, nil, true), 'Threshold_inplace', precision_forward) |