diff options
| author | Diego Caballero <dieg0ca6aller0@gmail.com> | 2025-09-16 18:01:30 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-16 18:01:30 -0700 |
| commit | 7bdd88c1e3a70d8213f8bc68403fbd844f11b00c (patch) | |
| tree | d0d622f140ab7ab09cfccd772639befaa8675d35 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 87bceae3fc64359c5a6ca362b466f8e938f4986c (diff) | |
[mlir][Vector] Add patterns to lower `vector.shuffle` (#157611)
This PR adds patterns to lower `vector.shuffle` with inputs with
different vector sizes more efficiently. The current LLVM lowering for
these cases degenerates to a sequence of `vector.extract` and
`vector.insert` operations. With this PR, the smaller input is promoted
to larger vector size by introducing an extra `vector.shuffle`.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 72dd103b33f7..79bfc9bbcda7 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -994,6 +994,22 @@ struct TestEliminateVectorMasks VscaleRange{vscaleMin, vscaleMax}); } }; + +struct TestVectorShuffleLowering + : public PassWrapper<TestVectorShuffleLowering, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorShuffleLowering) + + StringRef getArgument() const final { return "test-vector-shuffle-lowering"; } + StringRef getDescription() const final { + return "Test lowering patterns for vector.shuffle with mixed-size inputs"; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorShuffleLoweringPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; } // namespace namespace mlir { @@ -1023,6 +1039,8 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorScanLowering>(); + PassRegistration<TestVectorShuffleLowering>(); + PassRegistration<TestVectorDistribution>(); PassRegistration<TestVectorExtractStridedSliceLowering>(); |
