Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-11-03 01:53:45 +0300
committerGitHub <noreply@github.com>2016-11-03 01:53:45 +0300
commita665438965bffdf99136b3bc02e0d968fde54835 (patch)
tree9aec8acbc63cc3320138dd34a39f5c80b4e6c286
parent052e136928c2bf4c7f997c8aa244055d60d7269f (diff)
parent96847ac63686aea63349aa70bc346d8b9978a557 (diff)
Merge pull request #828 from apaszke/lapack
Add more size checks and improve some LAPACK error messages
-rw-r--r--lib/TH/generic/THTensorLapack.c89
1 files changed, 73 insertions, 16 deletions
diff --git a/lib/TH/generic/THTensorLapack.c b/lib/TH/generic/THTensorLapack.c
index 62d730a..7929939 100644
--- a/lib/TH/generic/THTensorLapack.c
+++ b/lib/TH/generic/THTensorLapack.c
@@ -103,12 +103,23 @@ static THTensor *THTensor_(cloneColumnMajor)(THTensor *self, THTensor *src)
void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
{
+ int free_b = 0;
if (a == NULL) a = ra_;
if (b == NULL) b = rb_;
- THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
- THArgCheck(b->nDimension == 2, 1, "B should be 2 dimensional");
- THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
- THArgCheck(a->size[0] == b->size[0], 2, "A,b size incompatible");
+ THArgCheck(a->nDimension == 2, 2, "A should have 2 dimensions, but has %d",
+ a->nDimension);
+ THArgCheck(b->nDimension == 1 || b->nDimension == 2, 1, "B should have 1 or 2 "
+ "dimensions, but has %d", b->nDimension);
+ THArgCheck(a->size[0] == a->size[1], 2, "A should be square, but is %ldx%ld",
+ a->size[0], a->size[1]);
+ THArgCheck(a->size[0] == b->size[0], 2, "A,B size incompatible - A has %ld "
+ "rows, B has %ld", a->size[0], b->size[0]);
+
+ if (b->nDimension == 1) {
+ b = THTensor_(newWithStorage2d)(b->storage, b->storageOffset, b->size[0],
+ b->stride[0], 1, 0);
+ free_b = 1;
+ }
int n, nrhs, lda, ldb, info;
THIntTensor *ipiv;
@@ -132,23 +143,36 @@ void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
THCleanup(
THTensor_(free)(ra__);
THTensor_(free)(rb__);
- THIntTensor_free(ipiv);),
+ THIntTensor_free(ipiv);
+ if (free_b) THTensor_(free)(b);),
"gesv", info, info);
THTensor_(freeCopyTo)(ra__, ra_);
THTensor_(freeCopyTo)(rb__, rb_);
THIntTensor_free(ipiv);
+ if (free_b) THTensor_(free)(b);
}
void THTensor_(trtrs)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a,
const char *uplo, const char *trans, const char *diag)
{
+ int free_b = 0;
if (a == NULL) a = ra_;
if (b == NULL) b = rb_;
- THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
- THArgCheck(b->nDimension == 2, 1, "A should be 2 dimensional");
- THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
- THArgCheck(b->size[0] == a->size[0], 2, "A,b size incompatible");
+ THArgCheck(a->nDimension == 2, 2, "A should have 2 dimensions, but has %d",
+ a->nDimension);
+ THArgCheck(b->nDimension == 1 || b->nDimension == 2, 1, "B should have 1 or 2 "
+ "dimensions, but has %d", b->nDimension);
+ THArgCheck(a->size[0] == a->size[1], 2, "A should be square, but is %ldx%ld",
+ a->size[0], a->size[1]);
+ THArgCheck(a->size[0] == b->size[0], 2, "A,B size incompatible - A has %ld "
+ "rows, B has %ld", a->size[0], b->size[0]);
+
+ if (b->nDimension == 1) {
+ b = THTensor_(newWithStorage2d)(b->storage, b->storageOffset, b->size[0],
+ b->stride[0], 1, 0);
+ free_b = 1;
+ }
int n, nrhs, lda, ldb, info;
THTensor *ra__; // working version of A matrix to be passed into lapack TRTRS
@@ -168,21 +192,35 @@ void THTensor_(trtrs)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a,
THLapackCheckWithCleanup("Lapack Error in %s : A(%d,%d) is zero, singular A",
- THCleanup(THTensor_(free)(ra__); THTensor_(free)(rb__);),
+ THCleanup(
+ THTensor_(free)(ra__);
+ THTensor_(free)(rb__);
+ if (free_b) THTensor_(free)(b);),
"trtrs", info, info);
THTensor_(freeCopyTo)(ra__, ra_);
THTensor_(freeCopyTo)(rb__, rb_);
+ if (free_b) THTensor_(free)(b);
}
void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
{
+ int free_b = 0;
// Note that a = NULL is interpreted as a = ra_, and b = NULL as b = rb_.
if (a == NULL) a = ra_;
if (b == NULL) b = rb_;
- THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
- THArgCheck(b->nDimension == 2, 1, "B should be 2 dimensional");
- THArgCheck(a->size[0] == b->size[0], 2, "size incompatible A,b");
+ THArgCheck(a->nDimension == 2, 2, "A should have 2 dimensions, but has %d",
+ a->nDimension);
+ THArgCheck(b->nDimension == 1 || b->nDimension == 2, 1, "B should have 1 or 2 "
+ "dimensions, but has %d", b->nDimension);
+ THArgCheck(a->size[0] == b->size[0], 2, "A,B size incompatible - A has %ld "
+ "rows, B has %ld", a->size[0], b->size[0]);
+
+ if (b->nDimension == 1) {
+ b = THTensor_(newWithStorage2d)(b->storage, b->storageOffset, b->size[0],
+ b->stride[0], 1, 0);
+ free_b = 1;
+ }
int m, n, nrhs, lda, ldb, info, lwork;
THTensor *work = NULL;
@@ -217,7 +255,8 @@ void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
THLapackCheckWithCleanup("Lapack Error in %s : The %d-th diagonal element of the triangular factor of A is zero",
THCleanup(THTensor_(free)(ra__);
THTensor_(free)(rb__);
- THTensor_(free)(work);),
+ THTensor_(free)(work);
+ if (free_b) THTensor_(free)(b);),
"gels", info,"");
/* rb__ is currently ldb by nrhs; resize it to n by nrhs */
@@ -228,6 +267,7 @@ void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
THTensor_(freeCopyTo)(ra__, ra_);
THTensor_(freeCopyTo)(rb__, rb_);
THTensor_(free)(work);
+ if (free_b) THTensor_(free)(b);
}
void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobvr)
@@ -312,6 +352,7 @@ void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a, const char *jobz
{
if (a == NULL) a = rv_;
THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional");
+ THArgCheck(a->size[0] == a->size[1], 1,"A should be square");
int n, lda, lwork, info;
THTensor *work;
@@ -572,9 +613,23 @@ void THTensor_(potrf)(THTensor *ra_, THTensor *a, const char *uplo)
void THTensor_(potrs)(THTensor *rb_, THTensor *b, THTensor *a, const char *uplo)
{
+ int free_b = 0;
if (b == NULL) b = rb_;
- THArgCheck(a->size[0] == a->size[1], 2, "A should be square");
+ THArgCheck(a->nDimension == 2, 2, "A should have 2 dimensions, but has %d",
+ a->nDimension);
+ THArgCheck(b->nDimension == 1 || b->nDimension == 2, 1, "B should have 1 or 2 "
+ "dimensions, but has %d", b->nDimension);
+ THArgCheck(a->size[0] == a->size[1], 2, "A should be square, but is %ldx%ld",
+ a->size[0], a->size[1]);
+ THArgCheck(a->size[0] == b->size[0], 2, "A,B size incompatible - A has %ld "
+ "rows, B has %ld", a->size[0], b->size[0]);
+
+ if (b->nDimension == 1) {
+ b = THTensor_(newWithStorage2d)(b->storage, b->storageOffset, b->size[0],
+ b->stride[0], 1, 0);
+ free_b = 1;
+ }
int n, nrhs, lda, ldb, info;
THTensor *ra__; // working version of A matrix to be passed into lapack TRTRS
@@ -595,9 +650,11 @@ void THTensor_(potrs)(THTensor *rb_, THTensor *b, THTensor *a, const char *uplo)
THLapackCheckWithCleanup("Lapack Error in %s : A(%d,%d) is zero, singular A",
THCleanup(
THTensor_(free)(ra__);
- THTensor_(free)(rb__);),
+ THTensor_(free)(rb__);
+ if (free_b) THTensor_(free)(b);),
"potrs", info, info);
+ if (free_b) THTensor_(free)(b);
THTensor_(free)(ra__);
THTensor_(freeCopyTo)(rb__, rb_);
}