summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda7..e8ea0cc02d7f 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -179,6 +179,28 @@ struct TestVectorUnrollingPatterns
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
+ patterns,
+ UnrollVectorOptions()
+ .setNativeShapeFn(
+ [](Operation *op) -> std::optional<SmallVector<int64_t>> {
+ auto shapeCast = dyn_cast<vector::ShapeCastOp>(op);
+ if (!shapeCast)
+ return std::nullopt;
+
+ auto resultShape = shapeCast.getResultVectorType().getShape();
+ // Special case with leading unit dims and different inner dim
+ // for result and target shape.
+ if (resultShape.size() == 2 && resultShape[0] == 1 &&
+ resultShape[1] == 32) {
+ return SmallVector<int64_t>{1, 16};
+ }
+ // Default case: [2,4] for all tests.
+ return SmallVector<int64_t>{2, 4};
+ })
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::ShapeCastOp>(op));
+ }));
+ populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
.setFilterConstraint([](Operation *op) {