diff options
| author | Yang Bai <baiyang0132@gmail.com> | 2025-08-19 01:09:12 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-08-18 10:09:12 -0700 |
| commit | 4eb1a07d7d1a9722e84490b0ff79d3ae5e260f76 (patch) | |
| tree | 23059f1cf8f99313ece4f31e2c363052910c4f2a /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 8135b7c1abd7d22f98cf3dbd7d7a93c9fc7755c6 (diff) | |
[mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering (#151175)
This patch introduces a new unrolling-based approach for lowering
multi-dimensional `vector.from_elements` operations.
**Implementation Details:**
1. **New Transform Pattern**: Added `UnrollFromElements` that unrolls a
N-D(N>=2) from_elements op to a (N-1)-D from_elements op align the
outermost dimension.
2. **Utility Functions**: Added `unrollVectorOp` to reuse the unroll
algo of vector.gather for vector.from_elements.
3. **Integration**: Added the unrolling pattern to the
convert-vector-to-llvm pass as a temporal transformation.
4. Use direct LLVM dialect operations instead of intermediate
vector.insert operations for efficiency in `VectorFromElementsLowering`.
**Example:**
```mlir
// unroll
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32>
=>
%poison_2d = ub.poison : vector<2x2xf32>
%vec_1d_0 = vector.from_elements %e0, %e1 : vector<2xf32>
%vec_2d_0 = vector.insert %vec_1d_0, %poison_2d [0] : vector<2xf32> into vector<2x2xf32>
%vec_1d_1 = vector.from_elements %e2, %e3 : vector<2xf32>
%result = vector.insert %vec_1d_1, %vec_2d_0 [1] : vector<2xf32> into vector<2x2xf32>
// convert-vector-to-llvm
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32>
=>
%poison_2d = ub.poison : vector<2x2xf32>
%poison_2d_cast = builtin.unrealized_conversion_cast %poison_2d : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
%poison_1d_0 = llvm.mlir.poison : vector<2xf32>
%c0_0 = llvm.mlir.constant(0 : i64) : i64
%vec_1d_0_0 = llvm.insertelement %e0, %poison_1d_0[%c0_0 : i64] : vector<2xf32>
%c1_0 = llvm.mlir.constant(1 : i64) : i64
%vec_1d_0_1 = llvm.insertelement %e1, %vec_1d_0_0[%c1_0 : i64] : vector<2xf32>
%vec_2d_0 = llvm.insertvalue %vec_1d_0_1, %poison_2d_cast[0] : !llvm.array<2 x vector<2xf32>>
%poison_1d_1 = llvm.mlir.poison : vector<2xf32>
%c0_1 = llvm.mlir.constant(0 : i64) : i64
%vec_1d_1_0 = llvm.insertelement %e2, %poison_1d_1[%c0_1 : i64] : vector<2xf32>
%c1_1 = llvm.mlir.constant(1 : i64) : i64
%vec_1d_1_1 = llvm.insertelement %e3, %vec_1d_1_0[%c1_1 : i64] : vector<2xf32>
%vec_2d_1 = llvm.insertvalue %vec_1d_1_1, %vec_2d_0[1] : !llvm.array<2 x vector<2xf32>>
%result = builtin.unrealized_conversion_cast %vec_2d_1 : !llvm.array<2 x vector<2xf32>> to vector<2x2xf32>
```
---------
Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com>
Co-authored-by: Yang Bai <yangb@nvidia.com>
Co-authored-by: James Newling <james.newling@gmail.com>
Co-authored-by: Diego Caballero <dieg0ca6aller0@gmail.com>
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f89c944b5c56..bb1598ee3efe 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -786,6 +786,28 @@ struct TestVectorGatherLowering } }; +struct TestUnrollVectorFromElements + : public PassWrapper<TestUnrollVectorFromElements, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorFromElements) + + StringRef getArgument() const final { + return "test-unroll-vector-from-elements"; + } + StringRef getDescription() const final { + return "Test unrolling patterns for from_elements ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<func::FuncDialect, vector::VectorDialect, ub::UBDialect>(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorFromElementsLoweringPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestFoldArithExtensionIntoVectorContractPatterns : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns, OperationPass<func::FuncOp>> { @@ -1059,6 +1081,8 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorGatherLowering>(); + PassRegistration<TestUnrollVectorFromElements>(); + PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>(); PassRegistration<TestVectorEmulateMaskedLoadStore>(); |
