Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNatalia Gimelshein <ngimelshein@nvidia.com>2017-07-15 04:02:23 +0300
committerSoumith Chintala <soumith@gmail.com>2017-07-18 08:33:59 +0300
commita0799aaaf809985a3be272c5d0e482d5f9d04136 (patch)
treebac886c4a82af4a68aa506e484ba42983e6fcb23
parentcaf84f3af0b1e9f6bd5c4129c01345f7bd72e431 (diff)
fix baddbmm for expanded tensors
-rw-r--r--lib/THC/generic/THCTensorMathBlas.cu13
1 files changed, 8 insertions, 5 deletions
diff --git a/lib/THC/generic/THCTensorMathBlas.cu b/lib/THC/generic/THCTensorMathBlas.cu
index 61c255a..d109296 100644
--- a/lib/THC/generic/THCTensorMathBlas.cu
+++ b/lib/THC/generic/THCTensorMathBlas.cu
@@ -492,13 +492,15 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
ldc = result_->stride[2];
}
- if (batch1->stride[transpose_result ? 2 : 1] == 1)
+ if (batch1->stride[transpose_result ? 2 : 1] == 1 &&
+ batch1->stride[transpose_result ? 1 : 2] != 0)
{
transpose_batch1 = 'n';
batch1_ = batch1;
lda = batch1_->stride[transpose_result ? 1 : 2];
}
- else if (batch1->stride[transpose_result ? 1 : 2] == 1)
+ else if (batch1->stride[transpose_result ? 1 : 2] == 1 &&
+ batch1->stride[transpose_result ? 2 : 1] != 0)
{
transpose_batch1 = 't';
batch1_ = batch1;
@@ -511,13 +513,15 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
lda = batch1_->stride[1];
}
- if (batch2->stride[transpose_result ? 2 : 1] == 1)
+ if (batch2->stride[transpose_result ? 2 : 1] == 1 &&
+ batch2->stride[transpose_result ? 1 : 2] != 0)
{
transpose_batch2 = 'n';
batch2_ = batch2;
ldb = batch2_->stride[transpose_result ? 1 : 2];
}
- else if (batch2->stride[transpose_result ? 1 : 2] == 1)
+ else if (batch2->stride[transpose_result ? 1 : 2] == 1 &&
+ batch2->stride[transpose_result ? 2 : 1] != 0)
{
transpose_batch2 = 't';
batch2_ = batch2;
@@ -529,7 +533,6 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
batch2_ = THCTensor_(newContiguous)(state, batch2);
ldb = batch2_->stride[1];
}
-
long num_batches = result_->size[0];
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)