diff options
author | nkoumchatzky <nkoumchatzky@twitter.com> | 2016-12-21 03:24:35 +0300 |
---|---|---|
committer | nkoumchatzky <nkoumchatzky@twitter.com> | 2016-12-26 18:23:42 +0300 |
commit | d41580eccefcbc1d11d404e0c4ae522560f8e263 (patch) | |
tree | 1bf3233d39ec499418fe4daa1b802b6e6b6095da | |
parent | 7ca7ec9d08f1ef2c753e72cbd014397736d6b5af (diff) |
Add a different code path for catting contiguous tensors along the first dimension, for speed reasons.
Fix a bug in cat when catting with an empty tensor along first dim (it added an extra dim).
Fix the ambiguous 'catting along last dimension' sentence in the doc and change the behavior to pick the maximum last dimension over all input tensors.
Now empty tensors are allowed.
-rw-r--r-- | TensorMath.lua | 12 | ||||
-rwxr-xr-x | doc/maths.md | 10 | ||||
-rw-r--r-- | lib/TH/generic/THTensorMath.c | 98 | ||||
-rw-r--r-- | test/test.lua | 54 |
4 files changed, 145 insertions, 29 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 682de23..5971a7b 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -9,7 +9,7 @@ local argtypes = wrap.CInterface.argtypes argtypes['ptrdiff_t'] = { helpname = function(arg) - return 'ptrdiff_t' + return 'ptrdiff_t' end, declare = function(arg) @@ -35,7 +35,7 @@ argtypes['ptrdiff_t'] = { end end end, - + carg = function(arg) return string.format('arg%d', arg.i) end, @@ -43,13 +43,13 @@ argtypes['ptrdiff_t'] = { creturn = function(arg) return string.format('arg%d', arg.i) end, - + precall = function(arg) if arg.returned then return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i) end end, - + postcall = function(arg) if arg.creturned then return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i) @@ -738,11 +738,11 @@ wrap("topk", {{name=Tensor, default=true, returned=true}, {name=Tensor}, {name=Tensor}, - {name="index", default=lastdim(2)}}, + {name="index", default=-1}}, cname("catArray"), {{name=Tensor, default=true, returned=true}, {name=Tensor .. "Array"}, - {name="index", default=lastdimarray(2)}}) + {name="index", default=-1}}) if Tensor == 'ByteTensor' then -- we declare this only once interface:print( diff --git a/doc/maths.md b/doc/maths.md index 252b52d..44e5ea6 100755 --- a/doc/maths.md +++ b/doc/maths.md @@ -60,12 +60,14 @@ The advantage of second case is, same `res2` `Tensor` can be used successively i <a name="torch.cat"></a> `x = torch.cat(x_1, x_2, [dimension])` returns a `Tensor` `x` which is the concatenation of `Tensor`s `x_1` and `x_2` along dimension `dimension`. -If `dimension` is not specified it is the last dimension. +If `dimension` is not specified or if it is `-1`, it is the maximum last dimension over all input tensors, except if all tensors are empty, then it is `1`. The other dimensions of `x_1` and `x_2` have to be equal. Also supports arrays with arbitrary numbers of `Tensor`s as inputs. +Empty tensors are ignored during catting, and thus do not throw an error. Performing cat on empty tensors only will always result in an empty tensor. + Examples: ```lua > torch.cat(torch.ones(3), torch.zeros(2)) @@ -116,6 +118,12 @@ Examples: 0.2206 0.7449 [torch.DoubleTensor of size 7x2] +> torch.cat({torch.Tensor(), torch.rand(3, 2)}, 1) + 0.3227 0.0493 + 0.9161 0.1086 + 0.2206 0.7449 +[torch.DoubleTensor of size 3x2] + ``` diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index e04d3b6..9fc1577 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -2035,53 +2035,111 @@ void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int THLongStorage *size; int i, j; long offset; - int ndim = dimension + 1; + int maxDim = dimension + 1; + int allEmpty = 1; + int allContiguous = 1; + int ldimension = dimension; + for (i = 0; i < numInputs; i++) { - ndim = THMax(ndim, inputs[i]->nDimension); + maxDim = THMax(maxDim, inputs[i]->nDimension); + } + + // When the user input dimension is -1 (i.e. -2 in C) + // Then we pick the maximum last dimension across all tensors. + if ( dimension == -2 ) + { + ldimension = maxDim?(maxDim-1):0; } THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs); - THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE); + THArgCheck(ldimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE); - size = THLongStorage_newWithSize(ndim); - for(i = 0; i < ndim; i++) + size = THLongStorage_newWithSize(maxDim); + + for(i = 0; i < maxDim; i++) { - long dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : 1; - if (i == dimension) + // dimSize is either the size of the dim if it exists, either 1 if #dim > 0, otherwise 0 + long dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : THMin(inputs[0]->nDimension, 1); + if (i == ldimension) { for (j = 1; j < numInputs; j++) { - dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : 1; + // accumulate the size over the dimension we want to cat on. + // Empty tensors are allowed + dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1); + if(inputs[j]->nDimension) + { + allContiguous = allContiguous && THTensor_(isContiguous)(inputs[j]); + } } } else { for (j = 1; j < numInputs; j++) { - if (dimSize != (i < inputs[j]->nDimension ? inputs[j]->size[i] : 1)) + long sz = (i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1)); + // If it's a dimension we're not catting on + // Then fail if sizes are different AND > 0 + if (dimSize != sz && dimSize && sz) { THLongStorage_free(size); THError("inconsistent tensor sizes"); } + else if(!dimSize) + { + dimSize = sz; + } } } + allEmpty = allEmpty && !dimSize; size->data[i] = dimSize; } - THTensor_(resize)(result, size, NULL); - THLongStorage_free(size); - - offset = 0; - for (j = 0; j < numInputs; j++) + // Initiate catting and resizing + // If at least one of the input is not empty + if (!allEmpty) { - long dimSize = dimension < inputs[j]->nDimension ? inputs[j]->size[dimension] : 1; - THTensor *nt = THTensor_(newWithTensor)(result); - THTensor_(narrow)(nt, NULL, dimension, offset, dimSize); - THTensor_(copy)(nt, inputs[j]); - THTensor_(free)(nt); - offset += dimSize; + THTensor_(resize)(result, size, NULL); + + allContiguous = allContiguous && THTensor_(isContiguous)(result); + + // First path is for contiguous inputs along dim 1 + // Second path for non-contiguous + if (ldimension == 0 && allContiguous) + { + real* result_data = result->storage->data + result->storageOffset; + offset = 0; + for (j = 0; j < numInputs; j++) + { + if (inputs[j]->nDimension) + { + THTensor* input0 = inputs[j]; + real* input0_data = input0->storage->data + input0->storageOffset; + long input0_size = THTensor_(nElement)(input0); + memcpy(result_data + offset, input0_data, input0_size*sizeof(real)); + offset += input0_size; + } + } + } + else + { + offset = 0; + for (j = 0; j < numInputs; j++) + { + if (inputs[j]->nDimension) + { + long dimSize = ldimension < inputs[j]->nDimension ? inputs[j]->size[ldimension] : 1; + THTensor *nt = THTensor_(newWithTensor)(result); + THTensor_(narrow)(nt, NULL, ldimension, offset, dimSize); + THTensor_(copy)(nt, inputs[j]); + THTensor_(free)(nt); + offset += dimSize; + } + } + } } + THLongStorage_free(size); } int THTensor_(equal)(THTensor *ta, THTensor* tb) diff --git a/test/test.lua b/test/test.lua index 3eb119f..eb7cf0a 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1827,7 +1827,32 @@ function torchtest.cat() local mxx = torch.Tensor() torch.cat(mxx, x, y, dim) mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value') - end + + local x = torch.rand(1,2,3) + local y = torch.Tensor() + local mx = torch.cat(x,y,dim) + mytester:asserteq(mx:size(1),1,'torch.cat size') + mytester:asserteq(mx:size(2),2,'torch.cat size') + mytester:asserteq(mx:size(3),3,'torch.cat size') + mytester:assertTensorEq(mx, x, 0, 'torch.cat value') + + local x = torch.Tensor() + local y = torch.Tensor() + local mx = torch.cat(x,y,dim) + mytester:asserteq(mx:dim(),0,'torch.cat dim') + end + local x = torch.Tensor() + local y = torch.rand(1,2,3) + local mx = torch.cat(x,y) + mytester:asserteq(mx:size(1),1,'torch.cat size') + mytester:asserteq(mx:size(2),2,'torch.cat size') + mytester:asserteq(mx:size(3),3,'torch.cat size') + mytester:assertTensorEq(mx, y, 0, 'torch.cat value') + + local x = torch.Tensor() + local y = torch.Tensor() + local mx = torch.cat(x,y) + mytester:asserteq(mx:dim(),0,'torch.cat dim') end function torchtest.catArray() for dim = 1, 3 do @@ -1849,7 +1874,32 @@ function torchtest.catArray() mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value') torch.cat(mxx:double(), {x:double(), y:double(), z:double()}, dim) mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value') - end + + local x = torch.rand(1,2,3) + local y = torch.Tensor() + local mx = torch.cat({x,y},dim) + mytester:asserteq(mx:size(1),1,'torch.cat size') + mytester:asserteq(mx:size(2),2,'torch.cat size') + mytester:asserteq(mx:size(3),3,'torch.cat size') + mytester:assertTensorEq(mx, x, 0, 'torch.cat value') + + local x = torch.Tensor() + local y = torch.Tensor() + local mx = torch.cat({x,y},dim) + mytester:asserteq(mx:dim(),0,'torch.cat dim') + end + local x = torch.Tensor() + local y = torch.rand(1,2,3) + local mx = torch.cat({x,y}) + mytester:asserteq(mx:size(1),1,'torch.cat size') + mytester:asserteq(mx:size(2),2,'torch.cat size') + mytester:asserteq(mx:size(3),3,'torch.cat size') + mytester:assertTensorEq(mx, y, 0, 'torch.cat value') + + local x = torch.Tensor() + local y = torch.Tensor() + local mx = torch.cat({x,y}) + mytester:asserteq(mx:dim(),0,'torch.cat dim') end function torchtest.sin_2() local x = torch.rand(msize,msize,msize) |