diff options
author | Trevor Killeen <killeentm@gmail.com> | 2016-11-15 23:13:46 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2016-11-16 00:31:32 +0300 |
commit | 5f1b19d98f47966b71d411e3bc1262d58ed6cc5d (patch) | |
tree | 678a8acc8e418f67067c9fba531778d53f09c863 | |
parent | d0eb61548f948b07cfcb0b4aaa3a89778eadbea9 (diff) |
[cutorch mag2gen] move gels to generic
-rw-r--r-- | TensorMath.lua | 2 | ||||
-rw-r--r-- | lib/THC/THCTensorMath.h | 1 | ||||
-rw-r--r-- | lib/THC/THCTensorMathMagma.cu | 35 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathMagma.cu | 45 | ||||
-rw-r--r-- | lib/THC/generic/THCTensorMathMagma.h | 1 | ||||
-rw-r--r-- | test/test.lua | 12 |
6 files changed, 55 insertions, 41 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 56e4452..70c28ca 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -1175,7 +1175,7 @@ for k, Tensor_ in pairs(handledTypenames) do if real == 'float' or real == 'double' then - for _,name in ipairs({"gesv"}) do + for _,name in ipairs({"gesv", "gels"}) do wrap(name, cname(name), {{name=Tensor, returned=true}, diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 2f032cf..fc70dc2 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -45,7 +45,6 @@ #include "THCGenerateAllTypes.h" // MAGMA (i.e. CUDA implementation of LAPACK functions) -THC_API void THCudaTensor_gels(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_); THC_API void THCudaTensor_syev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobz, const char *uplo); THC_API void THCudaTensor_geev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobvr); THC_API void THCudaTensor_gesvd(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_, THCudaTensor *rv_, THCudaTensor *a, const char *jobu); diff --git a/lib/THC/THCTensorMathMagma.cu b/lib/THC/THCTensorMathMagma.cu index 82afa66..362fc2f 100644 --- a/lib/THC/THCTensorMathMagma.cu +++ b/lib/THC/THCTensorMathMagma.cu @@ -23,41 +23,6 @@ void THCMagma_init(THCState *state) #endif } -void THCudaTensor_gels(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_) -{ -#ifdef USE_MAGMA - THArgCheck(a_->nDimension == 2, 1, "A should be 2 dimensional"); - THArgCheck(b_->nDimension == 2, 1, "b should be 2 dimensional"); - THArgCheck(a_->size[0] == b_->size[0], 2, "size incompatible A,b"); - THArgCheck(a_->size[0] >= a_->size[1], 2, "A should have m >= n"); - - THCudaTensor *a = THCudaTensor_newColumnMajor(state, ra_, a_); - THCudaTensor *b = THCudaTensor_newColumnMajor(state, rb_, b_); - float *a_data = THCudaTensor_data(state, a); - float *b_data = THCudaTensor_data(state, b); - - int m = a->size[0]; - int n = a->size[1]; - int nrhs = b->size[1]; - float wkopt; - - int info; - magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info); - - float *hwork = th_magma_malloc_pinned<float>((size_t)wkopt); - magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info); - magma_free_pinned(hwork); - - if (info != 0) - THError("MAGMA gels : Argument %d : illegal value", -info); - - THCudaTensor_freeCopyTo(state, a, ra_); - THCudaTensor_freeCopyTo(state, b, rb_); -#else - THError(NoMagma(gels)); -#endif -} - void THCudaTensor_syev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a, const char *jobzs, const char *uplos) { #ifdef USE_MAGMA diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu index c9358cc..75b8810 100644 --- a/lib/THC/generic/THCTensorMathMagma.cu +++ b/lib/THC/generic/THCTensorMathMagma.cu @@ -42,6 +42,51 @@ THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, T #endif } +void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_) +{ +#ifdef USE_MAGMA + THArgCheck(a_->nDimension == 2, 1, "A should be 2 dimensional"); + THArgCheck(b_->nDimension == 2, 1, "b should be 2 dimensional"); + THArgCheck(a_->size[0] == b_->size[0], 2, "size incompatible A,b"); + THArgCheck(a_->size[0] >= a_->size[1], 2, "A should have m >= n"); + + THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_); + THCTensor *b = THCTensor_(newColumnMajor)(state, rb_, b_); + real *a_data = THCTensor_(data)(state, a); + real *b_data = THCTensor_(data)(state, b); + + int m = a->size[0]; + int n = a->size[1]; + int nrhs = b->size[1]; + real wkopt; + + int info; +#if defined(THC_REAL_IS_FLOAT) + magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info); +#else + magma_dgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info); +#endif + + real *hwork = th_magma_malloc_pinned<real>((size_t)wkopt); + +#if defined(THC_REAL_IS_FLOAT) + magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info); +#else + magma_dgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info); +#endif + + magma_free_pinned(hwork); + + if (info != 0) + THError("MAGMA gels : Argument %d : illegal value", -info); + + THCTensor_(freeCopyTo)(state, a, ra_); + THCTensor_(freeCopyTo)(state, b, rb_); +#else + THError(NoMagma(gels)); +#endif +} + #endif #endif diff --git a/lib/THC/generic/THCTensorMathMagma.h b/lib/THC/generic/THCTensorMathMagma.h index e0f2bd4..b20061f 100644 --- a/lib/THC/generic/THCTensorMathMagma.h +++ b/lib/THC/generic/THCTensorMathMagma.h @@ -60,6 +60,7 @@ static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, T } THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); +THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); #endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) diff --git a/test/test.lua b/test/test.lua index 3f2f66d..25e2416 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2359,10 +2359,14 @@ if cutorch.magma then { 0.5360, 0.2048, 0.2745}, { 0.8535,-0.3938,-0.2140}, } - local rb1, ra1 = torch.gels(b, a) - local rb2, ra2 = torch.gels(b:cuda(), a:cuda()) - tester:assertle((rb2 - rb1:cuda()):abs():max(), 5e-4, "wrong gels answer") - tester:assertle((ra2 - ra1:cuda()):abs():max(), 5e-4, "wrong gels answer") + for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do + local at = a:type(typename) + local bt = b:type(typename) + local rb1, ra1 = torch.gels(bt, at) + local rb2, ra2 = torch.gels(bt:cuda(), at:cuda()) + tester:assertle((rb2 - rb1:cuda()):abs():max(), 5e-4, "wrong gels answer") + tester:assertle((ra2 - ra1:cuda()):abs():max(), 5e-4, "wrong gels answer") + end end function test.symeig() |