Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2016-11-15 23:37:59 +0300
committerTrevor Killeen <killeentm@gmail.com>2016-11-16 00:31:32 +0300
commit18d9e4ecc9f4f4d78d1b6ee5e11ac76d5e5514b1 (patch)
tree6f4ae01da7cda4fd69e0d6397cd131cd506383f5
parentdfcdce1c7769a3637cebf6bbbd78f5c5b50f9e98 (diff)
[cutorch mag2gen] move eig to generic
-rw-r--r--TensorMath.lua13
-rw-r--r--lib/THC/THCTensorMath.h1
-rw-r--r--lib/THC/THCTensorMathMagma.cu61
-rw-r--r--lib/THC/generic/THCTensorMathMagma.cu68
-rw-r--r--lib/THC/generic/THCTensorMathMagma.h1
-rw-r--r--test/test.lua11
6 files changed, 89 insertions, 66 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index a67b101..bb021e7 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -1202,6 +1202,19 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor},
{name='charoption', values={'N', 'V'}, default='N'},
{name='charoption', values={'U', 'L'}, default='U'}})
+
+ wrap("eig",
+ cname("geev"),
+ {{name=Tensor, returned=true},
+ {name=Tensor, returned=true},
+ {name=Tensor},
+ {name='charoption', values={'N', 'V'}, default='N'}},
+ cname("geev"),
+ {{name=Tensor, default=true, returned=true, invisible=true},
+ {name=Tensor, default=true, returned=true, invisible=true},
+ {name=Tensor},
+ {name='charoption', values={'N', 'V'}, default='N'}})
+
end
wrap("dot",
diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h
index 32e18cf..7ce7504 100644
--- a/lib/THC/THCTensorMath.h
+++ b/lib/THC/THCTensorMath.h
@@ -45,7 +45,6 @@
#include "THCGenerateAllTypes.h"
// MAGMA (i.e. CUDA implementation of LAPACK functions)
-THC_API void THCudaTensor_geev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobvr);
THC_API void THCudaTensor_gesvd(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_, THCudaTensor *rv_, THCudaTensor *a, const char *jobu);
THC_API void THCudaTensor_gesvd2(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_, THCudaTensor *rv_, THCudaTensor *ra_, THCudaTensor *a, const char *jobu);
THC_API void THCudaTensor_getri(THCState *state, THCudaTensor *ra_, THCudaTensor *a);
diff --git a/lib/THC/THCTensorMathMagma.cu b/lib/THC/THCTensorMathMagma.cu
index 029811e..47bc484 100644
--- a/lib/THC/THCTensorMathMagma.cu
+++ b/lib/THC/THCTensorMathMagma.cu
@@ -23,67 +23,6 @@ void THCMagma_init(THCState *state)
#endif
}
-void THCudaTensor_geev(THCState *state, THCudaTensor *re_, THCudaTensor *rv_, THCudaTensor *a_, const char *jobvrs)
-{
-#ifdef USE_MAGMA
- THArgCheck(a_->nDimension == 2, 3, "A should be 2 dimensional");
- THArgCheck(a_->size[0] == a_->size[1], 3, "A should be square");
-
- magma_vec_t jobvr = jobvrs[0] == 'N' ? MagmaNoVec : MagmaVec;
- int n = a_->size[0];
-
- float *a_data = th_magma_malloc_pinned<float>(n * n);
- THCudaTensor_copyTensor2d(state, a_data, a_);
-
- float *wr = th_magma_malloc_pinned<float>(n);
- float *wi = th_magma_malloc_pinned<float>(n);
-
- float *vr_data = NULL;
- int ldvr = 1;
- if (jobvr == MagmaVec)
- {
- vr_data = th_magma_malloc_pinned<float>(n * n);
- ldvr = n;
- }
-
- float wkopt;
- int info;
-
- magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info);
-
- int lwork = (int) wkopt;
- float *work_data = th_magma_malloc_pinned<float>(lwork);
-
- magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info);
-
- if (info > 0)
- THError("MAGMA geev : Failed to converge. %d off-diagonal elements of an didn't converge to zero", info);
- else if (info < 0)
- THError("MAGMA geev : Argument %d : illegal value", -info);
-
- {
- THCudaTensor_resize2d(state, re_, 2, n);
- THCudaTensor *re = THCudaTensor_newContiguous(state, re_);
- THCudaCheck(cudaMemcpy(re->storage->data + re->storageOffset, wr, n*sizeof(float), cudaMemcpyHostToDevice));
- THCudaCheck(cudaMemcpy(re->storage->data + re->storageOffset + n, wi, n*sizeof(float), cudaMemcpyHostToDevice));
- THCudaTensor_freeCopyTo(state, re, re_);
- THCudaTensor_transpose(state, re_, NULL, 0, 1);
- }
-
- if (jobvr == MagmaVec)
- THCudaTensor_copyArray2d(state, rv_, vr_data, n, n);
-
- magma_free_pinned(work_data);
- magma_free_pinned(vr_data);
- magma_free_pinned(wi);
- magma_free_pinned(wr);
- magma_free_pinned(a_data);
-
-#else
- THError(NoMagma(geev));
-#endif
-}
-
void THCudaTensor_gesvd(THCState *state, THCudaTensor *ru_, THCudaTensor *rs_, THCudaTensor *rv_, THCudaTensor *a, const char *jobu)
{
#ifdef USE_MAGMA
diff --git a/lib/THC/generic/THCTensorMathMagma.cu b/lib/THC/generic/THCTensorMathMagma.cu
index feab665..e0d505c 100644
--- a/lib/THC/generic/THCTensorMathMagma.cu
+++ b/lib/THC/generic/THCTensorMathMagma.cu
@@ -145,6 +145,74 @@ THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, T
#endif
}
+THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvrs)
+{
+#ifdef USE_MAGMA
+ THArgCheck(a_->nDimension == 2, 3, "A should be 2 dimensional");
+ THArgCheck(a_->size[0] == a_->size[1], 3, "A should be square");
+
+ magma_vec_t jobvr = jobvrs[0] == 'N' ? MagmaNoVec : MagmaVec;
+ int n = a_->size[0];
+
+ real *a_data = th_magma_malloc_pinned<real>(n * n);
+ THCTensor_(copyTensor2d)(state, a_data, a_);
+
+ real *wr = th_magma_malloc_pinned<real>(n);
+ real *wi = th_magma_malloc_pinned<real>(n);
+
+ real *vr_data = NULL;
+ int ldvr = 1;
+ if (jobvr == MagmaVec)
+ {
+ vr_data = th_magma_malloc_pinned<real>(n * n);
+ ldvr = n;
+ }
+
+ real wkopt;
+ int info;
+
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info);
+#else
+ magma_dgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info);
+#endif
+
+ int lwork = (int) wkopt;
+ real *work_data = th_magma_malloc_pinned<real>(lwork);
+
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info);
+#else
+ magma_dgeev(MagmaNoVec, jobvr, n, a_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info);
+#endif
+
+ if (info > 0)
+ THError("MAGMA geev : Failed to converge. %d off-diagonal elements of an didn't converge to zero", info);
+ else if (info < 0)
+ THError("MAGMA geev : Argument %d : illegal value", -info);
+
+ {
+ THCTensor_(resize2d)(state, re_, 2, n);
+ THCTensor *re = THCTensor_(newContiguous)(state, re_);
+ THCudaCheck(cudaMemcpy(re->storage->data + re->storageOffset, wr, n*sizeof(real), cudaMemcpyHostToDevice));
+ THCudaCheck(cudaMemcpy(re->storage->data + re->storageOffset + n, wi, n*sizeof(real), cudaMemcpyHostToDevice));
+ THCTensor_(freeCopyTo)(state, re, re_);
+ THCTensor_(transpose)(state, re_, NULL, 0, 1);
+ }
+
+ if (jobvr == MagmaVec)
+ THCTensor_(copyArray2d)(state, rv_, vr_data, n, n);
+
+ magma_free_pinned(work_data);
+ magma_free_pinned(vr_data);
+ magma_free_pinned(wi);
+ magma_free_pinned(wr);
+ magma_free_pinned(a_data);
+
+#else
+ THError(NoMagma(geev));
+#endif
+}
#endif
diff --git a/lib/THC/generic/THCTensorMathMagma.h b/lib/THC/generic/THCTensorMathMagma.h
index c09a7bb..1e72ef6 100644
--- a/lib/THC/generic/THCTensorMathMagma.h
+++ b/lib/THC/generic/THCTensorMathMagma.h
@@ -62,6 +62,7 @@ static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, T
THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobz, const char *uplo);
+THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvr);
#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
diff --git a/test/test.lua b/test/test.lua
index c7d1c52..5c1cacb 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -2392,10 +2392,13 @@ if cutorch.magma then
{ 0.5766, -0.6743, 0.6903, 0.3646, -0.4571},
{-0.8956, -0.4074, -0.7583, 0.1838, -0.0091},
}
- local e1,v1 = torch.eig(a, 'V')
- local e2,v2 = torch.eig(a:cuda(), 'V')
- tester:assertle((e2 - e1:cuda()):abs():max(), 1e-6, "wrong eig answer")
- tester:assertle((v2:abs() - v1:abs():cuda()):abs():max(), 1e-6, "wrong eig answer")
+ for _, typename in ipairs({'torch.DoubleTensor', 'torch.FloatTensor'}) do
+ local at = a:type(typename)
+ local e1,v1 = torch.eig(at, 'V')
+ local e2,v2 = torch.eig(at:cuda(), 'V')
+ tester:assertle((e2 - e1:cuda()):abs():max(), 1e-6, "wrong eig answer")
+ tester:assertle((v2:abs() - v1:abs():cuda()):abs():max(), 1e-6, "wrong eig answer")
+ end
end
function test.svd()