summaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2025-09-22 13:49:03 +0200
committerGitHub <noreply@github.com>2025-09-22 13:49:03 +0200
commit96a3a58e18c65b424f2ffccc1dfacdb2015fb942 (patch)
treed40457a6d30b01a152cce43b92dea3a85d2507d3 /mlir/lib/Transforms
parentec5460bc7034b351b928d00432273bff9261fc11 (diff)
[mlir][Transforms] Simplify `ConversionPatternRewriter::replaceOp` implementation (#158075)
Move the logic for building "out-of-thin-air" source materializations during op replacements from `replaceOp` to `findOrBuildReplacementValue`. That function already builds source materializations and can handle the case where an op result is dropped. This commit is in preparation of turning `replaceOp` into a non-virtual function. (It is sufficient for `replaceAllUsesWith` and `eraseOp` to be virtual.)
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp59
1 files changed, 24 insertions, 35 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ff1e31536cea..bf0136b39e03 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1618,6 +1618,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
+ // Note: Materialization must be built here because we cannot find a
+ // valid insertion point in the new block. (Will point to the old block.)
Value mat =
buildUnresolvedMaterialization(
MaterializationKind::Source,
@@ -1725,29 +1727,29 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// (regardless of the type) and build a source materialization to the
// original type.
repl = lookupOrNull(value);
+
+ // Compute the insertion point of the materialization.
+ OpBuilder::InsertPoint ip;
if (repl.empty()) {
- // No replacement value is registered in the mapping. This means that the
- // value is dropped and no longer needed. (If the value were still needed,
- // a source materialization producing a replacement value "out of thin air"
- // would have already been created during `replaceOp` or
- // `applySignatureConversion`.)
- return Value();
+ // The source materialization has no inputs. Insert it right before the
+ // value that it is replacing.
+ ip = computeInsertPoint(value);
+ } else {
+ // Compute the "earliest" insertion point at which all values in `repl` are
+ // defined. It is important to emit the materialization at that location
+ // because the same materialization may be reused in a different context.
+ // (That's because materializations are cached in the conversion value
+ // mapping.) The insertion point of the materialization must be valid for
+ // all future users that may be created later in the conversion process.
+ ip = computeInsertPoint(repl);
}
-
- // Note: `computeInsertPoint` computes the "earliest" insertion point at
- // which all values in `repl` are defined. It is important to emit the
- // materialization at that location because the same materialization may be
- // reused in a different context. (That's because materializations are cached
- // in the conversion value mapping.) The insertion point of the
- // materialization must be valid for all future users that may be created
- // later in the conversion process.
- Value castValue =
- buildUnresolvedMaterialization(MaterializationKind::Source,
- computeInsertPoint(repl), value.getLoc(),
- /*valuesToMap=*/repl, /*inputs=*/repl,
- /*outputTypes=*/value.getType(),
- /*originalType=*/Type(), converter)
- .front();
+ Value castValue = buildUnresolvedMaterialization(
+ MaterializationKind::Source, ip, value.getLoc(),
+ /*valuesToMap=*/repl, /*inputs=*/repl,
+ /*outputTypes=*/value.getType(),
+ /*originalType=*/Type(), converter,
+ /*isPureTypeConversion=*/!repl.empty())
+ .front();
return castValue;
}
@@ -1897,21 +1899,8 @@ void ConversionPatternRewriterImpl::replaceOp(
}
// Create mappings for each of the new result values.
- for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
- if (repl.empty()) {
- // This result was dropped and no replacement value was provided.
- // Materialize a replacement value "out of thin air".
- buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(result),
- result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
- /*outputTypes=*/result.getType(), /*originalType=*/Type(),
- currentTypeConverter, /*isPureTypeConversion=*/false);
- continue;
- }
-
- // Remap result to replacement value.
+ for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
mapping.map(static_cast<Value>(result), std::move(repl));
- }
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
// Mark this operation and all nested ops as replaced.