diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-02-24 15:51:47 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-02-24 15:51:47 +0300 |
commit | 641d9c508e027c0cd550ff435f5d6cfd02c7cecd (patch) | |
tree | da94629b7c84ed0bc2c1135899833a42f3470611 /test | |
parent | 680b3dd56fff4276099353871e5e9e9e4f394e72 (diff) |
add unit test for fill
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 43 |
1 files changed, 42 insertions, 1 deletions
diff --git a/test/test.lua b/test/test.lua index 3cb4166..ed00c88 100644 --- a/test/test.lua +++ b/test/test.lua @@ -112,7 +112,7 @@ local genericSingleOpTest = [[ end end return maxerrc, maxerrnc -]] +--]] function torchtest.sin() local f = loadstring(string.gsub(genericSingleOpTest, 'functionname', 'sin')) @@ -574,6 +574,47 @@ function torchtest.mv() mytester:assertlt(err, precision, 'error in torch.mv') end +function torchtest.fill() + local types = { + 'torch.ByteTensor', + 'torch.CharTensor', + 'torch.ShortTensor', + 'torch.IntTensor', + 'torch.FloatTensor', + 'torch.DoubleTensor', + 'torch.LongTensor', + } + + for k,t in ipairs(types) do + -- [res] torch.fill([res,] tensor, value) + local m1 = torch.ones(100,100):type(t) + local res1 = m1:clone() + res1[{ 3,{} }]:fill(2) + + local res2 = m1:clone() + for i = 1,m1:size(1) do + res2[{ 3,i }] = 2 + end + + local err = (res1-res2):double():abs():max() + + mytester:assertlt(err, precision, 'error in torch.fill - contiguous') + + local m1 = torch.ones(100,100):type(t) + local res1 = m1:clone() + res1[{ {},3 }]:fill(2) + + local res2 = m1:clone() + for i = 1,m1:size(1) do + res2[{ i,3 }] = 2 + end + + local err = (res1-res2):double():abs():max() + + mytester:assertlt(err, precision, 'error in torch.fill - non contiguous') + end +end + function torchtest.add() -- [res] torch.add([res,] tensor1, tensor2) local m1 = torch.randn(100,100) |