Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorGeorgOstrovski <ostrovski@google.com>2014-10-14 17:17:30 +0400
committerSoumith Chintala <soumith@gmail.com>2014-10-27 17:05:17 +0300
commit200a10387b21a215ee6ee7f438e4b365cebbd898 (patch)
tree2040dc79fbd18be5804a0faf49d5121c680aeaa7 /test
parente85b059511fce2f7b846118bf6521e4abe30c08f (diff)
Relax equality tests from exact precision 0.
Expecting exact precision 0 can lead to failing tests if underlying implementation changes, a high expected precision of say 1e-15 seem more appropriate (nntest.Power sometimes fails because of this)
Diffstat (limited to 'test')
-rw-r--r--test/test.lua48
1 files changed, 24 insertions, 24 deletions
diff --git a/test/test.lua b/test/test.lua
index aebc755..022ee8a 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -218,7 +218,7 @@ function nntest.Power()
local module = nn.Power(2)
local out = module:forward(in1)
local err = out:dist(in1:cmul(in1))
- mytester:asserteq(err, 0, torch.typename(module) .. ' - forward err ')
+ mytester:assertlt(err, 1e-15, torch.typename(module) .. ' - forward err ')
local ini = math.random(3,5)
local inj = math.random(3,5)
@@ -241,7 +241,7 @@ function nntest.Square()
local module = nn.Square()
local out = module:forward(in1)
local err = out:dist(in1:cmul(in1))
- mytester:asserteq(err, 0, torch.typename(module) .. ' - forward err ')
+ mytester:assertlt(err, 1e-15, torch.typename(module) .. ' - forward err ')
local ini = math.random(3,5)
local inj = math.random(3,5)
@@ -263,7 +263,7 @@ function nntest.Sqrt()
local module = nn.Sqrt()
local out = module:forward(in1)
local err = out:dist(in1:sqrt())
- mytester:asserteq(err, 0, torch.typename(module) .. ' - forward err ')
+ mytester:assertlt(err, 1e-15, torch.typename(module) .. ' - forward err ')
local ini = math.random(3,5)
local inj = math.random(3,5)
@@ -1060,16 +1060,16 @@ function nntest.SpatialFullConvolution()
end
function nntest.SpatialFullConvolutionMap()
- local from = math.ceil(torch.uniform(2,4))
- local to = math.ceil(torch.uniform(2,5))
- local fanin = math.ceil(torch.uniform(1, from))
+ local from = math.random(2,4)
+ local to = math.random(2,5)
+ local fanin = math.random(1, from)
local tt = nn.tables.random(from, to, fanin)
- local ki = math.ceil(torch.uniform(2,5))
- local kj = math.ceil(torch.uniform(2,5))
- local si = math.ceil(torch.uniform(1,3))
- local sj = math.ceil(torch.uniform(1,3))
- local ini = math.ceil(torch.uniform(5,7))
- local inj = math.ceil(torch.uniform(5,7))
+ local ki = math.random(2,5)
+ local kj = math.random(2,5)
+ local si = math.random(1,3)
+ local sj = math.random(1,3)
+ local ini = math.random(5,7)
+ local inj = math.random(5,7)
local module = nn.SpatialFullConvolutionMap(tt, ki, kj, si, sj)
local input = torch.Tensor(from, inj, ini):zero()
@@ -1105,15 +1105,15 @@ function nntest.SpatialFullConvolutionMap()
end
function nntest.SpatialFullConvolutionCompare()
- local from = math.ceil(torch.uniform(2,4))
- local to = math.ceil(torch.uniform(2,5))
+ local from = math.random(2,4)
+ local to = math.random(2,5)
local tt = nn.tables.full(from, to)
- local ki = math.ceil(torch.uniform(2,5))
- local kj = math.ceil(torch.uniform(2,5))
- local si = math.ceil(torch.uniform(1,3))
- local sj = math.ceil(torch.uniform(1,3))
- local ini = math.ceil(torch.uniform(7,8))
- local inj = math.ceil(torch.uniform(7,8))
+ local ki = math.random(2,5)
+ local kj = math.random(2,5)
+ local si = math.random(1,3)
+ local sj = math.random(1,3)
+ local ini = math.random(7,8)
+ local inj = math.random(7,8)
local module1 = nn.SpatialFullConvolutionMap(tt, ki, kj, si, sj)
local module2 = nn.SpatialFullConvolution(from, to, ki, kj, si, sj)
local input = torch.rand(from, inj, ini)
@@ -1307,12 +1307,12 @@ end
function nntest.SpatialMaxPooling()
local from = math.random(1,5)
- local ki = math.random(1,5)
- local kj = math.random(1,5)
+ local ki = math.random(1,4)
+ local kj = math.random(1,4)
local si = math.random(1,3)
local sj = math.random(1,3)
- local outi = math.random(2,4)
- local outj = math.random(2,4)
+ local outi = math.random(4,5)
+ local outj = math.random(4,5)
local ini = (outi-1)*si+ki
local inj = (outj-1)*sj+kj