diff options
| author | Lei Zhang <antiagainst@google.com> | 2023-05-17 08:57:13 -0700 |
|---|---|---|
| committer | Lei Zhang <antiagainst@google.com> | 2023-05-17 09:01:19 -0700 |
| commit | e000b62a342cac907fd77cfdd070f0b055f0c3c4 (patch) | |
| tree | 45564fc0513335d152027662e842538b271157a3 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 4eab303404d6bb2252b4baf807c5ac87a0fa3125 (diff) | |
[mlir][vector] Separate out vector transfer + tensor slice patterns
These patterns touches the structure generated from tiling so it
affects later steps like bufferization and vector hoisting.
Instead of putting them in canonicalization, this commit creates
separate entry points for them to be called explicitly.
This is NFC regarding the functionality and tests of those patterns.
It also addresses two TODO items in the codebase.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D150702
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index d0c79ab98915..50dfeff635cc 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -679,6 +679,26 @@ struct TestVectorGatherLowering } }; +struct TestVectorTransferTensorSlicePatterns + : public PassWrapper<TestVectorTransferTensorSlicePatterns, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestVectorTransferTensorSlicePatterns) + + StringRef getArgument() const final { + return "test-vector-transfer-tensor-slice-patterns"; + } + StringRef getDescription() const final { + return "Test patterns that fold vector transfer and tensor slice ops"; + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorTransferTensorSliceTransforms(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + } // namespace namespace mlir { @@ -713,6 +733,8 @@ void registerTestVectorLowerings() { PassRegistration<TestCreateVectorBroadcast>(); PassRegistration<TestVectorGatherLowering>(); + + PassRegistration<TestVectorTransferTensorSlicePatterns>(); } } // namespace test } // namespace mlir |
