summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorBenoit Jacob <jacob.benoit.1@gmail.com>2024-10-09 09:24:23 -0400
committerGitHub <noreply@github.com>2024-10-09 09:24:23 -0400
commita9ebdbb5ac7de7a028f6060b789196a43aea7580 (patch)
tree0547923129ccb53ac6022aab9071e2ca51940491 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent1e357cde4836d034d2f7a6d9af099eef23271756 (diff)
[MLIR] Vector: turn the ExtractStridedSlice rewrite pattern from #111541 into a canonicalization (#111614)
This is a reasonable canonicalization because `extract` is more constrained than `extract_strided_slices`, so there is no loss of semantics here, just lifting an op to a special-case higher/constrained op. And the additional `shape_cast` is merely adding leading unit dims to match the original result type. Context: discussion on #111541. I wasn't sure how this would turn out, but in the process of writing this PR, I discovered at least 2 bugs in the pattern introduced in #111541, which shows the value of shared canonicalization patterns which are exercised on a high number of testcases. --------- Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp23
1 files changed, 0 insertions, 23 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index d91e955b7064..72aaa7dc4f89 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -709,27 +709,6 @@ struct TestVectorExtractStridedSliceLowering
}
};
-struct TestVectorContiguousExtractStridedSliceToExtract
- : public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestVectorExtractStridedSliceLowering)
-
- StringRef getArgument() const final {
- return "test-vector-contiguous-extract-strided-slice-to-extract";
- }
- StringRef getDescription() const final {
- return "Test lowering patterns that rewrite simple cases of N-D "
- "extract_strided_slice, where the slice is contiguous, into extract "
- "and shape_cast";
- }
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
- }
-};
-
struct TestVectorBreakDownBitCast
: public PassWrapper<TestVectorBreakDownBitCast,
OperationPass<func::FuncOp>> {
@@ -956,8 +935,6 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorExtractStridedSliceLowering>();
- PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
-
PassRegistration<TestVectorBreakDownBitCast>();
PassRegistration<TestCreateVectorBroadcast>();