Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-01-01 20:47:19 +0300
committerGitHub <noreply@github.com>2017-01-01 20:47:19 +0300
commit659fe96bc72dbf26331eb6fe7b4f8455fbef70d4 (patch)
treecf0640d17664533fdaa5d561010155a078128a18 /test
parent6b408c71248ba1a4c340441ccbe6bf97902d7efa (diff)
parent12ffddf5cd7e29944246497ad7a499169c4bbd4e (diff)
Merge pull request #651 from pavanky/cat
Adding support for empty tensors in cat, catArray
Diffstat (limited to 'test')
-rw-r--r--test/test.lua78
1 files changed, 52 insertions, 26 deletions
diff --git a/test/test.lua b/test/test.lua
index 6748b1e..0eec0bf 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -3625,17 +3625,30 @@ end
function test.cat()
for k, typename in ipairs(typenames) do
for dim = 1, 3 do
- local x = torch.Tensor(13, minsize, minsize):uniform()
- :type(typename):transpose(1, dim)
- local y = torch.Tensor(17, minsize, minsize):uniform()
- :type(typename):transpose(1, dim)
- local mx = torch.cat(x, y, dim)
- tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value')
- tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value')
-
- local mxx = torch.Tensor():type(typename)
- torch.cat(mxx, x, y, dim)
- tester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
+ local x = torch.Tensor(13, minsize, minsize):uniform()
+ :type(typename):transpose(1, dim)
+ local y = torch.Tensor(17, minsize, minsize):uniform()
+ :type(typename):transpose(1, dim)
+ local mx = torch.cat(x, y, dim)
+ tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value')
+ tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value')
+
+ local mxx = torch.Tensor():type(typename)
+ torch.cat(mxx, x, y, dim)
+ tester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
+
+ local x = torch.CudaTensor(1, 2, 3):uniform()
+ local y = torch.CudaTensor()
+ local mx = torch.cat(x,y,dim)
+ tester:asserteq(mx:size(1),1,'torch.cat size')
+ tester:asserteq(mx:size(2),2,'torch.cat size')
+ tester:asserteq(mx:size(3),3,'torch.cat size')
+ tester:assertTensorEq(mx, x, 0, 'torch.cat value')
+
+ local x = torch.CudaTensor()
+ local y = torch.CudaTensor()
+ local mx = torch.cat(x,y,dim)
+ tester:asserteq(mx:dim(),0,'torch.cat dim')
end
end
end
@@ -3643,21 +3656,34 @@ end
function test.catArray()
for k, typename in ipairs(typenames) do
for dim = 1, 3 do
- local x = torch.Tensor(13, minsize, minsize):uniform()
- :type(typename):transpose(1, dim)
- local y = torch.Tensor(17, minsize, minsize):uniform()
- :type(typename):transpose(1, dim)
- local z = torch.Tensor(19, minsize, minsize):uniform()
- :type(typename):transpose(1, dim)
-
- local mx = torch.cat({x, y, z}, dim)
- tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value')
- tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value')
- tester:assertTensorEq(mx:narrow(dim, 31, 19), z, 0, 'torch.cat value')
-
- local mxx = torch.Tensor():type(typename)
- torch.cat(mxx, {x, y, z}, dim)
- tester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
+ local x = torch.Tensor(13, minsize, minsize):uniform()
+ :type(typename):transpose(1, dim)
+ local y = torch.Tensor(17, minsize, minsize):uniform()
+ :type(typename):transpose(1, dim)
+ local z = torch.Tensor(19, minsize, minsize):uniform()
+ :type(typename):transpose(1, dim)
+
+ local mx = torch.cat({x, y, z}, dim)
+ tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value')
+ tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value')
+ tester:assertTensorEq(mx:narrow(dim, 31, 19), z, 0, 'torch.cat value')
+
+ local mxx = torch.Tensor():type(typename)
+ torch.cat(mxx, {x, y, z}, dim)
+ tester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
+
+ local x = torch.CudaTensor(1, 2, 3):uniform()
+ local y = torch.CudaTensor()
+ local mx = torch.cat({x,y},dim)
+ tester:asserteq(mx:size(1),1,'torch.cat size')
+ tester:asserteq(mx:size(2),2,'torch.cat size')
+ tester:asserteq(mx:size(3),3,'torch.cat size')
+ tester:assertTensorEq(mx, x, 0, 'torch.cat value')
+
+ local x = torch.CudaTensor()
+ local y = torch.CudaTensor()
+ local mx = torch.cat({x,y},dim)
+ tester:asserteq(mx:dim(),0,'torch.cat dim')
end
end
end