Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristian Sarofeen <csarofeen@nvidia.com>2017-06-26 18:38:18 +0300
committerSoumith Chintala <soumith@gmail.com>2017-06-26 18:38:18 +0300
commit653811fa40a780dfa8f0e110f9febf5dfed8f0f0 (patch)
treeaf31264eb554f86333d9a9258266649da21c5c56
parent9db5057877c6ffa7df59727cbada13318d7e3eaf (diff)
Fp16 fixes for CUDA 9 (#783)
-rw-r--r--lib/THC/THCAtomics.cuh7
-rw-r--r--lib/THC/THCHalf.h6
-rw-r--r--lib/THC/THCNumerics.cuh5
-rw-r--r--lib/THC/THCTensorTypeUtils.cuh20
4 files changed, 37 insertions, 1 deletions
diff --git a/lib/THC/THCAtomics.cuh b/lib/THC/THCAtomics.cuh
index 7a0be48..400875c 100644
--- a/lib/THC/THCAtomics.cuh
+++ b/lib/THC/THCAtomics.cuh
@@ -102,9 +102,16 @@ static inline __device__ void atomicAdd(half *address, half val) {
do {
assumed = old;
+#if CUDA_VERSION < 9000
half hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
hsum = THCNumerics<half>::add(hsum, val);
+#else
+ __half_raw hsum;
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+ half tmpres = THCNumerics<half>::add(hsum, val);
+ hsum = __half_raw(tmpres);
+#endif
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
diff --git a/lib/THC/THCHalf.h b/lib/THC/THCHalf.h
index 7c055e7..d5bd5c1 100644
--- a/lib/THC/THCHalf.h
+++ b/lib/THC/THCHalf.h
@@ -13,6 +13,12 @@
#include <cuda_fp16.h>
#include <stdint.h>
+#if CUDA_VERSION >= 9000
+#ifndef __cplusplus
+ typedef __half_raw half;
+#endif
+#endif
+
THC_EXTERNC void THCFloat2Half(THCState *state, half *out, float *in, ptrdiff_t len);
THC_EXTERNC void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len);
THC_API half THC_float2half(float a);
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh
index b6d1dac..ba86e8f 100644
--- a/lib/THC/THCNumerics.cuh
+++ b/lib/THC/THCNumerics.cuh
@@ -111,8 +111,13 @@ struct THCNumerics<long> {
#ifdef CUDA_HALF_TENSOR
template <>
struct THCNumerics<half> {
+#if CUDA_VERSION < 9000
static inline __host__ __device__ half min() { half h; h.x = 0xfbff; return h; }
static inline __host__ __device__ half max() { half h; h.x = 0x7bff; return h; }
+#else
+ static inline __host__ __device__ half min() { __half_raw h; h.x = 0xfbff; return h; }
+ static inline __host__ __device__ half max() { __half_raw h; h.x = 0x7bff; return h; }
+#endif
static inline __host__ __device__ bool lt(half a, half b) {
#ifdef __CUDA_ARCH__
diff --git a/lib/THC/THCTensorTypeUtils.cuh b/lib/THC/THCTensorTypeUtils.cuh
index 37edb76..16a0cde 100644
--- a/lib/THC/THCTensorTypeUtils.cuh
+++ b/lib/THC/THCTensorTypeUtils.cuh
@@ -149,7 +149,11 @@ struct ScalarNegate<half> {
return __float2half(-__half2float(v));
#endif
#else
+#if CUDA_VERSION < 9000
half out = v;
+#else
+ __half_raw out = __half_raw(v);
+#endif
out.x ^= 0x8000; // toggle sign bit
return out;
#endif
@@ -170,11 +174,25 @@ struct ScalarInv<half> {
};
inline bool operator==(half a, half b) {
+#if CUDA_VERSION < 9000
return a.x == b.x;
+#else
+ __half_raw araw, braw;
+ araw = __half_raw(a);
+ braw = __half_raw(b);
+ return araw.x == braw.x;
+#endif
}
inline bool operator!=(half a, half b) {
- return a.x != b.x;
+#if CUDA_VERSION < 9000
+ return a.x != b.x;
+#else
+ __half_raw araw, braw;
+ araw = __half_raw(a);
+ braw = __half_raw(b);
+ return araw.x != braw.x;
+#endif
}
#endif // CUDA_HALF_TENSOR