diff options
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 31 | ||||
-rw-r--r-- | lib/THC/THCTensorMathPointwise.cuh | 54 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.cu | 8 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.cu | 17 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 32 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.h | 1 |
7 files changed, 113 insertions, 31 deletions
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index 04ce041..55c225f 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -15,37 +15,6 @@ #include <thrust/system/cuda/execution_policy.h> #endif -struct TensorPowOp { - TensorPowOp(float v) : val(v) {} - __device__ __forceinline__ void operator()(float* out, float* in) { - *out = powf(*in, val); - } - - __device__ __forceinline__ void operator()(float* v) { - *v = powf(*v, val); - } - - const float val; -}; - -void THCudaTensor_pow(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src)); - if (self_ == src) { - if (!THC_pointwiseApply1(state, self_, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCudaTensor_resizeAs(state, self_, src); - - if (!THC_pointwiseApply2(state, self_, src, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - THCudaCheck(cudaGetLastError()); -} - struct TensorTPowOp { TensorTPowOp(float v) : val(v) {} diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh index 72873ad..e378a83 100644 --- a/lib/THC/THCTensorMathPointwise.cuh +++ b/lib/THC/THCTensorMathPointwise.cuh @@ -264,6 +264,60 @@ struct TensorMulOp<half> { }; #endif // CUDA_HALF_TENSOR +template<typename T> +struct TensorPowOp { + TensorPowOp(T v) : val(v) {} + __device__ __forceinline__ void operator()(T* out, T* in) { + *out = powf((float) *in, (float) val); + } + + __device__ __forceinline__ void operator()(T* v) { + *v = powf((float) *v, (float) val); + } + + const T val; +}; + +template <> +struct TensorPowOp<double> { + TensorPowOp(double v) : val(v) {} + + __device__ __forceinline__ void operator()(double* out, double* in) { + *out = pow(*in, val); + } + + __device__ __forceinline__ void operator()(double* v) { + *v = pow(*v, val); + } + + const double val; +}; + +#ifdef CUDA_HALF_TENSOR +template <> +struct TensorPowOp<half> { + TensorPowOp(half v) : val(v) {} + + __device__ __forceinline__ void operator()(half* out, half* in) { + // No fp16 pow function yet + float fin = __half2float(*in); + float fval = __half2float(val); + float fout = powf(fin, fval); + *out = __float2half(fout); + } + + __device__ __forceinline__ void operator()(half* v) { + // No fp16 pow function yet + float fv = __half2float(*v); + float fval = __half2float(val); + float fout = powf(fv, fval); + *v = __float2half(fout); + } + + const half val; +}; +#endif // CUDA_HALF_TENSOR + template <typename T> struct TensorCPowOp { __device__ __forceinline__ void operator()(T* out, T* in) { diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index a0e550a..9ffc89b 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -36,6 +36,14 @@ THCTensor_(zero)(THCState *state, THCTensor *self_) } THC_API void +THCTensor_(mean)(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self, src)); + THCudaTensor_sum(state, self, src, dim); + THCudaTensor_div(state, self, self, THCudaTensor_size(state, src, dim)); +} + +THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size) { THAssert(THCTensor_(checkGPU)(state, 1, r_)); diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h index 5c9e66d..6b59262 100644 --- a/lib/THC/generic/THCTensorMath.h +++ b/lib/THC/generic/THCTensorMath.h @@ -4,6 +4,7 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, real value); THC_API void THCTensor_(zero)(THCState *state, THCTensor *self); +THC_API void THCTensor_(mean)(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim); THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *size); THC_API void THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size); diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index c4e763a..7cbd00f 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -101,6 +101,23 @@ void THCTensor_(sigmoid)(THCState* state, THCTensor* self_, THCTensor* src) { THCudaCheck(cudaGetLastError()); } +void THCTensor_pow(THCState *state, THCTensor *self_, THCTensor *src, real value) { + THAssert(THCTensor_(checkGPU)(state, 2, self_, src)); + if (self_ == src) { + if (!THC_pointwiseApply1(state, self_, TensorPowOp<real>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCTensor_(resizeAs)(state, self_, src); + + if (!THC_pointwiseApply2(state, self_, src, TensorPowOp<real>(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} + #endif THC_API void diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index 4b3d28b..90cc3e3 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -201,4 +201,36 @@ THCTensor_(min)(THCState *state, MinValuePair<typename TensorUtils<THCTensor>::DataType, long>()); } +#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) + +THC_API void THTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self, src)); + if (value == 0.0) { + THC_reduceDim(state, self, src, + TensorNonZeroOp<real>(), thrust::plus<real>(), + 0.0, dimension); + } else if (value == 1.0) { + THC_reduceDim(state, self, src, + TensorNormOp<real, 1>(), thrust::plus<real>(), + 0.0, dimension); + + } else if (value == 2.0) { + THC_reduceDim(state, self, src, + TensorNormOp<real, 2>(), thrust::plus<real>(), + 0.0, dimension); + THCTensor_(pow)(state, self, self, 0.5); + + } else { + THC_reduceDim(state, self, src, + TensorNormOp<real, -1>(), thrust::plus<real>(), + 0.0, dimension); + THCTensor_(pow)(state, self, self, 1.0 / value); + } + + THCudaCheck(cudaGetLastError()); +} + +#endif + #endif diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h index 699a89d..d25317e 100644 --- a/lib/THC/generic/THCTensorMathReduce.h +++ b/lib/THC/generic/THCTensorMathReduce.h @@ -6,6 +6,7 @@ THC_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension, real max_norm); THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, long dim, int flag); +THC_API void THTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, long dimension); #endif |