From c02d07fdf007afc6b928cda0342751889cc2604b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrzej=20Warzy=C5=84ski?= Date: Wed, 13 Dec 2023 20:29:12 +0000 Subject: [mlir][vector] Add pattern to drop unit dim from elementwise(a, b)) (#74817) For vectors with either leading or trailing unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b = shape_cast(b) res = elementwise(sc_a, sc_b) return shape_cast(res) The newly inserted shape_cast Ops fold (before elementwise Op) and then restore (after elementwise Op) the unit dim. Vectors `a` and `b` are required to be rank > 1. Example: ```mlir %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32> %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32> ``` gets converted to: ```mlir %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32> %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32> %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32> %mul_sc = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32> %cast = vector.shape_cast %mul_sc : vector<1x[4]xf32> to vector<[4]xf32> ``` In practice, the bottom 2 shape_cast(s) will be folded away. --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 1 + 1 file changed, 1 insertion(+) (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp') diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index e593c0defcd2..a643343e9342 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -455,6 +455,7 @@ struct TestFlattenVectorTransferPatterns void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); + registry.insert(); } void runOnOperation() override { RewritePatternSet patterns(&getContext()); -- cgit v1.2.3