Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorrkayaith <rkayaith@gmail.com>2022-05-18 10:38:42 +0300
committerRiver Riddle <riddleriver@gmail.com>2022-05-18 10:55:59 +0300
commit7814b559bd5e1dbb3c016b393068698bc5781cc5 (patch)
treee7a6392ccc44e838543f74463e644a98844ae12c /mlir
parente9a1c82d695472820c93af40cbf3d9fde2a149c6 (diff)
[GreedyPatternRewriter] Avoid reversing constant order
The previous fix from af371f9f98da only applied when using a bottom-up traversal. The change here applies the constant preprocessing logic to the top-down case as well. This resolves the issue with the canonicalizer pass still reordering constants, since it uses a top-down traversal by default. Fixes #51892 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D125623
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp28
-rw-r--r--mlir/test/Dialect/Arithmetic/canonicalize.mlir32
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir8
-rw-r--r--mlir/test/Transforms/test-operation-folder.mlir3
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp13
5 files changed, 51 insertions, 33 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 0b80cd66459b..b2e965945101 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -133,6 +133,16 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
};
#endif
+ auto insertKnownConstant = [&](Operation *op) {
+ // Check for existing constants when populating the worklist. This avoids
+ // accidentally reversing the constant order during processing.
+ Attribute constValue;
+ if (matchPattern(op, m_Constant(&constValue)))
+ if (!folder.insertKnownConstant(op, constValue))
+ return true;
+ return false;
+ };
+
bool changed = false;
unsigned iteration = 0;
do {
@@ -142,22 +152,18 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
for (auto &region : regions) {
- region.walk([this](Operation *op) {
- // If we aren't processing top-down, check for existing constants when
- // populating the worklist. This avoids accidentally reversing the
- // constant order during processing.
- Attribute constValue;
- if (matchPattern(op, m_Constant(&constValue)))
- if (!folder.insertKnownConstant(op, constValue))
- return;
- addToWorklist(op);
+ region.walk([&](Operation *op) {
+ if (!insertKnownConstant(op))
+ addToWorklist(op);
});
}
} else {
// Add all nested operations to the worklist in preorder.
for (auto &region : regions)
- region.walk<WalkOrder::PreOrder>(
- [this](Operation *op) { worklist.push_back(op); });
+ region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (!insertKnownConstant(op))
+ worklist.push_back(op);
+ });
// Reverse the list so our pop-back loop processes them in-order.
std::reverse(worklist.begin(), worklist.end());
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index c560222b18ec..11e20458bf60 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -733,8 +733,8 @@ func.func @bitcastOfBitcast(%arg : i16) -> i16 {
// -----
// CHECK-LABEL: test_maxsi
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
// CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]]
// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
func.func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -749,8 +749,8 @@ func.func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) {
}
// CHECK-LABEL: test_maxsi2
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
// CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]]
// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -767,8 +767,8 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
// CHECK-LABEL: test_maxui
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
// CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]]
// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
func.func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -783,8 +783,8 @@ func.func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) {
}
// CHECK-LABEL: test_maxui
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
// CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]]
// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -801,8 +801,8 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
// CHECK-LABEL: test_minsi
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
// CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]]
// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
func.func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -817,8 +817,8 @@ func.func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) {
}
// CHECK-LABEL: test_minsi
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
// CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]]
// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -835,8 +835,8 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
// CHECK-LABEL: test_minui
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
// CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]]
// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
func.func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
@@ -851,8 +851,8 @@ func.func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
}
// CHECK-LABEL: test_minui
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0
+// CHECK-DAG: %[[C0:.+]] = arith.constant 42
+// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
// CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]]
// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 986819e7b569..cca2f439bd70 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1036,9 +1036,9 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3
}
return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
-// CHECK: %[[CST42:.*]] = arith.constant dense<42>
-// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
+// CHECK: %[[ONE:.*]] = arith.constant dense<1>
+// CHECK: %[[CST42:.*]] = arith.constant dense<42>
// CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
// CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}}
// CHECK: tensor.extract %{{.*}}[]
@@ -1069,9 +1069,9 @@ func.func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tens
}
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
-// CHECK: %[[CST42:.*]] = arith.constant dense<42>
-// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
+// CHECK: %[[ONE:.*]] = arith.constant dense<1>
+// CHECK: %[[CST42:.*]] = arith.constant dense<42>
// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
// CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]]
// CHECK: tensor.extract %{{.*}}[]
diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir
index 488231a226ca..670ec232a392 100644
--- a/mlir/test/Transforms/test-operation-folder.mlir
+++ b/mlir/test/Transforms/test-operation-folder.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -test-patterns -test-patterns %s | FileCheck %s
+// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
+// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
func.func @foo() -> i32 {
%c42 = arith.constant 42 : i32
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 09c7a12d96df..264b118c8956 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -151,6 +151,9 @@ struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
+ TestPatternDriver() = default;
+ TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
+
StringRef getArgument() const final { return "test-patterns"; }
StringRef getDescription() const final { return "Run test dialect patterns"; }
void runOnOperation() override {
@@ -162,8 +165,16 @@ struct TestPatternDriver
FolderInsertBeforePreviouslyFoldedConstantPattern,
FolderCommutativeOp2WithConstant>(&getContext());
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ GreedyRewriteConfig config;
+ config.useTopDownTraversal = this->useTopDownTraversal;
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config);
}
+
+ Option<bool> useTopDownTraversal{
+ *this, "top-down",
+ llvm::cl::desc("Seed the worklist in general top-down order"),
+ llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
};
} // namespace