diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-11-15 23:30:52 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-11-16 00:31:32 +0300 |
commit | dfcdce1c7769a3637cebf6bbbd78f5c5b50f9e98 (patch) | |
tree | beb265016f63620679d1f399ed7aca439cb59bbd | |
parent | 5f1b19d98f47966b71d411e3bc1262d58ed6cc5d (diff) |
[cutorch mag2gen] move symeig to generic
-rw-r--r-- | TensorMath.lua | 13 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMathMagma.cu | 49 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathMagma.cu | 61 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathMagma.h | 1 | ||||
-rw-r--r-- | test/test.lua | 11 |
6 files changed, 81 insertions, 55 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 70c28ca..a67b101 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -1189,6 +1189,19 @@ for k, Tensor_ in pairs(handledTypenames) do {name=Tensor}}) end + wrap("symeig", + cname("syev"), + {{name=Tensor, returned=true}, + {name=Tensor, returned=true}, + {name=Tensor}, + {name='charoption', values={'N', 'V'}, default='N'}, + {name='charoption', values={'U', 'L'}, default='U'}}, + cname("syev"), + {{name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor}, + {name='charoption', values={'N', 'V'}, default='N'}, + {name='charoption', values={'U', 'L'}, default='U'}}) end wrap("dot", diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index fc70dc2..32e18cf 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -45,7 +45,6 @@ #include "THCGenerateAllTypes.h" // MAGMA (i.e. CUDA implementation of LAPACK functions) -THC_API void THCudaTensor_syev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobz, const char *uplo); THC_API void THCudaTensor_geev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobvr); THC_API void THCudaTensor_gesvd(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_, THCudaTensor *rv_, THCudaTensor *a, const char *jobu); THC_API void THCudaTensor_gesvd2(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_, THCudaTensor *rv_, THCudaTensor *ra_, THCudaTensor *a, const char *jobu); diff --git a/lib/THC/THCTensorMathMagma.cu b/lib/THC/THCTensorMathMagma.cu index 362fc2f..029811e 100644 --- a/lib/THC/THCTensorMathMagma.cu +++ b/lib/THC/THCTensorMathMagma.cu @@ -23,55 +23,6 @@ void THCMagma_init(THCState *state) #endif } -void THCudaTensor_syev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a, const char *jobzs, const char *uplos) -{ -#ifdef USE_MAGMA - int n = a->size[0]; - int lda = n; - - magma_uplo_t uplo = uplos[0] == 'U' ? MagmaUpper : MagmaLower; - magma_vec_t jobz = jobzs[0] == 'N' ? MagmaNoVec : MagmaVec; - - THCudaTensor *input = THCudaTensor_newColumnMajor(state, rv_, a); - float *input_data = THCudaTensor_data(state, input); - - // eigen values and workspace - float *w = th_magma_malloc_pinned<float>(n); - float *wA = th_magma_malloc_pinned<float>(lda); - - // compute optimal size of work array - int info; - float lwork; - int liwork; - magma_ssyevd_gpu(jobz, uplo, n, input_data, lda, w, wA, n, &lwork, -1, &liwork, -1, &info); - - float *work = th_magma_malloc_pinned<float>((size_t)lwork); - int *iwork = th_magma_malloc_pinned<int>(liwork); - - // compute eigenvalues and, optionally, eigenvectors - magma_ssyevd_gpu(jobz, uplo, n, input_data, lda, w, wA, n, work, (int) lwork, iwork, liwork, &info); - - // copy eigen values from w to re_ - if (info == 0) - THCudaTensor_copyArray1d(state, re_, w, n); - - magma_free_pinned(iwork); - magma_free_pinned(work); - magma_free_pinned(wA); - magma_free_pinned(w); - - // check error value - if (info > 0) - THError("MAGMA syev : Failed to converge. %d off-diagonal elements of an didn't converge to zero", info); - else if (info < 0) - THError("MAGMA syev : Argument %d : illegal value", -info); - - THCudaTensor_freeCopyTo(state, input, rv_); -#else - THError(NoMagma(syev)); -#endif -} - void THCudaTensor_geev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobvrs) { #ifdef USE_MAGMA diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu index 75b8810..feab665 100644 --- a/lib/THC/generic/THCTensorMathMagma.cu +++ b/lib/THC/generic/THCTensorMathMagma.cu @@ -42,7 +42,7 @@ THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, T #endif } -void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_) +THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_) { #ifdef USE_MAGMA THArgCheck(a_->nDimension == 2, 1, "A should be 2 dimensional"); @@ -87,6 +87,65 @@ void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor #endif } +THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a, const char *jobzs, const char *uplos) +{ +#ifdef USE_MAGMA + int n = a->size[0]; + int lda = n; + + magma_uplo_t uplo = uplos[0] == 'U' ? MagmaUpper : MagmaLower; + magma_vec_t jobz = jobzs[0] == 'N' ? MagmaNoVec : MagmaVec; + + THCTensor *input = THCTensor_(newColumnMajor)(state, rv_, a); + real *input_data = THCTensor_(data)(state, input); + + // eigen values and workspace + real *w = th_magma_malloc_pinned<real>(n); + real *wA = th_magma_malloc_pinned<real>(lda); + + // compute optimal size of work array + int info; + real lwork; + int liwork; + +#if defined(THC_REAL_IS_FLOAT) + magma_ssyevd_gpu(jobz, uplo, n, input_data, lda, w, wA, n, &lwork, -1, &liwork, -1, &info); +#else + magma_dsyevd_gpu(jobz, uplo, n, input_data, lda, w, wA, n, &lwork, -1, &liwork, -1, &info); +#endif + + real *work = th_magma_malloc_pinned<real>((size_t)lwork); + int *iwork = th_magma_malloc_pinned<int>(liwork); + + // compute eigenvalues and, optionally, eigenvectors +#if defined(THC_REAL_IS_FLOAT) + magma_ssyevd_gpu(jobz, uplo, n, input_data, lda, w, wA, n, work, (int) lwork, iwork, liwork, &info); +#else + magma_dsyevd_gpu(jobz, uplo, n, input_data, lda, w, wA, n, work, (int) lwork, iwork, liwork, &info); +#endif + + // copy eigen values from w to re_ + if (info == 0) + THCTensor_(copyArray1d)(state, re_, w, n); + + magma_free_pinned(iwork); + magma_free_pinned(work); + magma_free_pinned(wA); + magma_free_pinned(w); + + // check error value + if (info > 0) + THError("MAGMA syev : Failed to converge. %d off-diagonal elements of an didn't converge to zero", info); + else if (info < 0) + THError("MAGMA syev : Argument %d : illegal value", -info); + + THCTensor_(freeCopyTo)(state, input, rv_); +#else + THError(NoMagma(syev)); +#endif +} + + #endif #endif diff --git a/lib/THC/generic/THCTensorMathMagma.h b/lib/THC/generic/THCTensorMathMagma.h index b20061f..c09a7bb 100644 --- a/lib/THC/generic/THCTensorMathMagma.h +++ b/lib/THC/generic/THCTensorMathMagma.h @@ -61,6 +61,7 @@ static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, T THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); +THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobz, const char *uplo); #endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) diff --git a/test/test.lua b/test/test.lua index 25e2416..c7d1c52 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2375,10 +2375,13 @@ if cutorch.magma then {-0.47, -6.39, 4.17, 0.00, 0.00}, {-7.20, 1.50, -1.51, 5.70, 0.00}, {-0.65, -6.34, 2.67, 1.80, -7.10}}):t() - local e1,v1 = torch.symeig(a, 'V') - local e2,v2 = torch.symeig(a:cuda(), 'V') - tester:assertle((e2 - e1:cuda()):abs():max(), 1e-5, "wrong symeig answer") - tester:assertle((v2 - v1:cuda()):abs():max(), 1e-5, "wrong symeig answer") + for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do + local at = a:type(typename) + local e1,v1 = torch.symeig(at, 'V') + local e2,v2 = torch.symeig(at:cuda(), 'V') + tester:assertle((e2 - e1:cuda()):abs():max(), 1e-5, "wrong symeig answer") + tester:assertle((v2 - v1:cuda()):abs():max(), 1e-5, "wrong symeig answer") + end end function test.eig() |