diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-11-16 00:14:46 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-11-16 00:31:57 +0300 |
commit | 610a7905c9e571015c8189aebdffea7202819786 (patch) | |
tree | 9fc3d6b81d8ed5ad8b9dfde02e9212d41c5b46d5 | |
parent | 7250cc589f279bf0b3dd61563c7ce8087e7e63c6 (diff) |
[cutorch mag2gen] move qr to generic
-rw-r--r-- | TensorMath.lua | 10 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 4 | ||||
-rw-r--r-- | lib/THC/THCTensorMathMagma.cu | 50 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathMagma.cu | 63 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathMagma.h | 3 | ||||
-rw-r--r-- | test/test.lua | 11 |
6 files changed, 80 insertions, 61 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 0db19a7..61cd4e9 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -1269,6 +1269,16 @@ for k, Tensor_ in pairs(handledTypenames) do {name=Tensor}, {name='charoption', values={'U', 'L'}, default='U'}}) + wrap("qr", + cname("qr"), + {{name=Tensor, returned=true}, + {name=Tensor, returned=true}, + {name=Tensor}}, + cname("qr"), + {{name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor, default=true, returned=true, invisible=true}, + {name=Tensor}}) + end wrap("dot", diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 0850e3c..0b9ddb2 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -44,10 +44,6 @@ #include "generic/THCTensorSort.h" #include "THCGenerateAllTypes.h" -// MAGMA (i.e. CUDA implementation of LAPACK functions) -THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a); - - THC_API int THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self); THC_API int THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self); diff --git a/lib/THC/THCTensorMathMagma.cu b/lib/THC/THCTensorMathMagma.cu index 7edcae9..cac5d73 100644 --- a/lib/THC/THCTensorMathMagma.cu +++ b/lib/THC/THCTensorMathMagma.cu @@ -23,55 +23,5 @@ void THCMagma_init(THCState *state) #endif } -void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a_) -{ -#ifdef USE_MAGMA - THArgCheck(a_->nDimension == 2, 2, "A should be 2 dimensional"); - - THCudaTensor *a = THCudaTensor_newColumnMajor(state, rr_, a_); - int m = a->size[0]; - int n = a->size[1]; - int k = (m < n ? m : n); - -#ifdef MAGMA_V2 - int nb = magma_get_sgeqrf_nb(m, n); -#else - int nb = magma_get_sgeqrf_nb(m); -#endif - - float *a_data = THCudaTensor_data(state, a); - float *tau_data = th_magma_malloc_pinned<float>(n*n); - - THCudaTensor *work = THCudaTensor_newWithSize1d(state, (2*k + ((n+31)/32)*32)*nb); - float *work_data = THCudaTensor_data(state, work); - - int info; - magma_sgeqrf_gpu(m, n, a_data, m, tau_data, work_data, &info); - - if (info != 0) - THError("MAGMA geqrf : Argument %d : illegal value.", -info); - - THCudaTensor *q = THCudaTensor_newColumnMajor(state, rq_, a); - float *q_data = THCudaTensor_data(state, q); - - THCudaTensor_narrow(state, a, a, 0, 0, k); - THCudaTensor_triu(state, rr_, a, 0); - THCudaTensor_free(state, a); - - magma_sorgqr_gpu(m, n, k, q_data, m, tau_data, work_data, nb, &info); - - if (info != 0) - THError("MAGMA orgqr : Argument %d : illegal value.", -info); - - THCudaTensor_free(state, work); - magma_free_pinned(tau_data); - - THCudaTensor_narrow(state, q, q, 1, 0, k); - THCudaTensor_freeCopyTo(state, q, rq_); -#else - THError(NoMagma(qr)); -#endif -} - #include "generic/THCTensorMathMagma.cu" #include "THCGenerateAllTypes.h" diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu index 41e7569..d874259 100644 --- a/lib/THC/generic/THCTensorMathMagma.cu +++ b/lib/THC/generic/THCTensorMathMagma.cu @@ -423,7 +423,7 @@ __global__ void THCTensor_(copyLowerSymmetric)(real *input, int n, int len) } } -void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo) +THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo) { #ifdef USE_MAGMA THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional"); @@ -463,7 +463,7 @@ void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char #endif } -void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo) +THC_API void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo) { #ifdef USE_MAGMA THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional"); @@ -499,7 +499,7 @@ void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char #endif } -void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *a, const char *uplo) +THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *a, const char *uplo) { #ifdef USE_MAGMA THArgCheck(a->size[0] == a->size[1], 2, "A should be square"); @@ -531,6 +531,63 @@ void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor #endif } +THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a_) +{ +#ifdef USE_MAGMA + THArgCheck(a_->nDimension == 2, 2, "A should be 2 dimensional"); + + THCTensor *a = THCTensor_(newColumnMajor)(state, rr_, a_); + int m = a->size[0]; + int n = a->size[1]; + int k = (m < n ? m : n); + +#ifdef MAGMA_V2 + int nb = magma_get_sgeqrf_nb(m, n); +#else + int nb = magma_get_sgeqrf_nb(m); +#endif + + real *a_data = THCTensor_(data)(state, a); + real *tau_data = th_magma_malloc_pinned<real>(n*n); + + THCTensor *work = THCTensor_(newWithSize1d)(state, (2*k + ((n+31)/32)*32)*nb); + real *work_data = THCTensor_(data)(state, work); + + int info; +#if defined(THC_REAL_IS_FLOAT) + magma_sgeqrf_gpu(m, n, a_data, m, tau_data, work_data, &info); +#else + magma_dgeqrf_gpu(m, n, a_data, m, tau_data, work_data, &info); +#endif + + if (info != 0) + THError("MAGMA geqrf : Argument %d : illegal value.", -info); + + THCTensor *q = THCTensor_(newColumnMajor)(state, rq_, a); + real *q_data = THCTensor_(data)(state, q); + + THCTensor_(narrow)(state, a, a, 0, 0, k); + THCTensor_(triu)(state, rr_, a, 0); + THCTensor_(free)(state, a); + +#if defined(THC_REAL_IS_FLOAT) + magma_sorgqr_gpu(m, n, k, q_data, m, tau_data, work_data, nb, &info); +#else + magma_dorgqr_gpu(m, n, k, q_data, m, tau_data, work_data, nb, &info); +#endif + + if (info != 0) + THError("MAGMA orgqr : Argument %d : illegal value.", -info); + + THCTensor_(free)(state, work); + magma_free_pinned(tau_data); + + THCTensor_(narrow)(state, q, q, 1, 0, k); + THCTensor_(freeCopyTo)(state, q, rq_); +#else + THError(NoMagma(qr)); +#endif +} #endif diff --git a/lib/THC/generic/THCTensorMathMagma.h b/lib/THC/generic/THCTensorMathMagma.h index ce0ed29..364a8a7 100644 --- a/lib/THC/generic/THCTensorMathMagma.h +++ b/lib/THC/generic/THCTensorMathMagma.h @@ -59,6 +59,7 @@ static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, T return self; } +// MAGMA (i.e. CUDA implementation of LAPACK functions) THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobz, const char *uplo); @@ -69,6 +70,8 @@ THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a); THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo); THC_API void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo); THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *a, THCTensor *b, const char *uplo); +THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a); + #endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) diff --git a/test/test.lua b/test/test.lua index c49e17a..c508d5d 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2504,10 +2504,13 @@ if cutorch.magma then {-0.2987, 1.9035, -1.4192, -0.9738, 1.4384}, {-0.5315, 0.4958, 0.4449, -0.4676, -0.4878}, } - local q1,r1 = torch.qr(A) - local q2,r2 = torch.qr(A:cuda()) - tester:assertle((q2 - q1:cuda()):abs():max(), 1e-5, "wrong qr answer") - tester:assertle((r2 - r1:cuda()):abs():max(), 1e-5, "wrong qr answer") + for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do + local at = A:type(typename) + local q1,r1 = torch.qr(at) + local q2,r2 = torch.qr(at:cuda()) + tester:assertle((q2 - q1:cuda()):abs():max(), 1e-5, "wrong qr answer") + tester:assertle((r2 - r1:cuda()):abs():max(), 1e-5, "wrong qr answer") + end end end |