Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorMahesh Ravishankar <ravishankarm@google.com>2022-11-03 23:38:34 +0300
committerMahesh Ravishankar <ravishankarm@google.com>2022-11-03 23:38:34 +0300
commit38f34e587d10fcd7d18fd240e41248006faa639e (patch)
treee114ad88787772d16a04a1d1b8c5eda1e2b49f9d /mlir
parent24f9293de8794963bd29c731745a71ef6a1aab9d (diff)
[mlir][Arith] Fix folder of CmpIOp to not fail when element type is not integer.
The folder used `cast<IntegerType>` which would segfault if the type were a vector type. Handle this case appropriately and avoid failure. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D137345
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp19
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir26
2 files changed, 41 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d1d03a549092..2c0fc51d08a4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -19,6 +19,7 @@
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::arith;
@@ -1444,6 +1445,16 @@ static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
return DenseElementsAttr::get(shapedType, boolAttr);
}
+static Optional<int64_t> getIntegerWidth(Type t) {
+ if (auto intType = t.dyn_cast<IntegerType>()) {
+ return intType.getWidth();
+ }
+ if (auto vectorIntType = t.dyn_cast<VectorType>()) {
+ return vectorIntType.getElementType().cast<IntegerType>().getWidth();
+ }
+ return llvm::None;
+}
+
OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two operands");
@@ -1456,13 +1467,17 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
if (matchPattern(getRhs(), m_Zero())) {
if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
// extsi(%x : i1 -> iN) != 0 -> %x
- if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
+ Optional<int64_t> integerWidth =
+ getIntegerWidth(extOp.getOperand().getType());
+ if (integerWidth && integerWidth.value() == 1 &&
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
// extui(%x : i1 -> iN) != 0 -> %x
- if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
+ Optional<int64_t> integerWidth =
+ getIntegerWidth(extOp.getOperand().getType());
+ if (integerWidth && integerWidth.value() == 1 &&
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 337eec00f3bf..336324ef4eec 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -162,7 +162,7 @@ func.func @cmpi_const_right(%arg0: i64)
// -----
-// CHECK-LABEL: @cmpOfExtSI
+// CHECK-LABEL: @cmpOfExtSI(
// CHECK-NEXT: return %arg0
func.func @cmpOfExtSI(%arg0: i1) -> i1 {
%ext = arith.extsi %arg0 : i1 to i64
@@ -171,7 +171,7 @@ func.func @cmpOfExtSI(%arg0: i1) -> i1 {
return %res : i1
}
-// CHECK-LABEL: @cmpOfExtUI
+// CHECK-LABEL: @cmpOfExtUI(
// CHECK-NEXT: return %arg0
func.func @cmpOfExtUI(%arg0: i1) -> i1 {
%ext = arith.extui %arg0 : i1 to i64
@@ -182,6 +182,26 @@ func.func @cmpOfExtUI(%arg0: i1) -> i1 {
// -----
+// CHECK-LABEL: @cmpOfExtSIVector(
+// CHECK-NEXT: return %arg0
+func.func @cmpOfExtSIVector(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %ext = arith.extsi %arg0 : vector<4xi1> to vector<4xi64>
+ %c0 = arith.constant dense<0> : vector<4xi64>
+ %res = arith.cmpi ne, %ext, %c0 : vector<4xi64>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpOfExtUIVector(
+// CHECK-NEXT: return %arg0
+func.func @cmpOfExtUIVector(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %ext = arith.extui %arg0 : vector<4xi1> to vector<4xi64>
+ %c0 = arith.constant dense<0> : vector<4xi64>
+ %res = arith.cmpi ne, %ext, %c0 : vector<4xi64>
+ return %res : vector<4xi1>
+}
+
+// -----
+
// CHECK-LABEL: @extSIOfExtUI
// CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64
// CHECK: return %[[res]]
@@ -1660,3 +1680,5 @@ func.func @xorxor3(%a : i32, %b : i32) -> i32 {
%res = arith.xori %b, %c : i32
return %res : i32
}
+
+// -----