summaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/DialectConversion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/DialectConversion.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp352
1 files changed, 293 insertions, 59 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2470f2b122de..001c13e1ab08 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -182,15 +182,24 @@ private:
/// conversions.)
static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
+/// Return the operation that defines all values in the vector. Return nullptr
+/// if the values are not defined by the same operation.
+static Operation *getCommonDefiningOp(const ValueVector &values) {
+ assert(!values.empty() && "expected non-empty value vector");
+ Operation *op = values.front().getDefiningOp();
+ for (Value v : llvm::drop_begin(values)) {
+ if (v.getDefiningOp() != op)
+ return nullptr;
+ }
+ return op;
+}
+
/// A vector of values is a pure type conversion if all values are defined by
/// the same operation and the operation has the `kPureTypeConversionMarker`
/// attribute.
static bool isPureTypeConversion(const ValueVector &values) {
assert(!values.empty() && "expected non-empty value vector");
- Operation *op = values.front().getDefiningOp();
- for (Value v : llvm::drop_begin(values))
- if (v.getDefiningOp() != op)
- return false;
+ Operation *op = getCommonDefiningOp(values);
return op && op->hasAttr(kPureTypeConversionMarker);
}
@@ -841,7 +850,7 @@ namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
- : context(ctx), config(config) {}
+ : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -863,6 +872,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// failure.
template <typename RewriteTy, typename... Args>
void appendRewrite(Args &&...args) {
+ assert(config.allowPatternRollback && "appending rewrites is not allowed");
rewrites.push_back(
std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
}
@@ -889,15 +899,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasOpReplaced(Operation *op) const;
/// Lookup the most recently mapped values with the desired types in the
- /// mapping.
- ///
- /// Special cases:
- /// - If the desired type range is empty, simply return the most recently
- /// mapped values.
- /// - If there is no mapping to the desired types, also return the most
- /// recently mapped values.
- /// - If there is no mapping for the given values at all, return the given
- /// value.
+ /// mapping, taking into account only replacements. Perform a best-effort
+ /// search for existing materializations with the desired types.
///
/// If `skipPureTypeConversions` is "true", materializations that are pure
/// type conversions are not considered.
@@ -1066,6 +1069,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ConversionValueMapping mapping;
/// Ordered list of block operations (creations, splits, motions).
+ /// This vector is maintained only if `allowPatternRollback` is set to
+ /// "true". Otherwise, all IR rewrites are materialized immediately and no
+ /// bookkeeping is needed.
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
/// A set of operations that should no longer be considered for legalization.
@@ -1089,6 +1095,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;
+ /// A list of unresolved materializations that were created by the current
+ /// pattern.
+ DenseSet<UnrealizedConversionCastOp> patternMaterializations;
+
/// A mapping for looking up metadata of unresolved materializations.
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
unresolvedMaterializations;
@@ -1104,6 +1114,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Dialect conversion configuration.
const ConversionConfig &config;
+ /// A set of erased operations. This set is utilized only if
+ /// `allowPatternRollback` is set to "false". Conceptually, this set is
+ /// similar to `replacedOps` (which is maintained when the flag is set to
+ /// "true"). However, erasing from a DenseSet is more efficient than erasing
+ /// from a SetVector.
+ DenseSet<Operation *> erasedOps;
+
+ /// A set of erased blocks. This set is utilized only if
+ /// `allowPatternRollback` is set to "false".
+ DenseSet<Block *> erasedBlocks;
+
+ /// A rewriter that notifies the listener (if any) about all IR
+ /// modifications. This rewriter is utilized only if `allowPatternRollback`
+ /// is set to "false". If the flag is set to "true", the listener is notified
+ /// with a separate mechanism (e.g., in `IRRewrite::commit`).
+ IRRewriter notifyingRewriter;
+
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
@@ -1140,11 +1167,8 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
- if (!repl)
- return;
-
+static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
+ Value repl) {
if (isa<BlockArgument>(repl)) {
rewriter.replaceAllUsesWith(arg, repl);
return;
@@ -1161,6 +1185,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
});
}
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+ if (!repl)
+ return;
+ performReplaceBlockArg(rewriter, arg, repl);
+}
+
void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
@@ -1246,6 +1277,30 @@ void ConversionPatternRewriterImpl::applyRewrites() {
ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
+ // Helper function that looks up a single value.
+ auto lookup = [&](const ValueVector &values) -> ValueVector {
+ assert(!values.empty() && "expected non-empty value vector");
+
+ // If the pattern rollback is enabled, use the mapping to look up the
+ // values.
+ if (config.allowPatternRollback)
+ return mapping.lookup(values);
+
+ // Otherwise, look up values by examining the IR. All replacements have
+ // already been materialized in IR.
+ Operation *op = getCommonDefiningOp(values);
+ if (!op)
+ return {};
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
+ if (!castOp)
+ return {};
+ if (!this->unresolvedMaterializations.contains(castOp))
+ return {};
+ if (castOp.getOutputs() != values)
+ return {};
+ return castOp.getInputs();
+ };
+
// Helper function that looks up each value in `values` individually and then
// composes the results. If that fails, it tries to look up the entire vector
// at once.
@@ -1253,7 +1308,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
// If possible, replace each value with (one or multiple) mapped values.
ValueVector next;
for (Value v : values) {
- ValueVector r = mapping.lookup({v});
+ ValueVector r = lookup({v});
if (!r.empty()) {
llvm::append_range(next, r);
} else {
@@ -1273,7 +1328,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
// be stored (and looked up) in the mapping. But for performance reasons,
// we choose to reuse existing IR (when possible) instead of creating it
// multiple times.
- ValueVector r = mapping.lookup(values);
+ ValueVector r = lookup(values);
if (r.empty()) {
// No mapping found: The lookup stops here.
return {};
@@ -1347,15 +1402,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state,
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
StringRef patternName) {
for (auto &rewrite :
- llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
- if (!config.allowPatternRollback &&
- !isa<UnresolvedMaterializationRewrite>(rewrite)) {
- // Unresolved materializations can always be rolled back (erased).
- llvm::report_fatal_error("pattern '" + patternName +
- "' rollback of IR modifications requested");
- }
+ llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
rewrite->rollback();
- }
rewrites.resize(numRewritesToKeep);
}
@@ -1419,12 +1467,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation is ignored or was replaced.
- return replacedOps.count(op) || ignoredOps.count(op);
+ return wasOpReplaced(op) || ignoredOps.count(op);
}
bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
// Check to see if this operation was replaced.
- return replacedOps.count(op);
+ return replacedOps.count(op) || erasedOps.count(op);
}
//===----------------------------------------------------------------------===//
@@ -1508,7 +1556,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// a bit more efficient, so we try to do that when possible.
bool fastPath = !config.listener;
if (fastPath) {
- appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+ if (config.allowPatternRollback)
+ appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
newBlock->getOperations().splice(newBlock->end(), block->getOperations());
} else {
while (!block->empty())
@@ -1556,7 +1605,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
replaceUsesOfBlockArgument(origArg, replArgs, converter);
}
- appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
+ if (config.allowPatternRollback)
+ appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1585,23 +1635,32 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// tracking the materialization like we do for other operations.
OpBuilder builder(outputTypes.front().getContext());
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
- auto convertOp =
+ UnrealizedConversionCastOp convertOp =
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
if (isPureTypeConversion)
convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
- if (!valuesToMap.empty())
- mapping.map(valuesToMap, convertOp.getResults());
+
+ // Register the materialization.
if (castOp)
*castOp = convertOp;
unresolvedMaterializations[convertOp] =
UnresolvedMaterializationInfo(converter, kind, originalType);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
- std::move(valuesToMap));
+ if (config.allowPatternRollback) {
+ if (!valuesToMap.empty())
+ mapping.map(valuesToMap, convertOp.getResults());
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+ std::move(valuesToMap));
+ } else {
+ patternMaterializations.insert(convertOp);
+ }
return convertOp.getResults();
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
+ assert(config.allowPatternRollback &&
+ "this code path is valid only in rollback mode");
+
// Try to find a replacement value with the same type in the conversion value
// mapping. This includes cached materializations. We try to reuse those
// instead of generating duplicate IR.
@@ -1663,26 +1722,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(op->getParentOp()) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
"attempting to insert into a block within a replaced/erased op");
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyOperationInserted(op, previous);
+
if (wasDetached) {
- // If the op was detached, it is most likely a newly created op.
- // TODO: If the same op is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same op multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateOperationRewrite>(op);
+ // If the op was detached, it is most likely a newly created op. Add it the
+ // set of newly created ops, so that it will be legalized. If this op is
+ // not a newly created op, it will be legalized a second time, which is
+ // inefficient but harmless.
patternNewOps.insert(op);
+
+ if (config.allowPatternRollback) {
+ // TODO: If the same op is inserted multiple times from a detached
+ // state, the rollback mechanism may erase the same op multiple times.
+ // This is a bug in the rollback-based dialect conversion driver.
+ appendRewrite<CreateOperationRewrite>(op);
+ } else {
+ // In "no rollback" mode, there is an extra data structure for tracking
+ // erased operations that must be kept up to date.
+ erasedOps.erase(op);
+ }
return;
}
// The op was moved from one place to another.
- appendRewrite<MoveOperationRewrite>(op, previous);
+ if (config.allowPatternRollback)
+ appendRewrite<MoveOperationRewrite>(op, previous);
+}
+
+/// Given that `fromRange` is about to be replaced with `toRange`, compute
+/// replacement values with the types of `fromRange`.
+static SmallVector<Value>
+getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
+ const SmallVector<SmallVector<Value>> &toRange,
+ const TypeConverter *converter) {
+ assert(!impl.config.allowPatternRollback &&
+ "this code path is valid only in 'no rollback' mode");
+ SmallVector<Value> repls;
+ for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
+ if (from.use_empty()) {
+ // The replaced value is dead. No replacement value is needed.
+ repls.push_back(Value());
+ continue;
+ }
+
+ if (to.empty()) {
+ // The replaced value is dropped. Materialize a replacement value "out of
+ // thin air".
+ Value srcMat = impl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
+ /*outputTypes=*/from.getType(), /*originalType=*/Type(),
+ converter)[0];
+ repls.push_back(srcMat);
+ continue;
+ }
+
+ if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
+ // The replacement value already has the correct type. Use it directly.
+ repls.push_back(to[0]);
+ continue;
+ }
+
+ // The replacement value has the wrong type. Build a source materialization
+ // to the original type.
+ // TODO: This is a bit inefficient. We should try to reuse existing
+ // materializations if possible. This would require an extension of the
+ // `lookupOrDefault` API.
+ Value srcMat = impl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
+ /*originalType=*/Type(), converter)[0];
+ repls.push_back(srcMat);
+ }
+
+ return repls;
}
void ConversionPatternRewriterImpl::replaceOp(
Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
- assert(newValues.size() == op->getNumResults());
+ assert(newValues.size() == op->getNumResults() &&
+ "incorrect number of replacement values");
+
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed: materialize all IR changes immediately.
+ SmallVector<Value> repls = getReplacementValues(
+ *this, op->getResults(), newValues, currentTypeConverter);
+ // Update internal data structures, so that there are no dangling pointers
+ // to erased IR.
+ op->walk([&](Operation *op) {
+ erasedOps.insert(op);
+ ignoredOps.remove(op);
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ unresolvedMaterializations.erase(castOp);
+ patternMaterializations.erase(castOp);
+ }
+ // The original op will be erased, so remove it from the set of
+ // unlegalized ops.
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->erase(op);
+ });
+ op->walk([&](Block *block) { erasedBlocks.insert(block); });
+ // Replace the op with the replacement values and notify the listener.
+ notifyingRewriter.replaceOp(op, repls);
+ return;
+ }
+
assert(!ignoredOps.contains(op) && "operation was already replaced");
// Check if replaced op is an unresolved materialization, i.e., an
@@ -1722,11 +1874,46 @@ void ConversionPatternRewriterImpl::replaceOp(
void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
BlockArgument from, ValueRange to, const TypeConverter *converter) {
+ if (!config.allowPatternRollback) {
+ SmallVector<Value> toConv = llvm::to_vector(to);
+ SmallVector<Value> repls =
+ getReplacementValues(*this, from, {toConv}, converter);
+ IRRewriter r(from.getContext());
+ Value repl = repls.front();
+ if (!repl)
+ return;
+
+ performReplaceBlockArg(r, from, repl);
+ return;
+ }
+
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
mapping.map(from, to);
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed: materialize all IR changes immediately.
+ // Update internal data structures, so that there are no dangling pointers
+ // to erased IR.
+ block->walk([&](Operation *op) {
+ erasedOps.insert(op);
+ ignoredOps.remove(op);
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ unresolvedMaterializations.erase(castOp);
+ patternMaterializations.erase(castOp);
+ }
+ // The original op will be erased, so remove it from the set of
+ // unlegalized ops.
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->erase(op);
+ });
+ block->walk([&](Block *block) { erasedBlocks.insert(block); });
+ // Erase the block and notify the listener.
+ notifyingRewriter.eraseBlock(block);
+ return;
+ }
+
assert(!wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");
appendRewrite<EraseBlockRewrite>(block);
@@ -1760,23 +1947,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(newParentOp) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
"attempting to insert into a region within a replaced/erased op");
(void)newParentOp;
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyBlockInserted(block, previous, previousIt);
+
patternInsertedBlocks.insert(block);
if (wasDetached) {
// If the block was detached, it is most likely a newly created block.
- // TODO: If the same block is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same block multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateBlockRewrite>(block);
+ if (config.allowPatternRollback) {
+ // TODO: If the same block is inserted multiple times from a detached
+ // state, the rollback mechanism may erase the same block multiple times.
+ // This is a bug in the rollback-based dialect conversion driver.
+ appendRewrite<CreateBlockRewrite>(block);
+ } else {
+ // In "no rollback" mode, there is an extra data structure for tracking
+ // erased blocks that must be kept up to date.
+ erasedBlocks.erase(block);
+ }
return;
}
// The block was moved from one place to another.
- appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
+ if (config.allowPatternRollback)
+ appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
}
void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
@@ -1956,7 +2157,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// a bit more efficient, so we try to do that when possible.
bool fastPath = !getConfig().listener;
- if (fastPath)
+ if (fastPath && impl->config.allowPatternRollback)
impl->inlineBlockBefore(source, dest, before);
// Replace all uses of block arguments.
@@ -1982,6 +2183,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
}
void ConversionPatternRewriter::startOpModification(Operation *op) {
+ if (!impl->config.allowPatternRollback) {
+ // Pattern rollback is not allowed: no extra bookkeeping is needed.
+ PatternRewriter::startOpModification(op);
+ return;
+ }
assert(!impl->wasOpReplaced(op) &&
"attempting to modify a replaced/erased op");
#ifndef NDEBUG
@@ -1991,20 +2197,29 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
}
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
- assert(!impl->wasOpReplaced(op) &&
- "attempting to modify a replaced/erased op");
- PatternRewriter::finalizeOpModification(op);
impl->patternModifiedOps.insert(op);
+ if (!impl->config.allowPatternRollback) {
+ PatternRewriter::finalizeOpModification(op);
+ if (getConfig().listener)
+ getConfig().listener->notifyOperationModified(op);
+ return;
+ }
// There is nothing to do here, we only need to track the operation at the
// start of the update.
#ifndef NDEBUG
+ assert(!impl->wasOpReplaced(op) &&
+ "attempting to modify a replaced/erased op");
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
#endif
}
void ConversionPatternRewriter::cancelOpModification(Operation *op) {
+ if (!impl->config.allowPatternRollback) {
+ PatternRewriter::cancelOpModification(op);
+ return;
+ }
#ifndef NDEBUG
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
@@ -2439,17 +2654,23 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
-#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (!rewriterImpl.config.allowPatternRollback) {
- // Returning "failure" after modifying IR is not allowed.
+ // Erase all unresolved materializations.
+ for (auto op : rewriterImpl.patternMaterializations) {
+ rewriterImpl.unresolvedMaterializations.erase(op);
+ op.erase();
+ }
+ rewriterImpl.patternMaterializations.clear();
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ // Expensive pattern check that can detect API violations.
if (checkOp) {
OperationFingerPrint fingerPrintAfterPattern(checkOp);
if (fingerPrintAfterPattern != *topLevelFingerPrint)
llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
"' returned failure but IR did change");
}
- }
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ }
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
@@ -2473,6 +2694,16 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// successfully applied.
auto onSuccess = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // Eagerly erase unused materializations.
+ for (auto op : rewriterImpl.patternMaterializations) {
+ if (op->use_empty()) {
+ rewriterImpl.unresolvedMaterializations.erase(op);
+ op.erase();
+ }
+ }
+ rewriterImpl.patternMaterializations.clear();
+ }
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
@@ -2563,6 +2794,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the pattern moved or created any blocks, make sure the types of block
// arguments get legalized.
for (Block *block : insertedBlocks) {
+ if (impl.erasedBlocks.contains(block))
+ continue;
+
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
if (!parentOp || parentOp == op || block->getNumArguments() == 0)