Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-11-01 07:26:11 +0300
committersoumith <soumith@fb.com>2016-11-01 07:26:19 +0300
commitae0973f376218d856d5474c1a8b8ef021e9a497a (patch)
tree9aea4821f844bd221ec255a0e99fe29cc0ae5230
parente79ee3e57ad8348ea23c0d749d853f4a8c850626 (diff)
adding multiple types for distdistfix
-rw-r--r--TensorMath.lua8
-rw-r--r--lib/THC/THCTensorMath.h2
-rw-r--r--lib/THC/THCTensorMath2.cu30
-rw-r--r--lib/THC/THCTensorMathReduce.cuh34
-rw-r--r--lib/THC/THCTensorMathScan.cu15
-rw-r--r--lib/THC/generic/THCTensorMathReduce.cu24
-rw-r--r--lib/THC/generic/THCTensorMathReduce.h3
-rw-r--r--test/test.lua7
8 files changed, 70 insertions, 53 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index cfc39ce..802565e 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -944,6 +944,14 @@ for k, Tensor_ in pairs(handledTypenames) do
{name="index"},
{name=real}})
+ wrap("dist",
+ cname("dist"),
+ {{name=Tensor},
+ {name=Tensor},
+ {name=real, default=2},
+ {name=accreal, creturned=true}})
+
+
for _,name in ipairs({"var", "std"}) do
wrap(name,
cname(name .. "all"),
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 86c74b3..759c9a3 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -53,8 +53,6 @@ THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor
THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b);
THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a);
-THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value);
-
THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size);
THC_API void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size);
diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu
index 2b80977..9933b7e 100644
--- a/lib/THC/THCTensorMath2.cu
+++ b/lib/THC/THCTensorMath2.cu
@@ -8,14 +8,6 @@
#include "THCTensorMathReduce.cuh"
#include "THCTensorMathPointwise.cuh"
-#include <thrust/device_ptr.h>
-#include <thrust/transform_reduce.h>
-#include <thrust/functional.h>
-#include <thrust/inner_product.h>
-#if CUDA_VERSION >= 7000
-#include <thrust/system/cuda/execution_policy.h>
-#endif
-
struct TensorATan2Op {
__device__ __forceinline__ void operator()(float* out, float* a, float* b) {
*out = atan2f(*a, *b);
@@ -36,28 +28,6 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx,
THCudaCheck(cudaGetLastError());
}
-float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value)
-{
- THAssert(THCudaTensor_checkGPU(state, 2, self, src));
- self = THCudaTensor_newContiguous(state, self);
- ptrdiff_t size = THCudaTensor_nElement(state, self);
- src = THCudaTensor_newContiguous(state, src);
- thrust::device_ptr<float> self_data(THCudaTensor_data(state, self));
- thrust::device_ptr<float> src_data(THCudaTensor_data(state, src));
-
- float result = thrust::inner_product(
-#if CUDA_VERSION >= 7000
- thrust::cuda::par.on(THCState_getCurrentStream(state)),
-#endif
- self_data, self_data+size, src_data, (float) 0,
- thrust::plus<float>(), TensorDistOp<float>(value));
-
- THCudaTensor_free(state, src);
- THCudaTensor_free(state, self);
-
- return pow(result, (float)1.0/value);
-}
-
void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size)
{
THAssert(THCudaTensor_checkGPU(state, 1, r_));
diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh
index db2e424..77f06ab 100644
--- a/lib/THC/THCTensorMathReduce.cuh
+++ b/lib/THC/THCTensorMathReduce.cuh
@@ -7,6 +7,12 @@
#include "THCReduce.cuh"
#include "THCReduceAll.cuh"
#include <thrust/functional.h>
+#include <thrust/device_ptr.h>
+#include <thrust/transform_reduce.h>
+#include <thrust/inner_product.h>
+#if CUDA_VERSION >= 7000
+#include <thrust/system/cuda/execution_policy.h>
+#endif
// Reduction operators that support `half`, unlike Thrust
template <typename InT, typename AccT>
@@ -239,19 +245,21 @@ struct TensorNormOp<half, StaticExp>
};
#endif
-template <typename T>
+template <typename Tacc, typename T>
struct TensorDistOp
{
- TensorDistOp(T exp) : exponent(exp) {}
+ TensorDistOp(Tacc exp) : exponent(exp) {}
- __host__ __device__ T operator()(T x, T y) const {
- return THCNumerics<T>::pow(
- THCNumerics<T>::abs(THCNumerics<T>::sub(x, y)),
+ __host__ __device__ Tacc operator()(T x, T y) const {
+ Tacc xr = ScalarConvert<T, Tacc>::to(x);
+ Tacc yr = ScalarConvert<T, Tacc>::to(y);
+ return THCNumerics<Tacc>::pow(
+ THCNumerics<Tacc>::abs(THCNumerics<Tacc>::sub(xr, yr)),
exponent
);
}
- const T exponent;
+ const Tacc exponent;
};
#include <thrust/functional.h>
@@ -664,4 +672,18 @@ struct MinValuePair {
}
};
+template <typename T>
+struct AddOp {
+ __device__ __forceinline__ T operator()(T &lhs, T &rhs) {
+ return THCNumerics<T>::add(lhs, rhs);
+ }
+};
+
+template <typename T>
+struct MulOp {
+ __device__ __forceinline__ T operator()(T &lhs, T &rhs) {
+ return THCNumerics<T>::mul(lhs, rhs);
+ }
+};
+
#endif // THC_TENSORMATH_REDUCE_CUH
diff --git a/lib/THC/THCTensorMathScan.cu b/lib/THC/THCTensorMathScan.cu
index ee532bf..3345e25 100644
--- a/lib/THC/THCTensorMathScan.cu
+++ b/lib/THC/THCTensorMathScan.cu
@@ -5,20 +5,7 @@
#include "THCApply.cuh"
#include "THCReduce.cuh"
#include "THCNumerics.cuh"
-
-template <typename T>
-struct AddOp {
- __device__ __forceinline__ T operator()(T &lhs, T &rhs) {
- return THCNumerics<T>::add(lhs, rhs);
- }
-};
-
-template <typename T>
-struct MulOp {
- __device__ __forceinline__ T operator()(T &lhs, T &rhs) {
- return THCNumerics<T>::mul(lhs, rhs);
- }
-};
+#include "THCTensorMathReduce.cuh"
/* Perform an inclusive scan along an outer dimension of a tensor.
*
diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu
index 1e21d03..a8184b7 100644
--- a/lib/THC/generic/THCTensorMathReduce.cu
+++ b/lib/THC/generic/THCTensorMathReduce.cu
@@ -219,6 +219,30 @@ THCTensor_(normall)(THCState *state, THCTensor *self, real value)
return result;
}
+accreal THCTensor_(dist)(THCState *state, THCTensor *self,
+ THCTensor *src, real value)
+{
+ THAssert(THCTensor_(checkGPU)(state, 2, self, src));
+ self = THCTensor_(newContiguous)(state, self);
+ ptrdiff_t size = THCTensor_(nElement)(state, self);
+ src = THCTensor_(newContiguous)(state, src);
+ thrust::device_ptr<real> self_data(THCTensor_(data)(state, self));
+ thrust::device_ptr<real> src_data(THCTensor_(data)(state, src));
+
+ accreal result = thrust::inner_product(
+#if CUDA_VERSION >= 7000
+ thrust::cuda::par.on(THCState_getCurrentStream(state)),
+#endif
+ self_data, self_data+size, src_data, ScalarConvert<int, accreal>::to(0),
+ thrust::plus<accreal>(),
+ TensorDistOp<accreal, real>(ScalarConvert<real, accreal>::to(value)));
+
+ THCTensor_(free)(state, src);
+ THCTensor_(free)(state, self);
+
+ return THCNumerics<accreal>::pow(result, 1.0 / ScalarConvert<real, accreal>::to(value));
+}
+
#endif
THC_API accreal
diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h
index 09a26fc..dc38ed6 100644
--- a/lib/THC/generic/THCTensorMathReduce.h
+++ b/lib/THC/generic/THCTensorMathReduce.h
@@ -35,4 +35,7 @@ THC_API void THCTensor_(max)(THCState *state,
THC_API real THCTensor_(minall)(THCState *state, THCTensor *self);
THC_API real THCTensor_(maxall)(THCState *state, THCTensor *self);
+THC_API accreal THCTensor_(dist)(THCState *state, THCTensor *self, THCTensor *src,
+ real value);
+
#endif
diff --git a/test/test.lua b/test/test.lua
index 882168e..00cfa66 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1877,7 +1877,12 @@ function test.dist()
local sz2 = chooseInt(minsize, maxsize)
local x = torch.FloatTensor():rand(sz1, sz2)
local y = torch.FloatTensor():rand(sz1, sz2)
- compareFloatAndCudaTensorArgs(x, 'dist', y)
+ for _, typename in ipairs(float_typenames) do
+ local x = x:type(t2cpu[typename])
+ local y = y:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'dist', y)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'dist', y, 3)
+ end
checkMultiDevice(x, 'dist', y)
end