Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2017-03-24 19:01:36 +0300
committerAdam Paszke <adam.paszke@gmail.com>2017-03-24 19:01:36 +0300
commit67d511a1d0ac35c9bc419e382cad8f191a10f880 (patch)
treeafa34e048e1de6eddf21732456990d1c38d2090c
parent9c471e8b0c0f0cddf81ee0a0d27abdbbfa5f3802 (diff)
Make rinfo_ argument optional in btrifact
-rw-r--r--lib/TH/generic/THTensorLapack.c30
1 files changed, 18 insertions, 12 deletions
diff --git a/lib/TH/generic/THTensorLapack.c b/lib/TH/generic/THTensorLapack.c
index c8f10d4..9e3309c 100644
--- a/lib/TH/generic/THTensorLapack.c
+++ b/lib/TH/generic/THTensorLapack.c
@@ -973,26 +973,28 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf
THTensor *rai = THTensor_(new)();
THIntTensor *rpivoti = THIntTensor_new();
- if (!THIntTensor_isContiguous(rinfo_)) {
- THError("Error: rinfo_ is not contiguous.");
- }
- if (!THIntTensor_isContiguous(rpivots_)) {
- THError("Error: rpivots_ is not contiguous.");
+ int info = 0;
+ int *info_ptr = &info;
+ if (rinfo_) {
+ THIntTensor_resize1d(rinfo_, num_batches);
+ info_ptr = THIntTensor_data(rinfo_);
}
+
THIntTensor_resize2d(rpivots_, num_batches, n);
- THIntTensor_resize1d(rinfo_, num_batches);
- for (long batch = 0; batch < num_batches; ++batch) {
+ long batch = 0;
+ for (; batch < num_batches; ++batch) {
THTensor_(select)(ai, a, 0, batch);
THTensor_(select)(rai, ra__, 0, batch);
THIntTensor_select(rpivoti, rpivots_, 0, batch);
-#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_FLOAT)
THLapack_(getrf)(n, n, THTensor_(data)(rai), lda,
- THIntTensor_data(rpivoti), &THIntTensor_data(rinfo_)[batch]);
-#else
- THError("Unimplemented");
-#endif
+ THIntTensor_data(rpivoti), info_ptr);
+ if (rinfo_) {
+ info_ptr++;
+ } else if (info != 0) {
+ break;
+ }
}
THTensor_(free)(ai);
@@ -1002,6 +1004,10 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf
if (ra__ != ra_) {
THTensor_(freeCopyTo)(ra__, ra_);
}
+
+ if (!rinfo_ && info != 0) {
+ THError("failed to factorize batch element %ld (info == %d)", batch, info);
+ }
}
void THTensor_(btrisolve)(THTensor *rb_, THTensor *atf, THTensor *b, THIntTensor *pivots)