diff options
Diffstat (limited to 'lib/THC/THCTensorMathPairwise.cu')
-rw-r--r-- | lib/THC/THCTensorMathPairwise.cu | 133 |
1 files changed, 69 insertions, 64 deletions
diff --git a/lib/THC/THCTensorMathPairwise.cu b/lib/THC/THCTensorMathPairwise.cu index 383344b..2c081d1 100644 --- a/lib/THC/THCTensorMathPairwise.cu +++ b/lib/THC/THCTensorMathPairwise.cu @@ -1,96 +1,98 @@ #include "THCTensorMath.h" #include "THCGeneral.h" #include "THCBlas.h" +#include "THCHalf.h" #include "THCTensorCopy.h" #include "THCApply.cuh" #include "THCReduce.cuh" +template <typename T> struct TensorAddConstantOp { - TensorAddConstantOp(float v) : val(v) {} - __device__ __forceinline__ void operator()(float* out, float* in) { + TensorAddConstantOp(T v) : val(v) {} + __device__ __forceinline__ void operator()(T* out, T* in) { *out = *in + val; } - __device__ __forceinline__ void operator()(float* v) { + __device__ __forceinline__ void operator()(T* v) { *v += val; } - const float val; + const T val; }; -void THCudaTensor_add(THCState *state, THCudaTensor *self_, THCudaTensor *src_, float value) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src_)); - if (self_ == src_) { - if (!THCudaTensor_pointwiseApply1(state, self_, TensorAddConstantOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCudaTensor_resizeAs(state, self_, src_); - - if (!THCudaTensor_pointwiseApply2(state, self_, src_, TensorAddConstantOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } +#ifdef CUDA_HALF_TENSOR +template <> +struct TensorAddConstantOp<half> { + TensorAddConstantOp(half v) : val(v) {} + __device__ __forceinline__ void operator()(half* out, half* in) { +#ifdef CUDA_HALF_INSTRUCTIONS + *out = __hadd(*in, val); +#else + float fin = __half2float(*in); + float fval = __half2float(val); + float fout = fin + fval; + *out = __float2half(fout); +#endif } - THCudaCheck(cudaGetLastError()); -} + __device__ __forceinline__ void operator()(half* v) { +#ifdef CUDA_HALF_INSTRUCTIONS + *v = __hadd(*v, val); +#else + float fv = __half2float(*v); + float fval = __half2float(val); + fv += fval; + *v = __float2half(fv); +#endif + } -void THCudaTensor_sub(THCState *state, THCudaTensor *self_, THCudaTensor *src_, float value) -{ - THCudaTensor_add(state, self_, src_, -value); -} + const half val; +}; +#endif // CUDA_HALF_TENSOR +template <typename T> struct TensorMulConstantOp { - TensorMulConstantOp(float v) : val(v) {} - __device__ __forceinline__ void operator()(float* out, float* in) { + TensorMulConstantOp(T v) : val(v) {} + __device__ __forceinline__ void operator()(T* out, T* in) { *out = *in * val; } - __device__ __forceinline__ void operator()(float* v) { + __device__ __forceinline__ void operator()(T* v) { *v *= val; } - const float val; + const T val; }; -void THCudaTensor_mul(THCState *state, THCudaTensor *self_, THCudaTensor *src_, float value) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src_)); - if (self_ == src_) { - if (!THCudaTensor_pointwiseApply1(state, self_, TensorMulConstantOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCudaTensor_resizeAs(state, self_, src_); - - if (!THCudaTensor_pointwiseApply2(state, self_, src_, TensorMulConstantOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } +#ifdef CUDA_HALF_TENSOR +template <> +struct TensorMulConstantOp<half> { + TensorMulConstantOp(half v) : val(v) {} + __device__ __forceinline__ void operator()(half* out, half* in) { +#ifdef CUDA_HALF_INSTRUCTIONS + *out = __hmul(*in, val); +#else + float fin = __half2float(*in); + float fval = __half2float(val); + float fout = fin * fval; + *out = __float2half(fout); +#endif } - THCudaCheck(cudaGetLastError()); -} - -void THCudaTensor_div(THCState* state, THCudaTensor *self_, THCudaTensor *src_, float value) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src_)); - THArgCheck(value != 0.0f, 3, "divide by zero"); - - if (self_ == src_) { - if (!THCudaTensor_pointwiseApply1(state, self_, TensorMulConstantOp(1.0f / value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCudaTensor_resizeAs(state, self_, src_); - - if (!THCudaTensor_pointwiseApply2(state, self_, src_, TensorMulConstantOp(1.0f / value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } + __device__ __forceinline__ void operator()(half* v) { +#ifdef CUDA_HALF_INSTRUCTIONS + *v = __hmul(*v, val); +#else + float fv = __half2float(*v); + float fval = __half2float(val); + fv *= fval; + *v = __float2half(fv); +#endif } - THCudaCheck(cudaGetLastError()); -} + const half val; +}; +#endif // CUDA_HALF_TENSOR template <int Upper> struct TensorTriOp { @@ -143,13 +145,13 @@ void THCudaTensor_tril(THCState *state, THCudaTensor *self_, THCudaTensor *src_, TensorTriOp<0> op(start, stride0, stride1, k); if (self_ == src_) { - if (!THCudaTensor_pointwiseApply1(state, src, op)) { + if (!THC_pointwiseApply1(state, src, op)) { THArgCheck(false, 2, CUTORCH_DIM_WARNING); } } else { THCudaTensor_resizeAs(state, self_, src); - if (!THCudaTensor_pointwiseApply2(state, self_, src, op)) { + if (!THC_pointwiseApply2(state, self_, src, op)) { THArgCheck(false, 2, CUTORCH_DIM_WARNING); } } @@ -176,13 +178,13 @@ void THCudaTensor_triu(THCState *state, THCudaTensor *self_, THCudaTensor *src_, TensorTriOp<1> op(start, stride0, stride1, k); if (self_ == src_) { - if (!THCudaTensor_pointwiseApply1(state, src, op)) { + if (!THC_pointwiseApply1(state, src, op)) { THArgCheck(false, 2, CUTORCH_DIM_WARNING); } } else { THCudaTensor_resizeAs(state, self_, src); - if (!THCudaTensor_pointwiseApply2(state, self_, src, op)) { + if (!THC_pointwiseApply2(state, self_, src, op)) { THArgCheck(false, 2, CUTORCH_DIM_WARNING); } } @@ -192,3 +194,6 @@ void THCudaTensor_triu(THCState *state, THCudaTensor *self_, THCudaTensor *src_, THCudaCheck(cudaGetLastError()); } + +#include "generic/THCTensorMathPairwise.cu" +#include "THCGenerateAllTypes.h" |