diff options
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 79bfc9bbcda7..e8ea0cc02d7f 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -179,6 +179,28 @@ struct TestVectorUnrollingPatterns return success(isa<vector::StepOp>(op)); })); populateVectorUnrollPatterns( + patterns, + UnrollVectorOptions() + .setNativeShapeFn( + [](Operation *op) -> std::optional<SmallVector<int64_t>> { + auto shapeCast = dyn_cast<vector::ShapeCastOp>(op); + if (!shapeCast) + return std::nullopt; + + auto resultShape = shapeCast.getResultVectorType().getShape(); + // Special case with leading unit dims and different inner dim + // for result and target shape. + if (resultShape.size() == 2 && resultShape[0] == 1 && + resultShape[1] == 32) { + return SmallVector<int64_t>{1, 16}; + } + // Default case: [2,4] for all tests. + return SmallVector<int64_t>{2, 4}; + }) + .setFilterConstraint([](Operation *op) { + return success(isa<vector::ShapeCastOp>(op)); + })); + populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) .setFilterConstraint([](Operation *op) { |
