Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-10-03 23:22:01 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-10-07 21:50:28 +0300
commit167fb1025c65e25c25e78ee8017d52f1f29d349c (patch)
tree4ea60842a4803952349279a9450ca56eaef06ffb
parent61f8e132a92e0f935cfa4f1eb4a7575f77792702 (diff)
[cutorch refactor] fixes for norm, wrap/test
-rw-r--r--TensorMath.lua10
-rw-r--r--lib/THC/THCNumerics.cuh2
-rw-r--r--lib/THC/generic/THCTensorMathReduce.cu9
-rw-r--r--test/test.lua10
4 files changed, 24 insertions, 7 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 0bbf8d7..bfd1c06 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -797,6 +797,16 @@ for k, Tensor_ in pairs(handledTypenames) do
end
+ wrap("norm",
+ cname("normall"),
+ {{name=Tensor},
+ {name=real, default=2},
+ {name=accreal, creturned=true}},
+ cname("norm"),
+ {{name=Tensor, default=true, returned=true},
+ {name=Tensor},
+ {name=real},
+ {name="index"}})
wrap("renorm",
cname("renorm"),
diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh
index c718ba5..81ed0d1 100644
--- a/lib/THC/THCNumerics.cuh
+++ b/lib/THC/THCNumerics.cuh
@@ -571,7 +571,7 @@ struct THCNumerics<double> {
static inline __host__ __device__ double div (double a, double b) { return a / b; }
static inline __host__ __device__ double mul (double a, double b) { return a * b; }
static inline __host__ __device__ double sub (double a, double b) { return a - b; }
- static inline __host__ __device__ double pow (double a, double b) { return pow(a, b); }
+ static inline __host__ __device__ double pow (double a, double b) { return ::pow(a, b); }
};
/// `half` has some type conversion issues associated with it, since it
diff --git a/lib/THC/generic/THCTensorMathReduce.cu b/lib/THC/generic/THCTensorMathReduce.cu
index a9e58a9..1b86027 100644
--- a/lib/THC/generic/THCTensorMathReduce.cu
+++ b/lib/THC/generic/THCTensorMathReduce.cu
@@ -267,7 +267,7 @@ THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src,
THCudaCheck(cudaGetLastError());
}
-accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value)
+THC_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value)
{
THAssert(THCTensor_(checkGPU)(state, 1, self));
accreal result;
@@ -293,7 +293,7 @@ accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value)
ReduceAdd<accreal, accreal>(),
ScalarConvert<float, accreal>::to(0.0f),
&result, 0);
- result = THCNumerics<accreal>::pow(result, ScalarConvert<float, accreal>::to(0.5f));
+ result = THCNumerics<accreal>::sqrt(result);
} else {
THC_reduceAll(state, self,
TensorNormOp<real, -1>(value),
@@ -303,10 +303,7 @@ accreal THCTensor_(normall)(THCState *state, THCTensor *self, real value)
&result, 0);
result = THCNumerics<accreal>::pow(
result,
- THCNumerics<accreal>::div(
- ScalarConvert<float, accreal>::to(1.0f),
- ScalarConvert<real, accreal>::to(value)
- )
+ ScalarConvert<real, accreal>::to(THCNumerics<real>::cinv(value))
);
}
diff --git a/test/test.lua b/test/test.lua
index ed1022c..ddf67e4 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1647,6 +1647,16 @@ function test.indexFill()
end
function test.norm()
+ for n = 0, 3 do
+ local cpu = torch.FloatTensor(chooseInt(20, 50), 2):uniform(-0.5, 0.5)
+ for _, typename in ipairs(float_typenames) do
+ local x = cpu:type(t2cpu[typename])
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'norm', n)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'norm', n, 1)
+ compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'norm', n, 2)
+ end
+ end
+
for i = 1, 5 do
for n = 0, 3 do
local cpu = torch.FloatTensor(chooseInt(20, 50), 2):uniform(-0.5, 0.5)