Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2016-12-29 22:23:26 +0300
committerSoumith Chintala <soumith@gmail.com>2016-12-29 22:23:26 +0300
commita0c0b78471df5f4507791e870cf7df9607a64400 (patch)
treef91b908684ba7ff4727df1331d3a5c9b4b3b9cb8 /test
parent7ca7ec9d08f1ef2c753e72cbd014397736d6b5af (diff)
Add support for torch.HalfTensor (#874)
* Add support for torch.HalfTensor. * Improvements/Simplifications for torch.HalfTensor. Improvements/Simplifications: 1) Defines half type as TH_Half, so as to not conflict with cutorch version. Previously, these were defined as the same "half" type and required proper ordering of includes to ensure type was only defined once, which would have affected all downstream projects. 2) No longer generates math functions that are not actually defined on torch.HalfTensor, e.g. maskedFill, map, etc. 3) Adds tests for all available torch.HalfTensor functions 4) Allows compiling without TH_GENERIC_USE_HALF (so if there's a problem can just unset that in CMakeLists rather than backing out) 5) Some simplifications: removes a new copy optimization and some TH_HALF literal definitions Limitations: Because match functions are not defined, some "non-math" operators on torch.HalfTensor give an error message, e.g. __index__/__newindex__ with a ByteTensor apply a mask, but masks aren't implemented. These limitations aren't always obvious, (e.g. for documentation purposes), but they should always give an error message. * Rename TH_HALF to THHalf.
Diffstat (limited to 'test')
-rw-r--r--test/test.lua161
-rw-r--r--test/test_half.lua55
-rw-r--r--test/test_writeObject.lua11
3 files changed, 200 insertions, 27 deletions
diff --git a/test/test.lua b/test/test.lua
index 3eb119f..30ff339 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -19,6 +19,35 @@ local function maxdiff(x,y)
end
end
+-- workarounds for non-existant functions
+function torch.HalfTensor:__sub(other)
+ return (self:real() - other:real()):half()
+end
+
+function torch.HalfTensor:mean(dim)
+ return self:real():mean(dim):half()
+end
+
+function torch.HalfTensor:abs()
+ return self:real():abs():half()
+end
+
+function torch.HalfTensor:max()
+ return self:real():max()
+end
+
+function torch.HalfTensor:add(a, b)
+ return (self:real():add(a, b:real())):half()
+end
+
+function torch.HalfTensor:reshape(a, b)
+ return (self:real():reshape(a, b)):half()
+end
+
+function torch.HalfTensor:fill(a)
+ return self:real():fill(a):half()
+end
+
function torchtest.dot()
local types = {
['torch.DoubleTensor'] = 1e-8, -- for ddot
@@ -3053,7 +3082,13 @@ function torchtest.isTypeOfPattern()
end
function torchtest.isTensor()
- local t = torch.randn(3,4)
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_isTensor(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_isTensor(func)
+ local t = func(torch.randn(3,4))
mytester:assert(torch.isTensor(t), 'error in isTensor')
mytester:assert(torch.isTensor(t[1]), 'error in isTensor for subTensor')
mytester:assert(not torch.isTensor(t[1][2]), 'false positive in isTensor')
@@ -3061,14 +3096,26 @@ function torchtest.isTensor()
end
function torchtest.isStorage()
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_isStorage(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_isStorage(func)
local t = torch.randn(3,4)
mytester:assert(torch.isStorage(t:storage()), 'error in isStorage')
mytester:assert(not torch.isStorage(t), 'false positive in isStorage')
end
function torchtest.view()
- local tensor = torch.rand(15)
- local template = torch.rand(3,5)
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_view(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_view(func)
+ local tensor = func(torch.rand(15))
+ local template = func(torch.rand(3,5))
local target = template:size():totable()
mytester:assertTableEq(tensor:viewAs(template):size():totable(), target, 'Error in viewAs')
mytester:assertTableEq(tensor:view(3,5):size():totable(), target, 'Error in view')
@@ -3079,7 +3126,7 @@ function torchtest.view()
tensor_view:fill(torch.rand(1)[1])
mytester:asserteq((tensor_view-tensor):abs():max(), 0, 'Error in view')
- local target_tensor = torch.Tensor()
+ local target_tensor = func(torch.Tensor())
mytester:assertTableEq(target_tensor:viewAs(tensor, template):size():totable(), target, 'Error in viewAs')
mytester:assertTableEq(target_tensor:view(tensor, 3,5):size():totable(), target, 'Error in view')
mytester:assertTableEq(target_tensor:view(tensor, torch.LongStorage{3,5}):size():totable(), target, 'Error in view using LongStorage')
@@ -3090,9 +3137,15 @@ function torchtest.view()
end
function torchtest.expand()
- local result = torch.Tensor()
- local tensor = torch.rand(8,1)
- local template = torch.rand(8,5)
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_expand(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_expand(func)
+ local result = func(torch.Tensor())
+ local tensor = func(torch.rand(8,1))
+ local template = func(torch.rand(8,5))
local target = template:size():totable()
mytester:assertTableEq(tensor:expandAs(template):size():totable(), target, 'Error in expandAs')
mytester:assertTableEq(tensor:expand(8,5):size():totable(), target, 'Error in expand')
@@ -3107,8 +3160,14 @@ function torchtest.expand()
end
function torchtest.repeatTensor()
- local result = torch.Tensor()
- local tensor = torch.rand(8,4)
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_repeatTensor(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_repeatTensor(func, mean)
+ local result = func(torch.Tensor())
+ local tensor = func(torch.rand(8,4))
local size = {3,1,1}
local sizeStorage = torch.LongStorage(size)
local target = {3,8,4}
@@ -3122,10 +3181,16 @@ function torchtest.repeatTensor()
end
function torchtest.isSameSizeAs()
- local t1 = torch.Tensor(3, 4, 9, 10)
- local t2 = torch.Tensor(3, 4)
- local t3 = torch.Tensor(1, 9, 3, 3)
- local t4 = torch.Tensor(3, 4, 9, 10)
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_isSameSizeAs(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_isSameSizeAs(func)
+ local t1 = func(torch.Tensor(3, 4, 9, 10))
+ local t2 = func(torch.Tensor(3, 4))
+ local t3 = func(torch.Tensor(1, 9, 3, 3))
+ local t4 = func(torch.Tensor(3, 4, 9, 10))
mytester:assert(t1:isSameSizeAs(t2) == false, "wrong answer ")
mytester:assert(t1:isSameSizeAs(t3) == false, "wrong answer ")
@@ -3133,15 +3198,21 @@ function torchtest.isSameSizeAs()
end
function torchtest.isSetTo()
- local t1 = torch.Tensor(3, 4, 9, 10)
- local t2 = torch.Tensor(3, 4, 9, 10)
- local t3 = torch.Tensor():set(t1)
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_isSetTo(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_isSetTo(func)
+ local t1 = func(torch.Tensor(3, 4, 9, 10))
+ local t2 = func(torch.Tensor(3, 4, 9, 10))
+ local t3 = func(torch.Tensor()):set(t1)
local t4 = t3:reshape(12, 90)
mytester:assert(t1:isSetTo(t2) == false, "tensors do not share storage")
mytester:assert(t1:isSetTo(t3) == true, "tensor is set to other")
mytester:assert(t3:isSetTo(t1) == true, "isSetTo should be symmetric")
mytester:assert(t1:isSetTo(t4) == false, "tensors have different view")
- mytester:assert(not torch.Tensor():isSetTo(torch.Tensor()),
+ mytester:assert(not func(torch.Tensor()):isSetTo(func(torch.Tensor())),
"Tensors with no storages should not appear to be set " ..
"to each other")
end
@@ -3179,7 +3250,13 @@ function torchtest.equal()
end
function torchtest.isSize()
- local t1 = torch.Tensor(3, 4, 5)
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_isSize(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_isSize(func)
+ local t1 = func(torch.Tensor(3, 4, 5))
local s1 = torch.LongStorage({3, 4, 5})
local s2 = torch.LongStorage({5, 4, 3})
@@ -3196,6 +3273,7 @@ function torchtest.elementSize()
local long = torch.LongStorage():elementSize()
local float = torch.FloatStorage():elementSize()
local double = torch.DoubleStorage():elementSize()
+ local half = torch.HalfStorage():elementSize()
mytester:asserteq(byte, torch.ByteTensor():elementSize())
mytester:asserteq(char, torch.CharTensor():elementSize())
@@ -3204,6 +3282,7 @@ function torchtest.elementSize()
mytester:asserteq(long, torch.LongTensor():elementSize())
mytester:asserteq(float, torch.FloatTensor():elementSize())
mytester:asserteq(double, torch.DoubleTensor():elementSize())
+ mytester:asserteq(half, torch.HalfTensor():elementSize())
mytester:assertne(byte, 0)
mytester:assertne(char, 0)
@@ -3212,6 +3291,7 @@ function torchtest.elementSize()
mytester:assertne(long, 0)
mytester:assertne(float, 0)
mytester:assertne(double, 0)
+ mytester:assertne(half, 0)
-- These tests are portable, not necessarily strict for your system.
mytester:asserteq(byte, 1)
@@ -3222,11 +3302,18 @@ function torchtest.elementSize()
mytester:assert(long >= 4)
mytester:assert(long >= int)
mytester:assert(double >= float)
+ mytester:assert(half <= float)
end
function torchtest.split()
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_split(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_split(func)
local result = {}
- local tensor = torch.rand(7,4)
+ local tensor = func(torch.rand(7,4))
local splitSize = 3
local targetSize = {{3,4},{3,4},{1,4}}
local dim = 1
@@ -3251,8 +3338,14 @@ function torchtest.split()
end
function torchtest.chunk()
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_chunk(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_chunk(func)
local result = {}
- local tensor = torch.rand(4,7)
+ local tensor = func(torch.rand(4,7))
local nChunk = 3
local targetSize = {{4,3},{4,3},{4,1}}
local dim = 2
@@ -3272,24 +3365,34 @@ function torchtest.chunk()
end
end
-function torchtest.totable()
+function torchtest.table()
+ local convStorage = {
+ ['real'] = 'FloatStorage',
+ ['half'] = 'HalfStorage'
+ }
+ for k,v in ipairs(convStorage) do
+ torchtest_totable(torch.getmetatable(torch.Tensor():type())[k], v)
+ end
+end
+
+function torchtest_totable(func, storageType)
local table0D = {}
- local tensor0D = torch.Tensor(table0D)
+ local tensor0D = func(torch.Tensor(table0D))
mytester:assertTableEq(torch.totable(tensor0D), table0D, 'tensor0D:totable incorrect')
local table1D = {1, 2, 3}
- local tensor1D = torch.Tensor(table1D)
- local storage = torch.Storage(table1D)
+ local tensor1D = func(torch.Tensor(table1D))
+ local storage = torch[storageType](table1D)
mytester:assertTableEq(tensor1D:totable(), table1D, 'tensor1D:totable incorrect')
mytester:assertTableEq(storage:totable(), table1D, 'storage:totable incorrect')
mytester:assertTableEq(torch.totable(tensor1D), table1D, 'torch.totable incorrect for Tensors')
mytester:assertTableEq(torch.totable(storage), table1D, 'torch.totable incorrect for Storages')
local table2D = {{1, 2}, {3, 4}}
- local tensor2D = torch.Tensor(table2D)
+ local tensor2D = func(torch.Tensor(table2D))
mytester:assertTableEq(tensor2D:totable(), table2D, 'tensor2D:totable incorrect')
- local tensor3D = torch.Tensor({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})
+ local tensor3D = func(torch.Tensor({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}))
local tensorNonContig = tensor3D:select(2, 2)
mytester:assert(not tensorNonContig:isContiguous(), 'invalid test')
mytester:assertTableEq(tensorNonContig:totable(), {{3, 4}, {7, 8}},
@@ -3297,6 +3400,12 @@ function torchtest.totable()
end
function torchtest.permute()
+ for k,v in ipairs({"real", "half"}) do
+ torchtest_permute(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function torchtest_permute(func)
local orig = {1,2,3,4,5,6,7}
local perm = torch.randperm(7):totable()
local x = torch.Tensor(unpack(orig)):fill(0)
diff --git a/test/test_half.lua b/test/test_half.lua
new file mode 100644
index 0000000..bf3830b
--- /dev/null
+++ b/test/test_half.lua
@@ -0,0 +1,55 @@
+local mytester
+local torchtest = torch.TestSuite()
+
+-- Lua 5.2 compatibility
+local loadstring = loadstring or load
+local unpack = unpack or table.unpack
+
+function torchtest.easy()
+ local x=torch.randn(5, 6):half()
+ mytester:assert(x:isContiguous(), 'x should be contiguous')
+ mytester:assert(x:dim() == 2, 'x should have dim of 2')
+ mytester:assert(x:nDimension() == 2, 'x should have nDimension of 2')
+ mytester:assert(x:nElement() == 5 * 6, 'x should have 30 elements')
+ local stride = x:stride()
+ local expectedStride = torch.LongStorage{6,1}
+ for i=1,stride:size() do
+ mytester:assert(stride[i] == expectedStride[i], "stride is wrong")
+ end
+
+ x=x:t()
+ mytester:assert(not x:isContiguous(), 'x transpose should not be contiguous')
+ x=x:transpose(1,2)
+ mytester:assert(x:isContiguous(), 'x should be contiguous after 2 transposes')
+
+ local y=torch.HalfTensor()
+ y:resizeAs(x:t()):copy(x:t())
+ mytester:assert(x:isContiguous(), 'after resize and copy, x should be contiguous')
+ mytester:assertTensorEq(y, x:t(), 0.001, 'copy broken after resizeAs')
+ local z=torch.HalfTensor()
+ z:resize(6, 5):copy(x:t())
+ mytester:assertTensorEq(y, x:t(), 0.001, 'copy broken after resize')
+end
+
+function torchtest.narrowSub()
+ local x = torch.randn(5, 6):half()
+ local narrow = x:narrow(1, 2, 3)
+ local sub = x:sub(2, 4)
+ mytester:assertTensorEq(narrow, sub, 0.001, 'narrow not equal to sub')
+end
+
+function torchtest.selectClone()
+ local x = torch.zeros(5, 6)
+ x:select(1,2):fill(2)
+ x=x:half()
+ local y=x:clone()
+ mytester:assertTensorEq(x, y, 0.001, 'not equal after select and clone')
+ x:select(1,1):fill(3)
+ mytester:assert(y[1][1] == 0, 'clone broken')
+end
+
+torch.setheaptracking(true)
+math.randomseed(os.time())
+mytester = torch.Tester()
+mytester:add(torchtest)
+mytester:run(tests)
diff --git a/test/test_writeObject.lua b/test/test_writeObject.lua
index 1013a96..52bcb71 100644
--- a/test/test_writeObject.lua
+++ b/test/test_writeObject.lua
@@ -4,6 +4,9 @@ local myTester = torch.Tester()
local tests = torch.TestSuite()
+function torch.HalfTensor:norm()
+ return self:real():norm()
+end
-- checks that an object can be written and unwritten
-- returns false if an error occurs
@@ -66,7 +69,13 @@ function tests.test_a_recursive_closure()
end
function tests.test_a_tensor()
- local x = torch.rand(5, 10)
+ for k,v in ipairs({"real", "half"}) do
+ tests_test_a_tensor(torch.getmetatable(torch.Tensor():type())[v])
+ end
+end
+
+function tests_test_a_tensor(func)
+ local x = func(torch.rand(5, 10))
local xcopy = serializeAndDeserialize(x)
myTester:assert(x:norm() == xcopy:norm(), 'tensors should be the same')
end