summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorAndrzej Warzynski <andrzej.warzynski@arm.com>2023-06-02 15:32:12 +0100
committerAndrzej Warzynski <andrzej.warzynski@gmail.com>2023-06-15 10:13:41 +0100
commit4d339ec91e81ae33b0f3ea0f8a3596d99645a0e9 (patch)
tree6a15cfcd9c6004aa34f27ac65a3e7904d3226547 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parente9d77cd9b267cb43bf7a968053517ca499959f2f (diff)
[mlir][Vector] Add pattern to reorder elementwise and broadcast ops
The new pattern will replace elementwise(broadcast) with broadcast(elementwise) when safe. This change affects tests for vectorising nD-extract. In one case ("vectorize_nd_tensor_extract_with_tensor_extract") I just trimmed the test and only preserved the key parts (scalar and contiguous load from the original Op). We could do the same with some other tests if that helps maintainability. Differential Revision: https://reviews.llvm.org/D152812
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp27
1 files changed, 27 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a5de1fd4de43..554a7b6db472 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -374,6 +374,31 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
}
};
+struct TestSinkVectorBroadcast
+ : public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
+
+ TestSinkVectorBroadcast() = default;
+ TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<memref::MemRefDialect, affine::AffineDialect>();
+ }
+
+ StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
+
+ StringRef getDescription() const final {
+ return "Test lowering patterns that eliminate redundant brodacast "
+ "operations.";
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateSinkVectorBroadcastPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestVectorReduceToContractPatternsPatterns
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
OperationPass<func::FuncOp>> {
@@ -735,6 +760,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
+ PassRegistration<TestSinkVectorBroadcast>();
+
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
PassRegistration<TestFlattenVectorTransferPatterns>();