From e0cea169f7a77996ba066d9252b22f43726c23c7 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 12 Oct 2022 05:21:52 -0700 Subject: [mlir][Linalg] Drop filter-based splitReduction This transformation is available and tested via the transform dialect. Differential Revision: https://reviews.llvm.org/D135767 --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 9 - .../Dialect/Linalg/Transforms/SplitReduction.cpp | 34 +--- mlir/test/Dialect/Linalg/split_reduction.mlir | 193 --------------------- .../lib/Dialect/Linalg/TestLinalgTransforms.cpp | 40 ----- 4 files changed, 5 insertions(+), 271 deletions(-) delete mode 100644 mlir/test/Dialect/Linalg/split_reduction.mlir (limited to 'mlir') diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 4fd9e3fe0cf2..bbac0899338d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1050,7 +1050,6 @@ using ControlSplitReductionFn = void populateSplitReductionPattern( RewritePatternSet &patterns, const ControlSplitReductionFn &controlSplitReductionFn, - const LinalgTransformationFilter &f = LinalgTransformationFilter(), bool useAlloc = false); /// Apply transformation to split the single linalg op reduction into a parallel @@ -1094,14 +1093,6 @@ void populateSplitReductionPattern( /// linalg.yield %5 : f32 /// } -> tensor /// ``` -FailureOr -splitReduction(PatternRewriter &b, LinalgOp op, - const ControlSplitReductionFn &controlSplitReductionFn, - const LinalgTransformationFilter &f, bool useAlloc = false); - -/// Filterless version of the above. -/// Returns both the new linalg ops as well as the fillOp needed to initialize -/// the temporary expanded tensor with the proper neutral element. struct SplitReductionResult { Operation *initOrAlloc; FillOp fillOp; diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp index 7a257137cfb6..7df65c823a2f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -58,26 +58,6 @@ static Attribute getNeutralElement(Operation *op) { return Attribute(); } -FailureOr mlir::linalg::splitReduction( - PatternRewriter &b, LinalgOp op, - const ControlSplitReductionFn &controlSplitReductionFn, - const LinalgTransformationFilter &filter, bool useAlloc) { - if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || - op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || - !op.hasOnlyProjectedPermutations()) - return b.notifyMatchFailure(op, "precondition not met"); - - FailureOr res = - splitReduction(b, op, controlSplitReductionFn, useAlloc); - if (failed(res)) - return failure(); - - filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp); - filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp); - - return res->splitLinalgOp; -} - FailureOr mlir::linalg::splitReduction( PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { @@ -481,30 +461,26 @@ struct LinalgSplitReduction : public OpInterfaceRewritePattern { /// Construct a generic pattern applied to all LinalgOp that verify `filter`. LinalgSplitReduction(MLIRContext *context, ControlSplitReductionFn controlSplitReductionFn, - LinalgTransformationFilter f, bool useAlloc = false, - PatternBenefit benefit = 1) + bool useAlloc = false, PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), controlSplitReductionFn(std::move(controlSplitReductionFn)), - useAlloc(useAlloc), filter(std::move(f)) {} + useAlloc(useAlloc) {} LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { - return splitReduction(rewriter, op, controlSplitReductionFn, filter, - useAlloc); + return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc); } private: ControlSplitReductionFn controlSplitReductionFn; bool useAlloc; - LinalgTransformationFilter filter; }; } // namespace void linalg::populateSplitReductionPattern( RewritePatternSet &patterns, - const ControlSplitReductionFn &controlSplitReductionFn, - const LinalgTransformationFilter &f, bool useAlloc) { + const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { patterns.add(patterns.getContext(), - controlSplitReductionFn, f, useAlloc); + controlSplitReductionFn, useAlloc); } diff --git a/mlir/test/Dialect/Linalg/split_reduction.mlir b/mlir/test/Dialect/Linalg/split_reduction.mlir deleted file mode 100644 index 25a499958e27..000000000000 --- a/mlir/test/Dialect/Linalg/split_reduction.mlir +++ /dev/null @@ -1,193 +0,0 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction -split-input-file | FileCheck %s -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction-inner-parallel -split-input-file | FileCheck %s --check-prefix=INNERPARALLELCHECK - -func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) - outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> - return %0: tensor<16x32xf32> -} - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-LABEL: @matmul_split -// CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32> -// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32> -// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32> -// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> -// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] -// CHECK-SAME: , iterator_types = ["parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x4x64xf32>, tensor<4x64x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) { -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: linalg.yield -// CHECK: } -> tensor<16x32x4xf32> -// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) { -// CHECK: arith.addf -// CHECK: linalg.yield %{{.*}} : f32 -// CHECK: } -> tensor<16x32xf32> -// CHECK: return %[[R]] : tensor<16x32xf32> - -// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> -// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)> -// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// INNERPARALLELCHECK-LABEL: @matmul_split -// INNERPARALLELCHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 -// INNERPARALLELCHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x64x4xf32> -// INNERPARALLELCHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<64x4x32xf32> -// INNERPARALLELCHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32> -// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> -// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] -// INNERPARALLELCHECK-SAME: , iterator_types = ["parallel", "parallel", "reduction", "parallel"]} -// INNERPARALLELCHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x64x4xf32>, tensor<64x4x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) { -// INNERPARALLELCHECK: arith.mulf -// INNERPARALLELCHECK: arith.addf -// INNERPARALLELCHECK: linalg.yield -// INNERPARALLELCHECK: } -> tensor<16x32x4xf32> -// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], -// INNERPARALLELCHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) { -// INNERPARALLELCHECK: arith.addf -// INNERPARALLELCHECK: linalg.yield %{{.*}} : f32 -// INNERPARALLELCHECK: } -> tensor<16x32xf32> -// INNERPARALLELCHECK: return %[[R]] : tensor<16x32xf32> - -// ----- - -func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor, %out: tensor) -> tensor { - %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, - affine_map<(d0) -> ()>, - affine_map<(d0) -> ()>], - iterator_types = ["reduction"]} - ins(%arg0, %arg1 : tensor<32xf32>, tensor) - outs(%out : tensor) { - ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): - %40 = arith.subf %arg7, %arg8 : f32 - %41 = math.exp %40 : f32 - %42 = arith.mulf %41, %arg9 : f32 - linalg.yield %42 : f32 - } -> tensor - return %red : tensor -} - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)> -// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()> -//CHECK-LABEL: @generic_split_1d -// CHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32> -// CHECK: %[[INI:.*]] = tensor.empty() : tensor<4xf32> -// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> -// CHECK: %[[G:.*]] = linalg.generic -// CHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], -// CHECK: iterator_types = ["parallel", "reduction"]} ins(%[[I1]], %{{.*}} : tensor<4x8xf32>, tensor) outs(%[[F]] : tensor<4xf32>) { -// CHECK: arith.subf -// CHECK: math.exp -// CHECK: arith.mulf -// CHECK: linalg.yield -// CHECK: } -> tensor<4xf32> -// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor) { -// CHECK: arith.mulf -// CHECK: linalg.yield -// CHECK: } -> tensor -// CHECK: return %[[R]] : tensor - -// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> -// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)> -// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)> -// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()> -//INNERPARALLELCHECK-LABEL: @generic_split_1d -// INNERPARALLELCHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 -// INNERPARALLELCHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<8x4xf32> -// INNERPARALLELCHECK: %[[INI:.*]] = tensor.empty() : tensor<4xf32> -// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> -// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic -// INNERPARALLELCHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], -// INNERPARALLELCHECK: iterator_types = ["reduction", "parallel"]} ins(%[[I1]], %{{.*}} : tensor<8x4xf32>, tensor) outs(%[[F]] : tensor<4xf32>) { -// INNERPARALLELCHECK: arith.subf -// INNERPARALLELCHECK: math.exp -// INNERPARALLELCHECK: arith.mulf -// INNERPARALLELCHECK: linalg.yield -// INNERPARALLELCHECK: } -> tensor<4xf32> -// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor) { -// INNERPARALLELCHECK: arith.mulf -// INNERPARALLELCHECK: linalg.yield -// INNERPARALLELCHECK: } -> tensor -// INNERPARALLELCHECK: return %[[R]] : tensor - -// ----- - -func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>) - -> tensor<5x2xf32> -{ - %0 = linalg.generic { - indexing_maps = [ - affine_map<(d0, d1, d2) -> (d1, d0)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d2, d0)> - ], - iterator_types = ["parallel", "reduction", "parallel"] - } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) { - ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): - %3 = arith.addf %arg0, %arg1 : f32 - %4 = arith.maxf %3, %arg2 : f32 - linalg.yield %4 : f32 - } -> tensor<5x2xf32> - return %0 : tensor<5x2xf32> -} - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-LABEL: func @generic_split_3d -// CHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 -// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32> -// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> -// CHECK: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> -// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> -// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} -// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { -// CHECK: arith.addf -// CHECK: arith.maxf -// CHECK: linalg.yield -// CHECK: } -> tensor<5x2x4xf32> -// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} -// CHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) { -// CHECK: arith.maxf -// CHECK: linalg.yield -// CHECK: } -> tensor<5x2xf32> -// CHECK: return %[[R]] : tensor<5x2xf32> - -// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)> -// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)> -// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> -// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// INNERPARALLELCHECK-LABEL: func @generic_split_3d -// INNERPARALLELCHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 -// INNERPARALLELCHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32> -// INNERPARALLELCHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32> -// INNERPARALLELCHECK: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> -// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> -// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} -// INNERPARALLELCHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<8x4x2xf32>, tensor<5x8x4xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { -// INNERPARALLELCHECK: arith.addf -// INNERPARALLELCHECK: arith.maxf -// INNERPARALLELCHECK: linalg.yield -// INNERPARALLELCHECK: } -> tensor<5x2x4xf32> -// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} -// INNERPARALLELCHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) { -// INNERPARALLELCHECK: arith.maxf -// INNERPARALLELCHECK: linalg.yield -// INNERPARALLELCHECK: } -> tensor<5x2xf32> -// INNERPARALLELCHECK: return %[[R]] : tensor<5x2xf32> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 41dce75b711a..146a32edac52 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -84,14 +84,6 @@ struct TestLinalgTransforms llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into " "tensor.pad(subtensor)"), llvm::cl::init(false)}; - Option testSplitReduction{ - *this, "test-split-reduction", - llvm::cl::desc("Test split reduction transformation"), - llvm::cl::init(false)}; - Option testSplitReductionInnerParallel{ - *this, "test-split-reduction-inner-parallel", - llvm::cl::desc("Test split reduction with inner parallel transformation"), - llvm::cl::init(false)}; ListOption peeledLoops{ *this, "peeled-loops", llvm::cl::desc("Loops to be peeled when test-tile-pattern")}; @@ -176,34 +168,6 @@ static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } -static void applySplitReduction(func::FuncOp funcOp) { - RewritePatternSet patterns(funcOp.getContext()); - linalg::populateSplitReductionPattern( - patterns, - [](LinalgOp op) { - unsigned insertDimIndex = op.getNumLoops() - 1; - return SplitReductionOptions{4, insertDimIndex, false}; - }, - LinalgTransformationFilter( - ArrayRef{}, - StringAttr::get(funcOp.getContext(), "SPLIT"))); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); -} - -static void applySplitReductionInnerParallel(func::FuncOp funcOp) { - RewritePatternSet patterns(funcOp.getContext()); - linalg::populateSplitReductionPattern( - patterns, - [](LinalgOp op) { - unsigned insertDimIndex = op.getNumLoops() - 1; - return SplitReductionOptions{4, insertDimIndex, true}; - }, - LinalgTransformationFilter( - ArrayRef{}, - StringAttr::get(funcOp.getContext(), "SPLIT"))); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); -} - static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateBubbleUpExtractSliceOpPatterns(patterns); @@ -237,10 +201,6 @@ void TestLinalgTransforms::runOnOperation() { return applyGeneralizePadTensorPatterns(getOperation()); if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getOperation()); - if (testSplitReduction) - return applySplitReduction(getOperation()); - if (testSplitReductionInnerParallel) - return applySplitReductionInnerParallel(getOperation()); if (testBubbleUpExtractSliceOpPattern) return applyBubbleUpExtractSliceOpPattern(getOperation()); if (testSwapExtractSliceWithFill) -- cgit v1.2.3