Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-16 00:08:27 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-16 00:31:32 +0300
commit7250cc589f279bf0b3dd61563c7ce8087e7e63c6 (patch)
tree182801d89c0b089a2010f6341fa26da383891be2
parent5c719c51da75e72430bd3ff5b45ce2d062c7c349 (diff)
[cutorch mag2gen] move potr* to generic
-rw-r--r--TensorMath.lua32
-rw-r--r--lib/THC/THCTensorMath.h3
-rw-r--r--lib/THC/THCTensorMathMagma.cu117
-rw-r--r--lib/THC/generic/THCTensorMathMagma.cu131
-rw-r--r--lib/THC/generic/THCTensorMathMagma.h3
-rw-r--r--test/test.lua46
6 files changed, 194 insertions, 138 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 0b8bad8..0db19a7 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -1237,6 +1237,38 @@ for k, Tensor_ in pairs(handledTypenames) do
{{name=Tensor, default=true, returned=true, invisible=true},
{name=Tensor}})
+ wrap("potri",
+ cname("potri"),
+ {{name=Tensor, returned=true},
+ {name=Tensor},
+ {name='charoption', values={'U', 'L'}, default='U'}},
+ cname("potri"),
+ {{name=Tensor, default=true, returned=true, invisible=true},
+ {name=Tensor},
+ {name='charoption', values={'U', 'L'}, default='U'}})
+
+ wrap("potrf",
+ cname("potrf"),
+ {{name=Tensor, returned=true},
+ {name=Tensor},
+ {name='charoption', values={'U', 'L'}, default='U'}},
+ cname("potrf"),
+ {{name=Tensor, default=true, returned=true, invisible=true},
+ {name=Tensor},
+ {name='charoption', values={'U', 'L'}, default='U'}})
+
+ wrap("potrs",
+ cname("potrs"),
+ {{name=Tensor, returned=true},
+ {name=Tensor},
+ {name=Tensor},
+ {name='charoption', values={'U', 'L'}, default='U'}},
+ cname("potrs"),
+ {{name=Tensor, default=true, returned=true, invisible=true},
+ {name=Tensor},
+ {name=Tensor},
+ {name='charoption', values={'U', 'L'}, default='U'}})
+
end
wrap("dot",
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 7e02c6b..0850e3c 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -45,9 +45,6 @@
#include "THCGenerateAllTypes.h"
// MAGMA (i.e. CUDA implementation of LAPACK functions)
-THC_API void THCudaTensor_potri(THCState *state, THCudaTensor *ra_, THCudaTensor *a, const char *uplo);
-THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor *a, const char *uplo);
-THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b, const char *uplo);
THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a);
diff --git a/lib/THC/THCTensorMathMagma.cu b/lib/THC/THCTensorMathMagma.cu
index a60b792..7edcae9 100644
--- a/lib/THC/THCTensorMathMagma.cu
+++ b/lib/THC/THCTensorMathMagma.cu
@@ -23,123 +23,6 @@ void THCMagma_init(THCState *state)
#endif
}
-__global__ void THCudaTensor_copyUpperSymmetric(float *input, int n, int len)
-{
- for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) {
- const int r = idx % n;
- const int c = idx / n;
- if (r > c) {
- input[idx] = input[r*n + c];
- }
- }
-}
-
-__global__ void THCudaTensor_copyLowerSymmetric(float *input, int n, int len)
-{
- for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) {
- const int r = idx % n;
- const int c = idx / n;
- if (r < c) {
- input[idx] = input[r*n + c];
- }
- }
-}
-
-void THCudaTensor_potri(THCState *state, THCudaTensor *ra_, THCudaTensor *a, const char *uplo)
-{
-#ifdef USE_MAGMA
- THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
- THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
-
- int n = a->size[0];
- magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower;
-
- THCudaTensor *input = THCudaTensor_newColumnMajor(state, ra_, a);
- float *input_data = THCudaTensor_data(state, input);
-
- int info;
- magma_spotri_gpu(ul, n, input_data, n, &info);
- if (info > 0)
- THError("MAGMA potri : A(%d,%d) is 0, A cannot be factorized", info, info);
- else if (info < 0)
- THError("MAGMA potri : Argument %d : illegal value", -info);
-
- cudaStream_t stream = THCState_getCurrentStream(state);
- const int len = n*n;
- dim3 blocks(std::min(DIVUP(len, 128), 65535));
- dim3 threads(128);
- if (uplo[0] == 'U') {
- THCudaTensor_copyUpperSymmetric<<<blocks, threads, 0, stream>>>(input_data, n, len);
- } else {
- THCudaTensor_copyLowerSymmetric<<<blocks, threads, 0, stream>>>(input_data, n, len);
- }
-
- THCudaTensor_freeCopyTo(state, input, ra_);
-#else
- THError(NoMagma(potri));
-#endif
-}
-
-void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor *a, const char *uplo)
-{
-#ifdef USE_MAGMA
- THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
- THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
-
- int n = a->size[0];
- magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower;
-
- THCudaTensor *input = THCudaTensor_newColumnMajor(state, ra_, a);
- float *input_data = THCudaTensor_data(state, input);
-
- int info;
- magma_spotrf_gpu(ul, n, input_data, n, &info);
-
- // check error value
- if (info > 0)
- THError("MAGMA potrf : A(%d,%d) is 0, A cannot be factorized", info, info);
- else if (info < 0)
- THError("MAGMA potrf : Argument %d : illegal value", -info);
-
- if (uplo[0] == 'U') {
- THCudaTensor_triu(state, ra_, input, 0);
- } else {
- THCudaTensor_tril(state, ra_, input, 0);
- }
- THCudaTensor_free(state, input);
-#else
- THError(NoMagma(potrf));
-#endif
-}
-
-void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *b, THCudaTensor *a, const char *uplo)
-{
-#ifdef USE_MAGMA
- THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
-
- int n = a->size[0];
- int nrhs = b->size[1];
- magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower;
-
- THCudaTensor *b_ = THCudaTensor_newColumnMajor(state, rb_, b);
- float *b_data = THCudaTensor_data(state, b_);
- THCudaTensor *a_ = THCudaTensor_newColumnMajor(state, a, a);
- float *a_data = THCudaTensor_data(state, a_);
-
- int info;
- magma_spotrs_gpu(ul, n, nrhs, a_data, n, b_data, n, &info);
-
- // check error value
- if (info < 0)
- THError("MAGMA potrs : Argument %d : illegal value", -info);
-
- THCudaTensor_freeCopyTo(state, b_, rb_);
- THCudaTensor_free(state, a_);
-#else
- THError(NoMagma(potrs));
-#endif
-}
-
void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a_)
{
#ifdef USE_MAGMA
diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu
index 36e7e0d..41e7569 100644
--- a/lib/THC/generic/THCTensorMathMagma.cu
+++ b/lib/THC/generic/THCTensorMathMagma.cu
@@ -401,6 +401,137 @@ THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a)
#endif
}
+__global__ void THCTensor_(copyUpperSymmetric)(real *input, int n, int len)
+{
+ for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) {
+ const int r = idx % n;
+ const int c = idx / n;
+ if (r > c) {
+ input[idx] = input[r*n + c];
+ }
+ }
+}
+
+__global__ void THCTensor_(copyLowerSymmetric)(real *input, int n, int len)
+{
+ for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < len; idx += 65535) {
+ const int r = idx % n;
+ const int c = idx / n;
+ if (r < c) {
+ input[idx] = input[r*n + c];
+ }
+ }
+}
+
+void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo)
+{
+#ifdef USE_MAGMA
+ THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
+ THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
+
+ int n = a->size[0];
+ magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower;
+
+ THCTensor *input = THCTensor_(newColumnMajor)(state, ra_, a);
+ real *input_data = THCTensor_(data)(state, input);
+
+ int info;
+#if defined(THC_REAL_IS_FLOAT)
+ magma_spotri_gpu(ul, n, input_data, n, &info);
+#else
+ magma_dpotri_gpu(ul, n, input_data, n, &info);
+#endif
+
+ if (info > 0)
+ THError("MAGMA potri : A(%d,%d) is 0, A cannot be factorized", info, info);
+ else if (info < 0)
+ THError("MAGMA potri : Argument %d : illegal value", -info);
+
+ cudaStream_t stream = THCState_getCurrentStream(state);
+ const int len = n*n;
+ dim3 blocks(std::min(DIVUP(len, 128), 65535));
+ dim3 threads(128);
+ if (uplo[0] == 'U') {
+ THCTensor_(copyUpperSymmetric)<<<blocks, threads, 0, stream>>>(input_data, n, len);
+ } else {
+ THCTensor_(copyLowerSymmetric)<<<blocks, threads, 0, stream>>>(input_data, n, len);
+ }
+
+ THCTensor_(freeCopyTo)(state, input, ra_);
+#else
+ THError(NoMagma(potri));
+#endif
+}
+
+void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo)
+{
+#ifdef USE_MAGMA
+ THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
+ THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
+
+ int n = a->size[0];
+ magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower;
+
+ THCTensor *input = THCTensor_(newColumnMajor)(state, ra_, a);
+ real *input_data = THCTensor_(data)(state, input);
+
+ int info;
+#if defined(THC_REAL_IS_FLOAT)
+ magma_spotrf_gpu(ul, n, input_data, n, &info);
+#else
+ magma_dpotrf_gpu(ul, n, input_data, n, &info);
+#endif
+
+ // check error value
+ if (info > 0)
+ THError("MAGMA potrf : A(%d,%d) is 0, A cannot be factorized", info, info);
+ else if (info < 0)
+ THError("MAGMA potrf : Argument %d : illegal value", -info);
+
+ if (uplo[0] == 'U') {
+ THCTensor_(triu)(state, ra_, input, 0);
+ } else {
+ THCTensor_(tril)(state, ra_, input, 0);
+ }
+ THCTensor_(free)(state, input);
+#else
+ THError(NoMagma(potrf));
+#endif
+}
+
+void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *a, const char *uplo)
+{
+#ifdef USE_MAGMA
+ THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
+
+ int n = a->size[0];
+ int nrhs = b->size[1];
+ magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower;
+
+ THCTensor *b_ = THCTensor_(newColumnMajor)(state, rb_, b);
+ real *b_data = THCTensor_(data)(state, b_);
+ THCTensor *a_ = THCTensor_(newColumnMajor)(state, a, a);
+ real *a_data = THCTensor_(data)(state, a_);
+
+ int info;
+#if defined(THC_REAL_IS_FLOAT)
+ magma_spotrs_gpu(ul, n, nrhs, a_data, n, b_data, n, &info);
+#else
+ magma_dpotrs_gpu(ul, n, nrhs, a_data, n, b_data, n, &info);
+#endif
+
+ // check error value
+ if (info < 0)
+ THError("MAGMA potrs : Argument %d : illegal value", -info);
+
+ THCTensor_(freeCopyTo)(state, b_, rb_);
+ THCTensor_(free)(state, a_);
+#else
+ THError(NoMagma(potrs));
+#endif
+}
+
+
#endif
#endif
diff --git a/lib/THC/generic/THCTensorMathMagma.h b/lib/THC/generic/THCTensorMathMagma.h
index 4c8dbe3..ce0ed29 100644
--- a/lib/THC/generic/THCTensorMathMagma.h
+++ b/lib/THC/generic/THCTensorMathMagma.h
@@ -66,6 +66,9 @@ THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, T
THC_API void THCTensor_(gesvd)(THCState *state, THCTensor *ru_, THCTensor *rs_, THCTensor *rv_, THCTensor *a, const char *jobu);
THC_API void THCTensor_(gesvd2)(THCState *state, THCTensor *ru_, THCTensor *rs_, THCTensor *rv_, THCTensor *ra_, THCTensor *a, const char *jobu);
THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a);
+THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
+THC_API void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
+THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *a, THCTensor *b, const char *uplo);
#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
diff --git a/test/test.lua b/test/test.lua
index e766de3..c49e17a 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2328,7 +2328,7 @@ function test.inverse()
for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do
local at = a:type(typename)
local i1 = torch.inverse(at)
- local i2 = torch.inverse(at:cuda())
+ local i2 = torch.inverse(a:cuda())
tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong inverse answer")
end
end
@@ -2439,14 +2439,17 @@ if cutorch.magma then
}
A = A * A:t()
- for _, triarg in ipairs({'U', 'L'}) do
- local chol = torch.potrf(A, triarg)
-
- local i1 = torch.potri(chol, triarg)
- local i2 = torch.potri(chol:cuda(), triarg)
- local M = A:cuda() * i2
- tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong potri answer")
- tester:assertle((M - torch.eye(A:size(1)):cuda()):abs():max(), 1e-5, "potri not an inverse")
+ for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do
+ local at = A:type(typename)
+ for _, triarg in ipairs({'U', 'L'}) do
+ local chol = torch.potrf(at, triarg)
+
+ local i1 = torch.potri(chol, triarg)
+ local i2 = torch.potri(chol:cuda(), triarg)
+ local M = at:cuda() * i2
+ tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong potri answer")
+ tester:assertle((M - torch.eye(at:size(1)):cuda()):abs():max(), 1e-5, "potri not an inverse")
+ end
end
end
@@ -2458,10 +2461,13 @@ if cutorch.magma then
{-0.6738, 0.4734,-1.1123, 2.4071,-1.2756},
{-3.3883, 0.2807, 0.8161,-1.2756, 4.3415},
}
- for _, triarg in ipairs({'U', 'L'}) do
- local i1 = torch.potrf(A, triarg)
- local i2 = torch.potrf(A:cuda(), triarg)
- tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong potrf answer")
+ for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do
+ local at = A:type(typename)
+ for _, triarg in ipairs({'U', 'L'}) do
+ local i1 = torch.potrf(at, triarg)
+ local i2 = torch.potrf(at:cuda(), triarg)
+ tester:assertle((i2 - i1:cuda()):abs():max(), 1e-5, "wrong potrf answer")
+ end
end
end
@@ -2478,11 +2484,15 @@ if cutorch.magma then
{0.2334, 0.8594, 0.4103},
{0.7556, 0.1966, 0.9637},
{0.1420, 0.7185, 0.7476}})
- for _, triarg in ipairs({'U', 'L'}) do
- local chol = torch.potrf(A, triarg)
- local solve1 = torch.potrs(B, chol, triarg)
- local solve2 = torch.potrs(B:cuda(), chol:cuda(), triarg)
- tester:assertle((solve2 - solve1:cuda()):abs():max(), 1e-4, "wrong potrs answer")
+ for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do
+ local at = A:type(typename)
+ local bt = B:type(typename)
+ for _, triarg in ipairs({'U', 'L'}) do
+ local chol = torch.potrf(at, triarg)
+ local solve1 = torch.potrs(bt, chol, triarg)
+ local solve2 = torch.potrs(bt:cuda(), chol:cuda(), triarg)
+ tester:assertle((solve2 - solve1:cuda()):abs():max(), 1e-4, "wrong potrs answer")
+ end
end
end