Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-10-17 08:55:56 +0300
committerGitHub <noreply@github.com>2016-10-17 08:55:56 +0300
commit2b598fe5c226995b04c8270e1b723a7b20c027d4 (patch)
treed569d178261e40e055824331a319dc7cae9eb31a
parentac40c058125cae4abb459ebad3d07e10ce858e8d (diff)
parent03162f5a4145a3baaa2656d7586ca6434831bd54 (diff)
Merge pull request #554 from torch/catmultiple
torch.cat for multiple cuda types
-rw-r--r--TensorMath.lua13
-rw-r--r--lib/THC/THCTensorMath.cu69
-rw-r--r--lib/THC/THCTensorMath.h3
-rw-r--r--lib/THC/generic/THCTensorMath.cu71
-rw-r--r--lib/THC/generic/THCTensorMath.h4
-rw-r--r--test/test.lua53
6 files changed, 118 insertions, 95 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index e917f8c..f6803d1 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -873,7 +873,18 @@ for k, Tensor_ in pairs(handledTypenames) do
wrap("sign",
cname("sign"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
- {name=Tensor, method={default=1}}})
+ {name=Tensor, method={default=1}}})
+
+ wrap("cat",
+ cname("cat"),
+ {{name=Tensor, default=true, returned=true},
+ {name=Tensor},
+ {name=Tensor},
+ {name="index", default=lastdim(2)}},
+ cname("catArray"),
+ {{name=Tensor, default=true, returned=true},
+ {name=Tensor .. "Array"},
+ {name="index", default=lastdimarray(2)}})
if real == 'float' or real == 'double' or real == 'half' then
for _,name in ipairs({"log", "log1p", "exp",
diff --git a/lib/THC/THCTensorMath.cu b/lib/THC/THCTensorMath.cu
index bf8b399..f0bbd9c 100644
--- a/lib/THC/THCTensorMath.cu
+++ b/lib/THC/THCTensorMath.cu
@@ -6,75 +6,6 @@
#include <cfloat>
-void THCudaTensor_cat(THCState *state, THCudaTensor *result, THCudaTensor *ta, THCudaTensor *tb, int dimension)
-{
- THCudaTensor* inputs[2];
- inputs[0] = ta;
- inputs[1] = tb;
- THCudaTensor_catArray(state, result, inputs, 2, dimension);
-}
-
-void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCudaTensor **inputs, int numInputs, int dimension)
-{
- THLongStorage *size;
- int i, j;
- long offset;
- int ndim = dimension + 1;
- for (i = 0; i < numInputs; i++)
- {
- ndim = THMax(ndim, THCudaTensor_nDimension(state, inputs[i]));
- }
-
- THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
- THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension+1);
-
- size = THLongStorage_newWithSize(ndim);
- for(i = 0; i < ndim; i++)
- {
- long dimSize = i < THCudaTensor_nDimension(state, inputs[0])
- ? THCudaTensor_size(state, inputs[0], i)
- : 1;
- if (i == dimension)
- {
- for (j = 1; j < numInputs; j++)
- {
- dimSize += i < THCudaTensor_nDimension(state, inputs[j])
- ? THCudaTensor_size(state, inputs[j], i)
- : 1;
- }
- }
- else
- {
- for (j = 1; j < numInputs; j++)
- {
- if (dimSize != (i < THCudaTensor_nDimension(state, inputs[j])
- ? THCudaTensor_size(state, inputs[j], i)
- : 1)) {
- THLongStorage_free(size);
- THError("inconsistent tensor sizes");
- }
- }
- }
- size->data[i] = dimSize;
- }
-
- THCudaTensor_resize(state, result, size, NULL);
- THLongStorage_free(size);
-
- offset = 0;
- for (j = 0; j < numInputs; j++)
- {
- long dimSize = dimension < THCudaTensor_nDimension(state, inputs[j])
- ? THCudaTensor_size(state, inputs[j], dimension)
- : 1;
- THCudaTensor *nt = THCudaTensor_newWithTensor(state, result);
- THCudaTensor_narrow(state, nt, NULL, dimension, offset, dimSize);
- THCudaTensor_copy(state, nt, inputs[j]);
- THCudaTensor_free(state, nt);
- offset += dimSize;
- }
-}
-
template <typename T>
struct TensorFillOp {
TensorFillOp(T v) : val(v) {}
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 3bafd59..3d71469 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -58,9 +58,6 @@ THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor
THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b);
THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a);
-THC_API void THCudaTensor_cat(THCState *state, THCudaTensor *result, THCudaTensor *ta, THCudaTensor *tb, int dimension);
-THC_API void THCudaTensor_catArray(THCState *state, THCudaTensor *result, THCudaTensor **inputs, int numInputs, int dimension);
-
THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value);
THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size);
diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu
index 557f8f5..231695d 100644
--- a/lib/THC/generic/THCTensorMath.cu
+++ b/lib/THC/generic/THCTensorMath.cu
@@ -65,4 +65,75 @@ THCTensor_(numel)(THCState *state, THCTensor *t)
return THCTensor_(nElement)(state, t);
}
+void THCTensor_(cat)(THCState *state, THCTensor *result,
+ THCTensor *ta, THCTensor *tb, int dimension)
+{
+ THCTensor* inputs[2];
+ inputs[0] = ta;
+ inputs[1] = tb;
+ THCTensor_(catArray)(state, result, inputs, 2, dimension);
+}
+
+void THCTensor_(catArray)(THCState *state, THCTensor *result,
+ THCTensor **inputs, int numInputs, int dimension)
+{
+ THLongStorage *size;
+ int i, j;
+ long offset;
+ int ndim = dimension + 1;
+ for (i = 0; i < numInputs; i++)
+ {
+ ndim = THMax(ndim, THCTensor_(nDimension)(state, inputs[i]));
+ }
+
+ THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
+ THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension+1);
+
+ size = THLongStorage_newWithSize(ndim);
+ for(i = 0; i < ndim; i++)
+ {
+ long dimSize = i < THCTensor_(nDimension)(state, inputs[0])
+ ? THCTensor_(size)(state, inputs[0], i)
+ : 1;
+ if (i == dimension)
+ {
+ for (j = 1; j < numInputs; j++)
+ {
+ dimSize += i < THCTensor_(nDimension)(state, inputs[j])
+ ? THCTensor_(size)(state, inputs[j], i)
+ : 1;
+ }
+ }
+ else
+ {
+ for (j = 1; j < numInputs; j++)
+ {
+ if (dimSize != (i < THCTensor_(nDimension)(state, inputs[j])
+ ? THCTensor_(size)(state, inputs[j], i)
+ : 1)) {
+ THLongStorage_free(size);
+ THError("inconsistent tensor sizes");
+ }
+ }
+ }
+ size->data[i] = dimSize;
+ }
+
+ THCTensor_(resize)(state, result, size, NULL);
+ THLongStorage_free(size);
+
+ offset = 0;
+ for (j = 0; j < numInputs; j++)
+ {
+ long dimSize = dimension < THCTensor_(nDimension)(state, inputs[j])
+ ? THCTensor_(size)(state, inputs[j], dimension)
+ : 1;
+ THCTensor *nt = THCTensor_(newWithTensor)(state, result);
+ THCTensor_(narrow)(state, nt, NULL, dimension, offset, dimSize);
+ THCTensor_(copy)(state, nt, inputs[j]);
+ THCTensor_(free)(state, nt);
+ offset += dimSize;
+ }
+}
+
#endif
diff --git a/lib/THC/generic/THCTensorMath.h b/lib/THC/generic/THCTensorMath.h
index cfc706a..5f4f8ee 100644
--- a/lib/THC/generic/THCTensorMath.h
+++ b/lib/THC/generic/THCTensorMath.h
@@ -9,5 +9,9 @@ THC_API void THCTensor_(zeros)(THCState *state, THCTensor *r_, THLongStorage *si
THC_API void THCTensor_(ones)(THCState *state, THCTensor *r_, THLongStorage *size);
THC_API void THCTensor_(reshape)(THCState *state, THCTensor *r_, THCTensor *t, THLongStorage *size);
THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t);
+THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension);
+THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension);
+
+
#endif
diff --git a/test/test.lua b/test/test.lua
index eb9ee57..058103d 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -3243,33 +3243,42 @@ function test.topk()
end
function test.cat()
- for dim = 1, 3 do
- local x = torch.CudaTensor(13, minsize, minsize):uniform():transpose(1, dim)
- local y = torch.CudaTensor(17, minsize, minsize):uniform():transpose(1, dim)
- local mx = torch.cat(x, y, dim)
- tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value')
- tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value')
+ for k, typename in ipairs(typenames) do
+ for dim = 1, 3 do
+ local x = torch.Tensor(13, minsize, minsize):uniform()
+ :type(typename):transpose(1, dim)
+ local y = torch.Tensor(17, minsize, minsize):uniform()
+ :type(typename):transpose(1, dim)
+ local mx = torch.cat(x, y, dim)
+ tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value')
+ tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value')
- local mxx = torch.CudaTensor()
- torch.cat(mxx, x, y, dim)
- tester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
+ local mxx = torch.Tensor():type(typename)
+ torch.cat(mxx, x, y, dim)
+ tester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
+ end
end
end
function test.catArray()
- for dim = 1, 3 do
- local x = torch.CudaTensor(13, minsize, minsize):uniform():transpose(1, dim)
- local y = torch.CudaTensor(17, minsize, minsize):uniform():transpose(1, dim)
- local z = torch.CudaTensor(19, minsize, minsize):uniform():transpose(1, dim)
-
- local mx = torch.cat({x, y, z}, dim)
- tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value')
- tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value')
- tester:assertTensorEq(mx:narrow(dim, 31, 19), z, 0, 'torch.cat value')
-
- local mxx = torch.CudaTensor()
- torch.cat(mxx, {x, y, z}, dim)
- tester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
+ for k, typename in ipairs(typenames) do
+ for dim = 1, 3 do
+ local x = torch.Tensor(13, minsize, minsize):uniform()
+ :type(typename):transpose(1, dim)
+ local y = torch.Tensor(17, minsize, minsize):uniform()
+ :type(typename):transpose(1, dim)
+ local z = torch.Tensor(19, minsize, minsize):uniform()
+ :type(typename):transpose(1, dim)
+
+ local mx = torch.cat({x, y, z}, dim)
+ tester:assertTensorEq(mx:narrow(dim, 1, 13), x, 0, 'torch.cat value')
+ tester:assertTensorEq(mx:narrow(dim, 14, 17), y, 0, 'torch.cat value')
+ tester:assertTensorEq(mx:narrow(dim, 31, 19), z, 0, 'torch.cat value')
+
+ local mxx = torch.Tensor():type(typename)
+ torch.cat(mxx, {x, y, z}, dim)
+ tester:assertTensorEq(mx, mxx, 0, 'torch.cat value')
+ end
end
end