diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-10-13 01:01:01 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-10-13 01:18:30 +0300 |
commit | 139690901cd17f8c91362aed0256fefceb9d68c1 (patch) | |
tree | fad2a9c44611bd23f0be8e6c4d8fc5707e8b8888 /lib/THC | |
parent | ed056dd2492e872477ac3b4a5001ab7efaa4d260 (diff) |
Make atomicAdd functions static inline.
Diffstat (limited to 'lib/THC')
-rw-r--r-- | lib/THC/THCAtomics.cuh | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/lib/THC/THCAtomics.cuh b/lib/THC/THCAtomics.cuh index 0586edf..4229114 100644 --- a/lib/THC/THCAtomics.cuh +++ b/lib/THC/THCAtomics.cuh @@ -8,7 +8,7 @@ struct AtomicAddIntegerImpl; template<typename T> struct AtomicAddIntegerImpl<T, 1> { - __device__ void operator()(T *address, T val) { + inline __device__ void operator()(T *address, T val) { unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 3)); unsigned int old = *address_as_ui; @@ -27,7 +27,7 @@ struct AtomicAddIntegerImpl<T, 1> { template<typename T> struct AtomicAddIntegerImpl<T, 2> { - __device__ void operator()(T *address, T val) { + inline __device__ void operator()(T *address, T val) { unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); unsigned int old = *address_as_ui; @@ -46,7 +46,7 @@ struct AtomicAddIntegerImpl<T, 2> { template<typename T> struct AtomicAddIntegerImpl<T, 4> { - __device__ void operator()(T *address, T val) { + inline __device__ void operator()(T *address, T val) { unsigned int * address_as_ui = (unsigned int *) (address); unsigned int old = *address_as_ui; unsigned int newval; @@ -62,7 +62,7 @@ struct AtomicAddIntegerImpl<T, 4> { template<typename T> struct AtomicAddIntegerImpl<T, 8> { - __device__ void operator()(T *address, T val) { + inline __device__ void operator()(T *address, T val) { unsigned long long * address_as_ui = (unsigned long long *) (address); unsigned long long old = *address_as_ui; unsigned long long newval; @@ -76,24 +76,24 @@ struct AtomicAddIntegerImpl<T, 8> { } }; -__device__ void atomicAdd(unsigned char *address, unsigned char val) { +static inline __device__ void atomicAdd(unsigned char *address, unsigned char val) { AtomicAddIntegerImpl<unsigned char, sizeof(unsigned char)>()(address, val); } -__device__ void atomicAdd(char *address, char val) { +static inline __device__ void atomicAdd(char *address, char val) { AtomicAddIntegerImpl<char, sizeof(char)>()(address, val); } -__device__ void atomicAdd(short *address, short val) { +static inline __device__ void atomicAdd(short *address, short val) { AtomicAddIntegerImpl<short, sizeof(short)>()(address, val); } -__device__ void atomicAdd(long *address, long val) { +static inline __device__ void atomicAdd(long *address, long val) { AtomicAddIntegerImpl<long, sizeof(long)>()(address, val); } #ifdef CUDA_HALF_TENSOR -__device__ void atomicAdd(half *address, half val) { +static inline __device__ void atomicAdd(half *address, half val) { unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); unsigned int old = *address_as_ui; @@ -112,7 +112,7 @@ __device__ void atomicAdd(half *address, half val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 // from CUDA C Programmic Guide -__device__ void atomicAdd(double *address, double val) { +static inline __device__ void atomicAdd(double *address, double val) { unsigned long long int* address_as_ull = (unsigned long long int*)address; unsigned long long int old = *address_as_ull; unsigned long long int assumed; |