Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2017-04-18 20:28:13 +0300
committerAdam Paszke <adam.paszke@gmail.com>2017-04-18 20:28:13 +0300
commit61dcc2dd9b38fb786bed76c6e5ef44ab8eb85ad8 (patch)
tree264936b1fa67bd2d9213a535c5fc3c9d0a4788cf
parent5b4a2d5b0ec187b2523b0af88c3ade859f78fba7 (diff)
Remove double precision math from LogSigmoid too
-rw-r--r--lib/THCUNN/LogSigmoid.cu43
-rw-r--r--lib/THCUNN/Sigmoid.cu10
-rw-r--r--lib/THCUNN/Tanh.cu8
-rw-r--r--lib/THCUNN/generic/Sigmoid.cu4
-rw-r--r--lib/THCUNN/generic/THCUNN.h4
-rw-r--r--lib/THCUNN/generic/Tanh.cu2
6 files changed, 49 insertions, 22 deletions
diff --git a/lib/THCUNN/LogSigmoid.cu b/lib/THCUNN/LogSigmoid.cu
index d6152ea..bb9753a 100644
--- a/lib/THCUNN/LogSigmoid.cu
+++ b/lib/THCUNN/LogSigmoid.cu
@@ -6,22 +6,49 @@
template <typename T>
struct logSigmoid_updateOutput_functor
{
- __device__ void operator()(T *output, const T *input) const
- {
- T z = exp(-*input);
- *output = ScalarConvert<double, T>::to(-log(1. + z));
+ __device__ void operator()(T *output, const T *input) const {
+ *output = -THCNumerics<T>::log(1.f + THCNumerics<T>::exp(- *input));
}
};
template <typename T>
struct logSigmoid_updateGradInput_functor
{
- __device__ void operator()(T *gradInput, const T *input, const T *gradOutput) const
- {
- T z = exp(-*input);
- *gradInput = ScalarConvert<double, T>::to(*gradOutput * z / (1. + z));
+ __device__ void operator()(T *gradInput, const T *input, const T *gradOutput) const {
+ const T z = THCNumerics<T>::exp(- *input);
+ *gradInput = *gradOutput * z / (1.f + z);
}
};
+#ifdef CUDA_HALF_TENSOR
+template <>
+struct logSigmoid_updateOutput_functor<half> {
+ __device__ __forceinline__ void operator()(half* output, const half *input) const {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ const half one = __float2half(1.f);
+ *output = __hneg(THCNumerics<half>::log(one + THCNumerics<half>::exp(__hneg(*input))));
+#else
+ float in = __half2float(*input);
+ *output = __float2half(-THCNumerics<float>::log(1.f + THCNumerics<float>::exp(-in)));
+#endif
+ }
+};
+
+template <>
+struct logSigmoid_updateGradInput_functor<half> {
+ __device__ __forceinline__ void operator()(half* gradInput, const half *input, const half *gradOutput) const {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ const half one = __float2half(1.f);
+ const half in_exp = THCNumerics<half>::exp(__hneg(*input));
+ *gradInput = hdiv(__hmul(*gradOutput, in_exp), __hadd(one, in_exp));
+#else
+ const float in_exp = THCNumerics<float>::exp(-(__half2float(*input)));
+ const float go = __half2float(*gradOutput);
+ *gradInput = __float2half(go * in_exp / (1.f + in_exp));
+#endif
+ }
+};
+#endif
+
#include "generic/LogSigmoid.cu"
#include "THCGenerateFloatTypes.h"
diff --git a/lib/THCUNN/Sigmoid.cu b/lib/THCUNN/Sigmoid.cu
index 84bcf32..85bda93 100644
--- a/lib/THCUNN/Sigmoid.cu
+++ b/lib/THCUNN/Sigmoid.cu
@@ -4,7 +4,7 @@
#include <THC/THCApply.cuh>
template <typename T>
-struct SigmoidGradInputOp {
+struct sigmoid_updateGradInput_functor {
__device__ __forceinline__ void operator()(T* gradInput, const T *output, const T *gradOutput) const {
*gradInput = *gradOutput * (1.f - *output) * (*output);
}
@@ -12,14 +12,14 @@ struct SigmoidGradInputOp {
#ifdef CUDA_HALF_TENSOR
template <>
-struct SigmoidGradInputOp<half> {
+struct sigmoid_updateGradInput_functor<half> {
__device__ __forceinline__ void operator()(half* gradInput, const half *output, const half *gradOutput) const {
#ifdef CUDA_HALF_INSTRUCTIONS
- half one = __float2half(1.f);
+ const half one = __float2half(1.f);
*gradInput = __hmul(*gradOutput, __hmul(__hadd(one, __hneg(*output)), *output));
#else
- float out = __half2float(*output);
- float go = __half2float(*gradOutput);
+ const float out = __half2float(*output);
+ const float go = __half2float(*gradOutput);
*gradInput = __float2half(go * (1.f - out) * out);
#endif
}
diff --git a/lib/THCUNN/Tanh.cu b/lib/THCUNN/Tanh.cu
index 44e93eb..6781f33 100644
--- a/lib/THCUNN/Tanh.cu
+++ b/lib/THCUNN/Tanh.cu
@@ -4,7 +4,7 @@
#include <THC/THCApply.cuh>
template <typename T>
-struct TanhGradInputOp
+struct tanh_updateGradInput_functor
{
__device__ __forceinline__ void operator()(T *gradInput,
const T *output, const T *gradOutput) const {
@@ -14,7 +14,7 @@ struct TanhGradInputOp
#ifdef CUDA_HALF_TENSOR
template <>
-struct TanhGradInputOp<half>
+struct tanh_updateGradInput_functor<half>
{
__device__ __forceinline__ void operator()(half *gradInput,
const half *output, const half *gradOutput) const {
@@ -23,8 +23,8 @@ struct TanhGradInputOp<half>
const half out_square = __hmul(*output, *output);
*gradInput = __hmul(*gradOutput, __hadd(one, __hneg(out_square)));
#else
- float out = __half2float(*output);
- float go = __half2float(*gradOutput);
+ const float out = __half2float(*output);
+ const float go = __half2float(*gradOutput);
*gradInput = __float2half(go * (1.f - out * out));
#endif
}
diff --git a/lib/THCUNN/generic/Sigmoid.cu b/lib/THCUNN/generic/Sigmoid.cu
index 0e352d3..4f73ef2 100644
--- a/lib/THCUNN/generic/Sigmoid.cu
+++ b/lib/THCUNN/generic/Sigmoid.cu
@@ -20,10 +20,10 @@ void THNN_(Sigmoid_updateGradInput)(
THCTensor *gradInput,
THCTensor *output)
{
- THCUNN_check_nElement(state, input, gradOutput);
+ THCUNN_check_nElement(state, output, gradOutput);
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
THCTensor_(resizeAs)(state, gradInput, output);
- THC_pointwiseApply3(state, gradInput, output, gradOutput, SigmoidGradInputOp<real>());
+ THC_pointwiseApply3(state, gradInput, output, gradOutput, sigmoid_updateGradInput_functor<real>());
}
#endif
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index da8e41e..79426b7 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -911,7 +911,7 @@ TH_API void THNN_(Sigmoid_updateOutput)(
TH_API void THNN_(Sigmoid_updateGradInput)(
THCState *state,
- THCTensor *input,
+ THCTensor *input, // [OPTIONAL]
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output);
@@ -1002,7 +1002,7 @@ TH_API void THNN_(Tanh_updateOutput)(
TH_API void THNN_(Tanh_updateGradInput)(
THCState *state,
- THCTensor *input,
+ THCTensor *input, // [OPTIONAL]
THCTensor *gradOutput,
THCTensor *gradInput,
THCTensor *output);
diff --git a/lib/THCUNN/generic/Tanh.cu b/lib/THCUNN/generic/Tanh.cu
index e3fd77a..7551d40 100644
--- a/lib/THCUNN/generic/Tanh.cu
+++ b/lib/THCUNN/generic/Tanh.cu
@@ -24,7 +24,7 @@ void THNN_(Tanh_updateGradInput)(
THCUNN_check_shape(state, output, gradOutput);
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
THCTensor_(resizeAs)(state, gradInput, output);
- THC_pointwiseApply3(state, gradInput, output, gradOutput, TanhGradInputOp<real>());
+ THC_pointwiseApply3(state, gradInput, output, gradOutput, tanh_updateGradInput_functor<real>());
}
#endif