Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorAlexander Matyasko <alexander.matyasko@gmail.com>2017-05-05 09:06:23 +0300
committerSoumith Chintala <soumith@gmail.com>2017-05-09 02:44:15 +0300
commitcbaf928fcff1e6f4a4217d2b544c34eb649d4a86 (patch)
tree3199e34eb7c16568a827bec4505abf4d2616f8d6 /lib
parent3502535d13adf117fa90ac3c0827c44d0946d158 (diff)
Fix bug in magma qr decomposition and add tests for larger matrices
Diffstat (limited to 'lib')
-rw-r--r--lib/THC/generic/THCTensorMathMagma.cu33
1 files changed, 26 insertions, 7 deletions
diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu
index c3c2dca..c35a83e 100644
--- a/lib/THC/generic/THCTensorMathMagma.cu
+++ b/lib/THC/generic/THCTensorMathMagma.cu
@@ -608,19 +608,42 @@ THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THC
int k = (m < n ? m : n);
#ifdef MAGMA_V2
+#if defined(THC_REAL_IS_FLOAT)
int nb = magma_get_sgeqrf_nb(m, n);
#else
+ int nb = magma_get_dgeqrf_nb(m, n);
+#endif
+#else
+#if defined(THC_REAL_IS_FLOAT)
int nb = magma_get_sgeqrf_nb(m);
+#else
+ int nb = magma_get_dgeqrf_nb(m);
+#endif
#endif
real *a_data = THCTensor_(data)(state, a);
- real *tau_data = th_magma_malloc_pinned<real>(n*n);
-
- THCTensor *work = THCTensor_(newWithSize1d)(state, (2*k + ((n+31)/32)*32)*nb);
+ real *tau_data = th_magma_malloc_pinned<real>(k);
+ THCTensor *work = THCTensor_(newWithSize1d)(state, (2*k + magma_roundup(n, 32))*nb);
real *work_data = THCTensor_(data)(state, work);
int info;
#if defined(THC_REAL_IS_FLOAT)
+ magma_sgeqrf2_gpu(m, n, a_data, m, tau_data, &info);
+#else
+ magma_dgeqrf2_gpu(m, n, a_data, m, tau_data, &info);
+#endif
+
+ if (info != 0)
+ THError("MAGMA geqrf2 : Argument %d : illegal value.", -info);
+
+ THCTensor_(narrow)(state, a, a, 0, 0, k);
+ THCTensor_(triu)(state, rr_, a, 0);
+ THCTensor_(free)(state, a);
+
+ a = THCTensor_(newColumnMajor)(state, rq_, a_);
+ a_data = THCTensor_(data)(state, a);
+
+#if defined(THC_REAL_IS_FLOAT)
magma_sgeqrf_gpu(m, n, a_data, m, tau_data, work_data, &info);
#else
magma_dgeqrf_gpu(m, n, a_data, m, tau_data, work_data, &info);
@@ -632,10 +655,6 @@ THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THC
THCTensor *q = THCTensor_(newColumnMajor)(state, rq_, a);
real *q_data = THCTensor_(data)(state, q);
- THCTensor_(narrow)(state, a, a, 0, 0, k);
- THCTensor_(triu)(state, rr_, a, 0);
- THCTensor_(free)(state, a);
-
#if defined(THC_REAL_IS_FLOAT)
magma_sorgqr_gpu(m, k, k, q_data, m, tau_data, work_data, nb, &info);
#else