summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorDiego Caballero <dieg0ca6aller0@gmail.com>2025-09-16 18:01:30 -0700
committerGitHub <noreply@github.com>2025-09-16 18:01:30 -0700
commit7bdd88c1e3a70d8213f8bc68403fbd844f11b00c (patch)
treed0d622f140ab7ab09cfccd772639befaa8675d35 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent87bceae3fc64359c5a6ca362b466f8e938f4986c (diff)
[mlir][Vector] Add patterns to lower `vector.shuffle` (#157611)
This PR adds patterns to lower `vector.shuffle` with inputs with different vector sizes more efficiently. The current LLVM lowering for these cases degenerates to a sequence of `vector.extract` and `vector.insert` operations. With this PR, the smaller input is promoted to larger vector size by introducing an extra `vector.shuffle`.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp18
1 files changed, 18 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 72dd103b33f7..79bfc9bbcda7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -994,6 +994,22 @@ struct TestEliminateVectorMasks
VscaleRange{vscaleMin, vscaleMax});
}
};
+
+struct TestVectorShuffleLowering
+ : public PassWrapper<TestVectorShuffleLowering,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorShuffleLowering)
+
+ StringRef getArgument() const final { return "test-vector-shuffle-lowering"; }
+ StringRef getDescription() const final {
+ return "Test lowering patterns for vector.shuffle with mixed-size inputs";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorShuffleLoweringPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
@@ -1023,6 +1039,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorScanLowering>();
+ PassRegistration<TestVectorShuffleLowering>();
+
PassRegistration<TestVectorDistribution>();
PassRegistration<TestVectorExtractStridedSliceLowering>();