diff options
author | Nicolas Vasilache <nicolas.vasilache@gmail.com> | 2022-11-04 01:09:48 +0300 |
---|---|---|
committer | Nicolas Vasilache <nicolas.vasilache@gmail.com> | 2022-11-04 20:04:28 +0300 |
commit | c8fab80d64119ffcde78f0e9a70c5babb0da0467 (patch) | |
tree | fa930322ece23e40e7bed8c4938e2a796f965bc1 /mlir | |
parent | 8eab182bf2b7683fb5637a01b7664b802c759c2f (diff) |
[mlir][Transform] NFC - Add custom builders for some useful transforms.
Differential Revision: https://reviews.llvm.org/D137443
Diffstat (limited to 'mlir')
6 files changed, 206 insertions, 4 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h index f7952db7e2a2..2583875e2d0e 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -20,6 +20,12 @@ namespace linalg { class GenericOp; class LinalgOp; } // namespace linalg + +namespace transform { +// Types needed for builders. +struct TileSizesSpec {}; +struct NumThreadsSpec {}; +} // namespace transform } // namespace mlir //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 6cb14acb1b08..347def6c9d1b 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -112,6 +112,10 @@ def FuseIntoContainingOp : [TransformMappingAlloc, TransformMappingWrite]>:$fused_op); let assemblyFormat = "$producer_op `into` $containing_op attr-dict"; + + let builders = [ + OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)> + ]; } def GeneralizeOp : Op<Transform_Dialect, "structured.generalize", @@ -226,6 +230,10 @@ def MatchOp : Op<Transform_Dialect, "structured.match", // TODO: variadic results when needed. let results = (outs PDL_Operation:$results); + let builders = [ + OpBuilder<(ins "Value":$target, "ArrayRef<StringRef>":$opNames)> + ]; + let assemblyFormat = [{ (`ops` `{` $ops^ `}`)? (`interface` `{` $interface^ `}`)? @@ -600,6 +608,15 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction", let assemblyFormat = "$target attr-dict"; + let builders = [ + OpBuilder<(ins "Value":$target, + "int64_t":$splitFactor, + "int64_t":$insertSplitDimension, + CArg<"bool", "false">:$innerParallel, + CArg<"bool", "false">:$useScalingAlgorithm, + CArg<"bool", "false">:$useAlloc)> + ]; + let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, @@ -818,6 +835,30 @@ def TileToForeachThreadOp : OptionalAttr<I64ArrayAttr>:$thread_dim_mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$tiled_op); + + let builders = [ + OpBuilder<(ins "Value":$target, + "ArrayRef<int64_t>":$staticTileSizes, + CArg<"::mlir::transform::TileSizesSpec", + "::mlir::transform::TileSizesSpec()">, + CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>, + OpBuilder<(ins "Value":$target, + "ArrayRef<OpFoldResult>":$mixedTileSizes, + CArg<"::mlir::transform::TileSizesSpec", + "::mlir::transform::TileSizesSpec()">, + CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>, + OpBuilder<(ins "Value":$target, + "ArrayRef<int64_t>":$staticNumThreads, + CArg<"::mlir::transform::NumThreadsSpec", + "::mlir::transform::NumThreadsSpec()">, + CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>, + OpBuilder<(ins "Value":$target, + "ArrayRef<OpFoldResult>":$mixedNumThreads, + CArg<"::mlir::transform::NumThreadsSpec", + "::mlir::transform::NumThreadsSpec()">, + CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>, + ]; + let assemblyFormat = [{ $target oilist( `num_threads` custom<DynamicIndexList>($num_threads, @@ -943,6 +984,10 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize", let results = (outs PDL_Operation:$transformed); let assemblyFormat = "$target attr-dict"; + + let builders = [ + OpBuilder<(ins "Value":$target, CArg<"bool", "false">:$vectorizePadding)> + ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation *target, diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 4b1bb02ee757..42f8d5cb2769 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -253,6 +253,11 @@ def SplitHandlesOp : TransformDialectOp<"split_handles", let arguments = (ins TransformTypeInterface:$handle, I64Attr:$num_result_handles); let results = (outs Variadic<TransformTypeInterface>:$results); + + let builders = [ + OpBuilder<(ins "Value":$handle, "int64_t":$numResultHandles)> + ]; + let assemblyFormat = [{ $handle `in` `[` $num_result_handles `]` attr-dict `:` functional-type(operands, results) @@ -305,6 +310,12 @@ def PrintOp : TransformDialectOp<"print", let arguments = (ins Optional<TransformTypeInterface>:$target, OptionalAttr<StrAttr>:$name); let results = (outs); + + let builders = [ + OpBuilder<(ins CArg<"StringRef", "StringRef()">:$name)>, + OpBuilder<(ins "Value":$target, CArg<"StringRef", "StringRef()">:$name)> + ]; + let assemblyFormat = "$target attr-dict (`:` type($target)^)?"; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index c8a3cb6946e3..a35dd1448396 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -254,6 +254,14 @@ LogicalResult transform::FuseOp::verify() { // FuseIntoContainingOp //===----------------------------------------------------------------------===// +void transform::FuseIntoContainingOp::build(OpBuilder &builder, + OperationState &result, + Value producerOp, + Value containingOp) { + result.addOperands({producerOp, containingOp}); + result.addTypes(pdl::OperationType::get(builder.getContext())); +} + /// Find the first "extract" user of `producerOp` and tile it right before its /// use. The tiled op is fused under the `containingOp`. /// Return this fused op on success or nullptr if anything fails. @@ -628,6 +636,14 @@ LogicalResult transform::InterchangeOp::verify() { // MatchOp //===---------------------------------------------------------------------===// +void transform::MatchOp::build(OpBuilder &builder, OperationState &result, + Value target, ArrayRef<StringRef> opNames) { + result.addOperands(target); + result.addAttribute(MatchOp::getOpsAttrName(result.name), + builder.getStrArrayAttr(opNames)); + result.addTypes(pdl::OperationType::get(builder.getContext())); +} + DiagnosedSilenceableFailure transform::MatchOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -1069,6 +1085,34 @@ LogicalResult SplitOp::verify() { // SplitReductionOp //===----------------------------------------------------------------------===// +void transform::SplitReductionOp::build( + OpBuilder &builder, OperationState &result, Value target, + int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel, + bool useScalingAlgorithm, bool useAlloc) { + MLIRContext *ctx = builder.getContext(); + result.addOperands(target); + result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name), + builder.getI64IntegerAttr(splitFactor)); + result.addAttribute( + SplitReductionOp::getInsertSplitDimensionAttrName(result.name), + builder.getI64IntegerAttr(insertSplitDimension)); + if (innerParallel) { + result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name), + builder.getUnitAttr()); + } + if (useScalingAlgorithm) { + result.addAttribute( + SplitReductionOp::getUseScalingAlgorithmAttrName(result.name), + builder.getUnitAttr()); + } + if (useAlloc) { + result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name), + builder.getUnitAttr()); + } + auto resultType = pdl::OperationType::get(ctx); + result.addTypes({resultType, resultType, resultType, resultType}); +} + DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl<Operation *> &results, @@ -1277,13 +1321,75 @@ void transform::TileOp::getEffects( // TileToForeachThreadOp //===----------------------------------------------------------------------===// +void transform::TileToForeachThreadOp::build( + OpBuilder &builder, OperationState &result, Value target, + ArrayRef<int64_t> staticTileSizes, transform::TileSizesSpec, + ArrayRef<int64_t> threadDimMapping) { + return build(builder, result, target, + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + TileSizesSpec(), threadDimMapping); +} + +void transform::TileToForeachThreadOp::build( + OpBuilder &builder, OperationState &result, Value target, + ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec, + ArrayRef<int64_t> threadDimMapping) { + SmallVector<int64_t> staticTileSizes; + SmallVector<Value> dynamicTileSizes; + dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes, + ShapedType::kDynamicSize); + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, horrible + // bugs ensue. + MLIRContext *ctx = builder.getContext(); + auto operationType = pdl::OperationType::get(ctx); + auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); + ArrayAttr threadDimMappingAttr; + if (!threadDimMapping.empty()) + threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping); + build(builder, result, TypeRange{operationType, operationType}, target, + /*numThreads=*/ValueRange{}, dynamicTileSizes, + /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, + threadDimMappingAttr); +} + +void transform::TileToForeachThreadOp::build( + OpBuilder &builder, OperationState &result, Value target, + ArrayRef<int64_t> staticNumThreads, transform::NumThreadsSpec, + ArrayRef<int64_t> threadDimMapping) { + return build(builder, result, target, + getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), + NumThreadsSpec(), threadDimMapping); +} + +void transform::TileToForeachThreadOp::build( + OpBuilder &builder, OperationState &result, Value target, + ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec, + ArrayRef<int64_t> threadDimMapping) { + SmallVector<int64_t> staticNumThreads; + SmallVector<Value> dynamicNumThreads; + dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, + staticNumThreads, ShapedType::kDynamicSize); + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, horrible + // bugs ensue. + MLIRContext *ctx = builder.getContext(); + auto operationType = pdl::OperationType::get(ctx); + auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); + ArrayAttr threadDimMappingAttr; + if (!threadDimMapping.empty()) + threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping); + build(builder, result, TypeRange{operationType, operationType}, target, + dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr, + /*staticTileSizes=*/ArrayAttr(), threadDimMappingAttr); +} + DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef<Operation *> targets, ArrayRef<OpFoldResult> mixedNumThreads, ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping, SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) { - if (targets.empty()) return DiagnosedSilenceableFailure(success()); @@ -1573,6 +1679,16 @@ void transform::TileToScfForOp::getEffects( // VectorizeOp //===----------------------------------------------------------------------===// +void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result, + Value target, bool vectorizePadding) { + result.addOperands(target); + if (vectorizePadding) { + result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name), + builder.getUnitAttr()); + } + result.addTypes(pdl::OperationType::get(builder.getContext())); +} + namespace { /// This is an helper only to call vectorize via a pattern inside of /// VectorizeOp::applyToOne. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 5d84b7b0a603..9b136cccbe6f 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -314,11 +314,11 @@ transform::TransformResults::TransformResults(unsigned numSegments) { void transform::TransformResults::set(OpResult value, ArrayRef<Operation *> ops) { - unsigned position = value.getResultNumber(); - assert(position < segments.size() && + int64_t position = value.getResultNumber(); + assert(position < static_cast<int64_t>(segments.size()) && "setting results for a non-existent handle"); assert(segments[position].data() == nullptr && "results already set"); - unsigned start = operations.size(); + int64_t start = operations.size(); llvm::append_range(operations, ops); segments[position] = makeArrayRef(operations).drop_front(start); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 2be1bea91fbe..5fe2d465ee51 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -472,6 +472,16 @@ OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) { // SplitHandlesOp //===----------------------------------------------------------------------===// +void transform::SplitHandlesOp::build(OpBuilder &builder, + OperationState &result, Value target, + int64_t numResultHandles) { + result.addOperands(target); + result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name), + builder.getI64IntegerAttr(numResultHandles)); + auto pdlOpType = pdl::OperationType::get(builder.getContext()); + result.addTypes(SmallVector<pdl::OperationType>(numResultHandles, pdlOpType)); +} + DiagnosedSilenceableFailure transform::SplitHandlesOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -812,6 +822,20 @@ LogicalResult transform::WithPDLPatternsOp::verify() { // PrintOp //===----------------------------------------------------------------------===// +void transform::PrintOp::build(OpBuilder &builder, OperationState &result, + StringRef name) { + if (!name.empty()) { + result.addAttribute(PrintOp::getNameAttrName(result.name), + builder.getStrArrayAttr(name)); + } +} + +void transform::PrintOp::build(OpBuilder &builder, OperationState &result, + Value target, StringRef name) { + result.addOperands({target}); + build(builder, result, name); +} + DiagnosedSilenceableFailure transform::PrintOp::apply(transform::TransformResults &results, transform::TransformState &state) { |