diff options
author | Mahesh Ravishankar <ravishankarm@google.com> | 2022-11-03 23:38:34 +0300 |
---|---|---|
committer | Mahesh Ravishankar <ravishankarm@google.com> | 2022-11-03 23:38:34 +0300 |
commit | 38f34e587d10fcd7d18fd240e41248006faa639e (patch) | |
tree | e114ad88787772d16a04a1d1b8c5eda1e2b49f9d /mlir | |
parent | 24f9293de8794963bd29c731745a71ef6a1aab9d (diff) |
[mlir][Arith] Fix folder of CmpIOp to not fail when element type is not integer.
The folder used `cast<IntegerType>` which would segfault if the type were
a vector type. Handle this case appropriately and avoid failure.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D137345
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 19 | ||||
-rw-r--r-- | mlir/test/Dialect/Arith/canonicalize.mlir | 26 |
2 files changed, 41 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d1d03a549092..2c0fc51d08a4 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::arith; @@ -1444,6 +1445,16 @@ static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { return DenseElementsAttr::get(shapedType, boolAttr); } +static Optional<int64_t> getIntegerWidth(Type t) { + if (auto intType = t.dyn_cast<IntegerType>()) { + return intType.getWidth(); + } + if (auto vectorIntType = t.dyn_cast<VectorType>()) { + return vectorIntType.getElementType().cast<IntegerType>().getWidth(); + } + return llvm::None; +} + OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 2 && "cmpi takes two operands"); @@ -1456,13 +1467,17 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) { if (matchPattern(getRhs(), m_Zero())) { if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { // extsi(%x : i1 -> iN) != 0 -> %x - if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 && + Optional<int64_t> integerWidth = + getIntegerWidth(extOp.getOperand().getType()); + if (integerWidth && integerWidth.value() == 1 && getPredicate() == arith::CmpIPredicate::ne) return extOp.getOperand(); } if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { // extui(%x : i1 -> iN) != 0 -> %x - if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 && + Optional<int64_t> integerWidth = + getIntegerWidth(extOp.getOperand().getType()); + if (integerWidth && integerWidth.value() == 1 && getPredicate() == arith::CmpIPredicate::ne) return extOp.getOperand(); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 337eec00f3bf..336324ef4eec 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -162,7 +162,7 @@ func.func @cmpi_const_right(%arg0: i64) // ----- -// CHECK-LABEL: @cmpOfExtSI +// CHECK-LABEL: @cmpOfExtSI( // CHECK-NEXT: return %arg0 func.func @cmpOfExtSI(%arg0: i1) -> i1 { %ext = arith.extsi %arg0 : i1 to i64 @@ -171,7 +171,7 @@ func.func @cmpOfExtSI(%arg0: i1) -> i1 { return %res : i1 } -// CHECK-LABEL: @cmpOfExtUI +// CHECK-LABEL: @cmpOfExtUI( // CHECK-NEXT: return %arg0 func.func @cmpOfExtUI(%arg0: i1) -> i1 { %ext = arith.extui %arg0 : i1 to i64 @@ -182,6 +182,26 @@ func.func @cmpOfExtUI(%arg0: i1) -> i1 { // ----- +// CHECK-LABEL: @cmpOfExtSIVector( +// CHECK-NEXT: return %arg0 +func.func @cmpOfExtSIVector(%arg0: vector<4xi1>) -> vector<4xi1> { + %ext = arith.extsi %arg0 : vector<4xi1> to vector<4xi64> + %c0 = arith.constant dense<0> : vector<4xi64> + %res = arith.cmpi ne, %ext, %c0 : vector<4xi64> + return %res : vector<4xi1> +} + +// CHECK-LABEL: @cmpOfExtUIVector( +// CHECK-NEXT: return %arg0 +func.func @cmpOfExtUIVector(%arg0: vector<4xi1>) -> vector<4xi1> { + %ext = arith.extui %arg0 : vector<4xi1> to vector<4xi64> + %c0 = arith.constant dense<0> : vector<4xi64> + %res = arith.cmpi ne, %ext, %c0 : vector<4xi64> + return %res : vector<4xi1> +} + +// ----- + // CHECK-LABEL: @extSIOfExtUI // CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64 // CHECK: return %[[res]] @@ -1660,3 +1680,5 @@ func.func @xorxor3(%a : i32, %b : i32) -> i32 { %res = arith.xori %b, %c : i32 return %res : i32 } + +// ----- |