diff options
| author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2023-12-13 20:29:12 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-12-13 20:29:12 +0000 |
| commit | c02d07fdf007afc6b928cda0342751889cc2604b (patch) | |
| tree | 0f4b356dbcedadd2e0e21f79c75f74d8900bd67f /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 9512d6d2133a15a3e6272cbadd7fbb479011ccdb (diff) | |
[mlir][vector] Add pattern to drop unit dim from elementwise(a, b)) (#74817)
For vectors with either leading or trailing unit dim, replaces:
elementwise(a, b)
with:
sc_a = shape_cast(a)
sc_b = shape_cast(b)
res = elementwise(sc_a, sc_b)
return shape_cast(res)
The newly inserted shape_cast Ops fold (before elementwise Op) and then
restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
required to be rank > 1.
Example:
```mlir
%mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
%cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
```
gets converted to:
```mlir
%B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
%A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
%mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
%mul_sc = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
%cast = vector.shape_cast %mul_sc : vector<1x[4]xf32> to vector<[4]xf32>
```
In practice, the bottom 2 shape_cast(s) will be folded away.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 1 |
1 files changed, 1 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index e593c0defcd2..a643343e9342 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -455,6 +455,7 @@ struct TestFlattenVectorTransferPatterns void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<memref::MemRefDialect>(); registry.insert<affine::AffineDialect>(); + registry.insert<vector::VectorDialect>(); } void runOnOperation() override { RewritePatternSet patterns(&getContext()); |
