Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-10-06 18:45:38 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-10-07 21:50:28 +0300
commit4f67f808afcfae17df066bac67ff0d457e52b813 (patch)
tree5dad472ad78eedca0d60d0bb44dd262e8956f1ee
parent29e5059c3980bb2a905a7f7eacc4add1123b1b93 (diff)
[cutorch refactor] make dist(...)'s op generic, add missing unit test
-rw-r--r--lib/THC/THCTensorMath2.cu14
-rw-r--r--lib/THC/THCTensorMathReduce.cuh15
-rw-r--r--test/test.lua11
3 files changed, 27 insertions, 13 deletions
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu
index e0e3255..afd262d 100644
--- a/lib/THC/THCTensorMath2.cu
+++ b/lib/THC/THCTensorMath2.cu
@@ -68,18 +68,6 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx,
THCudaCheck(cudaGetLastError());
}
-struct dist_functor
-{
- const float exponent;
-
- dist_functor(float exponent_) : exponent(exponent_) {}
-
- __host__ __device__ float operator()(const float& x, const float& y) const
- {
- return pow(fabs(x-y), exponent);
- }
-};
-
float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value)
{
THAssert(THCudaTensor_checkGPU(state, 2, self, src));
@@ -94,7 +82,7 @@ float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src,
thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, src_data, (float) 0,
- thrust::plus<float>(), dist_functor(value));
+ thrust::plus<float>(), TensorDistOp<float>(value));
THCudaTensor_free(state, src);
THCudaTensor_free(state, self);
diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh
index 4b66ac3..8e368be 100644
--- a/lib/THC/THCTensorMathReduce.cuh
+++ b/lib/THC/THCTensorMathReduce.cuh
@@ -239,6 +239,21 @@ struct TensorNormOp<half, StaticExp>
};
#endif
+template <typename T>
+struct TensorDistOp
+{
+ TensorDistOp(T exp) : exponent(exp) {}
+
+ __host__ __device__ T operator()(T x, T y) const {
+ return THCNumerics<T>::pow(
+ THCNumerics<T>::abs(THCNumerics<T>::sub(x, y)),
+ exponent
+ );
+ }
+
+ const T exponent;
+};
+
#include <thrust/functional.h>
// Given the sum of values and the sum of squares, compute the variance or standard deviation.
diff --git a/test/test.lua b/test/test.lua
index f26effd..64884f0 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1754,6 +1754,17 @@ function test.renorm()
checkMultiDevice(x, 'renorm', 4, 2, maxnorm)
end
+function test.dist()
+ local minsize = 5
+ local maxsize = 10
+ local sz1 = chooseInt(minsize, maxsize)
+ local sz2 = chooseInt(minsize, maxsize)
+ local x = torch.FloatTensor():rand(sz1, sz2)
+ local y = torch.FloatTensor():rand(sz1, sz2)
+ compareFloatAndCudaTensorArgs(x, 'dist', y)
+ checkMultiDevice(x, 'dist', y)
+end
+
function test.indexCopy2()
for tries = 1, 5 do
local t = createTestTensor(1000000)