diff options
| author | Matthias Springer <me@m-sp.org> | 2025-06-15 07:59:06 +0000 |
|---|---|---|
| committer | Matthias Springer <me@m-sp.org> | 2025-06-18 09:26:35 +0000 |
| commit | 90de0cfa3cbc8bb453a2379e6e44c2a43aafcc92 (patch) | |
| tree | aeb58614dbe6585827a14c55b07210f37f3d287a | |
| parent | c6ca8c86fe46e61c94df3ba550e5a67dff3b554a (diff) | |
immediately materializeusers/matthias-springer/dialect_conv_immediately_materialize
| -rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 33 |
1 files changed, 28 insertions, 5 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c4b85ec4f67d..f115c6876282 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -116,6 +116,8 @@ struct ValueVectorMapInfo { /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { + ConversionValueMapping(const ConversionConfig &config) : config(config) {} + /// Return "true" if an SSA value is mapped to the given value. May return /// false positives. bool isMappedTo(Value value) const { return mappedTo.contains(value); } @@ -144,6 +146,12 @@ struct ConversionValueMapping { template <typename OldVal, typename NewVal> std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value> map(OldVal &&oldVal, NewVal &&newVal) { + if (!config.allowPatternRollback) { + // Rollbacks are not allowed. Rewrites are applied immediately. The + // mapping is not used. + return; + } + LLVM_DEBUG({ ValueVector next(newVal); while (true) { @@ -186,6 +194,9 @@ private: /// All SSA values that are mapped to. May contain false positives. DenseSet<Value> mappedTo; + + /// The configuration of the dialect conversion. + const ConversionConfig &config; }; } // namespace @@ -848,7 +859,7 @@ namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config) - : context(ctx), config(config) {} + : context(ctx), mapping(config), config(config) {} //===--------------------------------------------------------------------===// // State Management @@ -870,8 +881,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// failure. template <typename RewriteTy, typename... Args> void appendRewrite(Args &&...args) { - rewrites.push_back( - std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...)); + if (config.allowPatternRollback) { + rewrites.push_back( + std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...)); + return; + } + + // Rollbacks are not allowed. Apply the rewrite immediately. + IRRewriter rewriter(context, config.listener); + RewriteTy rewrite(*this, std::forward<Args>(args)...); + rewrite.commit(rewriter); + rewrite.cleanup(rewriter); } /// Undo the rewrites (motions, splits) one by one in reverse order until @@ -1595,8 +1615,11 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( } appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter); - // Mark this operation and all nested ops as replaced. - op->walk([&](Operation *op) { replacedOps.insert(op); }); + + if (config.allowPatternRollback) { + // Mark this operation and all nested ops as replaced. + op->walk([&](Operation *op) { replacedOps.insert(op); }); + } } void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { |
