Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorNicolas Vasilache <nicolas.vasilache@gmail.com>2022-11-04 01:09:48 +0300
committerNicolas Vasilache <nicolas.vasilache@gmail.com>2022-11-04 20:04:28 +0300
commitc8fab80d64119ffcde78f0e9a70c5babb0da0467 (patch)
treefa930322ece23e40e7bed8c4938e2a796f965bc1 /mlir
parent8eab182bf2b7683fb5637a01b7664b802c759c2f (diff)
[mlir][Transform] NFC - Add custom builders for some useful transforms.
Differential Revision: https://reviews.llvm.org/D137443
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h6
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td45
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformOps.td11
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp118
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp6
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp24
6 files changed, 206 insertions, 4 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index f7952db7e2a2..2583875e2d0e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -20,6 +20,12 @@ namespace linalg {
class GenericOp;
class LinalgOp;
} // namespace linalg
+
+namespace transform {
+// Types needed for builders.
+struct TileSizesSpec {};
+struct NumThreadsSpec {};
+} // namespace transform
} // namespace mlir
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6cb14acb1b08..347def6c9d1b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -112,6 +112,10 @@ def FuseIntoContainingOp :
[TransformMappingAlloc,
TransformMappingWrite]>:$fused_op);
let assemblyFormat = "$producer_op `into` $containing_op attr-dict";
+
+ let builders = [
+ OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)>
+ ];
}
def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
@@ -226,6 +230,10 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
// TODO: variadic results when needed.
let results = (outs PDL_Operation:$results);
+ let builders = [
+ OpBuilder<(ins "Value":$target, "ArrayRef<StringRef>":$opNames)>
+ ];
+
let assemblyFormat = [{
(`ops` `{` $ops^ `}`)?
(`interface` `{` $interface^ `}`)?
@@ -600,6 +608,15 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
let assemblyFormat = "$target attr-dict";
+ let builders = [
+ OpBuilder<(ins "Value":$target,
+ "int64_t":$splitFactor,
+ "int64_t":$insertSplitDimension,
+ CArg<"bool", "false">:$innerParallel,
+ CArg<"bool", "false">:$useScalingAlgorithm,
+ CArg<"bool", "false">:$useAlloc)>
+ ];
+
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::linalg::LinalgOp target,
@@ -818,6 +835,30 @@ def TileToForeachThreadOp :
OptionalAttr<I64ArrayAttr>:$thread_dim_mapping);
let results = (outs PDL_Operation:$foreach_thread_op,
PDL_Operation:$tiled_op);
+
+ let builders = [
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$staticTileSizes,
+ CArg<"::mlir::transform::TileSizesSpec",
+ "::mlir::transform::TileSizesSpec()">,
+ CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<OpFoldResult>":$mixedTileSizes,
+ CArg<"::mlir::transform::TileSizesSpec",
+ "::mlir::transform::TileSizesSpec()">,
+ CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$staticNumThreads,
+ CArg<"::mlir::transform::NumThreadsSpec",
+ "::mlir::transform::NumThreadsSpec()">,
+ CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<OpFoldResult>":$mixedNumThreads,
+ CArg<"::mlir::transform::NumThreadsSpec",
+ "::mlir::transform::NumThreadsSpec()">,
+ CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+ ];
+
let assemblyFormat = [{
$target oilist(
`num_threads` custom<DynamicIndexList>($num_threads,
@@ -943,6 +984,10 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
let results = (outs PDL_Operation:$transformed);
let assemblyFormat = "$target attr-dict";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target, CArg<"bool", "false">:$vectorizePadding)>
+ ];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation *target,
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 4b1bb02ee757..42f8d5cb2769 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -253,6 +253,11 @@ def SplitHandlesOp : TransformDialectOp<"split_handles",
let arguments = (ins TransformTypeInterface:$handle,
I64Attr:$num_result_handles);
let results = (outs Variadic<TransformTypeInterface>:$results);
+
+ let builders = [
+ OpBuilder<(ins "Value":$handle, "int64_t":$numResultHandles)>
+ ];
+
let assemblyFormat = [{
$handle `in` `[` $num_result_handles `]`
attr-dict `:` functional-type(operands, results)
@@ -305,6 +310,12 @@ def PrintOp : TransformDialectOp<"print",
let arguments = (ins Optional<TransformTypeInterface>:$target,
OptionalAttr<StrAttr>:$name);
let results = (outs);
+
+ let builders = [
+ OpBuilder<(ins CArg<"StringRef", "StringRef()">:$name)>,
+ OpBuilder<(ins "Value":$target, CArg<"StringRef", "StringRef()">:$name)>
+ ];
+
let assemblyFormat = "$target attr-dict (`:` type($target)^)?";
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c8a3cb6946e3..a35dd1448396 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -254,6 +254,14 @@ LogicalResult transform::FuseOp::verify() {
// FuseIntoContainingOp
//===----------------------------------------------------------------------===//
+void transform::FuseIntoContainingOp::build(OpBuilder &builder,
+ OperationState &result,
+ Value producerOp,
+ Value containingOp) {
+ result.addOperands({producerOp, containingOp});
+ result.addTypes(pdl::OperationType::get(builder.getContext()));
+}
+
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
@@ -628,6 +636,14 @@ LogicalResult transform::InterchangeOp::verify() {
// MatchOp
//===---------------------------------------------------------------------===//
+void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
+ Value target, ArrayRef<StringRef> opNames) {
+ result.addOperands(target);
+ result.addAttribute(MatchOp::getOpsAttrName(result.name),
+ builder.getStrArrayAttr(opNames));
+ result.addTypes(pdl::OperationType::get(builder.getContext()));
+}
+
DiagnosedSilenceableFailure
transform::MatchOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
@@ -1069,6 +1085,34 @@ LogicalResult SplitOp::verify() {
// SplitReductionOp
//===----------------------------------------------------------------------===//
+void transform::SplitReductionOp::build(
+ OpBuilder &builder, OperationState &result, Value target,
+ int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
+ bool useScalingAlgorithm, bool useAlloc) {
+ MLIRContext *ctx = builder.getContext();
+ result.addOperands(target);
+ result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
+ builder.getI64IntegerAttr(splitFactor));
+ result.addAttribute(
+ SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
+ builder.getI64IntegerAttr(insertSplitDimension));
+ if (innerParallel) {
+ result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
+ builder.getUnitAttr());
+ }
+ if (useScalingAlgorithm) {
+ result.addAttribute(
+ SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
+ builder.getUnitAttr());
+ }
+ if (useAlloc) {
+ result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
+ builder.getUnitAttr());
+ }
+ auto resultType = pdl::OperationType::get(ctx);
+ result.addTypes({resultType, resultType, resultType, resultType});
+}
+
DiagnosedSilenceableFailure
transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
@@ -1277,13 +1321,75 @@ void transform::TileOp::getEffects(
// TileToForeachThreadOp
//===----------------------------------------------------------------------===//
+void transform::TileToForeachThreadOp::build(
+ OpBuilder &builder, OperationState &result, Value target,
+ ArrayRef<int64_t> staticTileSizes, transform::TileSizesSpec,
+ ArrayRef<int64_t> threadDimMapping) {
+ return build(builder, result, target,
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ TileSizesSpec(), threadDimMapping);
+}
+
+void transform::TileToForeachThreadOp::build(
+ OpBuilder &builder, OperationState &result, Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec,
+ ArrayRef<int64_t> threadDimMapping) {
+ SmallVector<int64_t> staticTileSizes;
+ SmallVector<Value> dynamicTileSizes;
+ dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes,
+ ShapedType::kDynamicSize);
+ // Call the default builder which sets up the proper operands segment sizes
+ // attributes for multiple variadic operands. In the absence of this, horrible
+ // bugs ensue.
+ MLIRContext *ctx = builder.getContext();
+ auto operationType = pdl::OperationType::get(ctx);
+ auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
+ ArrayAttr threadDimMappingAttr;
+ if (!threadDimMapping.empty())
+ threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping);
+ build(builder, result, TypeRange{operationType, operationType}, target,
+ /*numThreads=*/ValueRange{}, dynamicTileSizes,
+ /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr,
+ threadDimMappingAttr);
+}
+
+void transform::TileToForeachThreadOp::build(
+ OpBuilder &builder, OperationState &result, Value target,
+ ArrayRef<int64_t> staticNumThreads, transform::NumThreadsSpec,
+ ArrayRef<int64_t> threadDimMapping) {
+ return build(builder, result, target,
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
+ NumThreadsSpec(), threadDimMapping);
+}
+
+void transform::TileToForeachThreadOp::build(
+ OpBuilder &builder, OperationState &result, Value target,
+ ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec,
+ ArrayRef<int64_t> threadDimMapping) {
+ SmallVector<int64_t> staticNumThreads;
+ SmallVector<Value> dynamicNumThreads;
+ dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
+ staticNumThreads, ShapedType::kDynamicSize);
+ // Call the default builder which sets up the proper operands segment sizes
+ // attributes for multiple variadic operands. In the absence of this, horrible
+ // bugs ensue.
+ MLIRContext *ctx = builder.getContext();
+ auto operationType = pdl::OperationType::get(ctx);
+ auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
+ ArrayAttr threadDimMappingAttr;
+ if (!threadDimMapping.empty())
+ threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping);
+ build(builder, result, TypeRange{operationType, operationType}, target,
+ dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr,
+ /*staticTileSizes=*/ArrayAttr(), threadDimMappingAttr);
+}
+
DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
RewriterBase &rewriter, transform::TransformState &state,
TransformOpInterface transformOp, ArrayRef<Operation *> targets,
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping,
SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
-
if (targets.empty())
return DiagnosedSilenceableFailure(success());
@@ -1573,6 +1679,16 @@ void transform::TileToScfForOp::getEffects(
// VectorizeOp
//===----------------------------------------------------------------------===//
+void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result,
+ Value target, bool vectorizePadding) {
+ result.addOperands(target);
+ if (vectorizePadding) {
+ result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name),
+ builder.getUnitAttr());
+ }
+ result.addTypes(pdl::OperationType::get(builder.getContext()));
+}
+
namespace {
/// This is an helper only to call vectorize via a pattern inside of
/// VectorizeOp::applyToOne.
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 5d84b7b0a603..9b136cccbe6f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -314,11 +314,11 @@ transform::TransformResults::TransformResults(unsigned numSegments) {
void transform::TransformResults::set(OpResult value,
ArrayRef<Operation *> ops) {
- unsigned position = value.getResultNumber();
- assert(position < segments.size() &&
+ int64_t position = value.getResultNumber();
+ assert(position < static_cast<int64_t>(segments.size()) &&
"setting results for a non-existent handle");
assert(segments[position].data() == nullptr && "results already set");
- unsigned start = operations.size();
+ int64_t start = operations.size();
llvm::append_range(operations, ops);
segments[position] = makeArrayRef(operations).drop_front(start);
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 2be1bea91fbe..5fe2d465ee51 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -472,6 +472,16 @@ OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
// SplitHandlesOp
//===----------------------------------------------------------------------===//
+void transform::SplitHandlesOp::build(OpBuilder &builder,
+ OperationState &result, Value target,
+ int64_t numResultHandles) {
+ result.addOperands(target);
+ result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name),
+ builder.getI64IntegerAttr(numResultHandles));
+ auto pdlOpType = pdl::OperationType::get(builder.getContext());
+ result.addTypes(SmallVector<pdl::OperationType>(numResultHandles, pdlOpType));
+}
+
DiagnosedSilenceableFailure
transform::SplitHandlesOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
@@ -812,6 +822,20 @@ LogicalResult transform::WithPDLPatternsOp::verify() {
// PrintOp
//===----------------------------------------------------------------------===//
+void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
+ StringRef name) {
+ if (!name.empty()) {
+ result.addAttribute(PrintOp::getNameAttrName(result.name),
+ builder.getStrArrayAttr(name));
+ }
+}
+
+void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
+ Value target, StringRef name) {
+ result.addOperands({target});
+ build(builder, result, name);
+}
+
DiagnosedSilenceableFailure
transform::PrintOp::apply(transform::TransformResults &results,
transform::TransformState &state) {