Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorPeiming Liu <peiming@google.com>2022-10-31 19:26:32 +0300
committerPeiming Liu <peiming@google.com>2022-11-02 19:53:36 +0300
commit1ca119728ee1566ecc53bed350cf6c8db6bc88e5 (patch)
tree348ac3b1da0d5f6d7a736780f29244bbeb9b50cd /mlir
parent96a74c452728fd330f99394bb25dacecd9325645 (diff)
[mlir][scf] support 1:N type conversion for scf.if/while/condition
Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D137100
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp110
-rw-r--r--mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir65
2 files changed, 114 insertions, 61 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index a441b6c80b75..ac3d76d56922 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -155,44 +155,57 @@ public:
} // namespace
namespace {
-class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
+class ConvertIfOpTypes
+ : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> {
public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IfOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // TODO: Generalize this to any type conversion, not just 1:1.
- //
- // We need to implement something more sophisticated here that tracks
- // which types convert to which other types and does the appropriate
- // materialization logic.
- // For example, it's possible that one result type converts to 0 types and
- // another to 2 types, so newResultTypes would at least be the right size
- // to not crash in the llvm::zip call below, but then we would set the the
- // wrong type on the SSA values! These edge cases are also why we cannot
- // safely use the TypeConverter::convertTypes helper here.
- SmallVector<Type, 6> newResultTypes;
- for (auto type : op.getResultTypes()) {
- Type newType = typeConverter->convertType(type);
- if (!newType)
- return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
- newResultTypes.push_back(newType);
- }
+ using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
- // See comments in the ForOp pattern for why we clone without regions and
- // then inline.
- IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
+ Optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ TypeRange dstTypes) const {
+
+ IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
+ adaptor.getCondition(), true);
+ newOp->setAttrs(op->getAttrs());
+
+ // We do not need the empty blocks created by rewriter.
+ rewriter.eraseBlock(newOp.elseBlock());
+ rewriter.eraseBlock(newOp.thenBlock());
+
+ // Inlines block from the original operation.
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
newOp.getThenRegion().end());
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
newOp.getElseRegion().end());
- // Update the operands and types.
- newOp->setOperands(adaptor.getOperands());
- for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
- std::get<0>(t).setType(std::get<1>(t));
- rewriter.replaceOp(op, newOp.getResults());
- return success();
+ return newOp;
+ }
+};
+} // namespace
+
+namespace {
+class ConvertWhileOpTypes
+ : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> {
+public:
+ using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
+
+ Optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ TypeRange dstTypes) const {
+ // Unpacked the iteration arguments.
+ SmallVector<Value> flatArgs;
+ for (Value arg : adaptor.getOperands())
+ unpackUnrealizedConversionCast(arg, flatArgs);
+
+ auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
+
+ for (auto i : {0u, 1u}) {
+ if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
+ return llvm::None;
+ auto &dstRegion = newOp.getRegion(i);
+ rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
+ }
+ return newOp;
}
};
} // namespace
@@ -218,42 +231,17 @@ public:
} // namespace
namespace {
-class ConvertWhileOpTypes : public OpConversionPattern<WhileOp> {
-public:
- using OpConversionPattern<WhileOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(WhileOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto *converter = getTypeConverter();
- assert(converter);
- SmallVector<Type> newResultTypes;
- if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
- return failure();
-
- auto newOp = rewriter.create<WhileOp>(op.getLoc(), newResultTypes,
- adaptor.getOperands());
- for (auto i : {0u, 1u}) {
- auto &dstRegion = newOp.getRegion(i);
- rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
- if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
- return rewriter.notifyMatchFailure(op, "could not convert body types");
- }
- rewriter.replaceOp(op, newOp.getResults());
- return success();
- }
-};
-} // namespace
-
-namespace {
class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
public:
using OpConversionPattern<ConditionOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.updateRootInPlace(
- op, [&]() { op->setOperands(adaptor.getOperands()); });
+ SmallVector<Value> unpackedYield;
+ for (Value operand : adaptor.getOperands())
+ unpackUnrealizedConversionCast(operand, unpackedYield);
+
+ rewriter.updateRootInPlace(op, [&]() { op->setOperands(unpackedYield); });
return success();
}
};
diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
index 334d58c62393..207e46b3d45a 100644
--- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
@@ -30,3 +30,68 @@ func.func @for(%in: tensor<1024xf32, #SparseVector>,
return %1 : tensor<1024xf32, #SparseVector>
}
+
+// CHECK-LABEL: func @if(
+// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>,
+// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>,
+// CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>,
+// CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>,
+// CHECK-SAME: %[[DIM_SIZE_1:.*6]]: memref<1xindex>,
+// CHECK-SAME: %[[DIM_CURSOR_1:.*7]]: memref<1xindex>,
+// CHECK-SAME: %[[MEM_SIZE_1:.*8]]: memref<3xindex>,
+// CHECK-SAME: %[[POINTER_1:.*9]]: memref<?xindex>,
+// CHECK-SAME: %[[INDICES_1:.*10]]: memref<?xindex>,
+// CHECK-SAME: %[[VALUE_1:.*11]]: memref<?xf32>,
+// CHECK-SAME: %[[TMP_arg12:.*12]]: i1) ->
+// CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
+// CHECK: %[[SV:.*]]:6 = scf.if %[[TMP_arg12]] -> (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
+// CHECK: scf.yield %[[DIM_SIZE]], %[[DIM_CURSOR]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: } else {
+// CHECK: scf.yield %[[DIM_SIZE_1]], %[[DIM_CURSOR_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: }
+// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+func.func @if(%t: tensor<1024xf32, #SparseVector>,
+ %f: tensor<1024xf32, #SparseVector>,
+ %c: i1) -> tensor<1024xf32, #SparseVector> {
+ %1 = scf.if %c -> tensor<1024xf32, #SparseVector> {
+ scf.yield %t : tensor<1024xf32, #SparseVector>
+ } else {
+ scf.yield %f : tensor<1024xf32, #SparseVector>
+ }
+
+ return %1 : tensor<1024xf32, #SparseVector>
+}
+
+// CHECK-LABEL: func @while(
+// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>,
+// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>,
+// CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>,
+// CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>,
+// CHECK-SAME: %[[TMP_arg6:.*6]]: i1) ->
+// CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
+// CHECK: %[[SV:.*]]:6 = scf.while (
+// CHECK-SAME: %[[TMP_arg7:.*]] = %[[DIM_SIZE]],
+// CHECK-SAME: %[[TMP_arg8:.*]] = %[[DIM_CURSOR]],
+// CHECK-SAME: %[[TMP_arg9:.*]] = %[[MEM_SIZE]],
+// CHECK-SAME: %[[TMP_arg10:.*]] = %[[POINTER]],
+// CHECK-SAME: %[[TMP_arg11:.*]] = %[[INDICES]],
+// CHECK-SAME: %[[TMP_arg12:.*]] = %[[VALUE]])
+// CHECK: scf.condition(%[[TMP_arg6]]) %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: } do {
+// CHECK: ^bb0(%[[TMP_arg7]]: memref<1xindex>, %[[TMP_arg8]]: memref<1xindex>, %[[TMP_arg9]]: memref<3xindex>, %[[TMP_arg10]]: memref<?xindex>, %[[TMP_arg11]]: memref<?xindex>, %[[TMP_arg12]]: memref<?xf32>):
+// CHECK: scf.yield %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK: }
+// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> {
+ %0 = scf.while (%arg4 = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> {
+ scf.condition(%c) %arg4 : tensor<1024xf32, #SparseVector>
+ } do {
+ ^bb0(%arg7: tensor<1024xf32, #SparseVector>):
+ scf.yield %arg7 : tensor<1024xf32, #SparseVector>
+ }
+ return %0: tensor<1024xf32, #SparseVector>
+}