summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp23
1 files changed, 23 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03ddebe82344..126d65b1b848 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns
}
};
+struct TestVectorBreakDownReductionPatterns
+ : public PassWrapper<TestVectorBreakDownReductionPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorBreakDownReductionPatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-break-down-reduction-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns to break down vector reductions into arith "
+ "reductions";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateBreakDownVectorReductionPatterns(patterns,
+ /*maxNumElementsToExtract=*/2);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
@@ -827,6 +848,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorChainedReductionFoldingPatterns>();
+ PassRegistration<TestVectorBreakDownReductionPatterns>();
+
PassRegistration<TestFlattenVectorTransferPatterns>();
PassRegistration<TestVectorScanLowering>();