Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-02-24 15:51:47 +0300
committerSoumith Chintala <soumith@gmail.com>2017-02-24 15:51:47 +0300
commit641d9c508e027c0cd550ff435f5d6cfd02c7cecd (patch)
treeda94629b7c84ed0bc2c1135899833a42f3470611 /test
parent680b3dd56fff4276099353871e5e9e9e4f394e72 (diff)
add unit test for fill
Diffstat (limited to 'test')
-rw-r--r--test/test.lua43
1 files changed, 42 insertions, 1 deletions
diff --git a/test/test.lua b/test/test.lua
index 3cb4166..ed00c88 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -112,7 +112,7 @@ local genericSingleOpTest = [[
end
end
return maxerrc, maxerrnc
-]]
+--]]
function torchtest.sin()
local f = loadstring(string.gsub(genericSingleOpTest, 'functionname', 'sin'))
@@ -574,6 +574,47 @@ function torchtest.mv()
mytester:assertlt(err, precision, 'error in torch.mv')
end
+function torchtest.fill()
+ local types = {
+ 'torch.ByteTensor',
+ 'torch.CharTensor',
+ 'torch.ShortTensor',
+ 'torch.IntTensor',
+ 'torch.FloatTensor',
+ 'torch.DoubleTensor',
+ 'torch.LongTensor',
+ }
+
+ for k,t in ipairs(types) do
+ -- [res] torch.fill([res,] tensor, value)
+ local m1 = torch.ones(100,100):type(t)
+ local res1 = m1:clone()
+ res1[{ 3,{} }]:fill(2)
+
+ local res2 = m1:clone()
+ for i = 1,m1:size(1) do
+ res2[{ 3,i }] = 2
+ end
+
+ local err = (res1-res2):double():abs():max()
+
+ mytester:assertlt(err, precision, 'error in torch.fill - contiguous')
+
+ local m1 = torch.ones(100,100):type(t)
+ local res1 = m1:clone()
+ res1[{ {},3 }]:fill(2)
+
+ local res2 = m1:clone()
+ for i = 1,m1:size(1) do
+ res2[{ i,3 }] = 2
+ end
+
+ local err = (res1-res2):double():abs():max()
+
+ mytester:assertlt(err, precision, 'error in torch.fill - non contiguous')
+ end
+end
+
function torchtest.add()
-- [res] torch.add([res,] tensor1, tensor2)
local m1 = torch.randn(100,100)