Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THCUNN/Tanh.cu')
-rw-r--r--lib/THCUNN/Tanh.cu8
1 files changed, 4 insertions, 4 deletions
diff --git a/lib/THCUNN/Tanh.cu b/lib/THCUNN/Tanh.cu
index 44e93eb..6781f33 100644
--- a/lib/THCUNN/Tanh.cu
+++ b/lib/THCUNN/Tanh.cu
@@ -4,7 +4,7 @@
#include <THC/THCApply.cuh>
template <typename T>
-struct TanhGradInputOp
+struct tanh_updateGradInput_functor
{
__device__ __forceinline__ void operator()(T *gradInput,
const T *output, const T *gradOutput) const {
@@ -14,7 +14,7 @@ struct TanhGradInputOp
#ifdef CUDA_HALF_TENSOR
template <>
-struct TanhGradInputOp<half>
+struct tanh_updateGradInput_functor<half>
{
__device__ __forceinline__ void operator()(half *gradInput,
const half *output, const half *gradOutput) const {
@@ -23,8 +23,8 @@ struct TanhGradInputOp<half>
const half out_square = __hmul(*output, *output);
*gradInput = __hmul(*gradOutput, __hadd(one, __hneg(out_square)));
#else
- float out = __half2float(*output);
- float go = __half2float(*gradOutput);
+ const float out = __half2float(*output);
+ const float go = __half2float(*gradOutput);
*gradInput = __float2half(go * (1.f - out * out));
#endif
}