diff options
author | rkayaith <rkayaith@gmail.com> | 2022-05-18 10:38:42 +0300 |
---|---|---|
committer | River Riddle <riddleriver@gmail.com> | 2022-05-18 10:55:59 +0300 |
commit | 7814b559bd5e1dbb3c016b393068698bc5781cc5 (patch) | |
tree | e7a6392ccc44e838543f74463e644a98844ae12c /mlir | |
parent | e9a1c82d695472820c93af40cbf3d9fde2a149c6 (diff) |
[GreedyPatternRewriter] Avoid reversing constant order
The previous fix from af371f9f98da only applied when using a bottom-up
traversal. The change here applies the constant preprocessing logic to the
top-down case as well. This resolves the issue with the canonicalizer pass still
reordering constants, since it uses a top-down traversal by default.
Fixes #51892
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D125623
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 28 | ||||
-rw-r--r-- | mlir/test/Dialect/Arithmetic/canonicalize.mlir | 32 | ||||
-rw-r--r-- | mlir/test/Dialect/SCF/canonicalize.mlir | 8 | ||||
-rw-r--r-- | mlir/test/Transforms/test-operation-folder.mlir | 3 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestPatterns.cpp | 13 |
5 files changed, 51 insertions, 33 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 0b80cd66459b..b2e965945101 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -133,6 +133,16 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) { }; #endif + auto insertKnownConstant = [&](Operation *op) { + // Check for existing constants when populating the worklist. This avoids + // accidentally reversing the constant order during processing. + Attribute constValue; + if (matchPattern(op, m_Constant(&constValue))) + if (!folder.insertKnownConstant(op, constValue)) + return true; + return false; + }; + bool changed = false; unsigned iteration = 0; do { @@ -142,22 +152,18 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) { if (!config.useTopDownTraversal) { // Add operations to the worklist in postorder. for (auto ®ion : regions) { - region.walk([this](Operation *op) { - // If we aren't processing top-down, check for existing constants when - // populating the worklist. This avoids accidentally reversing the - // constant order during processing. - Attribute constValue; - if (matchPattern(op, m_Constant(&constValue))) - if (!folder.insertKnownConstant(op, constValue)) - return; - addToWorklist(op); + region.walk([&](Operation *op) { + if (!insertKnownConstant(op)) + addToWorklist(op); }); } } else { // Add all nested operations to the worklist in preorder. for (auto ®ion : regions) - region.walk<WalkOrder::PreOrder>( - [this](Operation *op) { worklist.push_back(op); }); + region.walk<WalkOrder::PreOrder>([&](Operation *op) { + if (!insertKnownConstant(op)) + worklist.push_back(op); + }); // Reverse the list so our pop-back loop processes them in-order. std::reverse(worklist.begin(), worklist.end()); diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir index c560222b18ec..11e20458bf60 100644 --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -733,8 +733,8 @@ func.func @bitcastOfBitcast(%arg : i16) -> i16 { // ----- // CHECK-LABEL: test_maxsi -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127 +// CHECK-DAG: %[[C0:.+]] = arith.constant 42 +// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127 // CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]] // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] func.func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) { @@ -749,8 +749,8 @@ func.func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) { } // CHECK-LABEL: test_maxsi2 -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127 +// CHECK-DAG: %[[C0:.+]] = arith.constant 42 +// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127 // CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]] // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) { @@ -767,8 +767,8 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) { // ----- // CHECK-LABEL: test_maxui -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1 +// CHECK-DAG: %[[C0:.+]] = arith.constant 42 +// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1 // CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]] // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] func.func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) { @@ -783,8 +783,8 @@ func.func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) { } // CHECK-LABEL: test_maxui -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1 +// CHECK-DAG: %[[C0:.+]] = arith.constant 42 +// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1 // CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]] // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) { @@ -801,8 +801,8 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) { // ----- // CHECK-LABEL: test_minsi -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128 +// CHECK-DAG: %[[C0:.+]] = arith.constant 42 +// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128 // CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]] // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] func.func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) { @@ -817,8 +817,8 @@ func.func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) { } // CHECK-LABEL: test_minsi -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128 +// CHECK-DAG: %[[C0:.+]] = arith.constant 42 +// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128 // CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]] // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) { @@ -835,8 +835,8 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) { // ----- // CHECK-LABEL: test_minui -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0 +// CHECK-DAG: %[[C0:.+]] = arith.constant 42 +// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0 // CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]] // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] func.func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) { @@ -851,8 +851,8 @@ func.func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) { } // CHECK-LABEL: test_minui -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0 +// CHECK-DAG: %[[C0:.+]] = arith.constant 42 +// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0 // CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]] // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) { diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 986819e7b569..cca2f439bd70 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1036,9 +1036,9 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3 } return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32> } -// CHECK: %[[CST42:.*]] = arith.constant dense<42> -// CHECK: %[[ONE:.*]] = arith.constant dense<1> // CHECK: %[[ZERO:.*]] = arith.constant dense<0> +// CHECK: %[[ONE:.*]] = arith.constant dense<1> +// CHECK: %[[CST42:.*]] = arith.constant dense<42> // CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]]) // CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}} // CHECK: tensor.extract %{{.*}}[] @@ -1069,9 +1069,9 @@ func.func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tens } return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32> } -// CHECK: %[[CST42:.*]] = arith.constant dense<42> -// CHECK: %[[ONE:.*]] = arith.constant dense<1> // CHECK: %[[ZERO:.*]] = arith.constant dense<0> +// CHECK: %[[ONE:.*]] = arith.constant dense<1> +// CHECK: %[[CST42:.*]] = arith.constant dense<42> // CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]]) // CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]] // CHECK: tensor.extract %{{.*}}[] diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir index 488231a226ca..670ec232a392 100644 --- a/mlir/test/Transforms/test-operation-folder.mlir +++ b/mlir/test/Transforms/test-operation-folder.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt -test-patterns -test-patterns %s | FileCheck %s +// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s +// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s func.func @foo() -> i32 { %c42 = arith.constant 42 : i32 diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 09c7a12d96df..264b118c8956 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -151,6 +151,9 @@ struct TestPatternDriver : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) + TestPatternDriver() = default; + TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {} + StringRef getArgument() const final { return "test-patterns"; } StringRef getDescription() const final { return "Run test dialect patterns"; } void runOnOperation() override { @@ -162,8 +165,16 @@ struct TestPatternDriver FolderInsertBeforePreviouslyFoldedConstantPattern, FolderCommutativeOp2WithConstant>(&getContext()); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + GreedyRewriteConfig config; + config.useTopDownTraversal = this->useTopDownTraversal; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); } + + Option<bool> useTopDownTraversal{ + *this, "top-down", + llvm::cl::desc("Seed the worklist in general top-down order"), + llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; }; } // namespace |