diff options
Diffstat (limited to 'mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp')
| -rw-r--r-- | mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp | 26 |
1 files changed, 19 insertions, 7 deletions
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp index ed5d6d4a7fe4..cdb715064b0f 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -31,7 +31,8 @@ namespace { class ConvertToLLVMPassInterface { public: ConvertToLLVMPassInterface(MLIRContext *context, - ArrayRef<std::string> filterDialects); + ArrayRef<std::string> filterDialects, + bool allowPatternRollback = true); virtual ~ConvertToLLVMPassInterface() = default; /// Get the dependent dialects used by `convert-to-llvm`. @@ -60,6 +61,9 @@ protected: MLIRContext *context; /// List of dialects names to use as filters. ArrayRef<std::string> filterDialects; + /// An experimental flag to disallow pattern rollback. This is more efficient + /// but not supported by all lowering patterns. + bool allowPatternRollback; }; /// This DialectExtension can be attached to the context, which will invoke the @@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface { /// Apply the conversion driver. LogicalResult transform(Operation *op, AnalysisManager manager) const final { - if (failed(applyPartialConversion(op, *target, *patterns))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, *target, *patterns, config))) return failure(); return success(); } @@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface { patterns); // Apply the conversion. - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, target, std::move(patterns), config))) return failure(); return success(); } @@ -206,9 +214,11 @@ public: std::shared_ptr<ConvertToLLVMPassInterface> impl; // Choose the pass implementation. if (useDynamic) - impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects); + impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects, + allowPatternRollback); else - impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects); + impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects, + allowPatternRollback); if (failed(impl->initialize())) return failure(); this->impl = impl; @@ -228,8 +238,10 @@ public: //===----------------------------------------------------------------------===// ConvertToLLVMPassInterface::ConvertToLLVMPassInterface( - MLIRContext *context, ArrayRef<std::string> filterDialects) - : context(context), filterDialects(filterDialects) {} + MLIRContext *context, ArrayRef<std::string> filterDialects, + bool allowPatternRollback) + : context(context), filterDialects(filterDialects), + allowPatternRollback(allowPatternRollback) {} void ConvertToLLVMPassInterface::getDependentDialects( DialectRegistry ®istry) { |
