Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-10-05 21:55:38 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-10-07 21:50:28 +0300
commit89330c02a1c9e13658156bd8941b5b7b48e3b71e (patch)
tree4bf7ead32d7fe69ad422cf64e6aa7f01e0c492a3
parentaf459755c0d2477342aead1a645cb4969a7dd215 (diff)
[cutorch refactor] move clamp(...) to generic
-rw-r--r--TensorMath.lua7
-rw-r--r--lib/THC/THCTensorMath.h1
-rw-r--r--lib/THC/THCTensorMath2.cu33
-rw-r--r--lib/THC/THCTensorMathPointwise.cuh18
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.cu19
-rw-r--r--lib/THC/generic/THCTensorMathPointwise.h1
-rw-r--r--test/test.lua52
7 files changed, 95 insertions, 36 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 7ec2872..0e41835 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -589,6 +589,13 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor, method={default=1}},
{name=real}})
+ wrap("clamp",
+ cname("clamp"),
+ {{name=Tensor, default=true, returned=true, method={default='nil'}},
+ {name=Tensor, method={default=1}},
+ {name=real},
+ {name=real}})
+
wrap("div",
cname("div"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 21482b7..0c850ab 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -54,7 +54,6 @@ THC_API void THCudaTensor_cminValue(THCState *state, THCudaTensor *self, THCudaT
THC_API void THCudaTensor_cmaxValue(THCState *state, THCudaTensor *self, THCudaTensor *src, float value);
THC_API void THCudaTensor_cross(THCState *state, THCudaTensor *self, THCudaTensor *src1, THCudaTensor *src2, int dimension);
-THC_API void THCudaTensor_clamp(THCState *state, THCudaTensor *self, THCudaTensor *src, float min_value, float max_value);
// MAGMA (i.e. CUDA implementation of LAPACK functions)
THC_API void THCudaTensor_gesv(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_);
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu
index 84a5a1c..64b6af3 100644
--- a/lib/THC/THCTensorMath2.cu
+++ b/lib/THC/THCTensorMath2.cu
@@ -68,39 +68,6 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx,
THCudaCheck(cudaGetLastError());
}
-struct TensorClampOp {
- TensorClampOp(float min, float max) : minValue(min), maxValue(max) {}
- __device__ __forceinline__ void operator()(float* out, float* in) {
- *out = max(min(*in, maxValue), minValue);
- }
-
- __device__ __forceinline__ void operator()(float* v) {
- *v = max(min(*v, maxValue), minValue);
- }
-
- const float minValue;
- const float maxValue;
-};
-
-void THCudaTensor_clamp(THCState *state, THCudaTensor *self_, THCudaTensor *src, float min_value,
- float max_value)
-{
- THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
- if (self_ == src) {
- if (!THC_pointwiseApply1(state, self_, TensorClampOp(min_value, max_value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- THCudaTensor_resizeAs(state, self_, src);
-
- if (!THC_pointwiseApply2(state, self_, src, TensorClampOp(min_value, max_value))) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
-
- THCudaCheck(cudaGetLastError());
-}
-
struct TensorLerpOp {
TensorLerpOp(float w) : w(w) {}
diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh
index e378a83..a690c45 100644
--- a/lib/THC/THCTensorMathPointwise.cuh
+++ b/lib/THC/THCTensorMathPointwise.cuh
@@ -397,4 +397,22 @@ struct TensorDivOp<half> {
};
#endif // CUDA_HALF_TENSOR
+template <typename T>
+struct TensorClampOp {
+ TensorClampOp(T min, T max) : minValue(min), maxValue(max) {}
+ __device__ __forceinline__ void operator()(T* out, T* in) {
+ T val = THCNumerics<T>::lt(*in, maxValue) ? *in : maxValue;
+ *out = THCNumerics<T>::gt(minValue, val) ? minValue : val;
+ }
+
+ __device__ __forceinline__ void operator()(T* v) {
+ T val = THCNumerics<T>::lt(*v, maxValue) ? *v : maxValue;
+ *v = THCNumerics<T>::gt(minValue, val) ? minValue : val;
+ }
+
+ const T minValue;
+ const T maxValue;
+};
+
+
#endif // THC_TENSORMATH_POINTWISE_CUH
diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu
index b2f8950..707fc93 100644
--- a/lib/THC/generic/THCTensorMathPointwise.cu
+++ b/lib/THC/generic/THCTensorMathPointwise.cu
@@ -82,6 +82,25 @@ void THCTensor_(sign)(THCState* state, THCTensor* self_, THCTensor* src) {
THCudaCheck(cudaGetLastError());
}
+void THCTensor_(clamp)(THCState *state, THCTensor *self_, THCTensor *src, real min_value,
+ real max_value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self_, src));
+ if (self_ == src) {
+ if (!THC_pointwiseApply1(state, self_, TensorClampOp<real>(min_value, max_value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ } else {
+ THCTensor_(resizeAs)(state, self_, src);
+
+ if (!THC_pointwiseApply2(state, self_, src, TensorClampOp<real>(min_value, max_value))) {
+ THArgCheck(false, 2, CUTORCH_DIM_WARNING);
+ }
+ }
+
+ THCudaCheck(cudaGetLastError());
+}
+
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
void THCTensor_(sigmoid)(THCState* state, THCTensor* self_, THCTensor* src) {
diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h
index af50278..0c03045 100644
--- a/lib/THC/generic/THCTensorMathPointwise.h
+++ b/lib/THC/generic/THCTensorMathPointwise.h
@@ -36,6 +36,7 @@ THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(abs)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(sign)(THCState *state, THCTensor *self, THCTensor *src);
+THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, real min_value, real max_value);
THC_API void THCTensor_(cadd)(THCState *state, THCTensor *self, THCTensor *src1, real value, THCTensor *src2);
THC_API void THCTensor_(csub)(THCState *state, THCTensor *self, THCTensor *src1, real value, THCTensor *src2);
diff --git a/test/test.lua b/test/test.lua
index b1e6481..2a88268 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1372,7 +1372,12 @@ function test.clamp1()
if sz2 >= 2 then
x[1][2] = max_val + 1
end
- compareFloatAndCudaTensorArgs(x, 'clamp', min_val, max_val)
+ for _, typename in ipairs(typenames) do
+ if typename ~= 'torch.CudaCharTensor' and typename ~= 'torch.CudaByteTensor' then
+ local x = x:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'clamp', min_val, max_val);
+ end
+ end
checkMultiDevice(x, 'clamp', min_val, max_val)
end
@@ -1387,10 +1392,53 @@ function test.clamp2()
x[1][2] = max_val + 1
end
local y = torch.FloatTensor():resizeAs(x)
- compareFloatAndCudaTensorArgs(y, 'clamp', x, min_val, max_val)
+ for _, typename in ipairs(typenames) do
+ if typename ~= 'torch.CudaCharTensor' and typename ~= 'torch.CudaByteTensor' then
+ local x = x:type(t2cpu[typename])
+ local y = y:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, y, 'clamp', x, min_val, max_val);
+ end
+ end
checkMultiDevice(y, 'clamp', x, min_val, max_val)
end
+-- same as clamp1, clamp2 but only allow positive values
+function test.clamp3()
+ local sz1 = chooseInt(minsize, maxsize)
+ local sz2 = chooseInt(minsize, maxsize)
+ local x = torch.FloatTensor():rand(sz1, sz2):mul(5);
+ local min_val = 1
+ local max_val = 3
+ x[1][1] = min_val - 1
+ if sz2 >= 2 then
+ x[1][2] = max_val + 1
+ end
+ for _, typename in ipairs(typenames) do
+ local x = x:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'clamp', min_val, max_val);
+ end
+ checkMultiDevice(x, 'clamp', min_val, max_val)
+end
+
+function test.clamp4()
+ local sz1 = chooseInt(minsize, maxsize)
+ local sz2 = chooseInt(minsize, maxsize)
+ local x = torch.FloatTensor():rand(sz1, sz2):mul(5);
+ local min_val = 1
+ local max_val = 3
+ x[1][1] = min_val - 1
+ if sz2 >= 2 then
+ x[1][2] = max_val + 1
+ end
+ local y = torch.FloatTensor():resizeAs(x)
+ for _, typename in ipairs(typenames) do
+ local x = x:type(t2cpu[typename])
+ local y = y:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, y, 'clamp', x, min_val, max_val);
+ end
+ checkMultiDevice(x, 'clamp', min_val, max_val)
+end
+
function test.index()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)