summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp12
1 files changed, 10 insertions, 2 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 915f713f7047..f14fb18706d1 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -840,6 +840,9 @@ struct TestVectorLinearize final
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+ TestVectorLinearize() = default;
+ TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+
StringRef getArgument() const override { return "test-vector-linearize"; }
StringRef getDescription() const override {
return "Linearizes ND vectors for N >= 2 into 1D vectors";
@@ -848,6 +851,11 @@ struct TestVectorLinearize final
registry.insert<vector::VectorDialect>();
}
+ Option<unsigned> targetVectorBitwidth{
+ *this, "target-vector-bitwidth",
+ llvm::cl::desc(
+ "Minimum vector bitwidth to enable the flattening transformation"),
+ llvm::cl::init(std::numeric_limits<unsigned>::max())};
void runOnOperation() override {
auto *context = &getContext();
@@ -855,8 +863,8 @@ struct TestVectorLinearize final
RewritePatternSet patterns(context);
ConversionTarget target(*context);
- vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
- patterns, target);
+ vector::populateVectorLinearizeTypeConversionsAndLegality(
+ typeConverter, patterns, target, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();