diff options
Diffstat (limited to 'lib/THCUNN/IndexLinear.cu')
-rw-r--r-- | lib/THCUNN/IndexLinear.cu | 41 |
1 files changed, 11 insertions, 30 deletions
diff --git a/lib/THCUNN/IndexLinear.cu b/lib/THCUNN/IndexLinear.cu index fb2dc93..7d97b51 100644 --- a/lib/THCUNN/IndexLinear.cu +++ b/lib/THCUNN/IndexLinear.cu @@ -15,32 +15,11 @@ const long NNZ_PER_BLOCK_MAX = 1024; #define clamp(a, low, high) max(min((a), (high)), (low)) #endif -#ifndef ATOMIC_REAL_MINMAX -#define ATOMIC_REAL_MINMAX(func) \ - __device__ void atomic_##func(double *address, double val) { \ - unsigned long long int* address_as_ull = (unsigned long long int*)address; \ - unsigned long long int old = *address_as_ull; \ - unsigned long long int assumed; \ - do { \ - assumed = old; \ - old = atomicCAS(address_as_ull, assumed, \ - __double_as_longlong(func(val, __longlong_as_double(assumed)))); \ - } while (assumed != old); \ - } \ - __device__ void atomic_##func(float *address, float val) { \ - int* address_as_int = (int*)address; \ - int old = *address_as_int; \ - int assumed; \ - do { \ - assumed = old; \ - old = atomicCAS(address_as_int, assumed, \ - __float_as_int(func(val, __int_as_float(assumed)))); \ - } while (assumed != old); \ - } \ - -ATOMIC_REAL_MINMAX(max) -ATOMIC_REAL_MINMAX(min) -#endif +__device__ double atomicExch(double *address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long res = atomicExch(address_as_ull, __double_as_longlong(val)); + return __longlong_as_double(res); +} template<typename Ty, bool train> __global__ static @@ -113,14 +92,16 @@ void updateOutput( Ty *nWeightCurr = nWeight + nWeightOffset; if (train) { Ty absVal = fabs(val); - Ty maxVal = nWeight[key * weightStride + 0]; + Ty maxVal = nWeightCurr[0]; if (absVal > maxVal) { // Updating maxVal and invMaxVal. Go hogwild! - atomic_max(nWeightCurr + 0, absVal); - atomic_min(nWeightCurr + 1, 1.0/absVal); + Ty invAbsVal = 1.0 / absVal; + atomicExch(nWeightCurr + 0, absVal); + atomicExch(nWeightCurr + 1, invAbsVal); } - val = val * nWeightCurr[1] + nWeightCurr[3]; + val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3]; normalizedValues[id + tid] = val; + nWeightCurr[2] = 1; } else { val = clamp(val * nWeightCurr[1], -1.0, 1.0) + nWeightCurr[3]; } |