diff options
author | soumith <soumith@fb.com> | 2016-11-01 07:26:11 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-11-01 07:26:19 +0300 |
commit | ae0973f376218d856d5474c1a8b8ef021e9a497a (patch) | |
tree | 9aea4821f844bd221ec255a0e99fe29cc0ae5230 | |
parent | e79ee3e57ad8348ea23c0d749d853f4a8c850626 (diff) |
adding multiple types for distdistfix
-rw-r--r-- | TensorMath.lua | 8 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 2 | ||||
-rw-r--r-- | lib/THC/THCTensorMath2.cu | 30 | ||||
-rw-r--r-- | lib/THC/THCTensorMathReduce.cuh | 34 | ||||
-rw-r--r-- | lib/THC/THCTensorMathScan.cu | 15 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 24 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.h | 3 | ||||
-rw-r--r-- | test/test.lua | 7 |
8 files changed, 70 insertions, 53 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index cfc39ce..802565e 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -944,6 +944,14 @@ for k, Tensor_ in pairs(handledTypenames) do {name="index"}, {name=real}}) + wrap("dist", + cname("dist"), + {{name=Tensor}, + {name=Tensor}, + {name=real, default=2}, + {name=accreal, creturned=true}}) + + for _,name in ipairs({"var", "std"}) do wrap(name, cname(name .. "all"), diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 86c74b3..759c9a3 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -53,8 +53,6 @@ THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b); THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a); -THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value); - THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size); THC_API void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size); diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu index 2b80977..9933b7e 100644 --- a/lib/THC/THCTensorMath2.cu +++ b/lib/THC/THCTensorMath2.cu @@ -8,14 +8,6 @@ #include "THCTensorMathReduce.cuh" #include "THCTensorMathPointwise.cuh" -#include <thrust/device_ptr.h> -#include <thrust/transform_reduce.h> -#include <thrust/functional.h> -#include <thrust/inner_product.h> -#if CUDA_VERSION >= 7000 -#include <thrust/system/cuda/execution_policy.h> -#endif - struct TensorATan2Op { __device__ __forceinline__ void operator()(float* out, float* a, float* b) { *out = atan2f(*a, *b); @@ -36,28 +28,6 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx, THCudaCheck(cudaGetLastError()); } -float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value) -{ - THAssert(THCudaTensor_checkGPU(state, 2, self, src)); - self = THCudaTensor_newContiguous(state, self); - ptrdiff_t size = THCudaTensor_nElement(state, self); - src = THCudaTensor_newContiguous(state, src); - thrust::device_ptr<float> self_data(THCudaTensor_data(state, self)); - thrust::device_ptr<float> src_data(THCudaTensor_data(state, src)); - - float result = thrust::inner_product( -#if CUDA_VERSION >= 7000 - thrust::cuda::par.on(THCState_getCurrentStream(state)), -#endif - self_data, self_data+size, src_data, (float) 0, - thrust::plus<float>(), TensorDistOp<float>(value)); - - THCudaTensor_free(state, src); - THCudaTensor_free(state, self); - - return pow(result, (float)1.0/value); -} - void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size) { THAssert(THCudaTensor_checkGPU(state, 1, r_)); diff --git a/lib/THC/THCTensorMathReduce.cuh b/lib/THC/THCTensorMathReduce.cuh index db2e424..77f06ab 100644 --- a/lib/THC/THCTensorMathReduce.cuh +++ b/lib/THC/THCTensorMathReduce.cuh @@ -7,6 +7,12 @@ #include "THCReduce.cuh" #include "THCReduceAll.cuh" #include <thrust/functional.h> +#include <thrust/device_ptr.h> +#include <thrust/transform_reduce.h> +#include <thrust/inner_product.h> +#if CUDA_VERSION >= 7000 +#include <thrust/system/cuda/execution_policy.h> +#endif // Reduction operators that support `half`, unlike Thrust template <typename InT, typename AccT> @@ -239,19 +245,21 @@ struct TensorNormOp<half, StaticExp> }; #endif -template <typename T> +template <typename Tacc, typename T> struct TensorDistOp { - TensorDistOp(T exp) : exponent(exp) {} + TensorDistOp(Tacc exp) : exponent(exp) {} - __host__ __device__ T operator()(T x, T y) const { - return THCNumerics<T>::pow( - THCNumerics<T>::abs(THCNumerics<T>::sub(x, y)), + __host__ __device__ Tacc operator()(T x, T y) const { + Tacc xr = ScalarConvert<T, Tacc>::to(x); + Tacc yr = ScalarConvert<T, Tacc>::to(y); + return THCNumerics<Tacc>::pow( + THCNumerics<Tacc>::abs(THCNumerics<Tacc>::sub(xr, yr)), exponent ); } - const T exponent; + const Tacc exponent; }; #include <thrust/functional.h> @@ -664,4 +672,18 @@ struct MinValuePair { } }; +template <typename T> +struct AddOp { + __device__ __forceinline__ T operator()(T &lhs, T &rhs) { + return THCNumerics<T>::add(lhs, rhs); + } +}; + +template <typename T> +struct MulOp { + __device__ __forceinline__ T operator()(T &lhs, T &rhs) { + return THCNumerics<T>::mul(lhs, rhs); + } +}; + #endif // THC_TENSORMATH_REDUCE_CUH diff --git a/lib/THC/THCTensorMathScan.cu b/lib/THC/THCTensorMathScan.cu index ee532bf..3345e25 100644 --- a/lib/THC/THCTensorMathScan.cu +++ b/lib/THC/THCTensorMathScan.cu @@ -5,20 +5,7 @@ #include "THCApply.cuh" #include "THCReduce.cuh" #include "THCNumerics.cuh" - -template <typename T> -struct AddOp { - __device__ __forceinline__ T operator()(T &lhs, T &rhs) { - return THCNumerics<T>::add(lhs, rhs); - } -}; - -template <typename T> -struct MulOp { - __device__ __forceinline__ T operator()(T &lhs, T &rhs) { - return THCNumerics<T>::mul(lhs, rhs); - } -}; +#include "THCTensorMathReduce.cuh" /* Perform an inclusive scan along an outer dimension of a tensor. * diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index 1e21d03..a8184b7 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -219,6 +219,30 @@ THCTensor_(normall)(THCState *state, THCTensor *self, real value) return result; } +accreal THCTensor_(dist)(THCState *state, THCTensor *self, + THCTensor *src, real value) +{ + THAssert(THCTensor_(checkGPU)(state, 2, self, src)); + self = THCTensor_(newContiguous)(state, self); + ptrdiff_t size = THCTensor_(nElement)(state, self); + src = THCTensor_(newContiguous)(state, src); + thrust::device_ptr<real> self_data(THCTensor_(data)(state, self)); + thrust::device_ptr<real> src_data(THCTensor_(data)(state, src)); + + accreal result = thrust::inner_product( +#if CUDA_VERSION >= 7000 + thrust::cuda::par.on(THCState_getCurrentStream(state)), +#endif + self_data, self_data+size, src_data, ScalarConvert<int, accreal>::to(0), + thrust::plus<accreal>(), + TensorDistOp<accreal, real>(ScalarConvert<real, accreal>::to(value))); + + THCTensor_(free)(state, src); + THCTensor_(free)(state, self); + + return THCNumerics<accreal>::pow(result, 1.0 / ScalarConvert<real, accreal>::to(value)); +} + #endif THC_API accreal diff --git a/lib/THC/generic/THCTensorMathReduce.h b/lib/THC/generic/THCTensorMathReduce.h index 09a26fc..dc38ed6 100644 --- a/lib/THC/generic/THCTensorMathReduce.h +++ b/lib/THC/generic/THCTensorMathReduce.h @@ -35,4 +35,7 @@ THC_API void THCTensor_(max)(THCState *state, THC_API real THCTensor_(minall)(THCState *state, THCTensor *self); THC_API real THCTensor_(maxall)(THCState *state, THCTensor *self); +THC_API accreal THCTensor_(dist)(THCState *state, THCTensor *self, THCTensor *src, + real value); + #endif diff --git a/test/test.lua b/test/test.lua index 882168e..00cfa66 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1877,7 +1877,12 @@ function test.dist() local sz2 = chooseInt(minsize, maxsize) local x = torch.FloatTensor():rand(sz1, sz2) local y = torch.FloatTensor():rand(sz1, sz2) - compareFloatAndCudaTensorArgs(x, 'dist', y) + for _, typename in ipairs(float_typenames) do + local x = x:type(t2cpu[typename]) + local y = y:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'dist', y) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'dist', y, 3) + end checkMultiDevice(x, 'dist', y) end |