diff options
author | Alexander Matyasko <alexander.matyasko@gmail.com> | 2017-05-05 09:06:23 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-05-09 02:44:15 +0300 |
commit | cbaf928fcff1e6f4a4217d2b544c34eb649d4a86 (patch) | |
tree | 3199e34eb7c16568a827bec4505abf4d2616f8d6 /lib | |
parent | 3502535d13adf117fa90ac3c0827c44d0946d158 (diff) |
Fix bug in magma qr decomposition and add tests for larger matrices
Diffstat (limited to 'lib')
-rw-r--r-- | lib/THC/generic/THCTensorMathMagma.cu | 33 |
1 files changed, 26 insertions, 7 deletions
diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu index c3c2dca..c35a83e 100644 --- a/lib/THC/generic/THCTensorMathMagma.cu +++ b/lib/THC/generic/THCTensorMathMagma.cu @@ -608,19 +608,42 @@ THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THC int k = (m < n ? m : n); #ifdef MAGMA_V2 +#if defined(THC_REAL_IS_FLOAT) int nb = magma_get_sgeqrf_nb(m, n); #else + int nb = magma_get_dgeqrf_nb(m, n); +#endif +#else +#if defined(THC_REAL_IS_FLOAT) int nb = magma_get_sgeqrf_nb(m); +#else + int nb = magma_get_dgeqrf_nb(m); +#endif #endif real *a_data = THCTensor_(data)(state, a); - real *tau_data = th_magma_malloc_pinned<real>(n*n); - - THCTensor *work = THCTensor_(newWithSize1d)(state, (2*k + ((n+31)/32)*32)*nb); + real *tau_data = th_magma_malloc_pinned<real>(k); + THCTensor *work = THCTensor_(newWithSize1d)(state, (2*k + magma_roundup(n, 32))*nb); real *work_data = THCTensor_(data)(state, work); int info; #if defined(THC_REAL_IS_FLOAT) + magma_sgeqrf2_gpu(m, n, a_data, m, tau_data, &info); +#else + magma_dgeqrf2_gpu(m, n, a_data, m, tau_data, &info); +#endif + + if (info != 0) + THError("MAGMA geqrf2 : Argument %d : illegal value.", -info); + + THCTensor_(narrow)(state, a, a, 0, 0, k); + THCTensor_(triu)(state, rr_, a, 0); + THCTensor_(free)(state, a); + + a = THCTensor_(newColumnMajor)(state, rq_, a_); + a_data = THCTensor_(data)(state, a); + +#if defined(THC_REAL_IS_FLOAT) magma_sgeqrf_gpu(m, n, a_data, m, tau_data, work_data, &info); #else magma_dgeqrf_gpu(m, n, a_data, m, tau_data, work_data, &info); @@ -632,10 +655,6 @@ THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THC THCTensor *q = THCTensor_(newColumnMajor)(state, rq_, a); real *q_data = THCTensor_(data)(state, q); - THCTensor_(narrow)(state, a, a, 0, 0, k); - THCTensor_(triu)(state, rr_, a, 0); - THCTensor_(free)(state, a); - #if defined(THC_REAL_IS_FLOAT) magma_sorgqr_gpu(m, k, k, q_data, m, tau_data, work_data, nb, &info); #else |