diff options
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 03ddebe82344..126d65b1b848 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns } }; +struct TestVectorBreakDownReductionPatterns + : public PassWrapper<TestVectorBreakDownReductionPatterns, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestVectorBreakDownReductionPatterns) + + StringRef getArgument() const final { + return "test-vector-break-down-reduction-patterns"; + } + StringRef getDescription() const final { + return "Test patterns to break down vector reductions into arith " + "reductions"; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateBreakDownVectorReductionPatterns(patterns, + /*maxNumElementsToExtract=*/2); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestFlattenVectorTransferPatterns : public PassWrapper<TestFlattenVectorTransferPatterns, OperationPass<func::FuncOp>> { @@ -827,6 +848,8 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorChainedReductionFoldingPatterns>(); + PassRegistration<TestVectorBreakDownReductionPatterns>(); + PassRegistration<TestFlattenVectorTransferPatterns>(); PassRegistration<TestVectorScanLowering>(); |
