Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexander Matyasko <alexander.matyasko@gmail.com>2017-04-26 13:14:30 +0300
committerSoumith Chintala <soumith@gmail.com>2017-04-28 15:03:39 +0300
commit1442d41436fe324a948e5315d4ada25e4422a905 (patch)
tree217be0c8ddc3bd786dc91a4cb47c5bf5b805e614
parent0882670e24f2e7d0dfd307ba608999390b9df706 (diff)
Change magma_sgesvd to magma_sgesdd which is significantly faster
-rw-r--r--lib/THC/generic/THCTensorMathMagma.cu24
1 files changed, 13 insertions, 11 deletions
diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu
index 2d07102..c3c2dca 100644
--- a/lib/THC/generic/THCTensorMathMagma.cu
+++ b/lib/THC/generic/THCTensorMathMagma.cu
@@ -286,14 +286,14 @@ THC_API void THCTensor_(gesvd2)(THCState *state, THCTensor *ru_, THCTensor *rs_,
#ifdef USE_MAGMA
THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
- magma_vec_t jobu = jobus[0] == 'A' ? MagmaAllVec : jobus[0] == 'S' ? MagmaSomeVec : jobus[0] == 'O' ? MagmaOverwriteVec : MagmaNoVec;
- magma_vec_t jobvt = jobu;
+ magma_vec_t jobz = jobus[0] == 'A' ? MagmaAllVec : jobus[0] == 'S' ? MagmaSomeVec : jobus[0] == 'O' ? MagmaOverwriteVec : MagmaNoVec;
+ int iunused[1];
int m = a->size[0];
int n = a->size[1];
int k = m < n ? m : n;
- int j = (jobu == MagmaAllVec) ? m : k;
- int jv = (jobvt == MagmaAllVec) ? n : k;
+ int j = (jobz == MagmaAllVec) ? m : k;
+ int jv = (jobz == MagmaAllVec) ? n : k;
real *a_data = th_magma_malloc_pinned<real>(m * n);
THCTensor_(copyTensor2d)(state, a_data, a);
@@ -306,34 +306,36 @@ THC_API void THCTensor_(gesvd2)(THCState *state, THCTensor *ru_, THCTensor *rs_,
int info;
#if defined(THC_REAL_IS_FLOAT)
- magma_sgesvd(jobu, jobvt, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, &wkopt, -1, &info);
+ magma_sgesdd(jobz, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, &wkopt, -1, iunused, &info);
#else
- magma_dgesvd(jobu, jobvt, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, &wkopt, -1, &info);
+ magma_dgesdd(jobz, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, &wkopt, -1, iunused, &info);
#endif
int lwork = (int) wkopt;
real *work_data = th_magma_malloc_pinned<real>(lwork);
+ int *iwork = th_magma_malloc_pinned<int>(8 * k);
#if defined(THC_REAL_IS_FLOAT)
- magma_sgesvd(jobu, jobvt, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, work_data, lwork, &info);
+ magma_sgesdd(jobz, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, work_data, lwork, iwork, &info);
#else
- magma_dgesvd(jobu, jobvt, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, work_data, lwork, &info);
+ magma_dgesdd(jobz, m, n, a_data, m, rs_data, ru_data, m, rv_data, n, work_data, lwork, iwork, &info);
#endif
if (info > 0)
- THError("MAGMA gesvd : %d superdiagonals failed to converge", info);
+ THError("MAGMA gesdd : the updating process of SBDSDC did not converge (error: %d)", info);
else if (info < 0)
- THError("MAGMA gesvd : Argument %d : illegal value", -info);
+ THError("MAGMA gesdd : Argument %d : illegal value", -info);
THCTensor_(copyArray2d)(state, rv_, rv_data, n, n);
THCTensor_(transpose)(state, rv_, NULL, 0, 1);
- if (jobvt != MagmaAllVec)
+ if (jobz != MagmaAllVec)
THCTensor_(narrow)(state, rv_, rv_, 1, 0, jv);
THCTensor_(copyArray2d)(state, ru_, ru_data, m, j);
THCTensor_(copyArray1d)(state, rs_, rs_data, k);
THCTensor_(copyArray2d)(state, ra_, a_data, m, n);
magma_free_pinned(work_data);
+ magma_free_pinned(iwork);
magma_free_pinned(rv_data);
magma_free_pinned(ru_data);
magma_free_pinned(rs_data);