diff options
author | Guray Ozen <guray.ozen@gmail.com> | 2022-11-12 15:02:02 +0300 |
---|---|---|
committer | Guray Ozen <guray.ozen@gmail.com> | 2022-11-12 21:27:25 +0300 |
commit | d93be483eaf5e22f4192325f9357821cbd2e934e (patch) | |
tree | e2483eb1766ff186886278f7f3ae3002cb3d777a | |
parent | 4a28b7ba9816fd7e8ecadfcaf6cfd1949f909dfd (diff) |
[mlir][transform] Make `tile_to_foreach_thread_op` builder to use ArrayAttr
D137413 clarified `scf_foreach_thread` thread mapping nicely. `tile_to_foreach_thread_op` is one of the op that generates `scf_foreach_thread`, however, its builders are still having integer array.
This is bug fix of potential problem.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D137891
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td | 8 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 18 |
2 files changed, 10 insertions, 16 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index b8638f10de98..b92ed1986306 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -842,22 +842,22 @@ def TileToForeachThreadOp : "ArrayRef<int64_t>":$staticTileSizes, CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, - CArg<"ArrayRef<int64_t>", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedTileSizes, CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, - CArg<"ArrayRef<int64_t>", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef<int64_t>":$staticNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayRef<int64_t>", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayRef<int64_t>", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, ]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 7b720a7452a5..cdd4e158f9b1 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1326,7 +1326,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder, Value target, ArrayRef<int64_t> staticTileSizes, transform::TileSizesSpec, - ArrayRef<int64_t> mapping) { + ArrayAttr mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), TileSizesSpec(), mapping); @@ -1335,7 +1335,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder, void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec, - ArrayRef<int64_t> mapping) { + ArrayAttr mapping) { SmallVector<int64_t> staticTileSizes; SmallVector<Value> dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes, @@ -1346,12 +1346,9 @@ void transform::TileToForeachThreadOp::build( MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); - ArrayAttr mappingAttr; - if (!mapping.empty()) - mappingAttr = builder.getI64ArrayAttr(mapping); build(builder, result, TypeRange{operationType, operationType}, target, /*numThreads=*/ValueRange{}, dynamicTileSizes, - /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mappingAttr); + /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mapping); } void transform::TileToForeachThreadOp::build(OpBuilder &builder, @@ -1359,7 +1356,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder, Value target, ArrayRef<int64_t> staticNumThreads, transform::NumThreadsSpec, - ArrayRef<int64_t> mapping) { + ArrayAttr mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), NumThreadsSpec(), mapping); @@ -1368,7 +1365,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder, void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec, - ArrayRef<int64_t> mapping) { + ArrayAttr mapping) { SmallVector<int64_t> staticNumThreads; SmallVector<Value> dynamicNumThreads; dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, @@ -1379,12 +1376,9 @@ void transform::TileToForeachThreadOp::build( MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); - ArrayAttr mappingAttr; - if (!mapping.empty()) - mappingAttr = builder.getI64ArrayAttr(mapping); build(builder, result, TypeRange{operationType, operationType}, target, dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr, - /*staticTileSizes=*/ArrayAttr(), mappingAttr); + /*staticTileSizes=*/ArrayAttr(), mapping); } DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( |