Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorDavid Saxton <saxton@google.com>2016-02-24 20:54:10 +0300
committerDavid Saxton <saxton@google.com>2016-02-25 19:01:54 +0300
commitd8ff64c5707f716199e7782d630e44e2cf402a54 (patch)
treee3e2bcc1d0872cf7c363b7de12cd91818e3038e3 /test
parent3bbb49f62f2716c952f43115ea8caa450c8785d4 (diff)
Replace torch.Tester with totem.Tester + extra stuff.
This should bring a lot of benefit to code that uses torch.Tester (totem will eventually become deprecated). Note that torch.Tester and totem.Tester once shared the same code - this change brings it full circle. At a glance, extra functionality includes: - A general equality checker that accepts many different objects. - Deep table comparison with precision checking. - Stricter argument checking in using the test functions. - Better output. - torch.Storage comparison. - Extra features for fine-grained control of testing.
Diffstat (limited to 'test')
-rw-r--r--test/test.lua105
-rw-r--r--test/test_Tester.lua626
-rw-r--r--test/test_qr.lua2
-rw-r--r--test/test_sharedmem.lua2
-rw-r--r--test/test_writeObject.lua4
5 files changed, 665 insertions, 74 deletions
diff --git a/test/test.lua b/test/test.lua
index f86fc72..0640b2f 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1,7 +1,7 @@
--require 'torch'
local mytester
-local torchtest = {}
+local torchtest = torch.TestSuite()
local msize = 100
local precision
@@ -725,22 +725,22 @@ function torchtest.addbmm()
local res2 = torch.Tensor():resizeAs(res[1]):zero()
res2:addbmm(b1,b2)
- mytester:assertTensorEq(res2, res:sum(1), precision, 'addbmm result wrong')
+ mytester:assertTensorEq(res2, res:sum(1)[1], precision, 'addbmm result wrong')
res2:addbmm(1,b1,b2)
- mytester:assertTensorEq(res2, res:sum(1)*2, precision, 'addbmm result wrong')
+ mytester:assertTensorEq(res2, res:sum(1)[1]*2, precision, 'addbmm result wrong')
res2:addbmm(1,res2,.5,b1,b2)
- mytester:assertTensorEq(res2, res:sum(1)*2.5, precision, 'addbmm result wrong')
+ mytester:assertTensorEq(res2, res:sum(1)[1]*2.5, precision, 'addbmm result wrong')
local res3 = torch.addbmm(1,res2,0,b1,b2)
mytester:assertTensorEq(res3, res2, precision, 'addbmm result wrong')
local res4 = torch.addbmm(1,res2,.5,b1,b2)
- mytester:assertTensorEq(res4, res:sum(1)*3, precision, 'addbmm result wrong')
+ mytester:assertTensorEq(res4, res:sum(1)[1]*3, precision, 'addbmm result wrong')
local res5 = torch.addbmm(0,res2,1,b1,b2)
- mytester:assertTensorEq(res5, res:sum(1), precision, 'addbmm result wrong')
+ mytester:assertTensorEq(res5, res:sum(1)[1], precision, 'addbmm result wrong')
local res6 = torch.addbmm(.1,res2,.5,b1,b2)
mytester:assertTensorEq(res6, res2*.1 + res:sum(1)*.5, precision, 'addbmm result wrong')
@@ -1510,8 +1510,10 @@ function torchtest.kthvalue()
local mx, ix = torch.kthvalue(x, k)
local mxx, ixx = torch.sort(x)
- mytester:assertTensorEq(mxx:select(3, k), mx, 0, 'torch.kthvalue value')
- mytester:assertTensorEq(ixx:select(3, k), ix, 0, 'torch.kthvalue index')
+ mytester:assertTensorEq(mxx:select(3, k), mx:select(3, 1), 0,
+ 'torch.kthvalue value')
+ mytester:assertTensorEq(ixx:select(3, k), ix:select(3, 1), 0,
+ 'torch.kthvalue index')
end
do -- test use of result tensors
local k = math.random(1, msize)
@@ -1519,15 +1521,19 @@ function torchtest.kthvalue()
local ix = torch.LongTensor()
torch.kthvalue(mx, ix, x, k)
local mxx, ixx = torch.sort(x)
- mytester:assertTensorEq(mxx:select(3, k), mx, 0, 'torch.kthvalue value')
- mytester:assertTensorEq(ixx:select(3, k), ix, 0, 'torch.kthvalue index')
+ mytester:assertTensorEq(mxx:select(3, k), mx:select(3, 1), 0,
+ 'torch.kthvalue value')
+ mytester:assertTensorEq(ixx:select(3, k), ix:select(3, 1), 0,
+ 'torch.kthvalue index')
end
do -- test non-default dim
local k = math.random(1, msize)
local mx, ix = torch.kthvalue(x, k, 1)
local mxx, ixx = torch.sort(x, 1)
- mytester:assertTensorEq(mxx:select(1, k), mx, 0, 'torch.kthvalue value')
- mytester:assertTensorEq(ixx:select(1, k), ix, 0, 'torch.kthvalue index')
+ mytester:assertTensorEq(mxx:select(1, k), mx[1], 0,
+ 'torch.kthvalue value')
+ mytester:assertTensorEq(ixx:select(1, k), ix[1], 0,
+ 'torch.kthvalue index')
end
do -- non-contiguous
local y = x:narrow(2, 1, 1)
@@ -1557,8 +1563,10 @@ function torchtest.median()
local mxx, ixx = torch.sort(x)
local ind = math.floor((msize+1)/2)
- mytester:assertTensorEq(mxx:select(2, ind), mx, 0, 'torch.median value')
- mytester:assertTensorEq(ixx:select(2, ind), ix, 0, 'torch.median index')
+ mytester:assertTensorEq(mxx:select(2, ind), mx:select(2, 1), 0,
+ 'torch.median value')
+ mytester:assertTensorEq(ixx:select(2, ind), ix:select(2, 1), 0,
+ 'torch.median index')
-- Test use of result tensor
local mr = torch.Tensor()
@@ -1570,8 +1578,10 @@ function torchtest.median()
-- Test non-default dim
mx, ix = torch.median(x, 1)
mxx, ixx = torch.sort(x, 1)
- mytester:assertTensorEq(mxx:select(1, ind), mx, 0,'torch.median value')
- mytester:assertTensorEq(ixx:select(1, ind), ix, 0,'torch.median index')
+ mytester:assertTensorEq(mxx:select(1, ind), mx[1], 0,
+ 'torch.median value')
+ mytester:assertTensorEq(ixx:select(1, ind), ix[1], 0,
+ 'torch.median index')
-- input unchanged
mytester:assertTensorEq(x, x0, 0, 'torch.median modified input')
@@ -1658,7 +1668,7 @@ function torchtest.catArray()
mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
end
end
-function torchtest.sin()
+function torchtest.sin_2()
local x = torch.rand(msize,msize,msize)
local mx = torch.sin(x)
local mxx = torch.Tensor()
@@ -2205,30 +2215,6 @@ function torchtest.conv3_conv2_eq()
mytester:assertlt(maxdiff(o3,o32),precision,'torch.conv3_conv2_eq')
end
-function torchtest.fxcorr3_fxcorr2_eq()
- local ix = math.floor(torch.uniform(20,40))
- local iy = math.floor(torch.uniform(20,40))
- local iz = math.floor(torch.uniform(20,40))
- local kx = math.floor(torch.uniform(5,10))
- local ky = math.floor(torch.uniform(5,10))
- local kz = math.floor(torch.uniform(5,10))
-
- local x = torch.rand(ix,iy,iz)
- local k = torch.rand(kx,ky,kz)
-
- local o3 = torch.xcorr3(x,k,'F')
-
- local o32 = torch.zeros(o3:size())
-
- for i=1,x:size(1) do
- for j=1,k:size(1) do
- o32[i+j-1]:add(torch.xcorr2(x[i],k[k:size(1)-j + 1],'F'))
- end
- end
-
- mytester:assertlt(maxdiff(o3,o32),precision,'torch.conv3_conv2_eq')
-end
-
function torchtest.fconv3_fconv2_eq()
local ix = math.floor(torch.uniform(20,40))
local iy = math.floor(torch.uniform(20,40))
@@ -2269,27 +2255,6 @@ function torchtest.logical()
mytester:asserteq(x:nElement(),all:double():sum() , 'torch.logical')
end
-function torchtest.TestAsserts()
- mytester:assertError(function() error('hello') end, 'assertError: Error not caught')
- mytester:assertErrorPattern(function() error('hello') end, '.*ll.*', 'assertError: ".*ll.*" Error not caught')
-
- local x = torch.rand(100,100)*2-1;
- local xx = x:clone();
- mytester:assertTensorEq(x, xx, 1e-16, 'assertTensorEq: not deemed equal')
- mytester:assertTensorNe(x, xx+1, 1e-16, 'assertTensorNe: not deemed different')
- mytester:assertalmosteq(0, 1e-250, 1e-16, 'assertalmosteq: not deemed different')
-end
-
-function torchtest.BugInAssertTableEq()
- local t = {1,2,3}
- local tt = {1,2,3}
- mytester:assertTableEq(t, tt, 'assertTableEq: not deemed equal')
- mytester:assertTableNe(t, {3,2,1}, 'assertTableNe: not deemed different')
- mytester:assertTableEq({1,2,{4,5}}, {1,2,{4,5}}, 'assertTableEq: fails on recursive lists')
- mytester:assertTableNe(t, {1,2}, 'assertTableNe: different size not deemed different')
- mytester:assertTableNe(t, {1,2,3,4}, 'assertTableNe: different size not deemed different')
-end
-
function torchtest.RNGState()
local state = torch.getRNGState()
local stateCloned = state:clone()
@@ -2750,14 +2715,14 @@ function torchtest.classInModule()
-- Need a global for this module
_mymodule123 = {}
local x = torch.class('_mymodule123.myclass')
- mytester:assert(x, 'Could not create class in module')
+ mytester:assert(x ~= nil, 'Could not create class in module')
-- Remove the global
_G['_mymodule123'] = nil
end
function torchtest.classNoModule()
local x = torch.class('_myclass123')
- mytester:assert(x, 'Could not create class in module')
+ mytester:assert(x ~= nil, 'Could not create class in module')
end
function torchtest.type()
@@ -3003,7 +2968,7 @@ function torchtest.split()
mytester:assertTensorEq(tensor:narrow(dim, start, targetSize[i][dim]), split, 0.000001, 'Result content error in split '..i)
start = start + targetSize[i][dim]
end
- mytester:asserteq(#splits,#result, 0, 'Non-consistent output size from split')
+ mytester:asserteq(#splits, #result, 'Non-consistent output size from split')
for i, split in ipairs(splits) do
mytester:assertTensorEq(split,result[i], 0, 'Non-consistent outputs from split')
end
@@ -3136,13 +3101,13 @@ function torchtest.nonzero()
table.insert(dst, i)
end
end
- mytester:assertTensorEq(dst1, torch.LongTensor(dst), 0.0,
+ mytester:assertTensorEq(dst1:select(2, 1), torch.LongTensor(dst), 0.0,
"nonzero error")
- mytester:assertTensorEq(dst2, torch.LongTensor(dst), 0.0,
+ mytester:assertTensorEq(dst2:select(2, 1), torch.LongTensor(dst), 0.0,
"nonzero error")
- --mytester:assertTensorEq(dst3, torch.LongTensor(dst), 0.0,
- -- "nonzero error")
- mytester:assertTensorEq(dst4, torch.LongTensor(dst), 0.0,
+ --mytester:assertTensorEq(dst3:select(2, 1), torch.LongTensor(dst),
+ -- 0.0, "nonzero error")
+ mytester:assertTensorEq(dst4:select(2, 1), torch.LongTensor(dst), 0.0,
"nonzero error")
elseif shape:size() == 2 then
-- This test will allow through some false positives. It only checks
diff --git a/test/test_Tester.lua b/test/test_Tester.lua
new file mode 100644
index 0000000..a283360
--- /dev/null
+++ b/test/test_Tester.lua
@@ -0,0 +1,626 @@
+require 'torch'
+
+local tester = torch.Tester()
+
+local MESSAGE = "a really useful informative error message"
+
+local subtester = torch.Tester()
+-- The message only interests us in case of failure
+subtester._success = function(self) return true, MESSAGE end
+subtester._failure = function(self, message) return false, message end
+
+local tests = torch.TestSuite()
+
+local test_name_passed_to_setUp
+local calls_to_setUp = 0
+local calls_to_tearDown = 0
+
+local originalIoWrite = io.write
+local function disableIoWrite()
+ io.write = function() end
+end
+local function enableIoWrite()
+ io.write = originalIoWrite
+end
+
+local function meta_assert_success(success, message)
+ tester:assert(success == true, "assert wasn't successful")
+ tester:assert(string.find(message, MESSAGE) ~= nil, "message doesn't match")
+end
+local function meta_assert_failure(success, message)
+ tester:assert(success == false, "assert didn't fail")
+ tester:assert(string.find(message, MESSAGE) ~= nil, "message doesn't match")
+end
+
+function tests.really_test_assert()
+ assert((subtester:assert(true, MESSAGE)),
+ "subtester:assert doesn't actually work!")
+ assert(not (subtester:assert(false, MESSAGE)),
+ "subtester:assert doesn't actually work!")
+end
+
+function tests.setEarlyAbort()
+ disableIoWrite()
+
+ for _, earlyAbort in ipairs{false, true} do
+ local myTester = torch.Tester()
+
+ local invokedCount = 0
+ local myTests = {}
+ function myTests.t1()
+ invokedCount = invokedCount + 1
+ myTester:assert(false)
+ end
+ myTests.t2 = myTests.t1
+
+ myTester:setEarlyAbort(earlyAbort)
+ myTester:add(myTests)
+ pcall(myTester.run, myTester)
+
+ tester:assert(invokedCount == (earlyAbort and 1 or 2),
+ "wrong number of tests invoked for use with earlyAbort")
+ end
+
+ enableIoWrite()
+end
+
+function tests.setRethrowErrors()
+ disableIoWrite()
+
+ local myTester = torch.Tester()
+ myTester:setRethrowErrors(true)
+ myTester:add(function() error("a throw") end)
+
+ tester:assertErrorPattern(function() myTester:run() end,
+ "a throw",
+ "error should be rethrown")
+
+ enableIoWrite()
+end
+
+function tests.disable()
+ disableIoWrite()
+
+ for disableCount = 1, 2 do
+ local myTester = torch.Tester()
+ local tests = {}
+ local test1Invoked = false
+ local test2Invoked = false
+ function tests.test1()
+ test1Invoked = true
+ end
+ function tests.test2()
+ test2Invoked = true
+ end
+ myTester:add(tests)
+
+ if disableCount == 1 then
+ myTester:disable('test1'):run()
+ tester:assert((not test1Invoked) and test2Invoked,
+ "disabled test shouldn't have been invoked")
+ else
+ myTester:disable({'test1', 'test2'}):run()
+ tester:assert((not test1Invoked) and (not test2Invoked),
+ "disabled tests shouldn't have been invoked")
+ end
+ end
+
+ enableIoWrite()
+end
+
+function tests.assert()
+ meta_assert_success(subtester:assert(true, MESSAGE))
+ meta_assert_failure(subtester:assert(false, MESSAGE))
+end
+
+local function testEqNe(eqExpected, ...)
+ if eqExpected then
+ meta_assert_success(subtester:eq(...))
+ meta_assert_failure(subtester:ne(...))
+ else
+ meta_assert_failure(subtester:eq(...))
+ meta_assert_success(subtester:ne(...))
+ end
+end
+
+--[[ Test :assertGeneralEq and :assertGeneralNe (also known as :eq and :ne).
+
+Note that in-depth testing of testing of many specific types of data (such as
+Tensor) is covered below, when we test specific functions (such as
+:assertTensorEq). This just does a general check, as well as testing of testing
+of mixed datatypes.
+]]
+function tests.assertGeneral()
+ local one = torch.Tensor{1}
+
+ testEqNe(true, one, one, MESSAGE)
+ testEqNe(false, one, 1, MESSAGE)
+ testEqNe(true, "hi", "hi", MESSAGE)
+ testEqNe(true, {one, 1}, {one, 1}, MESSAGE)
+ testEqNe(true, {{{one}}}, {{{one}}}, MESSAGE)
+ testEqNe(false, {{{one}}}, {{one}}, MESSAGE)
+ testEqNe(true, torch.Storage{1}, torch.Storage{1}, MESSAGE)
+ testEqNe(false, torch.FloatStorage{1}, torch.LongStorage{1}, MESSAGE)
+ testEqNe(false, torch.Storage{1}, torch.Storage{1, 2}, MESSAGE)
+ testEqNe(false, "one", 1, MESSAGE)
+ testEqNe(false, {one}, {one + torch.Tensor{1e-10}}, MESSAGE)
+ testEqNe(true, {one}, {one + torch.Tensor{1e-10}}, 1e-9, MESSAGE)
+end
+
+function tests.assertlt()
+ meta_assert_success(subtester:assertlt(1, 2, MESSAGE))
+ meta_assert_failure(subtester:assertlt(2, 1, MESSAGE))
+ meta_assert_failure(subtester:assertlt(1, 1, MESSAGE))
+end
+
+function tests.assertgt()
+ meta_assert_success(subtester:assertgt(2, 1, MESSAGE))
+ meta_assert_failure(subtester:assertgt(1, 2, MESSAGE))
+ meta_assert_failure(subtester:assertgt(1, 1, MESSAGE))
+end
+
+function tests.assertle()
+ meta_assert_success(subtester:assertle(1, 2, MESSAGE))
+ meta_assert_failure(subtester:assertle(2, 1, MESSAGE))
+ meta_assert_success(subtester:assertle(1, 1, MESSAGE))
+end
+
+function tests.assertge()
+ meta_assert_success(subtester:assertge(2, 1, MESSAGE))
+ meta_assert_failure(subtester:assertge(1, 2, MESSAGE))
+ meta_assert_success(subtester:assertge(1, 1, MESSAGE))
+end
+
+function tests.asserteq()
+ meta_assert_success(subtester:asserteq(1, 1, MESSAGE))
+ meta_assert_failure(subtester:asserteq(1, 2, MESSAGE))
+end
+
+function tests.assertalmosteq()
+ meta_assert_success(subtester:assertalmosteq(1, 1, MESSAGE))
+ meta_assert_success(subtester:assertalmosteq(1, 1 + 1e-17, MESSAGE))
+ meta_assert_success(subtester:assertalmosteq(1, 2, 2, MESSAGE))
+ meta_assert_failure(subtester:assertalmosteq(1, 2, MESSAGE))
+ meta_assert_failure(subtester:assertalmosteq(1, 3, 1, MESSAGE))
+end
+
+function tests.assertne()
+ meta_assert_success(subtester:assertne(1, 2, MESSAGE))
+ meta_assert_failure(subtester:assertne(1, 1, MESSAGE))
+end
+
+-- The `alsoTestEq` flag is provided to test :eq in addition to :assertTensorEq.
+-- The behaviour of the two isn't always the same due to handling of tensors of
+-- different dimensions but the same number of elements.
+local function testTensorEqNe(eqExpected, alsoTestEq, ...)
+ if eqExpected then
+ meta_assert_success(subtester:assertTensorEq(...))
+ meta_assert_failure(subtester:assertTensorNe(...))
+ if alsoTestEq then
+ meta_assert_success(subtester:eq(...))
+ meta_assert_failure(subtester:ne(...))
+ end
+ else
+ meta_assert_failure(subtester:assertTensorEq(...))
+ meta_assert_success(subtester:assertTensorNe(...))
+ if alsoTestEq then
+ meta_assert_failure(subtester:eq(...))
+ meta_assert_success(subtester:ne(...))
+ end
+ end
+end
+
+function tests.assertTensor_types()
+ local allTypes = {
+ torch.ByteTensor,
+ torch.CharTensor,
+ torch.ShortTensor,
+ torch.IntTensor,
+ torch.LongTensor,
+ torch.FloatTensor,
+ torch.DoubleTensor,
+ }
+ for _, tensor1 in ipairs(allTypes) do
+ for _, tensor2 in ipairs(allTypes) do
+ local t1 = tensor1():ones(10)
+ local t2 = tensor2():ones(10)
+ testTensorEqNe(tensor1 == tensor2, true, t1, t2, 1e-6, MESSAGE)
+ end
+ end
+
+ testTensorEqNe(false, true, torch.FloatTensor(), torch.LongTensor(), MESSAGE)
+end
+
+function tests.assertTensor_sizes()
+ local t = torch.Tensor() -- no dimensions
+ local t2 = torch.ones(2)
+ local t3 = torch.ones(3)
+ local t12 = torch.ones(1, 2)
+ assert(subtester._assertTensorEqIgnoresDims == true) -- default state
+ testTensorEqNe(false, false, t, t2, 1e-6, MESSAGE)
+ testTensorEqNe(false, false, t, t3, 1e-6, MESSAGE)
+ testTensorEqNe(false, false, t, t12, 1e-6, MESSAGE)
+ testTensorEqNe(false, false, t2, t3, 1e-6, MESSAGE)
+ testTensorEqNe(true, false, t2, t12, 1e-6, MESSAGE)
+ testTensorEqNe(false, false, t3, t12, 1e-6, MESSAGE)
+ subtester._assertTensorEqIgnoresDims = false
+ testTensorEqNe(false, true, t, t2, 1e-6, MESSAGE)
+ testTensorEqNe(false, true, t, t3, 1e-6, MESSAGE)
+ testTensorEqNe(false, true, t, t12, 1e-6, MESSAGE)
+ testTensorEqNe(false, true, t2, t3, 1e-6, MESSAGE)
+ testTensorEqNe(false, true, t2, t12, 1e-6, MESSAGE)
+ testTensorEqNe(false, true, t3, t12, 1e-6, MESSAGE)
+ subtester._assertTensorEqIgnoresDims = true -- reset back
+end
+
+function tests.assertTensor_epsilon()
+ local t1 = torch.rand(100, 100)
+ local t2 = torch.rand(100, 100) * 1e-5
+ local t3 = t1 + t2
+ testTensorEqNe(true, true, t1, t3, 1e-4, MESSAGE)
+ testTensorEqNe(false, true, t1, t3, 1e-6, MESSAGE)
+end
+
+function tests.assertTensor_arg()
+ local one = torch.Tensor{1}
+
+ tester:assertErrorPattern(
+ function() subtester:assertTensorEq(one, 2) end,
+ "Second argument should be a Tensor")
+
+ -- Test that assertTensorEq support message and tolerance in either ordering
+ tester:assertNoError(
+ function() subtester:assertTensorEq(one, one, 0.1, MESSAGE) end)
+ tester:assertNoError(
+ function() subtester:assertTensorEq(one, one, MESSAGE, 0.1) end)
+end
+
+function tests.assertTensor()
+ local t1 = torch.randn(100, 100)
+ local t2 = t1:clone()
+ local t3 = torch.randn(100, 100)
+ testTensorEqNe(true, true, t1, t2, 1e-6, MESSAGE)
+ testTensorEqNe(false, true, t1, t3, 1e-6, MESSAGE)
+ testTensorEqNe(true, true, torch.Tensor(), torch.Tensor(), MESSAGE)
+end
+
+-- Check that calling assertTensorEq with two tensors with the same content but
+-- different dimensions gives a warning.
+function tests.assertTensorDimWarning()
+ local myTester = torch.Tester()
+ myTester:add(
+ function()
+ myTester:assertTensorEq(torch.Tensor{{1}}, torch.Tensor{1})
+ end)
+
+ local warningGiven = false
+ io.write = function(s)
+ if string.match(s, 'but different dimensions') then
+ warningGiven = true
+ end
+ end
+
+ myTester:run()
+ enableIoWrite()
+
+ tester:assert(warningGiven,
+ "Calling :assertTensorEq({{1}}, {1}) should give a warning")
+end
+
+local function testTableEqNe(eqExpected, ...)
+ if eqExpected then
+ meta_assert_success(subtester:assertTableEq(...))
+ meta_assert_failure(subtester:assertTableNe(...))
+ meta_assert_success(subtester:eq(...))
+ meta_assert_failure(subtester:ne(...))
+ else
+ meta_assert_failure(subtester:assertTableEq(...))
+ meta_assert_success(subtester:assertTableNe(...))
+ meta_assert_failure(subtester:eq(...))
+ meta_assert_success(subtester:ne(...))
+ end
+end
+
+function tests.assertTable()
+ testTableEqNe(true, {1, 2, 3}, {1, 2, 3}, MESSAGE)
+ testTableEqNe(false, {1, 2, 3}, {3, 2, 1}, MESSAGE)
+ testTableEqNe(true, {1, 2, {4, 5}}, {1, 2, {4, 5}}, MESSAGE)
+ testTableEqNe(false, {1, 2, 3}, {1,2}, MESSAGE)
+ testTableEqNe(false, {1, 2, 3}, {1, 2, 3, 4}, MESSAGE)
+ testTableEqNe(true, {{1}}, {{1}}, MESSAGE)
+ testTableEqNe(false, {{1}}, {{{1}}}, MESSAGE)
+ testTableEqNe(true, {false}, {false}, MESSAGE)
+ testTableEqNe(false, {true}, {false}, MESSAGE)
+ testTableEqNe(false, {false}, {true}, MESSAGE)
+
+ local tensor = torch.rand(100, 100)
+ local t1 = {1, "a", key = "value", tensor = tensor, subtable = {"nested"}}
+ local t2 = {1, "a", key = "value", tensor = tensor, subtable = {"nested"}}
+ testTableEqNe(true, t1, t2, MESSAGE)
+ for k, v in pairs(t1) do
+ local x = "something else"
+ t2[k] = nil
+ t2[x] = v
+ testTableEqNe(false, t1, t2, MESSAGE)
+ t2[x] = nil
+ t2[k] = x
+ testTableEqNe(false, t1, t2, MESSAGE)
+ t2[k] = v
+ testTableEqNe(true, t1, t2, MESSAGE)
+ end
+end
+
+local function good_fn() end
+local function bad_fn() error("muahaha!") end
+
+function tests.assertError()
+ meta_assert_success(subtester:assertError(bad_fn, MESSAGE))
+ meta_assert_failure(subtester:assertError(good_fn, MESSAGE))
+end
+
+function tests.assertNoError()
+ meta_assert_success(subtester:assertNoError(good_fn, MESSAGE))
+ meta_assert_failure(subtester:assertNoError(bad_fn, MESSAGE))
+end
+
+function tests.assertErrorPattern()
+ meta_assert_success(subtester:assertErrorPattern(bad_fn, "haha", MESSAGE))
+ meta_assert_failure(subtester:assertErrorPattern(bad_fn, "hehe", MESSAGE))
+end
+
+function tests.testSuite_duplicateTests()
+ local function createDuplicateTests()
+ local tests = torch.TestSuite()
+ function tests.testThis() end
+ function tests.testThis() end
+ end
+ tester:assertErrorPattern(createDuplicateTests,
+ "Test testThis is already defined.")
+end
+
+--[[ Returns a Tester with `numSuccess` success cases, `numFailure` failure
+ cases, and with an error if `hasError` is true.
+ Success and fail tests are evaluated with tester:eq
+]]
+local function genDummyTest(numSuccess, numFailure, hasError)
+ hasError = hasError or false
+
+ local dummyTester = torch.Tester()
+ local dummyTests = torch.TestSuite()
+
+ if numSuccess > 0 then
+ function dummyTests.testDummySuccess()
+ for i = 1, numSuccess do
+ dummyTester:eq({1}, {1}, '', 0)
+ end
+ end
+ end
+
+ if numFailure > 0 then
+ function dummyTests.testDummyFailure()
+ for i = 1, numFailure do
+ dummyTester:eq({1}, {2}, '', 0)
+ end
+ end
+ end
+
+ if hasError then
+ function dummyTests.testDummyError()
+ error('dummy error')
+ end
+ end
+
+ return dummyTester:add(dummyTests)
+end
+
+function tests.runStatusAndAssertCounts()
+ local emptyTest = genDummyTest(0, 0, false)
+ local sucTest = genDummyTest(1, 0, false)
+ local multSucTest = genDummyTest(4, 0, false)
+ local failTest = genDummyTest(0, 1, false)
+ local errTest = genDummyTest(0, 0, true)
+ local errFailTest = genDummyTest(0, 1, true)
+ local errSucTest = genDummyTest(1, 0, true)
+ local failSucTest = genDummyTest(1, 1, false)
+ local failSucErrTest = genDummyTest(1, 1, true)
+
+ disableIoWrite()
+
+ local success, msg = pcall(emptyTest.run, emptyTest)
+ tester:asserteq(success, true, "pcall should succeed for empty tests")
+
+ local success, msg = pcall(sucTest.run, sucTest)
+ tester:asserteq(success, true, "pcall should succeed for 1 successful test")
+
+ local success, msg = pcall(multSucTest.run, multSucTest)
+ tester:asserteq(success, true,
+ "pcall should succeed for 2+ successful tests")
+
+ local success, msg = pcall(failTest.run, failTest)
+ tester:asserteq(success, false, "pcall should fail for tests with failure")
+
+ local success, msg = pcall(errTest.run, errTest)
+ tester:asserteq(success, false, "pcall should fail for tests with error")
+
+ local success, msg = pcall(errFailTest.run, errFailTest)
+ tester:asserteq(success, false, "pcall should fail for error+fail tests")
+
+ local success, msg = pcall(errSucTest.run, errSucTest)
+ tester:asserteq(success, false, "pcall should fail for error+success tests")
+
+ local success, msg = pcall(failSucTest.run, failSucTest)
+ tester:asserteq(success, false, "pcall should fail for fail+success tests")
+
+ local success, msg = pcall(failSucErrTest.run, failSucErrTest)
+ tester:asserteq(success, false,
+ "pcall should fail for fail+success+err test")
+
+ enableIoWrite()
+
+ tester:asserteq(emptyTest.countasserts, 0,
+ "emptyTest should have 0 asserts")
+ tester:asserteq(sucTest.countasserts, 1, "sucTest should have 1 assert")
+ tester:asserteq(multSucTest.countasserts, 4,
+ "multSucTest should have 4 asserts")
+ tester:asserteq(failTest.countasserts, 1, "failTest should have 1 assert")
+ tester:asserteq(errTest.countasserts, 0, "errTest should have 0 asserts")
+ tester:asserteq(errFailTest.countasserts, 1,
+ "errFailTest should have 1 assert")
+ tester:asserteq(errSucTest.countasserts, 1,
+ "errSucTest should have 0 asserts")
+ tester:asserteq(failSucTest.countasserts, 2,
+ "failSucTest should have 2 asserts")
+end
+
+function tests.checkNestedTestsForbidden()
+ disableIoWrite()
+
+ local myTester = torch.Tester()
+ local myTests = {{function() end}}
+ tester:assertErrorPattern(function() myTester:add(myTests) end,
+ "Nested sets",
+ "tester should forbid adding nested test sets")
+
+ enableIoWrite()
+end
+
+function tests.checkWarningOnAssertObject()
+ -- This test checks that calling assert with an object generates a warning
+ local myTester = torch.Tester()
+ local myTests = {}
+ function myTests.assertAbuse()
+ myTester:assert({})
+ end
+ myTester:add(myTests)
+
+ local warningGiven = false
+ io.write = function(s)
+ if string.match(s, 'should only be used for boolean') then
+ warningGiven = true
+ end
+ end
+
+ myTester:run()
+ enableIoWrite()
+
+ tester:assert(warningGiven, "Should warn on calling :assert(object)")
+end
+
+function tests.checkWarningOnAssertNeObject()
+ -- This test checks that calling assertne with two objects generates warning
+ local myTester = torch.Tester()
+ local myTests = {}
+ function myTests.assertAbuse()
+ myTester:assertne({}, {})
+ end
+ myTester:add(myTests)
+
+ local warningGiven = false
+ io.write = function(s)
+ if string.match(s, 'assertne should only be used to compare basic') then
+ warningGiven = true
+ end
+ end
+
+ myTester:run()
+ enableIoWrite()
+
+ tester:assert(warningGiven, "Should warn on calling :assertne(obj, obj)")
+end
+
+function tests.checkWarningOnExtraAssertArguments()
+ -- This test checks that calling assert with extra args gives a lua error
+ local myTester = torch.Tester()
+ local myTests = {}
+ function myTests.assertAbuse()
+ myTester:assert(true, "some message", "extra argument")
+ end
+ myTester:add(myTests)
+
+ local errorGiven = false
+ io.write = function(s)
+ if string.match(s, 'Unexpected arguments') then
+ errorGiven = true
+ end
+ end
+ tester:assertError(function() myTester:run() end)
+ enableIoWrite()
+
+ tester:assert(errorGiven, ":assert should fail on extra arguments")
+end
+
+function tests.checkWarningOnUsingTable()
+ -- Checks that if we don't use a TestSuite then gives a warning
+ local myTester = torch.Tester()
+ local myTests = {}
+ myTester:add(myTests)
+
+ local errorGiven = false
+ io.write = function(s)
+ if string.match(s, 'use TestSuite rather than plain lua table') then
+ errorGiven = true
+ end
+ end
+ myTester:run()
+
+ enableIoWrite()
+ tester:assert(errorGiven, "Using a plain lua table for testsuite should warn")
+end
+
+function tests.checkMaxAllowedSetUpAndTearDown()
+ -- Checks can have at most 1 set-up and at most 1 tear-down function
+ local function f() end
+ local myTester = torch.Tester()
+
+ for _, name in ipairs({'_setUp', '_tearDown'}) do
+ tester:assertNoError(function() myTester:add(f, name) end,
+ "Adding 1 set-up / tear-down should be fine")
+ tester:assertErrorPattern(function() myTester:add(f, name) end,
+ "Only one",
+ "Adding second set-up / tear-down should fail")
+ end
+end
+
+function tests.test_setUp()
+ tester:asserteq(test_name_passed_to_setUp, 'test_setUp')
+ for key, value in pairs(tester.tests) do
+ tester:assertne(key, '_setUp')
+ end
+end
+
+function tests.test_tearDown()
+ for key, value in pairs(tester.tests) do
+ tester:assertne(key, '_tearDown')
+ end
+end
+
+function tests._setUp(name)
+ test_name_passed_to_setUp = name
+ calls_to_setUp = calls_to_setUp + 1
+end
+
+function tests._tearDown(name)
+ calls_to_tearDown = calls_to_tearDown + 1
+end
+
+tester:add(tests):run()
+
+-- Additional tests to check that _setUp and _tearDown were called.
+local test_count = 0
+for _ in pairs(tester.tests) do
+ test_count = test_count + 1
+end
+local postTests = torch.TestSuite()
+local postTester = torch.Tester()
+
+function postTests.test_setUp(tester)
+ postTester:asserteq(calls_to_setUp, test_count,
+ "Expected " .. test_count .. " calls to _setUp")
+end
+
+function postTests.test_tearDown()
+ postTester:asserteq(calls_to_tearDown, test_count,
+ "Expected " .. test_count .. " calls to _tearDown")
+end
+
+postTester:add(postTests):run()
diff --git a/test/test_qr.lua b/test/test_qr.lua
index c00c604..c850c3f 100644
--- a/test/test_qr.lua
+++ b/test/test_qr.lua
@@ -2,7 +2,7 @@
-- torch.qr(), torch.geqrf() and torch.orgqr().
local torch = require 'torch'
local tester = torch.Tester()
-local tests = {}
+local tests = torch.TestSuite()
-- torch.qr() with result tensors given.
local function qrInPlace(tensorFunc)
diff --git a/test/test_sharedmem.lua b/test/test_sharedmem.lua
index 9f594fe..14cdeaf 100644
--- a/test/test_sharedmem.lua
+++ b/test/test_sharedmem.lua
@@ -1,7 +1,7 @@
require 'torch'
local tester = torch.Tester()
-local tests = {}
+local tests = torch.TestSuite()
local function createSharedMemStorage(name, size, storageType)
local storageType = storageType or 'FloatStorage'
diff --git a/test/test_writeObject.lua b/test/test_writeObject.lua
index 8bccf10..ccf7eba 100644
--- a/test/test_writeObject.lua
+++ b/test/test_writeObject.lua
@@ -2,7 +2,7 @@ require 'torch'
local myTester = torch.Tester()
-local tests = {}
+local tests = torch.TestSuite()
-- checks that an object can be written and unwritten
@@ -91,7 +91,7 @@ function tests.test_error_msg()
end
local ok, msg = pcall(torch.save, 'saved.t7', evil_func)
myTester:assert(not ok)
- myTester:assert(msg:find('at <%?>%.outer%.theinner%.baz%.torch'))
+ myTester:assert(msg:find('at <%?>%.outer%.theinner%.baz%.torch') ~= nil)
end
function tests.test_warning_msg()