Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-10-07 19:57:36 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-10-07 21:50:28 +0300
commit63df041cae36863deaf9282de3228e6377f3bcba (patch)
tree6411cd6f6e769888e55b0554b93db1c14d91a836
parent4f67f808afcfae17df066bac67ff0d457e52b813 (diff)
[cutorch refactor] cmin/cmax to generic
-rw-r--r--TensorMath.lua12
-rw-r--r--lib/THC/CMakeLists.txt1
-rw-r--r--lib/THC/THCTensorMath.h5
-rw-r--r--lib/THC/THCTensorMathPointwise.cu117
-rw-r--r--lib/THC/THCTensorMathPointwise.cuh51
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.cu72
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.h4
-rw-r--r--test/test.lua28
8 files changed, 159 insertions, 131 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 53a112a..18c49cb 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -629,6 +629,18 @@ for k, Tensor_ in pairs(handledTypenames) do
{name="index"}})
end
+ for _,name in ipairs({"cmin", "cmax"}) do
+ wrap(name,
+ cname(name),
+ {{name=Tensor, default=true, returned=true},
+ {name=Tensor, method={default=1}},
+ {name=Tensor}},
+ cname(name .. "Value"),
+ {{name=Tensor, default=true, returned=true},
+ {name=Tensor, method={default=1}},
+ {name=real}})
+ end
+
if Tensor == 'CudaByteTensor' then
for _,name in pairs({'all', 'any'}) do
wrap(name,
diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt
index 5b9d100..181bc9d 100644
--- a/lib/THC/CMakeLists.txt
+++ b/lib/THC/CMakeLists.txt
@@ -145,7 +145,6 @@ SET(src-cuda
THCTensorMathBlas.cu
THCTensorMathMagma.cu
THCTensorMathPairwise.cu
- THCTensorMathPointwise.cu
THCTensorMathReduce.cu
THCTensorMathScan.cu
THCTensorIndex.cu
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index f54f026..7010ee3 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -48,11 +48,6 @@ THC_API void THCudaTensor_addcdiv(THCState *state, THCudaTensor *self, THCudaTen
THC_API void THCudaTensor_cumsum(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim);
THC_API void THCudaTensor_cumprod(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim);
-THC_API void THCudaTensor_cmin(THCState *state, THCudaTensor *self, THCudaTensor *src1, THCudaTensor *src2);
-THC_API void THCudaTensor_cmax(THCState *state, THCudaTensor *self, THCudaTensor *src1, THCudaTensor *src2);
-THC_API void THCudaTensor_cminValue(THCState *state, THCudaTensor *self, THCudaTensor *src, float value);
-THC_API void THCudaTensor_cmaxValue(THCState *state, THCudaTensor *self, THCudaTensor *src, float value);
-
// MAGMA (i.e. CUDA implementation of LAPACK functions)
THC_API void THCudaTensor_gesv(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_);
THC_API void THCudaTensor_gels(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_);
diff --git a/lib/THC/THCTensorMathPointwise.cu b/lib/THC/THCTensorMathPointwise.cu
deleted file mode 100644
index 4e3480c..0000000
--- a/lib/THC/THCTensorMathPointwise.cu
+++ /dev/null
@@ -1,117 +0,0 @@
-#include "THCTensorMathPointwise.cuh"
-
-struct TensorMaxOp {
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out = max(*out, *in);
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
- *out = max(*in1, *in2);
- }
-};
-
-void THCudaTensor_cmax(THCState *state, THCudaTensor *self, THCudaTensor *src1, THCudaTensor *src2)
-{
- THAssert(THCudaTensor_checkGPU(state, 3, self, src1, src2));
- THArgCheck(THCudaTensor_nElement(state, src1) ==
- THCudaTensor_nElement(state, src2), 2, "sizes do not match");
-
- if (self == src1) {
- if (!THC_pointwiseApply2(state, self, src2, TensorMaxOp())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- THCudaTensor_resizeAs(state, self, src1);
- if (!THC_pointwiseApply3(state, self, src1, src2, TensorMaxOp())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
-}
-
-struct TensorMinOp {
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out = min(*out, *in);
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
- *out = min(*in1, *in2);
- }
-};
-
-void THCudaTensor_cmin(THCState *state, THCudaTensor *self, THCudaTensor *src1, THCudaTensor *src2)
-{
- THAssert(THCudaTensor_checkGPU(state, 3, self, src1, src2));
- THArgCheck(THCudaTensor_nElement(state, src1) ==
- THCudaTensor_nElement(state, src2), 2, "sizes do not match");
-
- if (self == src1) {
- if (!THC_pointwiseApply2(state, self, src2, TensorMinOp())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- THCudaTensor_resizeAs(state, self, src1);
- if (!THC_pointwiseApply3(state, self, src1, src2, TensorMinOp())) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
-}
-
-struct TensorMaxValueOp {
- TensorMaxValueOp(float v) : val(v) {}
-
- __device__ __forceinline__ void operator()(float* out) {
- *out = max(*out, val);
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out = max(*in, val);
- }
-
- float val;
-};
-
-void THCudaTensor_cmaxValue(THCState *state, THCudaTensor *self, THCudaTensor *src, float value)
-{
- THAssert(THCudaTensor_checkGPU(state, 2, self, src));
-
- if (self == src) {
- if (!THC_pointwiseApply1(state, self, TensorMaxValueOp(value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- THCudaTensor_resizeAs(state, self, src);
- if (!THC_pointwiseApply2(state, self, src, TensorMaxValueOp(value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
-}
-
-struct TensorMinValueOp {
- TensorMinValueOp(float v) : val(v) {}
-
- __device__ __forceinline__ void operator()(float* out) {
- *out = min(*out, val);
- }
-
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out = min(*in, val);
- }
-
- float val;
-};
-
-void THCudaTensor_cminValue(THCState *state, THCudaTensor *self, THCudaTensor *src, float value)
-{
- THAssert(THCudaTensor_checkGPU(state, 2, self, src));
-
- if (self == src) {
- if (!THC_pointwiseApply1(state, self, TensorMinValueOp(value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- THCudaTensor_resizeAs(state, self, src);
- if (!THC_pointwiseApply2(state, self, src, TensorMinValueOp(value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
-}
diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh
index 3f99279..5a6de80 100644
--- a/lib/THC/THCTensorMathPointwise.cuh
+++ b/lib/THC/THCTensorMathPointwise.cuh
@@ -455,5 +455,56 @@ struct TensorCrossOp {
const long sx, sy, so;
};
+template <typename T>
+struct TensorMaxOp {
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ *out = THCNumerics<T>::gt(*out, *in) ? *out : *in;
+ }
+
+ __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
+ *out = THCNumerics<T>::gt(*in1, *in2) ? *in1 : *in2;
+ }
+};
+
+template <typename T>
+struct TensorMinOp {
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ *out = THCNumerics<T>::lt(*out, *in) ? *out : *in;
+ }
+
+ __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
+ *out = THCNumerics<T>::lt(*in1, *in2) ? *in1 : *in2;
+ }
+};
+
+template <typename T>
+struct TensorMaxValueOp {
+ TensorMaxValueOp(T v) : val(v) {}
+
+ __device__ __forceinline__ void operator()(T* out) {
+ *out = THCNumerics<T>::gt(*out, val) ? *out : val;
+ }
+
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ *out = THCNumerics<T>::gt(*in, val) ? *in : val;
+ }
+
+ T val;
+};
+
+template <typename T>
+struct TensorMinValueOp {
+ TensorMinValueOp(T v) : val(v) {}
+
+ __device__ __forceinline__ void operator()(T* out) {
+ *out = THCNumerics<T>::lt(*out, val) ? *out : val;
+ }
+
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ *out = THCNumerics<T>::lt(*in, val) ? *in : val;
+ }
+
+ T val;
+};
#endif // THC_TENSORMATH_POINTWISE_CUH
diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu
index 86cea2a..acb2d4c 100644
--- a/lib/THC/generic/THCTensorMathPointwise.cu
+++ b/lib/THC/generic/THCTensorMathPointwise.cu
@@ -342,4 +342,76 @@ THCTensor_(cdiv)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *
THCudaCheck(cudaGetLastError());
}
+THC_API void
+THCTensor_(cmax)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self, src1, src2));
+ THArgCheck(THCTensor_(nElement)(state, src1) ==
+ THCTensor_(nElement)(state, src2), 2, "sizes do not match");
+
+ if (self == src1) {
+ if (!THC_pointwiseApply2(state, self, src2, TensorMaxOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self, src1);
+ if (!THC_pointwiseApply3(state, self, src1, src2, TensorMaxOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+}
+
+THC_API void
+THCTensor_(cmin)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2)
+{
+ THAssert(THCTensor_(checkGPU)(state, 3, self, src1, src2));
+ THArgCheck(THCTensor_(nElement)(state, src1) ==
+ THCTensor_(nElement)(state, src2), 2, "sizes do not match");
+
+ if (self == src1) {
+ if (!THC_pointwiseApply2(state, self, src2, TensorMinOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self, src1);
+ if (!THC_pointwiseApply3(state, self, src1, src2, TensorMinOp<real>())) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+}
+
+THC_API void
+THCTensor_(cmaxValue)(THCState *state, THCTensor *self, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self, src));
+
+ if (self == src) {
+ if (!THC_pointwiseApply1(state, self, TensorMaxValueOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self, src);
+ if (!THC_pointwiseApply2(state, self, src, TensorMaxValueOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+}
+
+THC_API void
+THCTensor_(cminValue)(THCState *state, THCTensor *self, THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self, src));
+
+ if (self == src) {
+ if (!THC_pointwiseApply1(state, self, TensorMinValueOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self, src);
+ if (!THC_pointwiseApply2(state, self, src, TensorMinValueOp<real>(value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+}
+
#endif
diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h
index 03ab32e..efbe76c 100644
--- a/lib/THC/generic/THCTensorMathPointwise.h
+++ b/lib/THC/generic/THCTensorMathPointwise.h
@@ -44,5 +44,9 @@ THC_API void THCTensor_(csub)(THCState *state, THCTensor *self, THCTensor *src1,
THC_API void THCTensor_(cmul)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(cpow)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
THC_API void THCTensor_(cdiv)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(cmax)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(cmin)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2);
+THC_API void THCTensor_(cmaxValue)(THCState *state, THCTensor *self, THCTensor *src, real value);
+THC_API void THCTensor_(cminValue)(THCState *state, THCTensor *self, THCTensor *src, real value);
#endif
diff --git a/test/test.lua b/test/test.lua
index 64884f0..2a6e88c 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1041,10 +1041,16 @@ function test.cmax()
local c = torch.FloatTensor(sz1, sz2):zero()
local v = torch.uniform()
- compareFloatAndCudaTensorArgs(c, 'cmax', a, b)
- compareFloatAndCudaTensorArgs(c, 'cmax', a, v)
- compareFloatAndCudaTensorArgs(a, 'cmax', b)
- compareFloatAndCuda(a, 'cmax', v)
+ for _, typename in ipairs(typenames) do
+ local a = a:type(t2cpu[typename])
+ local b = b:type(t2cpu[typename])
+ local c = c:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, c, 'cmax', a, b)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, c, 'cmax', a, v)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, a, 'cmax', b)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, a, 'cmax', v)
+ end
+
checkMultiDevice(c, 'cmax', a, b)
checkMultiDevice(c, 'cmax', a, v)
checkMultiDevice(a, 'cmax', b)
@@ -1059,10 +1065,16 @@ function test.cmin()
local c = torch.FloatTensor(sz1, sz2):zero()
local v = torch.uniform()
- compareFloatAndCudaTensorArgs(c, 'cmin', a, b)
- compareFloatAndCudaTensorArgs(c, 'cmin', a, v)
- compareFloatAndCudaTensorArgs(a, 'cmin', b)
- compareFloatAndCuda(a, 'cmin', v)
+ for _, typename in ipairs(typenames) do
+ local a = a:type(t2cpu[typename])
+ local b = b:type(t2cpu[typename])
+ local c = c:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, c, 'cmin', a, b)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, c, 'cmin', a, v)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, a, 'cmin', b)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, a, 'cmin', v)
+ end
+
checkMultiDevice(c, 'cmin', a, b)
checkMultiDevice(c, 'cmin', a, v)
checkMultiDevice(a, 'cmin', b)