diff options
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 915f713f7047..f14fb18706d1 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -840,6 +840,9 @@ struct TestVectorLinearize final : public PassWrapper<TestVectorLinearize, OperationPass<>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) + TestVectorLinearize() = default; + TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {} + StringRef getArgument() const override { return "test-vector-linearize"; } StringRef getDescription() const override { return "Linearizes ND vectors for N >= 2 into 1D vectors"; @@ -848,6 +851,11 @@ struct TestVectorLinearize final registry.insert<vector::VectorDialect>(); } + Option<unsigned> targetVectorBitwidth{ + *this, "target-vector-bitwidth", + llvm::cl::desc( + "Minimum vector bitwidth to enable the flattening transformation"), + llvm::cl::init(std::numeric_limits<unsigned>::max())}; void runOnOperation() override { auto *context = &getContext(); @@ -855,8 +863,8 @@ struct TestVectorLinearize final RewritePatternSet patterns(context); ConversionTarget target(*context); - vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter, - patterns, target); + vector::populateVectorLinearizeTypeConversionsAndLegality( + typeConverter, patterns, target, targetVectorBitwidth); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); |
