Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2022-10-11 18:23:48 +0300
committerAlex Zinenko <zinenko@google.com>2022-10-12 11:16:28 +0300
commit32f0bde548cface29b26ee26763881dbcfb8bb58 (patch)
tree171aab61529a08939f2fea6badcdd93037c8a63a /mlir
parent812ad2167bd2e27f5d0dee07bb03a5910616e0b6 (diff)
[mlir] add transform dialect entry point
Introduce `transform::applyTransforms` as a top-level entry point to the Transform dialect-driven transformation infrastructure, by analogy with `applyFull/PartialConversion`. Clients are expected to use this function and no longer need to maintain the transformation state. Make the constructor of the TransformState private for that purpose. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D135681
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td27
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h29
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp31
-rw-r--r--mlir/test/Dialect/Transform/test-interpreter.mlir36
-rw-r--r--mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp9
5 files changed, 95 insertions, 37 deletions
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index cf5072bc5015..5fef82b63a7b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -16,19 +16,18 @@ def Transform_Dialect : Dialect {
let description = [{
## Disclaimer
- ** Proceed with care: not ready for general use. **
+ **This dialect is actively developed and may change frequently.**
- This dialect is evolving rapidly and may change on a very short notice. To
- decrease the maintenance burden and churn, only a few in-tree use cases are
- currently supported in the main tree:
+ To decrease the maintenance burden and churn, please post a description of
+ the intended use case on the MLIR forum. A few in-tree use cases are
+ currently supported:
- high-level transformations on "structured ops" (i.e. ops that operate on
chunks of data in a way that can be decomposed into operations on
smaller chunks of data and control flow) in Linalg, Tensor and Vector
- dialects.
-
- *Please post a description of the intended use case on the MLIR forum and
- wait for confirmation.*
+ dialects;
+ - loop transformations in the SCF dialect.
+
## Overview
@@ -79,6 +78,18 @@ def Transform_Dialect : Dialect {
expected to have the `PossibleTopLevelTransformOpTrait` and may be used
without arguments.
+ A program transformation expressed using the Transform dialect can be
+ programmatically triggered by calling:
+
+ ```c++
+ LogicalResult transform::applyTransforms(Operation *payloadRoot,
+ TransformOpInterface transform,
+ const TransformOptions &options);
+ ```
+
+ that applies the transformations specified by the top-level `transform` to
+ payload IR contained in `payloadRoot`.
+
## Dialect Extension Mechanism
This dialect is designed to be extensible, that is, clients of this dialect
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 25f61d639844..2c81985e3c12 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -206,6 +206,16 @@ private:
bool expensiveChecksEnabled = true;
};
+/// Entry point to the Transform dialect infrastructure. Applies the
+/// transformation specified by `transform` to payload IR contained in
+/// `payloadRoot`. The `transform` operation may contain other operations that
+/// will be executed following the internal logic of the operation. It must
+/// have the `PossibleTopLevelTransformOp` trait and not have any operands.
+/// This function internally keeps track of the transformation state.
+LogicalResult
+applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
+ const TransformOptions &options = TransformOptions());
+
/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
/// surrounding structure are referred to as transform IR. The operations to
@@ -250,15 +260,11 @@ class TransformState {
TransformOpReverseMapping reverse;
};
-public:
- /// Creates a state for transform ops living in the given region. The parent
- /// operation of the region. The second argument points to the root operation
- /// in the payload IR being transformed, which may or may not contain the
- /// region with transform ops. Additional options can be provided through the
- /// trailing configuration object.
- TransformState(Region &region, Operation *root,
- const TransformOptions &options = TransformOptions());
+ friend LogicalResult applyTransforms(Operation *payloadRoot,
+ TransformOpInterface transform,
+ const TransformOptions &options);
+public:
/// Returns the op at which the transformation state is rooted. This is
/// typically helpful for transformations that apply globally.
Operation *getTopLevel() const;
@@ -438,6 +444,13 @@ private:
/// Identifier for storing top-level value in the `operations` mapping.
static constexpr Value kTopLevelValue = Value();
+ /// Creates a state for transform ops living in the given region. The second
+ /// argument points to the root operation in the payload IR being transformed,
+ /// which may or may not contain the region with transform ops. Additional
+ /// options can be provided through the trailing configuration object.
+ TransformState(Region *region, Operation *payloadRoot,
+ const TransformOptions &options = TransformOptions());
+
/// Returns the mappings frame for the reigon in which the value is defined.
const Mappings &getMapping(Value value) const {
return const_cast<TransformState *>(this)->getMapping(value);
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 60414e80f75a..2810444ea864 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Operation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "transform-dialect"
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
@@ -25,14 +26,15 @@ using namespace mlir;
constexpr const Value transform::TransformState::kTopLevelValue;
-transform::TransformState::TransformState(Region &region, Operation *root,
+transform::TransformState::TransformState(Region *region,
+ Operation *payloadRoot,
const TransformOptions &options)
- : topLevel(root), options(options) {
- auto result = mappings.try_emplace(&region);
+ : topLevel(payloadRoot), options(options) {
+ auto result = mappings.try_emplace(region);
assert(result.second && "the region scope is already present");
(void)result;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- regionStack.push_back(&region);
+ regionStack.push_back(region);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
@@ -448,6 +450,27 @@ void transform::onlyReadsPayload(
}
//===----------------------------------------------------------------------===//
+// Entry point.
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::applyTransforms(Operation *payloadRoot,
+ TransformOpInterface transform,
+ const TransformOptions &options) {
+#ifndef NDEBUG
+ if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
+ transform->getNumOperands() != 0) {
+ transform->emitError()
+ << "expected transform to start at the top-level transform op";
+ llvm::report_fatal_error("could not run transforms",
+ /*gen_crash_diag=*/false);
+ }
+#endif // NDEBUG
+
+ TransformState state(transform->getParentRegion(), payloadRoot, options);
+ return state.applyTransform(transform).checkAndReport();
+}
+
+//===----------------------------------------------------------------------===//
// Generated interface implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 57cbcacb1605..735491a05aa6 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1,29 +1,41 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
-// expected-remark @below {{applying transformation}}
-transform.test_transform_op
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-remark @below {{applying transformation}}
+ transform.test_transform_op
+}
// -----
-%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
-// expected-remark @below {{succeeded}}
-transform.test_consume_operand_if_matches_param_or_fail %0[42]
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+ // expected-remark @below {{succeeded}}
+ transform.test_consume_operand_if_matches_param_or_fail %0[42]
+}
// -----
-%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
-// expected-error @below {{expected the operand to be associated with 21 got 42}}
-transform.test_consume_operand_if_matches_param_or_fail %0[21]
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+ // expected-error @below {{expected the operand to be associated with 21 got 42}}
+ transform.test_consume_operand_if_matches_param_or_fail %0[21]
+}
// -----
// It is okay to have multiple handles to the same payload op as long
// as only one of them is consumed. The expensive checks mode is necessary
// to detect double-consumption.
-%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
-%1 = transform.test_copy_payload %0
-// expected-remark @below {{succeeded}}
-transform.test_consume_operand_if_matches_param_or_fail %0[42]
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+ %1 = transform.test_copy_payload %0
+ // expected-remark @below {{succeeded}}
+ transform.test_consume_operand_if_matches_param_or_fail %0[42]
+}
// -----
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index ad5dcab9c184..1696cae0b446 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -41,13 +41,12 @@ public:
void runOnOperation() override {
ModuleOp module = getOperation();
- transform::TransformState state(
- module.getBodyRegion(), module,
- transform::TransformOptions().enableExpensiveChecks(
- enableExpensiveChecks));
for (auto op :
module.getBody()->getOps<transform::TransformOpInterface>()) {
- if (failed(state.applyTransform(op).checkAndReport()))
+ if (failed(transform::applyTransforms(
+ module, op,
+ transform::TransformOptions().enableExpensiveChecks(
+ enableExpensiveChecks))))
return signalPassFailure();
}
}