Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-06-16 18:08:32 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-06-16 18:08:32 +0300
commit2ece1b20241b8e4048c1c7ff650ce4953222642d (patch)
tree1011a039c4f0571c990b97fab7a658d27efb7ae2
parent55dd536a7f734857c30c7af34bdd7c80b332e640 (diff)
inplace hardtanh, remove relu6
-rw-r--r--lib/THCUNN/HardTanh.cu58
-rw-r--r--lib/THCUNN/ReLU6.cu92
-rw-r--r--lib/THCUNN/THCUNN.h18
3 files changed, 55 insertions, 113 deletions
diff --git a/lib/THCUNN/HardTanh.cu b/lib/THCUNN/HardTanh.cu
index 764a3c0..b341f5a 100644
--- a/lib/THCUNN/HardTanh.cu
+++ b/lib/THCUNN/HardTanh.cu
@@ -20,14 +20,36 @@ struct hardtanhupdateOutput_functor
else
*output = max_val_;
}
+
+ __device__ void operator()(float *input) const
+ {
+ if (*input < min_val_)
+ *input = min_val_;
+ else if (*input > max_val_)
+ *input = max_val_;
+ }
};
-void THNN_CudaHardTanh_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, float min_val, float max_val)
+void THNN_CudaHardTanh_updateOutput(
+ THCState *state,
+ THCudaTensor *input,
+ THCudaTensor *output,
+ float min_val,
+ float max_val,
+ bool inplace)
{
THCUNN_assertSameGPU(state, 2, input, output);
- THCudaTensor_resizeAs(state, output, input);
- THC_pointwiseApply2(state, output, input,
+ if(inplace)
+ {
+ THCudaTensor_set(state, output, input);
+ THC_pointwiseApply1(state, output, hardtanhupdateOutput_functor(min_val, max_val));
+ }
+ else
+ {
+ THCudaTensor_resizeAs(state, output, input);
+ THC_pointwiseApply2(state, output, input,
hardtanhupdateOutput_functor(min_val, max_val));
+ }
}
struct hardtanhupdateGradInput_functor
@@ -47,13 +69,35 @@ struct hardtanhupdateGradInput_functor
else
*gradInput = *gradOutput;
}
+
+ __device__ void operator()(float *gradInput, const float *input) const
+ {
+ if (*input <= min_val_ || *input >= max_val_)
+ *gradInput = 0;
+ }
};
-void THNN_CudaHardTanh_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, float min_val, float max_val)
+void THNN_CudaHardTanh_updateGradInput(
+ THCState *state,
+ THCudaTensor *input,
+ THCudaTensor *gradOutput,
+ THCudaTensor *gradInput,
+ float min_val,
+ float max_val,
+ bool inplace)
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
- THCudaTensor_resizeAs(state, gradInput, input);
- THC_pointwiseApply3(state, gradInput, input, gradOutput,
- hardtanhupdateGradInput_functor(min_val, max_val));
+ if (inplace)
+ {
+ THCudaTensor_resizeAs(state, gradInput, input);
+ THC_pointwiseApply3(state, gradInput, input, gradOutput,
+ hardtanhupdateGradInput_functor(min_val, max_val));
+ }
+ else
+ {
+ THCudaTensor_set(state, gradInput, gradOutput);
+ THC_pointwiseApply2(state, gradInput, input,
+ hardtanhupdateGradInput_functor(min_val, max_val));
+ }
}
diff --git a/lib/THCUNN/ReLU6.cu b/lib/THCUNN/ReLU6.cu
deleted file mode 100644
index a42f2c9..0000000
--- a/lib/THCUNN/ReLU6.cu
+++ /dev/null
@@ -1,92 +0,0 @@
-#include "THCUNN.h"
-#include "common.h"
-
-struct ReLU6UpdateOutput
-{
- ReLU6UpdateOutput() {}
-
- __device__ __forceinline__ void operator()(float *out, float *in)
- {
- float x = *in;
- *out = (x > 0) ? ((x < 6) ? x : 6) : 0;
- }
-};
-
-// in-place variant
-struct ReLU6UpdateOutputIP
-{
- ReLU6UpdateOutputIP() {}
-
- __device__ __forceinline__ void operator()(float *x)
- {
- *x = (*x > 0) ? ((*x < 6) ? *x : 6) : 0;
- }
-};
-
-void THNN_CudaReLU6_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output,
- bool inplace)
-{
- THCUNN_assertSameGPU(state, 2, input, output);
-
- if (inplace)
- {
- THC_pointwiseApply1(state, input,
- ReLU6UpdateOutputIP()
- );
- THCudaTensor_set(state, output, input);
- }
- else
- {
- THCudaTensor_resizeAs(state, output, input);
- THC_pointwiseApply2(state, output, input,
- ReLU6UpdateOutput()
- );
- }
-
- THCudaCheck(cudaGetLastError());
-}
-
-struct ReLU6UpdateGradInput
-{
- ReLU6UpdateGradInput() {}
-
- __device__ __forceinline__ void operator()(
- float *gradInput, float *input, float *gradOutput) const
- {
- *gradInput = (*input > 0 && *input < 6) ? *gradOutput : 0;
- }
-};
-
-struct ReLU6UpdateGradInputIP
-{
- ReLU6UpdateGradInputIP() {}
-
- __device__ __forceinline__ void operator()(
- float *gradOutput, float *input) const
- {
- *gradOutput = (*input > 0 && *input < 6) ? *gradOutput : 0;
- }
-};
-
-void THNN_CudaReLU6_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput,
- THCudaTensor *gradInput, bool inplace)
-{
- THCUNN_assertSameGPU(state, 3, input, gradInput, gradOutput);
-
- if (inplace)
- {
- THC_pointwiseApply2(state, gradOutput, input,
- ReLU6UpdateGradInputIP()
- );
- THCudaTensor_set(state, gradInput, gradOutput);
- }
- else
- {
- THCudaTensor_resizeAs(state, gradInput, input);
- THC_pointwiseApply3(state, gradInput, input, gradOutput,
- ReLU6UpdateGradInput()
- );
- }
-
- THCudaCheck(cudaGetLastError());
-}
diff --git a/lib/THCUNN/THCUNN.h b/lib/THCUNN/THCUNN.h
index 0b9d661..cbc71b4 100644
--- a/lib/THCUNN/THCUNN.h
+++ b/lib/THCUNN/THCUNN.h
@@ -97,14 +97,16 @@ TH_API void THNN_CudaHardTanh_updateOutput(
THCudaTensor *input,
THCudaTensor *output,
float min_val,
- float max_val);
+ float max_val,
+ bool inplace);
TH_API void THNN_CudaHardTanh_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
THCudaTensor *gradInput,
float min_val,
- float max_val);
+ float max_val,
+ bool inplace);
TH_API void THNN_CudaL1Cost_updateOutput(
THCState *state,
@@ -403,18 +405,6 @@ TH_API void THNN_CudaThreshold_updateGradInput(
double threshold,
bool inplace);
-TH_API void THNN_CudaReLU6_updateOutput(
- THCState *state,
- THCudaTensor *input,
- THCudaTensor *output,
- bool inplace);
-TH_API void THNN_CudaReLU6_updateGradInput(
- THCState *state,
- THCudaTensor *input,
- THCudaTensor *gradOutput,
- THCudaTensor *gradInput,
- bool inplace);
-
TH_API void THNN_CudaTemporalConvolution_updateOutput(
THCState *state,
THCudaTensor *input,