#ifndef THC_HALF_AUTO_NUMERICS_INC #define THC_HALF_AUTO_NUMERICS_INC #include "THCHalf.h" #include "THCNumerics.cuh" // Half numerics functions defined as free functions, so cunn code can be //written generically, i.e. without excessive calling of THCNumerics functions. #ifdef CUDA_HALF_TENSOR // these functions should move to THCNumerics inline __host__ __device__ half fmaxType(half x, half y) { return THCNumerics::ge(x, y) ? x : y; } inline __host__ __device__ float fmaxType(float x, half y) { return fmaxf(x, ScalarConvert::to(y)); } inline __host__ __device__ float fmaxType(float x, float y) { return fmaxf(x, y); } inline __host__ __device__ double fmaxType(double x, double y) { return fmax(x, y); } inline __host__ __device__ half mul(half a, half b) { #ifdef __CUDA_ARCH__ #ifdef CUDA_HALF_INSTRUCTIONS return __hmul(a, b); #else float fa = __half2float(a); float fb = __half2float(b); return __float2half( fa * fb ); #endif #else // __CUDA_ARCH__ return THC_float2half(THC_half2float(a) * THC_half2float(b)); #endif } inline __host__ __device__ half div(half a, half b) { #ifdef __CUDA_ARCH__ #ifdef CUDA_HALF_INSTRUCTIONS return __hdiv(a, b); #else float fa = __half2float(a); float fb = __half2float(b); return __float2half( fa / fb ); #endif #else // __CUDA_ARCH__ return THC_float2half(THC_half2float(a) / THC_half2float(b)); #endif } // arithmetic functions inline __host__ __device__ half operator+(half a, half b) { return THCNumerics::add(a, b); } inline __host__ __device__ float operator+(half a, float b) { return ScalarConvert::to(a) + b; } inline __host__ __device__ float operator+(float a, half b) { return a + ScalarConvert::to(b); } inline __host__ __device__ double operator+(double a, half b) { return a + ScalarConvert::to(b); } inline __host__ __device__ half operator-(half a) { return THCNumerics::neg(a); } inline __host__ __device__ half operator-(half a, half b) { return THCNumerics::add(a, THCNumerics::neg(b)); } inline __host__ __device__ half operator-(half a, int b) { return THCNumerics::add(a, THCNumerics::neg(ScalarConvert::to(b))); } inline __host__ __device__ float operator-(half a, float b) { return ScalarConvert::to(a) - b; } inline __host__ __device__ double operator-(half a, double b) { return ScalarConvert::to(a) - b; } inline __host__ __device__ half operator-(int a, half b) { return THCNumerics::add(ScalarConvert::to(a), THCNumerics::neg(b)); } inline __host__ __device__ float operator-(float a, half b) { return a - ScalarConvert::to(b); } inline __host__ __device__ double operator-(double a, half b) { return a - ScalarConvert::to(b); } inline __host__ __device__ half operator*(half a, half b) { return mul(a, b); } inline __host__ __device__ float operator*(half a, float b) { return ScalarConvert::to(a) * b; } inline __host__ __device__ double operator*(half a, double b) { return ScalarConvert::to(a) * b; } inline __host__ __device__ half operator*(half a, int b) { return a * ScalarConvert::to(b); } inline __host__ __device__ float operator*(float a, half b) { return a * ScalarConvert::to(b); } inline __host__ __device__ double operator*(double a, half b) { return a * ScalarConvert::to(b); } inline __host__ __device__ half operator/(half a, half b) { return div(a, b); } inline __host__ __device__ float operator/(float a, half b) { return a / ScalarConvert::to(b); } inline __host__ __device__ double operator/(double a, half b) { return a / ScalarConvert::to(b); } inline __host__ __device__ half operator/(int a, half b) { return ScalarConvert::to(a) / b; } inline __host__ __device__ float operator/(half a, float b) { return ScalarConvert::to(a) / b; } inline __host__ __device__ double operator/(half a, double b) { return ScalarConvert::to(a) / b; } inline __host__ __device__ half operator/(half a, int b) { return a / ScalarConvert::to(b); } inline __host__ __device__ half& operator+=(half &lhs, const half &rhs) { lhs = lhs + rhs; return lhs; } inline __host__ __device__ float& operator+=(float &lhs, const half &rhs) { lhs = lhs + rhs; return lhs; } inline __host__ __device__ float& operator-=(float &lhs, const half &rhs) { lhs = lhs - rhs; return lhs; } inline __host__ __device__ half& operator*=(half &lhs, const half &rhs) { lhs = lhs * rhs; return lhs; } inline __host__ __device__ half& operator/=(half &lhs, const int &rhs) { lhs = lhs / rhs; return lhs; } inline __host__ __device__ half& operator/=(half &lhs, const half &rhs) { lhs = lhs / rhs; return lhs; } inline __host__ __device__ half abs(half a) { return THCNumerics::abs(a); } inline __host__ __device__ half exp(half a) { return THCNumerics::exp(a); } inline __host__ __device__ half log1p(half a) { return THCNumerics::log1p(a); } inline __host__ __device__ half pow(half a, half b) { return THCNumerics::pow(a, b); } inline __host__ __device__ half sqrt(half a) { return THCNumerics::sqrt(a); } inline __host__ __device__ half tanh(half a) { return THCNumerics::tanh(a); } // comparison functions inline __host__ __device__ bool operator<(half a, half b) { return THCNumerics::lt(a, b); } inline __host__ __device__ bool operator<=(half a, half b) { return THCNumerics::le(a, b); } inline __host__ __device__ bool operator<=(half a, int b) { return THCNumerics::le(a, ScalarConvert::to(b)); } inline __host__ __device__ bool operator<(half a, int b) { return THCNumerics::lt(a, ScalarConvert::to(b)); } inline __host__ __device__ bool operator>(half a, half b) { return THCNumerics::gt(a, b); } inline __host__ __device__ bool operator>(half a, int b) { return THCNumerics::gt(a, ScalarConvert::to(b)); } inline __host__ __device__ bool operator>=(half a, half b) { return THCNumerics::ge(a, b); } inline __host__ __device__ bool operator>=(half a, int b) { return THCNumerics::ge(a, ScalarConvert::to(b)); } #endif #endif