diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-10-05 21:55:38 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:28 +0300 |
commit | 89330c02a1c9e13658156bd8941b5b7b48e3b71e (patch) | |
tree | 4bf7ead32d7fe69ad422cf64e6aa7f01e0c492a3 | |
parent | af459755c0d2477342aead1a645cb4969a7dd215 (diff) |
[cutorch refactor] move clamp(...) to generic
-rw-r--r-- | TensorMath.lua | 7 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 33 | ||||
-rw-r--r-- | lib/THC/THCTensorMathPointwise.cuh | 18 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.cu | 19 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathPointwise.h | 1 | ||||
-rw-r--r-- | test/test.lua | 52 |
7 files changed, 95 insertions, 36 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 7ec2872..0e41835 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -589,6 +589,13 @@ for k, Tensor_ in pairs(handledTypenames) do {name=Tensor, method={default=1}}, {name=real}}) + wrap("clamp", + cname("clamp"), + {{name=Tensor, default=true, returned=true, method={default='nil'}}, + {name=Tensor, method={default=1}}, + {name=real}, + {name=real}}) + wrap("div", cname("div"), {{name=Tensor, default=true, returned=true, method={default='nil'}}, diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 21482b7..0c850ab 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -54,7 +54,6 @@ THC_API void THCudaTensor_cminValue(THCState *state, THCudaTensor *self, THCudaT THC_API void THCudaTensor_cmaxValue(THCState *state, THCudaTensor *self, THCudaTensor *src, float value); THC_API void THCudaTensor_cross(THCState *state, THCudaTensor *self, THCudaTensor *src1, THCudaTensor *src2, int dimension); -THC_API void THCudaTensor_clamp(THCState *state, THCudaTensor *self, THCudaTensor *src, float min_value, float max_value); // MAGMA (i.e. CUDA implementation of LAPACK functions) THC_API void THCudaTensor_gesv(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index 84a5a1c..64b6af3 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -68,39 +68,6 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx, THCudaCheck(cudaGetLastError()); } -struct TensorClampOp { - TensorClampOp(float min, float max) : minValue(min), maxValue(max) {} - __device__ __forceinline__ void operator()(float* out, float* in) { - *out = max(min(*in, maxValue), minValue); - } - - __device__ __forceinline__ void operator()(float* v) { - *v = max(min(*v, maxValue), minValue); - } - - const float minValue; - const float maxValue; -}; - -void THCudaTensor_clamp(THCState *state, THCudaTensor *self_, THCudaTensor *src, float min_value, - float max_value) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self_, src)); - if (self_ == src) { - if (!THC_pointwiseApply1(state, self_, TensorClampOp(min_value, max_value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCudaTensor_resizeAs(state, self_, src); - - if (!THC_pointwiseApply2(state, self_, src, TensorClampOp(min_value, max_value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - THCudaCheck(cudaGetLastError()); -} - struct TensorLerpOp { TensorLerpOp(float w) : w(w) {} diff --git a/lib/THC/THCTensorMathPointwise.cuh b/lib/THC/THCTensorMathPointwise.cuh index e378a83..a690c45 100644 --- a/lib/THC/THCTensorMathPointwise.cuh +++ b/lib/THC/THCTensorMathPointwise.cuh @@ -397,4 +397,22 @@ struct TensorDivOp<half> { }; #endif // CUDA_HALF_TENSOR +template <typename T> +struct TensorClampOp { + TensorClampOp(T min, T max) : minValue(min), maxValue(max) {} + __device__ __forceinline__ void operator()(T* out, T* in) { + T val = THCNumerics<T>::lt(*in, maxValue) ? *in : maxValue; + *out = THCNumerics<T>::gt(minValue, val) ? minValue : val; + } + + __device__ __forceinline__ void operator()(T* v) { + T val = THCNumerics<T>::lt(*v, maxValue) ? *v : maxValue; + *v = THCNumerics<T>::gt(minValue, val) ? minValue : val; + } + + const T minValue; + const T maxValue; +}; + + #endif // THC_TENSORMATH_POINTWISE_CUH diff --git a/lib/THC/generic/THCTensorMathPointwise.cu b/lib/THC/generic/THCTensorMathPointwise.cu index b2f8950..707fc93 100644 --- a/lib/THC/generic/THCTensorMathPointwise.cu +++ b/lib/THC/generic/THCTensorMathPointwise.cu @@ -82,6 +82,25 @@ void THCTensor_(sign)(THCState* state, THCTensor* self_, THCTensor* src) { THCudaCheck(cudaGetLastError()); } +void THCTensor_(clamp)(THCState *state, THCTensor *self_, THCTensor *src, real min_value, + real max_value) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self_, src)); + if (self_ == src) { + if (!THC_pointwiseApply1(state, self_, TensorClampOp<real>(min_value, max_value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCTensor_(resizeAs)(state, self_, src); + + if (!THC_pointwiseApply2(state, self_, src, TensorClampOp<real>(min_value, max_value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} + #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) void THCTensor_(sigmoid)(THCState* state, THCTensor* self_, THCTensor* src) { diff --git a/lib/THC/generic/THCTensorMathPointwise.h b/lib/THC/generic/THCTensorMathPointwise.h index af50278..0c03045 100644 --- a/lib/THC/generic/THCTensorMathPointwise.h +++ b/lib/THC/generic/THCTensorMathPointwise.h @@ -36,6 +36,7 @@ THC_API void THCTensor_(cinv)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(abs)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(sign)(THCState *state, THCTensor *self, THCTensor *src); +THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, real min_value, real max_value); THC_API void THCTensor_(cadd)(THCState *state, THCTensor *self, THCTensor *src1, real value, THCTensor *src2); THC_API void THCTensor_(csub)(THCState *state, THCTensor *self, THCTensor *src1, real value, THCTensor *src2); diff --git a/test/test.lua b/test/test.lua index b1e6481..2a88268 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1372,7 +1372,12 @@ function test.clamp1() if sz2 >= 2 then x[1][2] = max_val + 1 end - compareFloatAndCudaTensorArgs(x, 'clamp', min_val, max_val) + for _, typename in ipairs(typenames) do + if typename ~= 'torch.CudaCharTensor' and typename ~= 'torch.CudaByteTensor' then + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'clamp', min_val, max_val); + end + end checkMultiDevice(x, 'clamp', min_val, max_val) end @@ -1387,10 +1392,53 @@ function test.clamp2() x[1][2] = max_val + 1 end local y = torch.FloatTensor():resizeAs(x) - compareFloatAndCudaTensorArgs(y, 'clamp', x, min_val, max_val) + for _, typename in ipairs(typenames) do + if typename ~= 'torch.CudaCharTensor' and typename ~= 'torch.CudaByteTensor' then + local x = x:type(t2cpu[typename]) + local y = y:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, y, 'clamp', x, min_val, max_val); + end + end checkMultiDevice(y, 'clamp', x, min_val, max_val) end +-- same as clamp1, clamp2 but only allow positive values +function test.clamp3() + local sz1 = chooseInt(minsize, maxsize) + local sz2 = chooseInt(minsize, maxsize) + local x = torch.FloatTensor():rand(sz1, sz2):mul(5); + local min_val = 1 + local max_val = 3 + x[1][1] = min_val - 1 + if sz2 >= 2 then + x[1][2] = max_val + 1 + end + for _, typename in ipairs(typenames) do + local x = x:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'clamp', min_val, max_val); + end + checkMultiDevice(x, 'clamp', min_val, max_val) +end + +function test.clamp4() + local sz1 = chooseInt(minsize, maxsize) + local sz2 = chooseInt(minsize, maxsize) + local x = torch.FloatTensor():rand(sz1, sz2):mul(5); + local min_val = 1 + local max_val = 3 + x[1][1] = min_val - 1 + if sz2 >= 2 then + x[1][2] = max_val + 1 + end + local y = torch.FloatTensor():resizeAs(x) + for _, typename in ipairs(typenames) do + local x = x:type(t2cpu[typename]) + local y = y:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, y, 'clamp', x, min_val, max_val); + end + checkMultiDevice(x, 'clamp', min_val, max_val) +end + function test.index() local sz1 = chooseInt(minsize, maxsize) local sz2 = chooseInt(minsize, maxsize) |