From 61dcc2dd9b38fb786bed76c6e5ef44ab8eb85ad8 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 18 Apr 2017 10:28:13 -0700 Subject: Remove double precision math from LogSigmoid too --- lib/THCUNN/LogSigmoid.cu | 43 +++++++++++++++++++++++++++++++++++-------- lib/THCUNN/Sigmoid.cu | 10 +++++----- lib/THCUNN/Tanh.cu | 8 ++++---- lib/THCUNN/generic/Sigmoid.cu | 4 ++-- lib/THCUNN/generic/THCUNN.h | 4 ++-- lib/THCUNN/generic/Tanh.cu | 2 +- 6 files changed, 49 insertions(+), 22 deletions(-) diff --git a/lib/THCUNN/LogSigmoid.cu b/lib/THCUNN/LogSigmoid.cu index d6152ea..bb9753a 100644 --- a/lib/THCUNN/LogSigmoid.cu +++ b/lib/THCUNN/LogSigmoid.cu @@ -6,22 +6,49 @@ template struct logSigmoid_updateOutput_functor { - __device__ void operator()(T *output, const T *input) const - { - T z = exp(-*input); - *output = ScalarConvert::to(-log(1. + z)); + __device__ void operator()(T *output, const T *input) const { + *output = -THCNumerics::log(1.f + THCNumerics::exp(- *input)); } }; template struct logSigmoid_updateGradInput_functor { - __device__ void operator()(T *gradInput, const T *input, const T *gradOutput) const - { - T z = exp(-*input); - *gradInput = ScalarConvert::to(*gradOutput * z / (1. + z)); + __device__ void operator()(T *gradInput, const T *input, const T *gradOutput) const { + const T z = THCNumerics::exp(- *input); + *gradInput = *gradOutput * z / (1.f + z); } }; +#ifdef CUDA_HALF_TENSOR +template <> +struct logSigmoid_updateOutput_functor { + __device__ __forceinline__ void operator()(half* output, const half *input) const { +#ifdef CUDA_HALF_INSTRUCTIONS + const half one = __float2half(1.f); + *output = __hneg(THCNumerics::log(one + THCNumerics::exp(__hneg(*input)))); +#else + float in = __half2float(*input); + *output = __float2half(-THCNumerics::log(1.f + THCNumerics::exp(-in))); +#endif + } +}; + +template <> +struct logSigmoid_updateGradInput_functor { + __device__ __forceinline__ void operator()(half* gradInput, const half *input, const half *gradOutput) const { +#ifdef CUDA_HALF_INSTRUCTIONS + const half one = __float2half(1.f); + const half in_exp = THCNumerics::exp(__hneg(*input)); + *gradInput = hdiv(__hmul(*gradOutput, in_exp), __hadd(one, in_exp)); +#else + const float in_exp = THCNumerics::exp(-(__half2float(*input))); + const float go = __half2float(*gradOutput); + *gradInput = __float2half(go * in_exp / (1.f + in_exp)); +#endif + } +}; +#endif + #include "generic/LogSigmoid.cu" #include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/Sigmoid.cu b/lib/THCUNN/Sigmoid.cu index 84bcf32..85bda93 100644 --- a/lib/THCUNN/Sigmoid.cu +++ b/lib/THCUNN/Sigmoid.cu @@ -4,7 +4,7 @@ #include template -struct SigmoidGradInputOp { +struct sigmoid_updateGradInput_functor { __device__ __forceinline__ void operator()(T* gradInput, const T *output, const T *gradOutput) const { *gradInput = *gradOutput * (1.f - *output) * (*output); } @@ -12,14 +12,14 @@ struct SigmoidGradInputOp { #ifdef CUDA_HALF_TENSOR template <> -struct SigmoidGradInputOp { +struct sigmoid_updateGradInput_functor { __device__ __forceinline__ void operator()(half* gradInput, const half *output, const half *gradOutput) const { #ifdef CUDA_HALF_INSTRUCTIONS - half one = __float2half(1.f); + const half one = __float2half(1.f); *gradInput = __hmul(*gradOutput, __hmul(__hadd(one, __hneg(*output)), *output)); #else - float out = __half2float(*output); - float go = __half2float(*gradOutput); + const float out = __half2float(*output); + const float go = __half2float(*gradOutput); *gradInput = __float2half(go * (1.f - out) * out); #endif } diff --git a/lib/THCUNN/Tanh.cu b/lib/THCUNN/Tanh.cu index 44e93eb..6781f33 100644 --- a/lib/THCUNN/Tanh.cu +++ b/lib/THCUNN/Tanh.cu @@ -4,7 +4,7 @@ #include template -struct TanhGradInputOp +struct tanh_updateGradInput_functor { __device__ __forceinline__ void operator()(T *gradInput, const T *output, const T *gradOutput) const { @@ -14,7 +14,7 @@ struct TanhGradInputOp #ifdef CUDA_HALF_TENSOR template <> -struct TanhGradInputOp +struct tanh_updateGradInput_functor { __device__ __forceinline__ void operator()(half *gradInput, const half *output, const half *gradOutput) const { @@ -23,8 +23,8 @@ struct TanhGradInputOp const half out_square = __hmul(*output, *output); *gradInput = __hmul(*gradOutput, __hadd(one, __hneg(out_square))); #else - float out = __half2float(*output); - float go = __half2float(*gradOutput); + const float out = __half2float(*output); + const float go = __half2float(*gradOutput); *gradInput = __float2half(go * (1.f - out * out)); #endif } diff --git a/lib/THCUNN/generic/Sigmoid.cu b/lib/THCUNN/generic/Sigmoid.cu index 0e352d3..4f73ef2 100644 --- a/lib/THCUNN/generic/Sigmoid.cu +++ b/lib/THCUNN/generic/Sigmoid.cu @@ -20,10 +20,10 @@ void THNN_(Sigmoid_updateGradInput)( THCTensor *gradInput, THCTensor *output) { - THCUNN_check_nElement(state, input, gradOutput); + THCUNN_check_nElement(state, output, gradOutput); THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput); THCTensor_(resizeAs)(state, gradInput, output); - THC_pointwiseApply3(state, gradInput, output, gradOutput, SigmoidGradInputOp()); + THC_pointwiseApply3(state, gradInput, output, gradOutput, sigmoid_updateGradInput_functor()); } #endif diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index da8e41e..79426b7 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -911,7 +911,7 @@ TH_API void THNN_(Sigmoid_updateOutput)( TH_API void THNN_(Sigmoid_updateGradInput)( THCState *state, - THCTensor *input, + THCTensor *input, // [OPTIONAL] THCTensor *gradOutput, THCTensor *gradInput, THCTensor *output); @@ -1002,7 +1002,7 @@ TH_API void THNN_(Tanh_updateOutput)( TH_API void THNN_(Tanh_updateGradInput)( THCState *state, - THCTensor *input, + THCTensor *input, // [OPTIONAL] THCTensor *gradOutput, THCTensor *gradInput, THCTensor *output); diff --git a/lib/THCUNN/generic/Tanh.cu b/lib/THCUNN/generic/Tanh.cu index e3fd77a..7551d40 100644 --- a/lib/THCUNN/generic/Tanh.cu +++ b/lib/THCUNN/generic/Tanh.cu @@ -24,7 +24,7 @@ void THNN_(Tanh_updateGradInput)( THCUNN_check_shape(state, output, gradOutput); THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput); THCTensor_(resizeAs)(state, gradInput, output); - THC_pointwiseApply3(state, gradInput, output, gradOutput, TanhGradInputOp()); + THC_pointwiseApply3(state, gradInput, output, gradOutput, tanh_updateGradInput_functor()); } #endif -- cgit v1.2.3