diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-01-01 20:47:19 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-01-01 20:47:19 +0300 |
commit | 659fe96bc72dbf26331eb6fe7b4f8455fbef70d4 (patch) | |
tree | cf0640d17664533fdaa5d561010155a078128a18 /test | |
parent | 6b408c71248ba1a4c340441ccbe6bf97902d7efa (diff) | |
parent | 12ffddf5cd7e29944246497ad7a499169c4bbd4e (diff) |
Merge pull request #651 from pavanky/cat
Adding support for empty tensors in cat, catArray
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 78 |
1 files changed, 52 insertions, 26 deletions
diff --git a/test/test.lua b/test/test.lua index 6748b1e..0eec0bf 100644 --- a/test/test.lua +++ b/test/test.lua @@ -3625,17 +3625,30 @@ end function test.cat() for k, typename in ipairs(typenames) do for dim = 1, 3 do - local x = torch.Tensor(13, minsize, minsize):uniform() - :type(typename):transpose(1, dim) - local y = torch.Tensor(17, minsize, minsize):uniform() - :type(typename):transpose(1, dim) - local mx = torch.cat(x, y, dim) - tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') - tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') - - local mxx = torch.Tensor():type(typename) - torch.cat(mxx, x, y, dim) - tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + local x = torch.Tensor(13, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local y = torch.Tensor(17, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local mx = torch.cat(x, y, dim) + tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') + tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') + + local mxx = torch.Tensor():type(typename) + torch.cat(mxx, x, y, dim) + tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + + local x = torch.CudaTensor(1, 2, 3):uniform() + local y = torch.CudaTensor() + local mx = torch.cat(x,y,dim) + tester:asserteq(mx:size(1),1,'torch.cat size') + tester:asserteq(mx:size(2),2,'torch.cat size') + tester:asserteq(mx:size(3),3,'torch.cat size') + tester:assertTensorEq(mx, x, 0, 'torch.cat value') + + local x = torch.CudaTensor() + local y = torch.CudaTensor() + local mx = torch.cat(x,y,dim) + tester:asserteq(mx:dim(),0,'torch.cat dim') end end end @@ -3643,21 +3656,34 @@ end function test.catArray() for k, typename in ipairs(typenames) do for dim = 1, 3 do - local x = torch.Tensor(13, minsize, minsize):uniform() - :type(typename):transpose(1, dim) - local y = torch.Tensor(17, minsize, minsize):uniform() - :type(typename):transpose(1, dim) - local z = torch.Tensor(19, minsize, minsize):uniform() - :type(typename):transpose(1, dim) - - local mx = torch.cat({x, y, z}, dim) - tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') - tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') - tester:assertTensorEq(mx:narrow(dim, 31, 19), z, 0, 'torch.cat value') - - local mxx = torch.Tensor():type(typename) - torch.cat(mxx, {x, y, z}, dim) - tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + local x = torch.Tensor(13, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local y = torch.Tensor(17, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local z = torch.Tensor(19, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + + local mx = torch.cat({x, y, z}, dim) + tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') + tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') + tester:assertTensorEq(mx:narrow(dim, 31, 19), z, 0, 'torch.cat value') + + local mxx = torch.Tensor():type(typename) + torch.cat(mxx, {x, y, z}, dim) + tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + + local x = torch.CudaTensor(1, 2, 3):uniform() + local y = torch.CudaTensor() + local mx = torch.cat({x,y},dim) + tester:asserteq(mx:size(1),1,'torch.cat size') + tester:asserteq(mx:size(2),2,'torch.cat size') + tester:asserteq(mx:size(3),3,'torch.cat size') + tester:assertTensorEq(mx, x, 0, 'torch.cat value') + + local x = torch.CudaTensor() + local y = torch.CudaTensor() + local mx = torch.cat({x,y},dim) + tester:asserteq(mx:dim(),0,'torch.cat dim') end end end |