diff options
author | Peter Klausler <pklausler@nvidia.com> | 2022-04-30 18:32:50 +0300 |
---|---|---|
committer | Peter Klausler <pklausler@nvidia.com> | 2022-05-10 01:09:39 +0300 |
commit | 85fdbc1569f5c97daafd6a0daade54282806aa6a (patch) | |
tree | 4960329543a57ab90c7c175bc0073f3842dbda9f /flang | |
parent | 9641b9be9dfc599cbb6a812c1e587eff2ddd8707 (diff) |
[flang] Correct folding of SPREAD() for higher ranks
The construction of the dimension order vector used to populate the
result array was incorrect, leading to a scrambled-looking result
for rank-3 and higher results. Fix, and extend tests.
Differential Revision: https://reviews.llvm.org/D125113
Diffstat (limited to 'flang')
-rw-r--r-- | flang/lib/Evaluate/fold-implementation.h | 4 | ||||
-rw-r--r-- | flang/test/Evaluate/fold-spread.f90 | 4 |
2 files changed, 5 insertions, 3 deletions
diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h index 2e12b502d0fc..317575ef9112 100644 --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -890,9 +890,9 @@ template <typename T> Expr<T> Folder<T>::SPREAD(FunctionRef<T> &&funcRef) { Constant<T> spread{source->Reshape(std::move(shape))}; std::vector<int> dimOrder; for (int j{0}; j < sourceRank; ++j) { - dimOrder.push_back(j); + dimOrder.push_back(j < *dim - 1 ? j : j + 1); } - dimOrder.insert(dimOrder.begin() + *dim - 1, sourceRank); + dimOrder.push_back(*dim - 1); ConstantSubscripts at{spread.lbounds()}; // all 1 spread.CopyFrom(*source, TotalElementCount(spread.shape()), at, &dimOrder); return Expr<T>{std::move(spread)}; diff --git a/flang/test/Evaluate/fold-spread.f90 b/flang/test/Evaluate/fold-spread.f90 index 127de8fbbe6a..b7e493ee061c 100644 --- a/flang/test/Evaluate/fold-spread.f90 +++ b/flang/test/Evaluate/fold-spread.f90 @@ -5,9 +5,11 @@ module m1 logical, parameter :: test_stov = all(spread(1, 1, 2) == [1, 1]) logical, parameter :: test_vtom1 = all(spread([1, 2], 1, 3) == reshape([1, 1, 1, 2, 2, 2], [3, 2])) logical, parameter :: test_vtom2 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3])) - logical, parameter :: test_vtom3 = all(spread([1, 2], 2, 3) == reshape([1, 2, 1, 2, 1, 2], [2, 3])) + logical, parameter :: test_vtom3 = all(spread([1, 2], 1, 3) == reshape([1, 1, 1, 2, 2, 2], [3, 2])) logical, parameter :: test_log1 = all(all(spread([.false., .true.], 1, 2), dim=2) .eqv. [.false., .false.]) logical, parameter :: test_log2 = all(all(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.]) logical, parameter :: test_log3 = all(any(spread([.false., .true.], 1, 2), dim=2) .eqv. [.true., .true.]) logical, parameter :: test_log4 = all(any(spread([.false., .true.], 2, 2), dim=2) .eqv. [.false., .true.]) + logical, parameter :: test_m2toa3 = all(spread(reshape([(j,j=1,6)],[2,3]),1,4) == & + reshape([((j,k=1,4),j=1,6)],[4,2,3])) end module |