Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornkoumchatzky <nkoumchatzky@twitter.com>2016-12-21 03:24:35 +0300
committernkoumchatzky <nkoumchatzky@twitter.com>2016-12-26 18:23:42 +0300
commitd41580eccefcbc1d11d404e0c4ae522560f8e263 (patch)
tree1bf3233d39ec499418fe4daa1b802b6e6b6095da
parent7ca7ec9d08f1ef2c753e72cbd014397736d6b5af (diff)
Add a different code path for catting contiguous tensors along the first dimension, for speed reasons.
Fix a bug in cat when catting with an empty tensor along first dim (it added an extra dim). Fix the ambiguous 'catting along last dimension' sentence in the doc and change the behavior to pick the maximum last dimension over all input tensors. Now empty tensors are allowed.
-rw-r--r--TensorMath.lua12
-rwxr-xr-xdoc/maths.md10
-rw-r--r--lib/TH/generic/THTensorMath.c98
-rw-r--r--test/test.lua54
4 files changed, 145 insertions, 29 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 682de23..5971a7b 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -9,7 +9,7 @@ local argtypes = wrap.CInterface.argtypes
argtypes['ptrdiff_t'] = {
helpname = function(arg)
- return 'ptrdiff_t'
+ return 'ptrdiff_t'
end,
declare = function(arg)
@@ -35,7 +35,7 @@ argtypes['ptrdiff_t'] = {
end
end
end,
-
+
carg = function(arg)
return string.format('arg%d', arg.i)
end,
@@ -43,13 +43,13 @@ argtypes['ptrdiff_t'] = {
creturn = function(arg)
return string.format('arg%d', arg.i)
end,
-
+
precall = function(arg)
if arg.returned then
return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
end
end,
-
+
postcall = function(arg)
if arg.creturned then
return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
@@ -738,11 +738,11 @@ wrap("topk",
{{name=Tensor, default=true, returned=true},
{name=Tensor},
{name=Tensor},
- {name="index", default=lastdim(2)}},
+ {name="index", default=-1}},
cname("catArray"),
{{name=Tensor, default=true, returned=true},
{name=Tensor .. "Array"},
- {name="index", default=lastdimarray(2)}})
+ {name="index", default=-1}})
if Tensor == 'ByteTensor' then -- we declare this only once
interface:print(
diff --git a/doc/maths.md b/doc/maths.md
index 252b52d..44e5ea6 100755
--- a/doc/maths.md
+++ b/doc/maths.md
@@ -60,12 +60,14 @@ The advantage of second case is, same `res2` `Tensor` can be used successively i
<a name="torch.cat"></a>
`x = torch.cat(x_1, x_2, [dimension])` returns a `Tensor` `x` which is the concatenation of `Tensor`s `x_1` and `x_2` along dimension `dimension`.
-If `dimension` is not specified it is the last dimension.
+If `dimension` is not specified or if it is `-1`, it is the maximum last dimension over all input tensors, except if all tensors are empty, then it is `1`.
The other dimensions of `x_1` and `x_2` have to be equal.
Also supports arrays with arbitrary numbers of `Tensor`s as inputs.
+Empty tensors are ignored during catting, and thus do not throw an error. Performing cat on empty tensors only will always result in an empty tensor.
+
Examples:
```lua
> torch.cat(torch.ones(3), torch.zeros(2))
@@ -116,6 +118,12 @@ Examples:
0.2206 0.7449
[torch.DoubleTensor of size 7x2]
+> torch.cat({torch.Tensor(), torch.rand(3, 2)}, 1)
+ 0.3227 0.0493
+ 0.9161 0.1086
+ 0.2206 0.7449
+[torch.DoubleTensor of size 3x2]
+
```
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c
index e04d3b6..9fc1577 100644
--- a/lib/TH/generic/THTensorMath.c
+++ b/lib/TH/generic/THTensorMath.c
@@ -2035,53 +2035,111 @@ void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int
THLongStorage *size;
int i, j;
long offset;
- int ndim = dimension + 1;
+ int maxDim = dimension + 1;
+ int allEmpty = 1;
+ int allContiguous = 1;
+ int ldimension = dimension;
+
for (i = 0; i < numInputs; i++)
{
- ndim = THMax(ndim, inputs[i]->nDimension);
+ maxDim = THMax(maxDim, inputs[i]->nDimension);
+ }
+
+ // When the user input dimension is -1 (i.e. -2 in C)
+ // Then we pick the maximum last dimension across all tensors.
+ if ( dimension == -2 )
+ {
+ ldimension = maxDim?(maxDim-1):0;
}
THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
- THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
+ THArgCheck(ldimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
- size = THLongStorage_newWithSize(ndim);
- for(i = 0; i < ndim; i++)
+ size = THLongStorage_newWithSize(maxDim);
+
+ for(i = 0; i < maxDim; i++)
{
- long dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : 1;
- if (i == dimension)
+ // dimSize is either the size of the dim if it exists, either 1 if #dim > 0, otherwise 0
+ long dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : THMin(inputs[0]->nDimension, 1);
+ if (i == ldimension)
{
for (j = 1; j < numInputs; j++)
{
- dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : 1;
+ // accumulate the size over the dimension we want to cat on.
+ // Empty tensors are allowed
+ dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1);
+ if(inputs[j]->nDimension)
+ {
+ allContiguous = allContiguous && THTensor_(isContiguous)(inputs[j]);
+ }
}
}
else
{
for (j = 1; j < numInputs; j++)
{
- if (dimSize != (i < inputs[j]->nDimension ? inputs[j]->size[i] : 1))
+ long sz = (i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1));
+ // If it's a dimension we're not catting on
+ // Then fail if sizes are different AND > 0
+ if (dimSize != sz && dimSize && sz)
{
THLongStorage_free(size);
THError("inconsistent tensor sizes");
}
+ else if(!dimSize)
+ {
+ dimSize = sz;
+ }
}
}
+ allEmpty = allEmpty && !dimSize;
size->data[i] = dimSize;
}
- THTensor_(resize)(result, size, NULL);
- THLongStorage_free(size);
-
- offset = 0;
- for (j = 0; j < numInputs; j++)
+ // Initiate catting and resizing
+ // If at least one of the input is not empty
+ if (!allEmpty)
{
- long dimSize = dimension < inputs[j]->nDimension ? inputs[j]->size[dimension] : 1;
- THTensor *nt = THTensor_(newWithTensor)(result);
- THTensor_(narrow)(nt, NULL, dimension, offset, dimSize);
- THTensor_(copy)(nt, inputs[j]);
- THTensor_(free)(nt);
- offset += dimSize;
+ THTensor_(resize)(result, size, NULL);
+
+ allContiguous = allContiguous && THTensor_(isContiguous)(result);
+
+ // First path is for contiguous inputs along dim 1
+ // Second path for non-contiguous
+ if (ldimension == 0 && allContiguous)
+ {
+ real* result_data = result->storage->data + result->storageOffset;
+ offset = 0;
+ for (j = 0; j < numInputs; j++)
+ {
+ if (inputs[j]->nDimension)
+ {
+ THTensor* input0 = inputs[j];
+ real* input0_data = input0->storage->data + input0->storageOffset;
+ long input0_size = THTensor_(nElement)(input0);
+ memcpy(result_data + offset, input0_data, input0_size*sizeof(real));
+ offset += input0_size;
+ }
+ }
+ }
+ else
+ {
+ offset = 0;
+ for (j = 0; j < numInputs; j++)
+ {
+ if (inputs[j]->nDimension)
+ {
+ long dimSize = ldimension < inputs[j]->nDimension ? inputs[j]->size[ldimension] : 1;
+ THTensor *nt = THTensor_(newWithTensor)(result);
+ THTensor_(narrow)(nt, NULL, ldimension, offset, dimSize);
+ THTensor_(copy)(nt, inputs[j]);
+ THTensor_(free)(nt);
+ offset += dimSize;
+ }
+ }
+ }
}
+ THLongStorage_free(size);
}
int THTensor_(equal)(THTensor *ta, THTensor* tb)
diff --git a/test/test.lua b/test/test.lua
index 3eb119f..eb7cf0a 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1827,7 +1827,32 @@ function torchtest.cat()
local mxx = torch.Tensor()
torch.cat(mxx, x, y, dim)
mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
- end
+
+ local x = torch.rand(1,2,3)
+ local y = torch.Tensor()
+ local mx = torch.cat(x,y,dim)
+ mytester:asserteq(mx:size(1),1,'torch.cat size')
+ mytester:asserteq(mx:size(2),2,'torch.cat size')
+ mytester:asserteq(mx:size(3),3,'torch.cat size')
+ mytester:assertTensorEq(mx, x, 0, 'torch.cat value')
+
+ local x = torch.Tensor()
+ local y = torch.Tensor()
+ local mx = torch.cat(x,y,dim)
+ mytester:asserteq(mx:dim(),0,'torch.cat dim')
+ end
+ local x = torch.Tensor()
+ local y = torch.rand(1,2,3)
+ local mx = torch.cat(x,y)
+ mytester:asserteq(mx:size(1),1,'torch.cat size')
+ mytester:asserteq(mx:size(2),2,'torch.cat size')
+ mytester:asserteq(mx:size(3),3,'torch.cat size')
+ mytester:assertTensorEq(mx, y, 0, 'torch.cat value')
+
+ local x = torch.Tensor()
+ local y = torch.Tensor()
+ local mx = torch.cat(x,y)
+ mytester:asserteq(mx:dim(),0,'torch.cat dim')
end
function torchtest.catArray()
for dim = 1, 3 do
@@ -1849,7 +1874,32 @@ function torchtest.catArray()
mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
torch.cat(mxx:double(), {x:double(), y:double(), z:double()}, dim)
mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
- end
+
+ local x = torch.rand(1,2,3)
+ local y = torch.Tensor()
+ local mx = torch.cat({x,y},dim)
+ mytester:asserteq(mx:size(1),1,'torch.cat size')
+ mytester:asserteq(mx:size(2),2,'torch.cat size')
+ mytester:asserteq(mx:size(3),3,'torch.cat size')
+ mytester:assertTensorEq(mx, x, 0, 'torch.cat value')
+
+ local x = torch.Tensor()
+ local y = torch.Tensor()
+ local mx = torch.cat({x,y},dim)
+ mytester:asserteq(mx:dim(),0,'torch.cat dim')
+ end
+ local x = torch.Tensor()
+ local y = torch.rand(1,2,3)
+ local mx = torch.cat({x,y})
+ mytester:asserteq(mx:size(1),1,'torch.cat size')
+ mytester:asserteq(mx:size(2),2,'torch.cat size')
+ mytester:asserteq(mx:size(3),3,'torch.cat size')
+ mytester:assertTensorEq(mx, y, 0, 'torch.cat value')
+
+ local x = torch.Tensor()
+ local y = torch.Tensor()
+ local mx = torch.cat({x,y})
+ mytester:asserteq(mx:dim(),0,'torch.cat dim')
end
function torchtest.sin_2()
local x = torch.rand(msize,msize,msize)