summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorAndrzej WarzyƄski <andrzej.warzynski@arm.com>2024-08-16 16:53:53 +0100
committerGitHub <noreply@github.com>2024-08-16 16:53:53 +0100
commit42944da5ba7617bbc02f341e9ef401c325310a73 (patch)
tree26f2c323b09b0540fc3c595e480381c7143c46c4 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parenta434cac523da6db542350fd747967520aaae8fbb (diff)
[mlir][vector] Group re-order patterns together (#102856)
Group all patterns that re-order vector.transpose and vector.broadcast Ops (*) under `populateSinkVectorOpsPatterns`. These patterns are normally used to "sink" redundant Vector Ops, hence grouping together. Example: ```mlir %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> %r = arith.addf %at, %bt : vector<2x4xf32> ``` would get converted to: ```mlir %0 = arith.addf %a, %b : vector<4x2xf32> %r = vector.transpose %0, [1, 0] : vector<2x4xf32> ``` This patch also moves all tests for these patterns so that all of them are: * run under one test-flag: `test-vector-sink-patterns`, * located in one file: "vector-sink.mlir". To facilitate this change: * `-test-sink-vector-broadcast` is renamed as `test-vector-sink-patterns`, * "sink-vector-broadcast.mlir" is renamed as "vector-sink.mlir", * tests for `ReorderCastOpsOnBroadcast` and `ReorderElementwiseOpsOnTranspose` patterns are moved from "vector-reduce-to-contract.mlir" to "vector-sink.mlir", * `ReorderElementwiseOpsOnTranspose` patterns are removed from `populateVectorReductionToContractPatterns` and added to (newly created) `populateSinkVectorOpsPatterns`, * `ReorderCastOpsOnBroadcast` patterns are removed from `populateVectorReductionToContractPatterns` - these are already present in `populateSinkVectorOpsPatterns`. This should allow us better layering and more straightforward testing. For the latter, the goal is to be able to easily identify which pattern a particular test is exercising (especially when it's a specific pattern). NOTES FOR DOWNSTREAM USERS In order to preserve the current functionality, please make sure to add * `populateSinkVectorOpsPatterns`, wherever you are using `populateVectorReductionToContractPatterns`. Also, rename `populateSinkVectorBroadcastPatterns` as `populateSinkVectorOpsPatterns`. (*) I didn't notice any other re-order patterns.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp18
1 files changed, 9 insertions, 9 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 29c763b622e8..72aaa7dc4f89 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -374,27 +374,27 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
}
};
-struct TestSinkVectorBroadcast
- : public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
+struct TestVectorSinkPatterns
+ : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
- TestSinkVectorBroadcast() = default;
- TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
+ TestVectorSinkPatterns() = default;
+ TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
}
- StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
+ StringRef getArgument() const final { return "test-vector-sink-patterns"; }
StringRef getDescription() const final {
return "Test lowering patterns that eliminate redundant brodacast "
- "operations.";
+ "and transpose operations.";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- populateSinkVectorBroadcastPatterns(patterns);
+ populateSinkVectorOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
@@ -919,7 +919,7 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
- PassRegistration<TestSinkVectorBroadcast>();
+ PassRegistration<TestVectorSinkPatterns>();
PassRegistration<TestVectorReduceToContractPatternsPatterns>();