diff options
author | Natalia Gimelshein <ngimelshein@nvidia.com> | 2017-07-15 04:02:23 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-07-18 08:33:59 +0300 |
commit | a0799aaaf809985a3be272c5d0e482d5f9d04136 (patch) | |
tree | bac886c4a82af4a68aa506e484ba42983e6fcb23 | |
parent | caf84f3af0b1e9f6bd5c4129c01345f7bd72e431 (diff) |
fix baddbmm for expanded tensors
-rw-r--r-- | lib/THC/generic/THCTensorMathBlas.cu | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu index 61c255a..d109296 100644 --- a/lib/THC/generic/THCTensorMathBlas.cu +++ b/lib/THC/generic/THCTensorMathBlas.cu @@ -492,13 +492,15 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, ldc = result_->stride[2]; } - if (batch1->stride[transpose_result ? 2 : 1] == 1) + if (batch1->stride[transpose_result ? 2 : 1] == 1 && + batch1->stride[transpose_result ? 1 : 2] != 0) { transpose_batch1 = 'n'; batch1_ = batch1; lda = batch1_->stride[transpose_result ? 1 : 2]; } - else if (batch1->stride[transpose_result ? 1 : 2] == 1) + else if (batch1->stride[transpose_result ? 1 : 2] == 1 && + batch1->stride[transpose_result ? 2 : 1] != 0) { transpose_batch1 = 't'; batch1_ = batch1; @@ -511,13 +513,15 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, lda = batch1_->stride[1]; } - if (batch2->stride[transpose_result ? 2 : 1] == 1) + if (batch2->stride[transpose_result ? 2 : 1] == 1 && + batch2->stride[transpose_result ? 1 : 2] != 0) { transpose_batch2 = 'n'; batch2_ = batch2; ldb = batch2_->stride[transpose_result ? 1 : 2]; } - else if (batch2->stride[transpose_result ? 1 : 2] == 1) + else if (batch2->stride[transpose_result ? 1 : 2] == 1 && + batch2->stride[transpose_result ? 2 : 1] != 0) { transpose_batch2 = 't'; batch2_ = batch2; @@ -529,7 +533,6 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, batch2_ = THCTensor_(newContiguous)(state, batch2); ldb = batch2_->stride[1]; } - long num_batches = result_->size[0]; #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) |