diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2017-04-18 19:45:39 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2017-04-18 19:55:11 +0300 |
commit | 9c6905e5a4a432b79314983beda5c932990659c6 (patch) | |
tree | 6c08060e0cf810c94c0abfb4cf4e8e1d1d2a69f3 | |
parent | e2469527b1f85802e93ace4ce709bdabd7f1010a (diff) |
Update ops for Sigmoid and Tanh
-rw-r--r-- | lib/THCUNN/Sigmoid.cu | 27 | ||||
-rw-r--r-- | lib/THCUNN/Tanh.cu | 28 | ||||
-rw-r--r-- | lib/THCUNN/generic/Sigmoid.cu | 5 | ||||
-rw-r--r-- | lib/THCUNN/generic/Tanh.cu | 4 |
4 files changed, 39 insertions, 25 deletions
diff --git a/lib/THCUNN/Sigmoid.cu b/lib/THCUNN/Sigmoid.cu index 88d217f..84bcf32 100644 --- a/lib/THCUNN/Sigmoid.cu +++ b/lib/THCUNN/Sigmoid.cu @@ -4,22 +4,27 @@ #include <THC/THCApply.cuh> template <typename T> -struct sigmoidupdateOutput_functor -{ - __device__ void operator()(T *output, const T *input) const - { - *output = ScalarConvert<double, T>::to(1./(1.+ exp(-*input))); +struct SigmoidGradInputOp { + __device__ __forceinline__ void operator()(T* gradInput, const T *output, const T *gradOutput) const { + *gradInput = *gradOutput * (1.f - *output) * (*output); } }; -template <typename T> -struct sigmoidupdateGradInput_functor -{ - __device__ void operator()(T *gradInput, const T *output, const T *gradOutput) const - { - *gradInput = ScalarConvert<double, T>::to(*gradOutput * (1.-*output) * (*output)); +#ifdef CUDA_HALF_TENSOR +template <> +struct SigmoidGradInputOp<half> { + __device__ __forceinline__ void operator()(half* gradInput, const half *output, const half *gradOutput) const { +#ifdef CUDA_HALF_INSTRUCTIONS + half one = __float2half(1.f); + *gradInput = __hmul(*gradOutput, __hmul(__hadd(one, __hneg(*output)), *output)); +#else + float out = __half2float(*output); + float go = __half2float(*gradOutput); + *gradInput = __float2half(go * (1.f - out) * out); +#endif } }; +#endif #include "generic/Sigmoid.cu" #include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/Tanh.cu b/lib/THCUNN/Tanh.cu index fa423b7..44e93eb 100644 --- a/lib/THCUNN/Tanh.cu +++ b/lib/THCUNN/Tanh.cu @@ -4,22 +4,32 @@ #include <THC/THCApply.cuh> template <typename T> -struct tanhupdateOutput_functor +struct TanhGradInputOp { - __device__ void operator()(T *output, const T *input) const - { - *output = tanh(*input); + __device__ __forceinline__ void operator()(T *gradInput, + const T *output, const T *gradOutput) const { + *gradInput = *gradOutput * (1.f - *output * *output); } }; -template <typename T> -struct tanhupdateGradInput_functor +#ifdef CUDA_HALF_TENSOR +template <> +struct TanhGradInputOp<half> { - __device__ void operator()(T *gradInput, const T *output, const T *gradOutput) const - { - *gradInput = *gradOutput * (1 - *output * *output); + __device__ __forceinline__ void operator()(half *gradInput, + const half *output, const half *gradOutput) const { +#ifdef CUDA_HALF_INSTRUCTIONS + const half one = __float2half(1.f); + const half out_square = __hmul(*output, *output); + *gradInput = __hmul(*gradOutput, __hadd(one, __hneg(out_square))); +#else + float out = __half2float(*output); + float go = __half2float(*gradOutput); + *gradInput = __float2half(go * (1.f - out * out)); +#endif } }; +#endif #include "generic/Tanh.cu" #include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/generic/Sigmoid.cu b/lib/THCUNN/generic/Sigmoid.cu index a4f7328..0e352d3 100644 --- a/lib/THCUNN/generic/Sigmoid.cu +++ b/lib/THCUNN/generic/Sigmoid.cu @@ -10,8 +10,7 @@ void THNN_(Sigmoid_updateOutput)( THCTensor *output) { THCUNN_assertSameGPU(state, 2, input, output); - THCTensor_(resizeAs)(state, output, input); - THC_pointwiseApply2(state, output, input, sigmoidupdateOutput_functor<real>()); + THCTensor_(sigmoid)(state, output, input); } void THNN_(Sigmoid_updateGradInput)( @@ -24,7 +23,7 @@ void THNN_(Sigmoid_updateGradInput)( THCUNN_check_nElement(state, input, gradOutput); THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput); THCTensor_(resizeAs)(state, gradInput, output); - THC_pointwiseApply3(state, gradInput, output, gradOutput, sigmoidupdateGradInput_functor<real>()); + THC_pointwiseApply3(state, gradInput, output, gradOutput, SigmoidGradInputOp<real>()); } #endif diff --git a/lib/THCUNN/generic/Tanh.cu b/lib/THCUNN/generic/Tanh.cu index 6cc7d86..e3fd77a 100644 --- a/lib/THCUNN/generic/Tanh.cu +++ b/lib/THCUNN/generic/Tanh.cu @@ -11,7 +11,7 @@ void THNN_(Tanh_updateOutput)( { THCUNN_assertSameGPU(state, 2, input, output); THCTensor_(resizeAs)(state, output, input); - THC_pointwiseApply2(state, output, input, tanhupdateOutput_functor<real>()); + THCTensor_(tanh)(state, output, input); } void THNN_(Tanh_updateGradInput)( @@ -24,7 +24,7 @@ void THNN_(Tanh_updateGradInput)( THCUNN_check_shape(state, output, gradOutput); THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput); THCTensor_(resizeAs)(state, gradInput, output); - THC_pointwiseApply3(state, gradInput, output, gradOutput, tanhupdateGradInput_functor<real>()); + THC_pointwiseApply3(state, gradInput, output, gradOutput, TanhGradInputOp<real>()); } #endif |