summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorNishant Patel <nishant.b.patel@intel.com>2025-11-19 16:16:44 -0800
committerGitHub <noreply@github.com>2025-11-19 16:16:44 -0800
commitaf73aeaa19929127655d544b48a5145105e9e28c (patch)
tree26c8d6e0647cc17e1371da8d1150c3695607197a /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent7de59f0b247a481d81a3f2a4ce9f322c5a5c68ef (diff)
[MLIR][Vector] Add unroll pattern for vector.shape_cast (#167738)
This PR adds pattern for unrolling shape_cast given a targetShape. This PR is a follow up of #164010 which was very general and was using inserts and extracts on each element (which is also LowerVectorShapeCast.cpp is doing). After doing some more research on use cases, we (me and @Jianhui-Li ) realized that the previous version in #164010 is unnecessarily generic and doesn't fit our performance needs. Our use case requires that targetShape is contiguous in both source and result vector. This pattern only applies when contiguous slices can be extracted from the source vector and inserted into the result vector such that each slice remains in vector form with targetShape (and not decompose to scalars). In these cases, the unrolling proceeds as: vector.extract_strided_slice -> vector.shape_cast (on the slice unrolled) -> vector.insert_strided_slice
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda7..e8ea0cc02d7f 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -179,6 +179,28 @@ struct TestVectorUnrollingPatterns
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
+ patterns,
+ UnrollVectorOptions()
+ .setNativeShapeFn(
+ [](Operation *op) -> std::optional<SmallVector<int64_t>> {
+ auto shapeCast = dyn_cast<vector::ShapeCastOp>(op);
+ if (!shapeCast)
+ return std::nullopt;
+
+ auto resultShape = shapeCast.getResultVectorType().getShape();
+ // Special case with leading unit dims and different inner dim
+ // for result and target shape.
+ if (resultShape.size() == 2 && resultShape[0] == 1 &&
+ resultShape[1] == 32) {
+ return SmallVector<int64_t>{1, 16};
+ }
+ // Default case: [2,4] for all tests.
+ return SmallVector<int64_t>{2, 4};
+ })
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::ShapeCastOp>(op));
+ }));
+ populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
.setFilterConstraint([](Operation *op) {