diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-10-03 23:22:01 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-10-07 21:50:28 +0300 |
commit | 167fb1025c65e25c25e78ee8017d52f1f29d349c (patch) | |
tree | 4ea60842a4803952349279a9450ca56eaef06ffb | |
parent | 61f8e132a92e0f935cfa4f1eb4a7575f77792702 (diff) |
[cutorch refactor] fixes for norm, wrap/test
-rw-r--r-- | TensorMath.lua | 10 | ||||
-rw-r--r-- | lib/THC/THCNumerics.cuh | 2 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathReduce.cu | 9 | ||||
-rw-r--r-- | test/test.lua | 10 |
4 files changed, 24 insertions, 7 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 0bbf8d7..bfd1c06 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -797,6 +797,16 @@ for k, Tensor_ in pairs(handledTypenames) do end + wrap("norm", + cname("normall"), + {{name=Tensor}, + {name=real, default=2}, + {name=accreal, creturned=true}}, + cname("norm"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=real}, + {name="index"}}) wrap("renorm", cname("renorm"), diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index c718ba5..81ed0d1 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -571,7 +571,7 @@ struct THCNumerics<double> { static inline __host__ __device__ double div (double a, double b) { return a / b; } static inline __host__ __device__ double mul (double a, double b) { return a * b; } static inline __host__ __device__ double sub (double a, double b) { return a - b; } - static inline __host__ __device__ double pow (double a, double b) { return pow(a, b); } + static inline __host__ __device__ double pow (double a, double b) { return ::pow(a, b); } }; /// `half` has some type conversion issues associated with it, since it diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu index a9e58a9..1b86027 100644 --- a/lib/THC/generic/THCTensorMathReduce.cu +++ b/lib/THC/generic/THCTensorMathReduce.cu @@ -267,7 +267,7 @@ THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, THCudaCheck(cudaGetLastError()); } -accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value) +THC_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value) { THAssert(THCTensor_(checkGPU)(state, 1, self)); accreal result; @@ -293,7 +293,7 @@ accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value) ReduceAdd<accreal, accreal>(), ScalarConvert<float, accreal>::to(0.0f), &result, 0); - result = THCNumerics<accreal>::pow(result, ScalarConvert<float, accreal>::to(0.5f)); + result = THCNumerics<accreal>::sqrt(result); } else { THC_reduceAll(state, self, TensorNormOp<real, -1>(value), @@ -303,10 +303,7 @@ accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value) &result, 0); result = THCNumerics<accreal>::pow( result, - THCNumerics<accreal>::div( - ScalarConvert<float, accreal>::to(1.0f), - ScalarConvert<real, accreal>::to(value) - ) + ScalarConvert<real, accreal>::to(THCNumerics<real>::cinv(value)) ); } diff --git a/test/test.lua b/test/test.lua index ed1022c..ddf67e4 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1647,6 +1647,16 @@ function test.indexFill() end function test.norm() + for n = 0, 3 do + local cpu = torch.FloatTensor(chooseInt(20, 50), 2):uniform(-0.5, 0.5) + for _, typename in ipairs(float_typenames) do + local x = cpu:type(t2cpu[typename]) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'norm', n) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'norm', n, 1) + compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'norm', n, 2) + end + end + for i = 1, 5 do for n = 0, 3 do local cpu = torch.FloatTensor(chooseInt(20, 50), 2):uniform(-0.5, 0.5) |