Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-16 00:14:46 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-16 00:31:57 +0300
commit610a7905c9e571015c8189aebdffea7202819786 (patch)
tree9fc3d6b81d8ed5ad8b9dfde02e9212d41c5b46d5
parent7250cc589f279bf0b3dd61563c7ce8087e7e63c6 (diff)
[cutorch mag2gen] move qr to generic
-rw-r--r--TensorMath.lua10
-rw-r--r--lib/THC/THCTensorMath.h4
-rw-r--r--lib/THC/THCTensorMathMagma.cu50
-rw-r--r--lib/THC/generic/THCTensorMathMagma.cu63
-rw-r--r--lib/THC/generic/THCTensorMathMagma.h3
-rw-r--r--test/test.lua11
6 files changed, 80 insertions, 61 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 0db19a7..61cd4e9 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -1269,6 +1269,16 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor},
{name='charoption', values={'U', 'L'}, default='U'}})
+ wrap("qr",
+ cname("qr"),
+ {{name=Tensor, returned=true},
+ {name=Tensor, returned=true},
+ {name=Tensor}},
+ cname("qr"),
+ {{name=Tensor, default=true, returned=true, invisible=true},
+ {name=Tensor, default=true, returned=true, invisible=true},
+ {name=Tensor}})
+
end
wrap("dot",
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 0850e3c..0b9ddb2 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -44,10 +44,6 @@
#include "generic/THCTensorSort.h"
#include "THCGenerateAllTypes.h"
-// MAGMA (i.e. CUDA implementation of LAPACK functions)
-THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a);
-
-
THC_API int THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self);
THC_API int THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self);
diff --git a/lib/THC/THCTensorMathMagma.cu b/lib/THC/THCTensorMathMagma.cu
index 7edcae9..cac5d73 100644
--- a/lib/THC/THCTensorMathMagma.cu
+++ b/lib/THC/THCTensorMathMagma.cu
@@ -23,55 +23,5 @@ void THCMagma_init(THCState *state)
#endif
}
-void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a_)
-{
-#ifdef USE_MAGMA
- THArgCheck(a_->nDimension == 2, 2, "A should be 2 dimensional");
-
- THCudaTensor *a = THCudaTensor_newColumnMajor(state, rr_, a_);
- int m = a->size[0];
- int n = a->size[1];
- int k = (m < n ? m : n);
-
-#ifdef MAGMA_V2
- int nb = magma_get_sgeqrf_nb(m, n);
-#else
- int nb = magma_get_sgeqrf_nb(m);
-#endif
-
- float *a_data = THCudaTensor_data(state, a);
- float *tau_data = th_magma_malloc_pinned<float>(n*n);
-
- THCudaTensor *work = THCudaTensor_newWithSize1d(state, (2*k + ((n+31)/32)*32)*nb);
- float *work_data = THCudaTensor_data(state, work);
-
- int info;
- magma_sgeqrf_gpu(m, n, a_data, m, tau_data, work_data, &info);
-
- if (info != 0)
- THError("MAGMA geqrf : Argument %d : illegal value.", -info);
-
- THCudaTensor *q = THCudaTensor_newColumnMajor(state, rq_, a);
- float *q_data = THCudaTensor_data(state, q);
-
- THCudaTensor_narrow(state, a, a, 0, 0, k);
- THCudaTensor_triu(state, rr_, a, 0);
- THCudaTensor_free(state, a);
-
- magma_sorgqr_gpu(m, n, k, q_data, m, tau_data, work_data, nb, &info);
-
- if (info != 0)
- THError("MAGMA orgqr : Argument %d : illegal value.", -info);
-
- THCudaTensor_free(state, work);
- magma_free_pinned(tau_data);
-
- THCudaTensor_narrow(state, q, q, 1, 0, k);
- THCudaTensor_freeCopyTo(state, q, rq_);
-#else
- THError(NoMagma(qr));
-#endif
-}
-
#include "generic/THCTensorMathMagma.cu"
#include "THCGenerateAllTypes.h"
diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu
index 41e7569..d874259 100644
--- a/lib/THC/generic/THCTensorMathMagma.cu
+++ b/lib/THC/generic/THCTensorMathMagma.cu
@@ -423,7 +423,7 @@ __global__ void THCTensor_(copyLowerSymmetric)(real *input, int n, int len)
}
}
-void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo)
+THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo)
{
#ifdef USE_MAGMA
THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
@@ -463,7 +463,7 @@ void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char
#endif
}
-void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo)
+THC_API void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo)
{
#ifdef USE_MAGMA
THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
@@ -499,7 +499,7 @@ void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char
#endif
}
-void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *a, const char *uplo)
+THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *a, const char *uplo)
{
#ifdef USE_MAGMA
THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
@@ -531,6 +531,63 @@ void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor
#endif
}
+THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a_)
+{
+#ifdef USE_MAGMA
+ THArgCheck(a_->nDimension == 2, 2, "A should be 2 dimensional");
+
+ THCTensor *a = THCTensor_(newColumnMajor)(state, rr_, a_);
+ int m = a->size[0];
+ int n = a->size[1];
+ int k = (m < n ? m : n);
+
+#ifdef MAGMA_V2
+ int nb = magma_get_sgeqrf_nb(m, n);
+#else
+ int nb = magma_get_sgeqrf_nb(m);
+#endif
+
+ real *a_data = THCTensor_(data)(state, a);
+ real *tau_data = th_magma_malloc_pinned<real>(n*n);
+
+ THCTensor *work = THCTensor_(newWithSize1d)(state, (2*k + ((n+31)/32)*32)*nb);
+ real *work_data = THCTensor_(data)(state, work);
+
+ int info;
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sgeqrf_gpu(m, n, a_data, m, tau_data, work_data, &info);
+#else
+ magma_dgeqrf_gpu(m, n, a_data, m, tau_data, work_data, &info);
+#endif
+
+ if (info != 0)
+ THError("MAGMA geqrf : Argument %d : illegal value.", -info);
+
+ THCTensor *q = THCTensor_(newColumnMajor)(state, rq_, a);
+ real *q_data = THCTensor_(data)(state, q);
+
+ THCTensor_(narrow)(state, a, a, 0, 0, k);
+ THCTensor_(triu)(state, rr_, a, 0);
+ THCTensor_(free)(state, a);
+
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sorgqr_gpu(m, n, k, q_data, m, tau_data, work_data, nb, &info);
+#else
+ magma_dorgqr_gpu(m, n, k, q_data, m, tau_data, work_data, nb, &info);
+#endif
+
+ if (info != 0)
+ THError("MAGMA orgqr : Argument %d : illegal value.", -info);
+
+ THCTensor_(free)(state, work);
+ magma_free_pinned(tau_data);
+
+ THCTensor_(narrow)(state, q, q, 1, 0, k);
+ THCTensor_(freeCopyTo)(state, q, rq_);
+#else
+ THError(NoMagma(qr));
+#endif
+}
#endif
diff --git a/lib/THC/generic/THCTensorMathMagma.h b/lib/THC/generic/THCTensorMathMagma.h
index ce0ed29..364a8a7 100644
--- a/lib/THC/generic/THCTensorMathMagma.h
+++ b/lib/THC/generic/THCTensorMathMagma.h
@@ -59,6 +59,7 @@ static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, T
return self;
}
+// MAGMA (i.e. CUDA implementation of LAPACK functions)
THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobz, const char *uplo);
@@ -69,6 +70,8 @@ THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a);
THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
THC_API void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *a, THCTensor *b, const char *uplo);
+THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a);
+
#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
diff --git a/test/test.lua b/test/test.lua
index c49e17a..c508d5d 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2504,10 +2504,13 @@ if cutorch.magma then
{-0.2987, 1.9035, -1.4192, -0.9738, 1.4384},
{-0.5315, 0.4958, 0.4449, -0.4676, -0.4878},
}
- local q1,r1 = torch.qr(A)
- local q2,r2 = torch.qr(A:cuda())
- tester:assertle((q2 - q1:cuda()):abs():max(), 1e-5, "wrong qr answer")
- tester:assertle((r2 - r1:cuda()):abs():max(), 1e-5, "wrong qr answer")
+ for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do
+ local at = A:type(typename)
+ local q1,r1 = torch.qr(at)
+ local q2,r2 = torch.qr(at:cuda())
+ tester:assertle((q2 - q1:cuda()):abs():max(), 1e-5, "wrong qr answer")
+ tester:assertle((r2 - r1:cuda()):abs():max(), 1e-5, "wrong qr answer")
+ end
end
end