Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'lib/THC/THCTensorMathPointwise.cu')
-rw-r--r--lib/THC/THCTensorMathPointwise.cu429
1 files changed, 259 insertions, 170 deletions
diff --git a/lib/THC/THCTensorMathPointwise.cu b/lib/THC/THCTensorMathPointwise.cu
index 80b1ead..72b16e8 100644
--- a/lib/THC/THCTensorMathPointwise.cu
+++ b/lib/THC/THCTensorMathPointwise.cu
@@ -1,6 +1,7 @@
#include "THCTensorMath.h"
#include "THCGeneral.h"
#include "THCBlas.h"
+#include "THCHalf.h"
#include "THCTensorCopy.h"
#include "THCApply.cuh"
#include "THCReduce.cuh"
@@ -19,13 +20,13 @@
void THCudaTensor_##NAME(THCState* state, THCudaTensor* self_, THCudaTensor* src) { \
THAssert(THCudaTensor_checkGPU(state, 2, self_, src)); \
if (self_ == src) { \
- if (!THCudaTensor_pointwiseApply1(state, self_, Tensor##NAME##Op())) { \
+ if (!THC_pointwiseApply1(state, self_, Tensor##NAME##Op())) { \
THArgCheck(false, 2, CUTORCH_DIM_WARNING); \
} \
} else { \
THCudaTensor_resizeAs(state, self_, src); \
\
- if (!THCudaTensor_pointwiseApply2(state, self_, src, Tensor##NAME##Op())) { \
+ if (!THC_pointwiseApply2(state, self_, src, Tensor##NAME##Op())) { \
THArgCheck(false, 2, CUTORCH_DIM_WARNING); \
} \
} \
@@ -76,171 +77,13 @@ struct TensorSigmoidOp {
void THCudaTensor_sigmoid(THCState* state, THCudaTensor* self_, THCudaTensor* src) {
THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
if (self_ == src) {
- if (!THCudaTensor_pointwiseApply1(state, self_, TensorSigmoidOp())) {
+ if (!THC_pointwiseApply1(state, self_, TensorSigmoidOp())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCudaTensor_resizeAs(state, self_, src);
- if (!THCudaTensor_pointwiseApply2(state, self_, src, TensorSigmoidOp())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
-
- THCudaCheck(cudaGetLastError());
-}
-
-struct TensorAddOp {
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out += *in;
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
- *out = *in1 + *in2;
- }
-};
-
-struct TensorCAddOp {
- TensorCAddOp(float v) : val(v) {}
-
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out += val * *in;
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
- *out = *in1 + val * *in2;
- }
-
- float val;
-};
-
-void THCudaTensor_cadd(THCState *state, THCudaTensor *self_, THCudaTensor* src1, float value, THCudaTensor *src2)
-{
- THAssert(THCudaTensor_checkGPU(state, 3, self_, src1, src2));
- THArgCheck(THCudaTensor_nElement(state, src1) ==
- THCudaTensor_nElement(state, src2), 3, "sizes do not match");
-
- if (self_ == src1) {
- if (value == 1.0f) {
- // self += src2
- if (!THCudaTensor_pointwiseApply2(state, self_, src2, TensorAddOp())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- // self += value * src2
- if (!THCudaTensor_pointwiseApply2(state, self_, src2, TensorCAddOp(value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
- } else {
- THCudaTensor_resizeAs(state, self_, src1);
-
- if (value == 1.0f) {
- // self = src1 + src2
- if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorAddOp())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- // self = src1 + value * src2
- if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorCAddOp(value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
- }
-
- THCudaCheck(cudaGetLastError());
-}
-
-struct TensorSubOp {
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out -= *in;
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
- *out = *in1 - *in2;
- }
-};
-
-
-struct TensorCSubOp {
- TensorCSubOp(float v) : val(v) {}
-
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out -= val * *in;
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
- *out = *in1 - val * *in2;
- }
-
- float val;
-};
-
-
-void THCudaTensor_csub(THCState *state, THCudaTensor *self_, THCudaTensor* src1, float value, THCudaTensor *src2)
-{
- THAssert(THCudaTensor_checkGPU(state, 3, self_, src1, src2));
- THArgCheck(THCudaTensor_nElement(state, src1) ==
- THCudaTensor_nElement(state, src2), 3, "sizes do not match");
-
- if (self_ == src1) {
- if (value == 1.0f) {
- // self -= src2
- if (!THCudaTensor_pointwiseApply2(state, self_, src2, TensorSubOp())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- // self += -value * src2
- if (!THCudaTensor_pointwiseApply2(state, self_, src2, TensorCAddOp(-value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
- } else {
- THCudaTensor_resizeAs(state, self_, src1);
-
- if (value == 1.0f) {
- // self = src1 - src2
- if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorSubOp())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- // self = src1 - value * src2
- if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorCAddOp(-value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
- }
-
- THCudaCheck(cudaGetLastError());
-}
-
-
-struct TensorMulOp {
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out *= *in;
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
- *out = *in1 * *in2;
- }
-};
-
-void THCudaTensor_cmul(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2)
-{
- THAssert(THCudaTensor_checkGPU(state, 3, self_, src1, src2));
- THArgCheck(THCudaTensor_nElement(state, src1) ==
- THCudaTensor_nElement(state, src2), 3, "sizes do not match");
-
- if (self_ == src1) {
- // self *= src2
- if (!THCudaTensor_pointwiseApply2(state, self_, src2, TensorMulOp())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- THCudaTensor_resizeAs(state, self_, src1);
-
- // self = src1 * src2
- if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorMulOp())) {
+ if (!THC_pointwiseApply2(state, self_, src, TensorSigmoidOp())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}
@@ -265,12 +108,12 @@ void THCudaTensor_cmax(THCState *state, THCudaTensor *self, THCudaTensor *src1,
THCudaTensor_nElement(state, src2), 2, "sizes do not match");
if (self == src1) {
- if (!THCudaTensor_pointwiseApply2(state, self, src2, TensorMaxOp())) {
+ if (!THC_pointwiseApply2(state, self, src2, TensorMaxOp())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCudaTensor_resizeAs(state, self, src1);
- if (!THCudaTensor_pointwiseApply3(state, self, src1, src2, TensorMaxOp())) {
+ if (!THC_pointwiseApply3(state, self, src1, src2, TensorMaxOp())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}
@@ -293,12 +136,12 @@ void THCudaTensor_cmin(THCState *state, THCudaTensor *self, THCudaTensor *src1,
THCudaTensor_nElement(state, src2), 2, "sizes do not match");
if (self == src1) {
- if (!THCudaTensor_pointwiseApply2(state, self, src2, TensorMinOp())) {
+ if (!THC_pointwiseApply2(state, self, src2, TensorMinOp())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCudaTensor_resizeAs(state, self, src1);
- if (!THCudaTensor_pointwiseApply3(state, self, src1, src2, TensorMinOp())) {
+ if (!THC_pointwiseApply3(state, self, src1, src2, TensorMinOp())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}
@@ -323,12 +166,12 @@ void THCudaTensor_cmaxValue(THCState *state, THCudaTensor *self, THCudaTensor *s
THAssert(THCudaTensor_checkGPU(state, 2, self, src));
if (self == src) {
- if (!THCudaTensor_pointwiseApply1(state, self, TensorMaxValueOp(value))) {
+ if (!THC_pointwiseApply1(state, self, TensorMaxValueOp(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCudaTensor_resizeAs(state, self, src);
- if (!THCudaTensor_pointwiseApply2(state, self, src, TensorMaxValueOp(value))) {
+ if (!THC_pointwiseApply2(state, self, src, TensorMaxValueOp(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}
@@ -353,13 +196,259 @@ void THCudaTensor_cminValue(THCState *state, THCudaTensor *self, THCudaTensor *s
THAssert(THCudaTensor_checkGPU(state, 2, self, src));
if (self == src) {
- if (!THCudaTensor_pointwiseApply1(state, self, TensorMinValueOp(value))) {
+ if (!THC_pointwiseApply1(state, self, TensorMinValueOp(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCudaTensor_resizeAs(state, self, src);
- if (!THCudaTensor_pointwiseApply2(state, self, src, TensorMinValueOp(value))) {
+ if (!THC_pointwiseApply2(state, self, src, TensorMinValueOp(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}
}
+
+template <typename T>
+struct TensorAddOp {
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ *out += *in;
+ }
+
+ __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
+ *out = *in1 + *in2;
+ }
+};
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+struct TensorAddOp<half> {
+ __device__ __forceinline__ void operator()(half* out, half* in) {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ *out = __hadd(*out, *in);
+#else
+ float fout = __half2float(*out);
+ float fin = __half2float(*in);
+ fout += fin;
+ *out = __float2half(fout);
+#endif
+ }
+
+ __device__ __forceinline__ void operator()(half* out, half* in1, half* in2) {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ *out = __hadd(*in1, *in2);
+#else
+ float fin1 = __half2float(*in1);
+ float fin2 = __half2float(*in2);
+ float fout = fin1 + fin2;
+ *out = __float2half(fout);
+#endif
+ }
+};
+#endif // CUDA_HALF_TENSOR
+
+template <typename T>
+struct TensorCAddOp {
+ TensorCAddOp(T v) : val(v) {}
+
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ *out += val * *in;
+ }
+
+ __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
+ *out = *in1 + val * *in2;
+ }
+
+ T val;
+};
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+struct TensorCAddOp<half> {
+ TensorCAddOp(half v) : val(v) {}
+
+ __device__ __forceinline__ void operator()(half* out, half* in) {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ *out = __hadd(*out, __hmul(val, *in));
+#else
+ float fout = __half2float(*out);
+ float fval = __half2float(val);
+ float fin = __half2float(*in);
+
+ fout += fval * fin;
+ *out = __float2half(fout);
+#endif
+ }
+
+ __device__ __forceinline__ void operator()(half* out, half* in1, half* in2) {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ *out = __hadd(*in1, __hmul(val, *in2));
+#else
+ float fin1 = __half2float(*in1);
+ float fin2 = __half2float(*in2);
+ float fval = __half2float(val);
+
+ float fout = fin1 + fval * fin2;
+ *out = __float2half(fout);
+#endif
+ }
+
+ half val;
+};
+#endif // CUDA_HALF_TENSOR
+
+template <typename T>
+struct TensorSubOp {
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ *out -= *in;
+ }
+
+ __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
+ *out = *in1 - *in2;
+ }
+};
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+struct TensorSubOp<half> {
+ __device__ __forceinline__ void operator()(half* out, half* in) {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ *out = __hsub(*out, *in);
+#else
+ float fout = __half2float(*out);
+ float fin = __half2float(*in);
+ fout -= fin;
+ *out = __float2half(fout);
+#endif
+ }
+
+ __device__ __forceinline__ void operator()(half* out, half* in1, half* in2) {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ *out = __hsub(*in1, *in2);
+#else
+ float fin1 = __half2float(*in1);
+ float fin2 = __half2float(*in2);
+ float fout = fin1 - fin2;
+ *out = __float2half(fout);
+#endif
+ }
+};
+#endif // CUDA_HALF_TENSOR
+
+template <typename T>
+struct TensorMulOp {
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ *out *= *in;
+ }
+
+ __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
+ *out = *in1 * *in2;
+ }
+};
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+struct TensorMulOp<half> {
+ __device__ __forceinline__ void operator()(half* out, half* in) {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ *out = __hmul(*out, *in);
+#else
+ float fout = __half2float(*out);
+ float fin = __half2float(*in);
+ fout *= fin;
+ *out = __float2half(fout);
+#endif
+ }
+
+ __device__ __forceinline__ void operator()(half* out, half* in1, half* in2) {
+#ifdef CUDA_HALF_INSTRUCTIONS
+ *out = __hmul(*in1, *in2);
+#else
+ float fin1 = __half2float(*in1);
+ float fin2 = __half2float(*in2);
+ float fout = fin1 * fin2;
+ *out = __float2half(fout);
+#endif
+ }
+};
+#endif // CUDA_HALF_TENSOR
+
+template <typename T>
+struct TensorCPowOp {
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ *out = powf((float) *out, (float) *in);
+ }
+
+ __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
+ *out = powf((float) *in1, (float) *in2);
+ }
+};
+
+template <>
+struct TensorCPowOp<double> {
+ __device__ __forceinline__ void operator()(double* out, double* in) {
+ *out = pow(*out, *in);
+ }
+
+ __device__ __forceinline__ void operator()(double* out, double* in1, double* in2) {
+ *out = pow(*in1, *in2);
+ }
+};
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+struct TensorCPowOp<half> {
+ __device__ __forceinline__ void operator()(half* out, half* in) {
+ // No fp16 pow function yet
+ float fout = __half2float(*out);
+ float fin = __half2float(*in);
+ fout = powf(fout, fin);
+ *out = __float2half(fout);
+ }
+
+ __device__ __forceinline__ void operator()(half* out, half* in1, half* in2) {
+ // No fp16 pow function yet
+ float fin1 = __half2float(*in1);
+ float fin2 = __half2float(*in2);
+ float fout = powf(fin1, fin2);
+ *out = __float2half(fout);
+ }
+};
+#endif // CUDA_HALF_TENSOR
+
+template <typename T>
+struct TensorDivOp {
+ __device__ __forceinline__ void
+ operator()(T* out, T* in) {
+ *out /= *in;
+ }
+
+ __device__ __forceinline__ void
+ operator()(T* out, T* in1, T* in2) {
+ *out = *in1 / *in2;
+ }
+};
+
+#ifdef CUDA_HALF_TENSOR
+template <>
+struct TensorDivOp<half> {
+ __device__ __forceinline__ void
+ operator()(half* out, half* in) {
+ // No fp16 div instruction yet
+ float fout = __half2float(*out);
+ float fin = __half2float(*in);
+ fout /= fin;
+ *out = __float2half(fout);
+ }
+
+ __device__ __forceinline__ void
+ operator()(half* out, half* in1, half* in2) {
+ // No fp16 div instruction yet
+ float fin1 = __half2float(*in1);
+ float fin2 = __half2float(*in2);
+ float fout = fin1 / fin2;
+ *out = __float2half(fout);
+ }
+};
+#endif // CUDA_HALF_TENSOR
+
+#include "generic/THCTensorMathPointwise.cu"
+#include "THCGenerateAllTypes.h"