summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorAndrzej WarzyƄski <andrzej.warzynski@arm.com>2023-12-13 20:29:12 +0000
committerGitHub <noreply@github.com>2023-12-13 20:29:12 +0000
commitc02d07fdf007afc6b928cda0342751889cc2604b (patch)
tree0f4b356dbcedadd2e0e21f79c75f74d8900bd67f /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent9512d6d2133a15a3e6272cbadd7fbb479011ccdb (diff)
[mlir][vector] Add pattern to drop unit dim from elementwise(a, b)) (#74817)
For vectors with either leading or trailing unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b = shape_cast(b) res = elementwise(sc_a, sc_b) return shape_cast(res) The newly inserted shape_cast Ops fold (before elementwise Op) and then restore (after elementwise Op) the unit dim. Vectors `a` and `b` are required to be rank > 1. Example: ```mlir %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32> %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32> ``` gets converted to: ```mlir %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32> %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32> %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32> %mul_sc = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32> %cast = vector.shape_cast %mul_sc : vector<1x[4]xf32> to vector<[4]xf32> ``` In practice, the bottom 2 shape_cast(s) will be folded away.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp1
1 files changed, 1 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e593c0defcd2..a643343e9342 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -455,6 +455,7 @@ struct TestFlattenVectorTransferPatterns
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
registry.insert<affine::AffineDialect>();
+ registry.insert<vector::VectorDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());