summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorYang Bai <baiyang0132@gmail.com>2025-08-19 01:09:12 +0800
committerGitHub <noreply@github.com>2025-08-18 10:09:12 -0700
commit4eb1a07d7d1a9722e84490b0ff79d3ae5e260f76 (patch)
tree23059f1cf8f99313ece4f31e2c363052910c4f2a /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent8135b7c1abd7d22f98cf3dbd7d7a93c9fc7755c6 (diff)
[mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering (#151175)
This patch introduces a new unrolling-based approach for lowering multi-dimensional `vector.from_elements` operations. **Implementation Details:** 1. **New Transform Pattern**: Added `UnrollFromElements` that unrolls a N-D(N>=2) from_elements op to a (N-1)-D from_elements op align the outermost dimension. 2. **Utility Functions**: Added `unrollVectorOp` to reuse the unroll algo of vector.gather for vector.from_elements. 3. **Integration**: Added the unrolling pattern to the convert-vector-to-llvm pass as a temporal transformation. 4. Use direct LLVM dialect operations instead of intermediate vector.insert operations for efficiency in `VectorFromElementsLowering`. **Example:** ```mlir // unroll %v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32> => %poison_2d = ub.poison : vector<2x2xf32> %vec_1d_0 = vector.from_elements %e0, %e1 : vector<2xf32> %vec_2d_0 = vector.insert %vec_1d_0, %poison_2d [0] : vector<2xf32> into vector<2x2xf32> %vec_1d_1 = vector.from_elements %e2, %e3 : vector<2xf32> %result = vector.insert %vec_1d_1, %vec_2d_0 [1] : vector<2xf32> into vector<2x2xf32> // convert-vector-to-llvm %v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32> => %poison_2d = ub.poison : vector<2x2xf32> %poison_2d_cast = builtin.unrealized_conversion_cast %poison_2d : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>> %poison_1d_0 = llvm.mlir.poison : vector<2xf32> %c0_0 = llvm.mlir.constant(0 : i64) : i64 %vec_1d_0_0 = llvm.insertelement %e0, %poison_1d_0[%c0_0 : i64] : vector<2xf32> %c1_0 = llvm.mlir.constant(1 : i64) : i64 %vec_1d_0_1 = llvm.insertelement %e1, %vec_1d_0_0[%c1_0 : i64] : vector<2xf32> %vec_2d_0 = llvm.insertvalue %vec_1d_0_1, %poison_2d_cast[0] : !llvm.array<2 x vector<2xf32>> %poison_1d_1 = llvm.mlir.poison : vector<2xf32> %c0_1 = llvm.mlir.constant(0 : i64) : i64 %vec_1d_1_0 = llvm.insertelement %e2, %poison_1d_1[%c0_1 : i64] : vector<2xf32> %c1_1 = llvm.mlir.constant(1 : i64) : i64 %vec_1d_1_1 = llvm.insertelement %e3, %vec_1d_1_0[%c1_1 : i64] : vector<2xf32> %vec_2d_1 = llvm.insertvalue %vec_1d_1_1, %vec_2d_0[1] : !llvm.array<2 x vector<2xf32>> %result = builtin.unrealized_conversion_cast %vec_2d_1 : !llvm.array<2 x vector<2xf32>> to vector<2x2xf32> ``` --------- Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com> Co-authored-by: Yang Bai <yangb@nvidia.com> Co-authored-by: James Newling <james.newling@gmail.com> Co-authored-by: Diego Caballero <dieg0ca6aller0@gmail.com>
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f89c944b5c56..bb1598ee3efe 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -786,6 +786,28 @@ struct TestVectorGatherLowering
}
};
+struct TestUnrollVectorFromElements
+ : public PassWrapper<TestUnrollVectorFromElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorFromElements)
+
+ StringRef getArgument() const final {
+ return "test-unroll-vector-from-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for from_elements ops";
+ }
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect, ub::UBDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorFromElementsLoweringPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
@@ -1059,6 +1081,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorGatherLowering>();
+ PassRegistration<TestUnrollVectorFromElements>();
+
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
PassRegistration<TestVectorEmulateMaskedLoadStore>();