diff options
| author | Nishant Patel <nishant.b.patel@intel.com> | 2025-11-19 16:16:44 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-11-19 16:16:44 -0800 |
| commit | af73aeaa19929127655d544b48a5145105e9e28c (patch) | |
| tree | 26c8d6e0647cc17e1371da8d1150c3695607197a /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 7de59f0b247a481d81a3f2a4ce9f322c5a5c68ef (diff) | |
[MLIR][Vector] Add unroll pattern for vector.shape_cast (#167738)
This PR adds pattern for unrolling shape_cast given a targetShape. This
PR is a follow up of #164010 which was very general and was using
inserts and extracts on each element (which is also
LowerVectorShapeCast.cpp is doing).
After doing some more research on use cases, we (me and @Jianhui-Li )
realized that the previous version in #164010 is unnecessarily generic
and doesn't fit our performance needs.
Our use case requires that targetShape is contiguous in both source and
result vector.
This pattern only applies when contiguous slices can be extracted from
the source vector and inserted into the result vector such that each
slice remains in vector form with targetShape (and not decompose to
scalars). In these cases, the unrolling proceeds as:
vector.extract_strided_slice -> vector.shape_cast (on the slice
unrolled) -> vector.insert_strided_slice
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 79bfc9bbcda7..e8ea0cc02d7f 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -179,6 +179,28 @@ struct TestVectorUnrollingPatterns return success(isa<vector::StepOp>(op)); })); populateVectorUnrollPatterns( + patterns, + UnrollVectorOptions() + .setNativeShapeFn( + [](Operation *op) -> std::optional<SmallVector<int64_t>> { + auto shapeCast = dyn_cast<vector::ShapeCastOp>(op); + if (!shapeCast) + return std::nullopt; + + auto resultShape = shapeCast.getResultVectorType().getShape(); + // Special case with leading unit dims and different inner dim + // for result and target shape. + if (resultShape.size() == 2 && resultShape[0] == 1 && + resultShape[1] == 32) { + return SmallVector<int64_t>{1, 16}; + } + // Default case: [2,4] for all tests. + return SmallVector<int64_t>{2, 4}; + }) + .setFilterConstraint([](Operation *op) { + return success(isa<vector::ShapeCastOp>(op)); + })); + populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) .setFilterConstraint([](Operation *op) { |
