summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2023-05-17 08:57:13 -0700
committerLei Zhang <antiagainst@google.com>2023-05-17 09:01:19 -0700
commite000b62a342cac907fd77cfdd070f0b055f0c3c4 (patch)
tree45564fc0513335d152027662e842538b271157a3 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent4eab303404d6bb2252b4baf807c5ac87a0fa3125 (diff)
[mlir][vector] Separate out vector transfer + tensor slice patterns
These patterns touches the structure generated from tiling so it affects later steps like bufferization and vector hoisting. Instead of putting them in canonicalization, this commit creates separate entry points for them to be called explicitly. This is NFC regarding the functionality and tests of those patterns. It also addresses two TODO items in the codebase. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D150702
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index d0c79ab98915..50dfeff635cc 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -679,6 +679,26 @@ struct TestVectorGatherLowering
}
};
+struct TestVectorTransferTensorSlicePatterns
+ : public PassWrapper<TestVectorTransferTensorSlicePatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorTransferTensorSlicePatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-transfer-tensor-slice-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns that fold vector transfer and tensor slice ops";
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorTransferTensorSliceTransforms(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
} // namespace
namespace mlir {
@@ -713,6 +733,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestCreateVectorBroadcast>();
PassRegistration<TestVectorGatherLowering>();
+
+ PassRegistration<TestVectorTransferTensorSlicePatterns>();
}
} // namespace test
} // namespace mlir