diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-10-06 18:45:38 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:28 +0300 |
commit | 4f67f808afcfae17df066bac67ff0d457e52b813 (patch) | |
tree | 5dad472ad78eedca0d60d0bb44dd262e8956f1ee | |
parent | 29e5059c3980bb2a905a7f7eacc4add1123b1b93 (diff) |
[cutorch refactor] make dist(...)'s op generic, add missing unit test
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 14 | ||||
-rw-r--r-- | lib/THC/THCTensorMathReduce.cuh | 15 | ||||
-rw-r--r-- | test/test.lua | 11 |
3 files changed, 27 insertions, 13 deletions
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index e0e3255..afd262d 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -68,18 +68,6 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx, THCudaCheck(cudaGetLastError()); } -struct dist_functor -{ - const float exponent; - - dist_functor(float exponent_) : exponent(exponent_) {} - - __host__ __device__ float operator()(const float& x, const float& y) const - { - return pow(fabs(x-y), exponent); - } -}; - float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value) { THAssert(THCudaTensor_checkGPU(state, 2, self, src)); @@ -94,7 +82,7 @@ float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, thrust::cuda::par.on(THCState_getCurrentStream(state)), #endif self_data, self_data+size, src_data, (float) 0, - thrust::plus<float>(), dist_functor(value)); + thrust::plus<float>(), TensorDistOp<float>(value)); THCudaTensor_free(state, src); THCudaTensor_free(state, self); diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh index 4b66ac3..8e368be 100644 --- a/lib/THC/THCTensorMathReduce.cuh +++ b/lib/THC/THCTensorMathReduce.cuh @@ -239,6 +239,21 @@ struct TensorNormOp<half, StaticExp> }; #endif +template <typename T> +struct TensorDistOp +{ + TensorDistOp(T exp) : exponent(exp) {} + + __host__ __device__ T operator()(T x, T y) const { + return THCNumerics<T>::pow( + THCNumerics<T>::abs(THCNumerics<T>::sub(x, y)), + exponent + ); + } + + const T exponent; +}; + #include <thrust/functional.h> // Given the sum of values and the sum of squares, compute the variance or standard deviation. diff --git a/test/test.lua b/test/test.lua index f26effd..64884f0 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1754,6 +1754,17 @@ function test.renorm() checkMultiDevice(x, 'renorm', 4, 2, maxnorm) end +function test.dist() + local minsize = 5 + local maxsize = 10 + local sz1 = chooseInt(minsize, maxsize) + local sz2 = chooseInt(minsize, maxsize) + local x = torch.FloatTensor():rand(sz1, sz2) + local y = torch.FloatTensor():rand(sz1, sz2) + compareFloatAndCudaTensorArgs(x, 'dist', y) + checkMultiDevice(x, 'dist', y) +end + function test.indexCopy2() for tries = 1, 5 do local t = createTestTensor(1000000) |