diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2017-03-24 19:01:36 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2017-03-24 19:01:36 +0300 |
commit | 67d511a1d0ac35c9bc419e382cad8f191a10f880 (patch) | |
tree | afa34e048e1de6eddf21732456990d1c38d2090c | |
parent | 9c471e8b0c0f0cddf81ee0a0d27abdbbfa5f3802 (diff) |
Make rinfo_ argument optional in btrifact
-rw-r--r-- | lib/TH/generic/THTensorLapack.c | 30 |
1 files changed, 18 insertions, 12 deletions
diff --git a/lib/TH/generic/THTensorLapack.c b/lib/TH/generic/THTensorLapack.c index c8f10d4..9e3309c 100644 --- a/lib/TH/generic/THTensorLapack.c +++ b/lib/TH/generic/THTensorLapack.c @@ -973,26 +973,28 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf THTensor *rai = THTensor_(new)(); THIntTensor *rpivoti = THIntTensor_new(); - if (!THIntTensor_isContiguous(rinfo_)) { - THError("Error: rinfo_ is not contiguous."); - } - if (!THIntTensor_isContiguous(rpivots_)) { - THError("Error: rpivots_ is not contiguous."); + int info = 0; + int *info_ptr = &info; + if (rinfo_) { + THIntTensor_resize1d(rinfo_, num_batches); + info_ptr = THIntTensor_data(rinfo_); } + THIntTensor_resize2d(rpivots_, num_batches, n); - THIntTensor_resize1d(rinfo_, num_batches); - for (long batch = 0; batch < num_batches; ++batch) { + long batch = 0; + for (; batch < num_batches; ++batch) { THTensor_(select)(ai, a, 0, batch); THTensor_(select)(rai, ra__, 0, batch); THIntTensor_select(rpivoti, rpivots_, 0, batch); -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_FLOAT) THLapack_(getrf)(n, n, THTensor_(data)(rai), lda, - THIntTensor_data(rpivoti), &THIntTensor_data(rinfo_)[batch]); -#else - THError("Unimplemented"); -#endif + THIntTensor_data(rpivoti), info_ptr); + if (rinfo_) { + info_ptr++; + } else if (info != 0) { + break; + } } THTensor_(free)(ai); @@ -1002,6 +1004,10 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf if (ra__ != ra_) { THTensor_(freeCopyTo)(ra__, ra_); } + + if (!rinfo_ && info != 0) { + THError("failed to factorize batch element %ld (info == %d)", batch, info); + } } void THTensor_(btrisolve)(THTensor *rb_, THTensor *atf, THTensor *b, THIntTensor *pivots) |