diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-11-10 20:12:25 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-10 20:12:25 +0300 |
commit | f27dd4f4a6381af42c1d87accf514c98ed67caa5 (patch) | |
tree | 820b5b31a8acda19ad8b2fa82dd36c5a31bbf8b5 | |
parent | 09d1a715f817115a4e9eabf8eaae53d1bc594a99 (diff) | |
parent | c2fb6078edbfc2933491e91ca99aa315cdb834c3 (diff) |
Merge pull request #368 from gchanan/halfGenericIfdef
Protect half operations with CUDA_HALF_TENSOR with generic modules.
-rw-r--r-- | lib/THCUNN/RReLU.cu | 2 | ||||
-rw-r--r-- | lib/THCUNN/SharedMem.cuh | 2 | ||||
-rw-r--r-- | lib/THCUNN/SparseLinear.cu | 2 | ||||
-rw-r--r-- | lib/THCUNN/THCHalfAutoNumerics.cuh | 6 |
4 files changed, 10 insertions, 2 deletions
diff --git a/lib/THCUNN/RReLU.cu b/lib/THCUNN/RReLU.cu index 55a34ec..44cd322 100644 --- a/lib/THCUNN/RReLU.cu +++ b/lib/THCUNN/RReLU.cu @@ -13,10 +13,12 @@ template<typename T> inline T __device__ curand_uniform_type(curandStateMtgp32 *state); +#ifdef CUDA_HALF_TENSOR template <> inline half __device__ curand_uniform_type<half>(curandStateMtgp32 *state) { return ScalarConvert<float, half>::to(curand_uniform(state)); } +#endif template <> inline float __device__ curand_uniform_type<float>(curandStateMtgp32 *state) { diff --git a/lib/THCUNN/SharedMem.cuh b/lib/THCUNN/SharedMem.cuh index 8d83d9f..070d269 100644 --- a/lib/THCUNN/SharedMem.cuh +++ b/lib/THCUNN/SharedMem.cuh @@ -13,6 +13,7 @@ struct SharedMem { } }; +#ifdef CUDA_HALF_TENSOR template <> struct SharedMem<half> { @@ -21,6 +22,7 @@ struct SharedMem<half> return s_half; } }; +#endif template <> struct SharedMem<float> diff --git a/lib/THCUNN/SparseLinear.cu b/lib/THCUNN/SparseLinear.cu index 9435735..a7ffa1e 100644 --- a/lib/THCUNN/SparseLinear.cu +++ b/lib/THCUNN/SparseLinear.cu @@ -16,6 +16,7 @@ static void init_cusparse() { } } +#ifdef CUDA_HALF_TENSOR void THNN_CudaHalfSparseLinear_updateOutput( THCState *state, THCudaHalfTensor *input, @@ -78,6 +79,7 @@ void THNN_CudaHalfSparseLinear_updateParameters( double learningRate) { THError("THCudaHalfTensor not supported with SparseLinear"); } +#endif #include "generic/SparseLinear.cu" #include "THCGenerateFloatType.h" diff --git a/lib/THCUNN/THCHalfAutoNumerics.cuh b/lib/THCUNN/THCHalfAutoNumerics.cuh index 89a9602..183c71e 100644 --- a/lib/THCUNN/THCHalfAutoNumerics.cuh +++ b/lib/THCUNN/THCHalfAutoNumerics.cuh @@ -7,10 +7,9 @@ // Half numerics functions defined as free functions, so cunn code can be //written generically, i.e. without excessive calling of THCNumerics<half> functions. -#ifdef CUDA_HALF_TENSOR - // these functions should move to THCNumerics +#ifdef CUDA_HALF_TENSOR inline __host__ __device__ half fmaxType(half x, half y) { return THCNumerics<half>::ge(x, y) ? x : y; } @@ -18,6 +17,7 @@ inline __host__ __device__ half fmaxType(half x, half y) { inline __host__ __device__ float fmaxType(float x, half y) { return fmaxf(x, ScalarConvert<half, float>::to(y)); } +#endif inline __host__ __device__ float fmaxType(float x, float y) { return fmaxf(x, y); @@ -27,6 +27,8 @@ inline __host__ __device__ double fmaxType(double x, double y) { return fmax(x, y); } +#ifdef CUDA_HALF_TENSOR + inline __host__ __device__ half mul(half a, half b) { #ifdef __CUDA_ARCH__ #ifdef CUDA_HALF_INSTRUCTIONS |