diff options
author | Peiming Liu <peiming@google.com> | 2022-10-31 19:26:32 +0300 |
---|---|---|
committer | Peiming Liu <peiming@google.com> | 2022-11-02 19:53:36 +0300 |
commit | 1ca119728ee1566ecc53bed350cf6c8db6bc88e5 (patch) | |
tree | 348ac3b1da0d5f6d7a736780f29244bbeb9b50cd /mlir | |
parent | 96a74c452728fd330f99394bb25dacecd9325645 (diff) |
[mlir][scf] support 1:N type conversion for scf.if/while/condition
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D137100
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp | 110 | ||||
-rw-r--r-- | mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir | 65 |
2 files changed, 114 insertions, 61 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index a441b6c80b75..ac3d76d56922 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -155,44 +155,57 @@ public: } // namespace namespace { -class ConvertIfOpTypes : public OpConversionPattern<IfOp> { +class ConvertIfOpTypes + : public Structural1ToNConversionPattern<IfOp, ConvertIfOpTypes> { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // TODO: Generalize this to any type conversion, not just 1:1. - // - // We need to implement something more sophisticated here that tracks - // which types convert to which other types and does the appropriate - // materialization logic. - // For example, it's possible that one result type converts to 0 types and - // another to 2 types, so newResultTypes would at least be the right size - // to not crash in the llvm::zip call below, but then we would set the the - // wrong type on the SSA values! These edge cases are also why we cannot - // safely use the TypeConverter::convertTypes helper here. - SmallVector<Type, 6> newResultTypes; - for (auto type : op.getResultTypes()) { - Type newType = typeConverter->convertType(type); - if (!newType) - return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); - newResultTypes.push_back(newType); - } + using Structural1ToNConversionPattern::Structural1ToNConversionPattern; - // See comments in the ForOp pattern for why we clone without regions and - // then inline. - IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation())); + Optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + TypeRange dstTypes) const { + + IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes, + adaptor.getCondition(), true); + newOp->setAttrs(op->getAttrs()); + + // We do not need the empty blocks created by rewriter. + rewriter.eraseBlock(newOp.elseBlock()); + rewriter.eraseBlock(newOp.thenBlock()); + + // Inlines block from the original operation. rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), newOp.getThenRegion().end()); rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), newOp.getElseRegion().end()); - // Update the operands and types. - newOp->setOperands(adaptor.getOperands()); - for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) - std::get<0>(t).setType(std::get<1>(t)); - rewriter.replaceOp(op, newOp.getResults()); - return success(); + return newOp; + } +}; +} // namespace + +namespace { +class ConvertWhileOpTypes + : public Structural1ToNConversionPattern<WhileOp, ConvertWhileOpTypes> { +public: + using Structural1ToNConversionPattern::Structural1ToNConversionPattern; + + Optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + TypeRange dstTypes) const { + // Unpacked the iteration arguments. + SmallVector<Value> flatArgs; + for (Value arg : adaptor.getOperands()) + unpackUnrealizedConversionCast(arg, flatArgs); + + auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs); + + for (auto i : {0u, 1u}) { + if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) + return llvm::None; + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + } + return newOp; } }; } // namespace @@ -218,42 +231,17 @@ public: } // namespace namespace { -class ConvertWhileOpTypes : public OpConversionPattern<WhileOp> { -public: - using OpConversionPattern<WhileOp>::OpConversionPattern; - - LogicalResult - matchAndRewrite(WhileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *converter = getTypeConverter(); - assert(converter); - SmallVector<Type> newResultTypes; - if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) - return failure(); - - auto newOp = rewriter.create<WhileOp>(op.getLoc(), newResultTypes, - adaptor.getOperands()); - for (auto i : {0u, 1u}) { - auto &dstRegion = newOp.getRegion(i); - rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); - if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) - return rewriter.notifyMatchFailure(op, "could not convert body types"); - } - rewriter.replaceOp(op, newOp.getResults()); - return success(); - } -}; -} // namespace - -namespace { class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> { public: using OpConversionPattern<ConditionOp>::OpConversionPattern; LogicalResult matchAndRewrite(ConditionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.updateRootInPlace( - op, [&]() { op->setOperands(adaptor.getOperands()); }); + SmallVector<Value> unpackedYield; + for (Value operand : adaptor.getOperands()) + unpackUnrealizedConversionCast(operand, unpackedYield); + + rewriter.updateRootInPlace(op, [&]() { op->setOperands(unpackedYield); }); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir index 334d58c62393..207e46b3d45a 100644 --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -30,3 +30,68 @@ func.func @for(%in: tensor<1024xf32, #SparseVector>, return %1 : tensor<1024xf32, #SparseVector> } + +// CHECK-LABEL: func @if( +// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>, +// CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>, +// CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>, +// CHECK-SAME: %[[DIM_SIZE_1:.*6]]: memref<1xindex>, +// CHECK-SAME: %[[DIM_CURSOR_1:.*7]]: memref<1xindex>, +// CHECK-SAME: %[[MEM_SIZE_1:.*8]]: memref<3xindex>, +// CHECK-SAME: %[[POINTER_1:.*9]]: memref<?xindex>, +// CHECK-SAME: %[[INDICES_1:.*10]]: memref<?xindex>, +// CHECK-SAME: %[[VALUE_1:.*11]]: memref<?xf32>, +// CHECK-SAME: %[[TMP_arg12:.*12]]: i1) -> +// CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) { +// CHECK: %[[SV:.*]]:6 = scf.if %[[TMP_arg12]] -> (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) { +// CHECK: scf.yield %[[DIM_SIZE]], %[[DIM_CURSOR]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32> +// CHECK: } else { +// CHECK: scf.yield %[[DIM_SIZE_1]], %[[DIM_CURSOR_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32> +// CHECK: } +// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32> +func.func @if(%t: tensor<1024xf32, #SparseVector>, + %f: tensor<1024xf32, #SparseVector>, + %c: i1) -> tensor<1024xf32, #SparseVector> { + %1 = scf.if %c -> tensor<1024xf32, #SparseVector> { + scf.yield %t : tensor<1024xf32, #SparseVector> + } else { + scf.yield %f : tensor<1024xf32, #SparseVector> + } + + return %1 : tensor<1024xf32, #SparseVector> +} + +// CHECK-LABEL: func @while( +// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>, +// CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>, +// CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>, +// CHECK-SAME: %[[TMP_arg6:.*6]]: i1) -> +// CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) { +// CHECK: %[[SV:.*]]:6 = scf.while ( +// CHECK-SAME: %[[TMP_arg7:.*]] = %[[DIM_SIZE]], +// CHECK-SAME: %[[TMP_arg8:.*]] = %[[DIM_CURSOR]], +// CHECK-SAME: %[[TMP_arg9:.*]] = %[[MEM_SIZE]], +// CHECK-SAME: %[[TMP_arg10:.*]] = %[[POINTER]], +// CHECK-SAME: %[[TMP_arg11:.*]] = %[[INDICES]], +// CHECK-SAME: %[[TMP_arg12:.*]] = %[[VALUE]]) +// CHECK: scf.condition(%[[TMP_arg6]]) %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32> +// CHECK: } do { +// CHECK: ^bb0(%[[TMP_arg7]]: memref<1xindex>, %[[TMP_arg8]]: memref<1xindex>, %[[TMP_arg9]]: memref<3xindex>, %[[TMP_arg10]]: memref<?xindex>, %[[TMP_arg11]]: memref<?xindex>, %[[TMP_arg12]]: memref<?xf32>): +// CHECK: scf.yield %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32> +// CHECK: } +// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32> +func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> { + %0 = scf.while (%arg4 = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> { + scf.condition(%c) %arg4 : tensor<1024xf32, #SparseVector> + } do { + ^bb0(%arg7: tensor<1024xf32, #SparseVector>): + scf.yield %arg7 : tensor<1024xf32, #SparseVector> + } + return %0: tensor<1024xf32, #SparseVector> +} |