diff options
author | Alexander Matyasko <alexander.matyasko@gmail.com> | 2017-04-26 13:14:30 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-04-28 15:03:39 +0300 |
commit | 1442d41436fe324a948e5315d4ada25e4422a905 (patch) | |
tree | 217be0c8ddc3bd786dc91a4cb47c5bf5b805e614 | |
parent | 0882670e24f2e7d0dfd307ba608999390b9df706 (diff) |
Change magma_sgesvd to magma_sgesdd which is significantly faster
-rw-r--r-- | lib/THC/generic/THCTensorMathMagma.cu | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu index 2d07102..c3c2dca 100644 --- a/lib/THC/generic/THCTensorMathMagma.cu +++ b/lib/THC/generic/THCTensorMathMagma.cu @@ -286,14 +286,14 @@ THC_API void THCTensor_(gesvd2)(THCState *state, THCTensor *ru_, THCTensor *rs_, #ifdef USE_MAGMA THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional"); - magma_vec_t jobu = jobus[0] == 'A' ? MagmaAllVec : jobus[0] == 'S' ? MagmaSomeVec : jobus[0] == 'O' ? MagmaOverwriteVec : MagmaNoVec; - magma_vec_t jobvt = jobu; + magma_vec_t jobz = jobus[0] == 'A' ? MagmaAllVec : jobus[0] == 'S' ? MagmaSomeVec : jobus[0] == 'O' ? MagmaOverwriteVec : MagmaNoVec; + int iunused[1]; int m = a->size[0]; int n = a->size[1]; int k = m < n ? m : n; - int j = (jobu == MagmaAllVec) ? m : k; - int jv = (jobvt == MagmaAllVec) ? n : k; + int j = (jobz == MagmaAllVec) ? m : k; + int jv = (jobz == MagmaAllVec) ? n : k; real *a_data = th_magma_malloc_pinned<real>(m * n); THCTensor_(copyTensor2d)(state, a_data, a); @@ -306,34 +306,36 @@ THC_API void THCTensor_(gesvd2)(THCState *state, THCTensor *ru_, THCTensor *rs_, int info; #if defined(THC_REAL_IS_FLOAT) - magma_sgesvd(jobu, jobvt, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, &wkopt, -1, &info); + magma_sgesdd(jobz, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, &wkopt, -1, iunused, &info); #else - magma_dgesvd(jobu, jobvt, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, &wkopt, -1, &info); + magma_dgesdd(jobz, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, &wkopt, -1, iunused, &info); #endif int lwork = (int) wkopt; real *work_data = th_magma_malloc_pinned<real>(lwork); + int *iwork = th_magma_malloc_pinned<int>(8 * k); #if defined(THC_REAL_IS_FLOAT) - magma_sgesvd(jobu, jobvt, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, work_data, lwork, &info); + magma_sgesdd(jobz, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, work_data, lwork, iwork, &info); #else - magma_dgesvd(jobu, jobvt, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, work_data, lwork, &info); + magma_dgesdd(jobz, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, work_data, lwork, iwork, &info); #endif if (info > 0) - THError("MAGMA gesvd : %d superdiagonals failed to converge", info); + THError("MAGMA gesdd : the updating process of SBDSDC did not converge (error: %d)", info); else if (info < 0) - THError("MAGMA gesvd : Argument %d : illegal value", -info); + THError("MAGMA gesdd : Argument %d : illegal value", -info); THCTensor_(copyArray2d)(state, rv_, rv_data, n, n); THCTensor_(transpose)(state, rv_, NULL, 0, 1); - if (jobvt != MagmaAllVec) + if (jobz != MagmaAllVec) THCTensor_(narrow)(state, rv_, rv_, 1, 0, jv); THCTensor_(copyArray2d)(state, ru_, ru_data, m, j); THCTensor_(copyArray1d)(state, rs_, rs_data, k); THCTensor_(copyArray2d)(state, ra_, a_data, m, n); magma_free_pinned(work_data); + magma_free_pinned(iwork); magma_free_pinned(rv_data); magma_free_pinned(ru_data); magma_free_pinned(rs_data); |