diff options
| author | Andrzej Warzynski <andrzej.warzynski@arm.com> | 2023-06-02 15:32:12 +0100 |
|---|---|---|
| committer | Andrzej Warzynski <andrzej.warzynski@gmail.com> | 2023-06-15 10:13:41 +0100 |
| commit | 4d339ec91e81ae33b0f3ea0f8a3596d99645a0e9 (patch) | |
| tree | 6a15cfcd9c6004aa34f27ac65a3e7904d3226547 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | e9d77cd9b267cb43bf7a968053517ca499959f2f (diff) | |
[mlir][Vector] Add pattern to reorder elementwise and broadcast ops
The new pattern will replace elementwise(broadcast) with
broadcast(elementwise) when safe.
This change affects tests for vectorising nD-extract. In one case
("vectorize_nd_tensor_extract_with_tensor_extract") I just trimmed the
test and only preserved the key parts (scalar and contiguous load from
the original Op). We could do the same with some other tests if that
helps maintainability.
Differential Revision: https://reviews.llvm.org/D152812
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index a5de1fd4de43..554a7b6db472 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -374,6 +374,31 @@ struct TestVectorTransferCollapseInnerMostContiguousDims } }; +struct TestSinkVectorBroadcast + : public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast) + + TestSinkVectorBroadcast() = default; + TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<memref::MemRefDialect, affine::AffineDialect>(); + } + + StringRef getArgument() const final { return "test-sink-vector-broadcast"; } + + StringRef getDescription() const final { + return "Test lowering patterns that eliminate redundant brodacast " + "operations."; + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateSinkVectorBroadcastPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestVectorReduceToContractPatternsPatterns : public PassWrapper<TestVectorReduceToContractPatternsPatterns, OperationPass<func::FuncOp>> { @@ -735,6 +760,8 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); + PassRegistration<TestSinkVectorBroadcast>(); + PassRegistration<TestVectorReduceToContractPatternsPatterns>(); PassRegistration<TestFlattenVectorTransferPatterns>(); |
