diff options
| author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2024-08-16 16:53:53 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-16 16:53:53 +0100 |
| commit | 42944da5ba7617bbc02f341e9ef401c325310a73 (patch) | |
| tree | 26f2c323b09b0540fc3c595e480381c7143c46c4 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | a434cac523da6db542350fd747967520aaae8fbb (diff) | |
[mlir][vector] Group re-order patterns together (#102856)
Group all patterns that re-order vector.transpose and vector.broadcast
Ops (*) under `populateSinkVectorOpsPatterns`. These patterns are
normally used to "sink" redundant Vector Ops, hence grouping together.
Example:
```mlir
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.addf %at, %bt : vector<2x4xf32>
```
would get converted to:
```mlir
%0 = arith.addf %a, %b : vector<4x2xf32>
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
```
This patch also moves all tests for these patterns so that all of them
are:
* run under one test-flag: `test-vector-sink-patterns`,
* located in one file: "vector-sink.mlir".
To facilitate this change:
* `-test-sink-vector-broadcast` is renamed as
`test-vector-sink-patterns`,
* "sink-vector-broadcast.mlir" is renamed as "vector-sink.mlir",
* tests for `ReorderCastOpsOnBroadcast` and
`ReorderElementwiseOpsOnTranspose` patterns are moved from
"vector-reduce-to-contract.mlir" to "vector-sink.mlir",
* `ReorderElementwiseOpsOnTranspose` patterns are removed from
`populateVectorReductionToContractPatterns` and added to (newly
created) `populateSinkVectorOpsPatterns`,
* `ReorderCastOpsOnBroadcast` patterns are removed from
`populateVectorReductionToContractPatterns` - these are already
present in `populateSinkVectorOpsPatterns`.
This should allow us better layering and more straightforward testing.
For the latter, the goal is to be able to easily identify which pattern
a particular test is exercising (especially when it's a specific
pattern).
NOTES FOR DOWNSTREAM USERS
In order to preserve the current functionality, please make sure to add
* `populateSinkVectorOpsPatterns`,
wherever you are using `populateVectorReductionToContractPatterns`.
Also, rename `populateSinkVectorBroadcastPatterns` as
`populateSinkVectorOpsPatterns`.
(*) I didn't notice any other re-order patterns.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 29c763b622e8..72aaa7dc4f89 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -374,27 +374,27 @@ struct TestVectorTransferCollapseInnerMostContiguousDims } }; -struct TestSinkVectorBroadcast - : public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast) +struct TestVectorSinkPatterns + : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns) - TestSinkVectorBroadcast() = default; - TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default; + TestVectorSinkPatterns() = default; + TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<memref::MemRefDialect, affine::AffineDialect>(); } - StringRef getArgument() const final { return "test-sink-vector-broadcast"; } + StringRef getArgument() const final { return "test-vector-sink-patterns"; } StringRef getDescription() const final { return "Test lowering patterns that eliminate redundant brodacast " - "operations."; + "and transpose operations."; } void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateSinkVectorBroadcastPatterns(patterns); + populateSinkVectorOpsPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -919,7 +919,7 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); - PassRegistration<TestSinkVectorBroadcast>(); + PassRegistration<TestVectorSinkPatterns>(); PassRegistration<TestVectorReduceToContractPatternsPatterns>(); |
