summaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp')
-rw-r--r--mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp26
1 files changed, 19 insertions, 7 deletions
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index ed5d6d4a7fe4..cdb715064b0f 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -31,7 +31,8 @@ namespace {
class ConvertToLLVMPassInterface {
public:
ConvertToLLVMPassInterface(MLIRContext *context,
- ArrayRef<std::string> filterDialects);
+ ArrayRef<std::string> filterDialects,
+ bool allowPatternRollback = true);
virtual ~ConvertToLLVMPassInterface() = default;
/// Get the dependent dialects used by `convert-to-llvm`.
@@ -60,6 +61,9 @@ protected:
MLIRContext *context;
/// List of dialects names to use as filters.
ArrayRef<std::string> filterDialects;
+ /// An experimental flag to disallow pattern rollback. This is more efficient
+ /// but not supported by all lowering patterns.
+ bool allowPatternRollback;
};
/// This DialectExtension can be attached to the context, which will invoke the
@@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
/// Apply the conversion driver.
LogicalResult transform(Operation *op, AnalysisManager manager) const final {
- if (failed(applyPartialConversion(op, *target, *patterns)))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(op, *target, *patterns, config)))
return failure();
return success();
}
@@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
patterns);
// Apply the conversion.
- if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
return failure();
return success();
}
@@ -206,9 +214,11 @@ public:
std::shared_ptr<ConvertToLLVMPassInterface> impl;
// Choose the pass implementation.
if (useDynamic)
- impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
+ impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
+ allowPatternRollback);
else
- impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
+ impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
+ allowPatternRollback);
if (failed(impl->initialize()))
return failure();
this->impl = impl;
@@ -228,8 +238,10 @@ public:
//===----------------------------------------------------------------------===//
ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
- MLIRContext *context, ArrayRef<std::string> filterDialects)
- : context(context), filterDialects(filterDialects) {}
+ MLIRContext *context, ArrayRef<std::string> filterDialects,
+ bool allowPatternRollback)
+ : context(context), filterDialects(filterDialects),
+ allowPatternRollback(allowPatternRollback) {}
void ConvertToLLVMPassInterface::getDependentDialects(
DialectRegistry &registry) {