diff options
| author | Jakub Kuderski <jakub@nod-labs.com> | 2023-12-18 17:54:54 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-12-18 17:54:54 -0500 |
| commit | 07677113ffeb3744df350ef7c4ece1a93f7a5e1f (patch) | |
| tree | e3f42340ca8118aaeedb153c7bf3bcabea52115f /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 5b57da32a861ec4d7ef4adc7f1560142cf58d1ed (diff) | |
[mlir][vector] Add pattern to break down reductions into arith ops (#75727)
The number of vector elements considered 'small' enough to extract is
parameterized.
This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do. Targets without dedicated reduction
intrinsics can use that as an emulation path too.
Depends on https://github.com/llvm/llvm-project/pull/75846.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 03ddebe82344..126d65b1b848 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns } }; +struct TestVectorBreakDownReductionPatterns + : public PassWrapper<TestVectorBreakDownReductionPatterns, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestVectorBreakDownReductionPatterns) + + StringRef getArgument() const final { + return "test-vector-break-down-reduction-patterns"; + } + StringRef getDescription() const final { + return "Test patterns to break down vector reductions into arith " + "reductions"; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateBreakDownVectorReductionPatterns(patterns, + /*maxNumElementsToExtract=*/2); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestFlattenVectorTransferPatterns : public PassWrapper<TestFlattenVectorTransferPatterns, OperationPass<func::FuncOp>> { @@ -827,6 +848,8 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorChainedReductionFoldingPatterns>(); + PassRegistration<TestVectorBreakDownReductionPatterns>(); + PassRegistration<TestFlattenVectorTransferPatterns>(); PassRegistration<TestVectorScanLowering>(); |
