diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-10-17 08:55:56 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-10-17 08:55:56 +0300 |
commit | 2b598fe5c226995b04c8270e1b723a7b20c027d4 (patch) | |
tree | d569d178261e40e055824331a319dc7cae9eb31a | |
parent | ac40c058125cae4abb459ebad3d07e10ce858e8d (diff) | |
parent | 03162f5a4145a3baaa2656d7586ca6434831bd54 (diff) |
Merge pull request #554 from torch/catmultiple
torch.cat for multiple cuda types
-rw-r--r-- | TensorMath.lua | 13 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.cu | 69 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 3 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.cu | 71 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMath.h | 4 | ||||
-rw-r--r-- | test/test.lua | 53 |
6 files changed, 118 insertions, 95 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index e917f8c..f6803d1 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -873,7 +873,18 @@ for k, Tensor_ in pairs(handledTypenames) do wrap("sign", cname("sign"), {{name=Tensor, default=true, returned=true, method={default='nil'}}, - {name=Tensor, method={default=1}}}) + {name=Tensor, method={default=1}}}) + + wrap("cat", + cname("cat"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor}, + {name=Tensor}, + {name="index", default=lastdim(2)}}, + cname("catArray"), + {{name=Tensor, default=true, returned=true}, + {name=Tensor .. "Array"}, + {name="index", default=lastdimarray(2)}}) if real == 'float' or real == 'double' or real == 'half' then for _,name in ipairs({"log", "log1p", "exp", diff --git a/lib/THC/THCTensorMath.cu b/lib/THC/THCTensorMath.cu index bf8b399..f0bbd9c 100644 --- a/lib/THC/THCTensorMath.cu +++ b/lib/THC/THCTensorMath.cu @@ -6,75 +6,6 @@ #include <cfloat> -void THCudaTensor_cat(THCState *state, THCudaTensor *result, THCudaTensor *ta, THCudaTensor *tb, int dimension) -{ - THCudaTensor* inputs[2]; - inputs[0] = ta; - inputs[1] = tb; - THCudaTensor_catArray(state, result, inputs, 2, dimension); -} - -void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCudaTensor **inputs, int numInputs, int dimension) -{ - THLongStorage *size; - int i, j; - long offset; - int ndim = dimension + 1; - for (i = 0; i < numInputs; i++) - { - ndim = THMax(ndim, THCudaTensor_nDimension(state, inputs[i])); - } - - THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs); - THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension+1); - - size = THLongStorage_newWithSize(ndim); - for(i = 0; i < ndim; i++) - { - long dimSize = i < THCudaTensor_nDimension(state, inputs[0]) - ? THCudaTensor_size(state, inputs[0], i) - : 1; - if (i == dimension) - { - for (j = 1; j < numInputs; j++) - { - dimSize += i < THCudaTensor_nDimension(state, inputs[j]) - ? THCudaTensor_size(state, inputs[j], i) - : 1; - } - } - else - { - for (j = 1; j < numInputs; j++) - { - if (dimSize != (i < THCudaTensor_nDimension(state, inputs[j]) - ? THCudaTensor_size(state, inputs[j], i) - : 1)) { - THLongStorage_free(size); - THError("inconsistent tensor sizes"); - } - } - } - size->data[i] = dimSize; - } - - THCudaTensor_resize(state, result, size, NULL); - THLongStorage_free(size); - - offset = 0; - for (j = 0; j < numInputs; j++) - { - long dimSize = dimension < THCudaTensor_nDimension(state, inputs[j]) - ? THCudaTensor_size(state, inputs[j], dimension) - : 1; - THCudaTensor *nt = THCudaTensor_newWithTensor(state, result); - THCudaTensor_narrow(state, nt, NULL, dimension, offset, dimSize); - THCudaTensor_copy(state, nt, inputs[j]); - THCudaTensor_free(state, nt); - offset += dimSize; - } -} - template <typename T> struct TensorFillOp { TensorFillOp(T v) : val(v) {} diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 3bafd59..3d71469 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -58,9 +58,6 @@ THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b); THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a); -THC_API void THCudaTensor_cat(THCState *state, THCudaTensor *result, THCudaTensor *ta, THCudaTensor *tb, int dimension); -THC_API void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCudaTensor **inputs, int numInputs, int dimension); - THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value); THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size); diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index 557f8f5..231695d 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -65,4 +65,75 @@ THCTensor_(numel)(THCState *state, THCTensor *t) return THCTensor_(nElement)(state, t); } +void THCTensor_(cat)(THCState *state, THCTensor *result, + THCTensor *ta, THCTensor *tb, int dimension) +{ + THCTensor* inputs[2]; + inputs[0] = ta; + inputs[1] = tb; + THCTensor_(catArray)(state, result, inputs, 2, dimension); +} + +void THCTensor_(catArray)(THCState *state, THCTensor *result, + THCTensor **inputs, int numInputs, int dimension) +{ + THLongStorage *size; + int i, j; + long offset; + int ndim = dimension + 1; + for (i = 0; i < numInputs; i++) + { + ndim = THMax(ndim, THCTensor_(nDimension)(state, inputs[i])); + } + + THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs); + THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension+1); + + size = THLongStorage_newWithSize(ndim); + for(i = 0; i < ndim; i++) + { + long dimSize = i < THCTensor_(nDimension)(state, inputs[0]) + ? THCTensor_(size)(state, inputs[0], i) + : 1; + if (i == dimension) + { + for (j = 1; j < numInputs; j++) + { + dimSize += i < THCTensor_(nDimension)(state, inputs[j]) + ? THCTensor_(size)(state, inputs[j], i) + : 1; + } + } + else + { + for (j = 1; j < numInputs; j++) + { + if (dimSize != (i < THCTensor_(nDimension)(state, inputs[j]) + ? THCTensor_(size)(state, inputs[j], i) + : 1)) { + THLongStorage_free(size); + THError("inconsistent tensor sizes"); + } + } + } + size->data[i] = dimSize; + } + + THCTensor_(resize)(state, result, size, NULL); + THLongStorage_free(size); + + offset = 0; + for (j = 0; j < numInputs; j++) + { + long dimSize = dimension < THCTensor_(nDimension)(state, inputs[j]) + ? THCTensor_(size)(state, inputs[j], dimension) + : 1; + THCTensor *nt = THCTensor_(newWithTensor)(state, result); + THCTensor_(narrow)(state, nt, NULL, dimension, offset, dimSize); + THCTensor_(copy)(state, nt, inputs[j]); + THCTensor_(free)(state, nt); + offset += dimSize; + } +} + #endif diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h index cfc706a..5f4f8ee 100644 --- a/lib/THC/generic/THCTensorMath.h +++ b/lib/THC/generic/THCTensorMath.h @@ -9,5 +9,9 @@ THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *si THC_API void THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size); THC_API void THCTensor_(reshape)(THCState *state, THCTensor *r_, THCTensor *t, THLongStorage *size); THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t); +THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension); +THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension); + + #endif diff --git a/test/test.lua b/test/test.lua index eb9ee57..058103d 100644 --- a/test/test.lua +++ b/test/test.lua @@ -3243,33 +3243,42 @@ function test.topk() end function test.cat() - for dim = 1, 3 do - local x = torch.CudaTensor(13, minsize, minsize):uniform():transpose(1, dim) - local y = torch.CudaTensor(17, minsize, minsize):uniform():transpose(1, dim) - local mx = torch.cat(x, y, dim) - tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') - tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') + for k, typename in ipairs(typenames) do + for dim = 1, 3 do + local x = torch.Tensor(13, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local y = torch.Tensor(17, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local mx = torch.cat(x, y, dim) + tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') + tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') - local mxx = torch.CudaTensor() - torch.cat(mxx, x, y, dim) - tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + local mxx = torch.Tensor():type(typename) + torch.cat(mxx, x, y, dim) + tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + end end end function test.catArray() - for dim = 1, 3 do - local x = torch.CudaTensor(13, minsize, minsize):uniform():transpose(1, dim) - local y = torch.CudaTensor(17, minsize, minsize):uniform():transpose(1, dim) - local z = torch.CudaTensor(19, minsize, minsize):uniform():transpose(1, dim) - - local mx = torch.cat({x, y, z}, dim) - tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') - tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') - tester:assertTensorEq(mx:narrow(dim, 31, 19), z, 0, 'torch.cat value') - - local mxx = torch.CudaTensor() - torch.cat(mxx, {x, y, z}, dim) - tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + for k, typename in ipairs(typenames) do + for dim = 1, 3 do + local x = torch.Tensor(13, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local y = torch.Tensor(17, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + local z = torch.Tensor(19, minsize, minsize):uniform() + :type(typename):transpose(1, dim) + + local mx = torch.cat({x, y, z}, dim) + tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value') + tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value') + tester:assertTensorEq(mx:narrow(dim, 31, 19), z, 0, 'torch.cat value') + + local mxx = torch.Tensor():type(typename) + torch.cat(mxx, {x, y, z}, dim) + tester:assertTensorEq(mx, mxx, 0, 'torch.cat value') + end end end |