Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-15 23:13:46 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-16 00:31:32 +0300
commit5f1b19d98f47966b71d411e3bc1262d58ed6cc5d (patch)
tree678a8acc8e418f67067c9fba531778d53f09c863
parentd0eb61548f948b07cfcb0b4aaa3a89778eadbea9 (diff)
[cutorch mag2gen] move gels to generic
-rw-r--r--TensorMath.lua2
-rw-r--r--lib/THC/THCTensorMath.h1
-rw-r--r--lib/THC/THCTensorMathMagma.cu35
-rw-r--r--lib/THC/generic/THCTensorMathMagma.cu45
-rw-r--r--lib/THC/generic/THCTensorMathMagma.h1
-rw-r--r--test/test.lua12
6 files changed, 55 insertions, 41 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 56e4452..70c28ca 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -1175,7 +1175,7 @@ for k, Tensor_ in pairs(handledTypenames) do
if real == 'float' or real == 'double' then
- for _,name in ipairs({"gesv"}) do
+ for _,name in ipairs({"gesv", "gels"}) do
wrap(name,
cname(name),
{{name=Tensor, returned=true},
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 2f032cf..fc70dc2 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -45,7 +45,6 @@
#include "THCGenerateAllTypes.h"
// MAGMA (i.e. CUDA implementation of LAPACK functions)
-THC_API void THCudaTensor_gels(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_);
THC_API void THCudaTensor_syev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobz, const char *uplo);
THC_API void THCudaTensor_geev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobvr);
THC_API void THCudaTensor_gesvd(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_, THCudaTensor *rv_, THCudaTensor *a, const char *jobu);
diff --git a/lib/THC/THCTensorMathMagma.cu b/lib/THC/THCTensorMathMagma.cu
index 82afa66..362fc2f 100644
--- a/lib/THC/THCTensorMathMagma.cu
+++ b/lib/THC/THCTensorMathMagma.cu
@@ -23,41 +23,6 @@ void THCMagma_init(THCState *state)
#endif
}
-void THCudaTensor_gels(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_)
-{
-#ifdef USE_MAGMA
- THArgCheck(a_->nDimension == 2, 1, "A should be 2 dimensional");
- THArgCheck(b_->nDimension == 2, 1, "b should be 2 dimensional");
- THArgCheck(a_->size[0] == b_->size[0], 2, "size incompatible A,b");
- THArgCheck(a_->size[0] >= a_->size[1], 2, "A should have m >= n");
-
- THCudaTensor *a = THCudaTensor_newColumnMajor(state, ra_, a_);
- THCudaTensor *b = THCudaTensor_newColumnMajor(state, rb_, b_);
- float *a_data = THCudaTensor_data(state, a);
- float *b_data = THCudaTensor_data(state, b);
-
- int m = a->size[0];
- int n = a->size[1];
- int nrhs = b->size[1];
- float wkopt;
-
- int info;
- magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info);
-
- float *hwork = th_magma_malloc_pinned<float>((size_t)wkopt);
- magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info);
- magma_free_pinned(hwork);
-
- if (info != 0)
- THError("MAGMA gels : Argument %d : illegal value", -info);
-
- THCudaTensor_freeCopyTo(state, a, ra_);
- THCudaTensor_freeCopyTo(state, b, rb_);
-#else
- THError(NoMagma(gels));
-#endif
-}
-
void THCudaTensor_syev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a, const char *jobzs, const char *uplos)
{
#ifdef USE_MAGMA
diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu
index c9358cc..75b8810 100644
--- a/lib/THC/generic/THCTensorMathMagma.cu
+++ b/lib/THC/generic/THCTensorMathMagma.cu
@@ -42,6 +42,51 @@ THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, T
#endif
}
+void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_)
+{
+#ifdef USE_MAGMA
+ THArgCheck(a_->nDimension == 2, 1, "A should be 2 dimensional");
+ THArgCheck(b_->nDimension == 2, 1, "b should be 2 dimensional");
+ THArgCheck(a_->size[0] == b_->size[0], 2, "size incompatible A,b");
+ THArgCheck(a_->size[0] >= a_->size[1], 2, "A should have m >= n");
+
+ THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_);
+ THCTensor *b = THCTensor_(newColumnMajor)(state, rb_, b_);
+ real *a_data = THCTensor_(data)(state, a);
+ real *b_data = THCTensor_(data)(state, b);
+
+ int m = a->size[0];
+ int n = a->size[1];
+ int nrhs = b->size[1];
+ real wkopt;
+
+ int info;
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info);
+#else
+ magma_dgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info);
+#endif
+
+ real *hwork = th_magma_malloc_pinned<real>((size_t)wkopt);
+
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info);
+#else
+ magma_dgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info);
+#endif
+
+ magma_free_pinned(hwork);
+
+ if (info != 0)
+ THError("MAGMA gels : Argument %d : illegal value", -info);
+
+ THCTensor_(freeCopyTo)(state, a, ra_);
+ THCTensor_(freeCopyTo)(state, b, rb_);
+#else
+ THError(NoMagma(gels));
+#endif
+}
+
#endif
#endif
diff --git a/lib/THC/generic/THCTensorMathMagma.h b/lib/THC/generic/THCTensorMathMagma.h
index e0f2bd4..b20061f 100644
--- a/lib/THC/generic/THCTensorMathMagma.h
+++ b/lib/THC/generic/THCTensorMathMagma.h
@@ -60,6 +60,7 @@ static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, T
}
THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
+THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
diff --git a/test/test.lua b/test/test.lua
index 3f2f66d..25e2416 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2359,10 +2359,14 @@ if cutorch.magma then
{ 0.5360, 0.2048, 0.2745},
{ 0.8535,-0.3938,-0.2140},
}
- local rb1, ra1 = torch.gels(b, a)
- local rb2, ra2 = torch.gels(b:cuda(), a:cuda())
- tester:assertle((rb2 - rb1:cuda()):abs():max(), 5e-4, "wrong gels answer")
- tester:assertle((ra2 - ra1:cuda()):abs():max(), 5e-4, "wrong gels answer")
+ for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do
+ local at = a:type(typename)
+ local bt = b:type(typename)
+ local rb1, ra1 = torch.gels(bt, at)
+ local rb2, ra2 = torch.gels(bt:cuda(), at:cuda())
+ tester:assertle((rb2 - rb1:cuda()):abs():max(), 5e-4, "wrong gels answer")
+ tester:assertle((ra2 - ra1:cuda()):abs():max(), 5e-4, "wrong gels answer")
+ end
end
function test.symeig()