diff options
author | Pavan Yalamanchili <pyalamanchili@twitter.com> | 2016-12-26 07:21:22 +0300 |
---|---|---|
committer | Pavan Yalamanchili <pyalamanchili@twitter.com> | 2016-12-28 00:37:42 +0300 |
commit | 12ffddf5cd7e29944246497ad7a499169c4bbd4e (patch) | |
tree | 119b8540ec9b1afdeb8eb32dce7a15a61d1db7c3 /test | |
parent | 33449f817dbd5f32cb0162a4f500b6753152718d (diff) |
Adding support for empty tensors in cat, catArray
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 78 |
1 files changed, 52 insertions, 26 deletions
diff --git a/test/test.lua b/test/test.lua index cd7fd61..e0ae867 100644 --- a/test/test.lua +++ b/test/test.lua @@ -3595,17 +3595,30 @@ end function test.cat() for k, typename in ipairs(typenames) do for dim = 1, 3 do - local x = torch.Tensor(13, minsize, minsize):uniform() - :type(typename):transpose(1, dim) - local y = torch.Tensor(17, minsize, minsize):uniform() - :type(typename):transpose(1, dim) - local mx = torch.cat(x, y, dim) - tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') - tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') - - local mxx = torch.Tensor():type(typename) - torch.cat(mxx, x, y, dim) - tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + local x = torch.Tensor(13, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local y = torch.Tensor(17, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local mx = torch.cat(x, y, dim) + tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') + tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') + + local mxx = torch.Tensor():type(typename) + torch.cat(mxx, x, y, dim) + tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + + local x = torch.CudaTensor(1, 2, 3):uniform() + local y = torch.CudaTensor() + local mx = torch.cat(x,y,dim) + tester:asserteq(mx:size(1),1,'torch.cat size') + tester:asserteq(mx:size(2),2,'torch.cat size') + tester:asserteq(mx:size(3),3,'torch.cat size') + tester:assertTensorEq(mx, x, 0, 'torch.cat value') + + local x = torch.CudaTensor() + local y = torch.CudaTensor() + local mx = torch.cat(x,y,dim) + tester:asserteq(mx:dim(),0,'torch.cat dim') end end end @@ -3613,21 +3626,34 @@ end function test.catArray() for k, typename in ipairs(typenames) do for dim = 1, 3 do - local x = torch.Tensor(13, minsize, minsize):uniform() - :type(typename):transpose(1, dim) - local y = torch.Tensor(17, minsize, minsize):uniform() - :type(typename):transpose(1, dim) - local z = torch.Tensor(19, minsize, minsize):uniform() - :type(typename):transpose(1, dim) - - local mx = torch.cat({x, y, z}, dim) - tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') - tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') - tester:assertTensorEq(mx:narrow(dim, 31, 19), z, 0, 'torch.cat value') - - local mxx = torch.Tensor():type(typename) - torch.cat(mxx, {x, y, z}, dim) - tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + local x = torch.Tensor(13, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local y = torch.Tensor(17, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local z = torch.Tensor(19, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + + local mx = torch.cat({x, y, z}, dim) + tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') + tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') + tester:assertTensorEq(mx:narrow(dim, 31, 19), z, 0, 'torch.cat value') + + local mxx = torch.Tensor():type(typename) + torch.cat(mxx, {x, y, z}, dim) + tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + + local x = torch.CudaTensor(1, 2, 3):uniform() + local y = torch.CudaTensor() + local mx = torch.cat({x,y},dim) + tester:asserteq(mx:size(1),1,'torch.cat size') + tester:asserteq(mx:size(2),2,'torch.cat size') + tester:asserteq(mx:size(3),3,'torch.cat size') + tester:assertTensorEq(mx, x, 0, 'torch.cat value') + + local x = torch.CudaTensor() + local y = torch.CudaTensor() + local mx = torch.cat({x,y},dim) + tester:asserteq(mx:dim(),0,'torch.cat dim') end end end |