Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/llvm/llvm-project.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuray Ozen <guray.ozen@gmail.com>2022-11-12 15:02:02 +0300
committerGuray Ozen <guray.ozen@gmail.com>2022-11-12 21:27:25 +0300
commitd93be483eaf5e22f4192325f9357821cbd2e934e (patch)
treee2483eb1766ff186886278f7f3ae3002cb3d777a
parent4a28b7ba9816fd7e8ecadfcaf6cfd1949f909dfd (diff)
[mlir][transform] Make `tile_to_foreach_thread_op` builder to use ArrayAttr
D137413 clarified `scf_foreach_thread` thread mapping nicely. `tile_to_foreach_thread_op` is one of the op that generates `scf_foreach_thread`, however, its builders are still having integer array. This is bug fix of potential problem. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D137891
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td8
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp18
2 files changed, 10 insertions, 16 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b8638f10de98..b92ed1986306 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -842,22 +842,22 @@ def TileToForeachThreadOp :
"ArrayRef<int64_t>":$staticTileSizes,
CArg<"::mlir::transform::TileSizesSpec",
"::mlir::transform::TileSizesSpec()">,
- CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+ CArg<"ArrayAttr", "{}">:$mapping)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedTileSizes,
CArg<"::mlir::transform::TileSizesSpec",
"::mlir::transform::TileSizesSpec()">,
- CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+ CArg<"ArrayAttr", "{}">:$mapping)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<int64_t>":$staticNumThreads,
CArg<"::mlir::transform::NumThreadsSpec",
"::mlir::transform::NumThreadsSpec()">,
- CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+ CArg<"ArrayAttr", "{}">:$mapping)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedNumThreads,
CArg<"::mlir::transform::NumThreadsSpec",
"::mlir::transform::NumThreadsSpec()">,
- CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+ CArg<"ArrayAttr", "{}">:$mapping)>,
];
let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7b720a7452a5..cdd4e158f9b1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1326,7 +1326,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
Value target,
ArrayRef<int64_t> staticTileSizes,
transform::TileSizesSpec,
- ArrayRef<int64_t> mapping) {
+ ArrayAttr mapping) {
return build(builder, result, target,
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
TileSizesSpec(), mapping);
@@ -1335,7 +1335,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
void transform::TileToForeachThreadOp::build(
OpBuilder &builder, OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec,
- ArrayRef<int64_t> mapping) {
+ ArrayAttr mapping) {
SmallVector<int64_t> staticTileSizes;
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes,
@@ -1346,12 +1346,9 @@ void transform::TileToForeachThreadOp::build(
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
- ArrayAttr mappingAttr;
- if (!mapping.empty())
- mappingAttr = builder.getI64ArrayAttr(mapping);
build(builder, result, TypeRange{operationType, operationType}, target,
/*numThreads=*/ValueRange{}, dynamicTileSizes,
- /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mappingAttr);
+ /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mapping);
}
void transform::TileToForeachThreadOp::build(OpBuilder &builder,
@@ -1359,7 +1356,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
Value target,
ArrayRef<int64_t> staticNumThreads,
transform::NumThreadsSpec,
- ArrayRef<int64_t> mapping) {
+ ArrayAttr mapping) {
return build(builder, result, target,
getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
NumThreadsSpec(), mapping);
@@ -1368,7 +1365,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
void transform::TileToForeachThreadOp::build(
OpBuilder &builder, OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec,
- ArrayRef<int64_t> mapping) {
+ ArrayAttr mapping) {
SmallVector<int64_t> staticNumThreads;
SmallVector<Value> dynamicNumThreads;
dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
@@ -1379,12 +1376,9 @@ void transform::TileToForeachThreadOp::build(
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
- ArrayAttr mappingAttr;
- if (!mapping.empty())
- mappingAttr = builder.getI64ArrayAttr(mapping);
build(builder, result, TypeRange{operationType, operationType}, target,
dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr,
- /*staticTileSizes=*/ArrayAttr(), mappingAttr);
+ /*staticTileSizes=*/ArrayAttr(), mapping);
}
DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(