diff options
| author | Jakub Kuderski <jakub@nod-labs.com> | 2023-11-22 10:30:04 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-11-22 10:30:04 -0500 |
| commit | d33bad66d86a6fdb443c59561f9524f451a82db0 (patch) | |
| tree | b681b2657878d295a600104a18dfa54d1263cb63 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 8c02b34e3b9b1e2596651959ba76c66a7afaf545 (diff) | |
[mlir][vector] Add patterns to simplify chained reductions (#73048)
Chained reductions get created during vector unrolling. These patterns
simplify them into a series of adds followed by a final reductions.
This is preferred on GPU targets like SPIR-V/Vulkan where vector
reduction gets lowered into subgroup operations that are generally more
expensive than simple vector additions.
For now, only the `add` combining kind is handled.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 1a177fa31de3..feb716cdbf40 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -420,6 +420,25 @@ struct TestVectorReduceToContractPatternsPatterns } }; +struct TestVectorChainedReductionFoldingPatterns + : public PassWrapper<TestVectorChainedReductionFoldingPatterns, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestVectorChainedReductionFoldingPatterns) + + StringRef getArgument() const final { + return "test-vector-chained-reduction-folding-patterns"; + } + StringRef getDescription() const final { + return "Test patterns to fold chained vector reductions"; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateChainedVectorReductionFoldingPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestFlattenVectorTransferPatterns : public PassWrapper<TestFlattenVectorTransferPatterns, OperationPass<func::FuncOp>> { @@ -773,6 +792,8 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorReduceToContractPatternsPatterns>(); + PassRegistration<TestVectorChainedReductionFoldingPatterns>(); + PassRegistration<TestFlattenVectorTransferPatterns>(); PassRegistration<TestVectorScanLowering>(); |
