summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorJakub Kuderski <jakub@nod-labs.com>2023-12-18 17:54:54 -0500
committerGitHub <noreply@github.com>2023-12-18 17:54:54 -0500
commit07677113ffeb3744df350ef7c4ece1a93f7a5e1f (patch)
treee3f42340ca8118aaeedb153c7bf3bcabea52115f /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent5b57da32a861ec4d7ef4adc7f1560142cf58d1ed (diff)
[mlir][vector] Add pattern to break down reductions into arith ops (#75727)
The number of vector elements considered 'small' enough to extract is parameterized. This is to avoid going into specialized reduction lowering when a single/couple of arith ops can do. Targets without dedicated reduction intrinsics can use that as an emulation path too. Depends on https://github.com/llvm/llvm-project/pull/75846.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp23
1 files changed, 23 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03ddebe82344..126d65b1b848 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns
}
};
+struct TestVectorBreakDownReductionPatterns
+ : public PassWrapper<TestVectorBreakDownReductionPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorBreakDownReductionPatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-break-down-reduction-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns to break down vector reductions into arith "
+ "reductions";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateBreakDownVectorReductionPatterns(patterns,
+ /*maxNumElementsToExtract=*/2);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
@@ -827,6 +848,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorChainedReductionFoldingPatterns>();
+ PassRegistration<TestVectorBreakDownReductionPatterns>();
+
PassRegistration<TestFlattenVectorTransferPatterns>();
PassRegistration<TestVectorScanLowering>();