summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorJakub Kuderski <jakub@nod-labs.com>2023-11-22 10:30:04 -0500
committerGitHub <noreply@github.com>2023-11-22 10:30:04 -0500
commitd33bad66d86a6fdb443c59561f9524f451a82db0 (patch)
treeb681b2657878d295a600104a18dfa54d1263cb63 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent8c02b34e3b9b1e2596651959ba76c66a7afaf545 (diff)
[mlir][vector] Add patterns to simplify chained reductions (#73048)
Chained reductions get created during vector unrolling. These patterns simplify them into a series of adds followed by a final reductions. This is preferred on GPU targets like SPIR-V/Vulkan where vector reduction gets lowered into subgroup operations that are generally more expensive than simple vector additions. For now, only the `add` combining kind is handled.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 1a177fa31de3..feb716cdbf40 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -420,6 +420,25 @@ struct TestVectorReduceToContractPatternsPatterns
}
};
+struct TestVectorChainedReductionFoldingPatterns
+ : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorChainedReductionFoldingPatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-chained-reduction-folding-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns to fold chained vector reductions";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateChainedVectorReductionFoldingPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
@@ -773,6 +792,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
+ PassRegistration<TestVectorChainedReductionFoldingPatterns>();
+
PassRegistration<TestFlattenVectorTransferPatterns>();
PassRegistration<TestVectorScanLowering>();