summaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/DialectConversion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/DialectConversion.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp159
1 files changed, 30 insertions, 129 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f288c7fc2cb7..b58a95c3baf7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -624,10 +624,9 @@ private:
class ReplaceOperationRewrite : public OperationRewrite {
public:
ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Operation *op, const TypeConverter *converter,
- bool changedResults)
+ Operation *op, const TypeConverter *converter)
: OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
- converter(converter), changedResults(changedResults) {}
+ converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::ReplaceOperation;
@@ -641,15 +640,10 @@ public:
const TypeConverter *getConverter() const { return converter; }
- bool hasChangedResults() const { return changedResults; }
-
private:
/// An optional type converter that can be used to materialize conversions
/// between the new and old values if necessary.
const TypeConverter *converter;
-
- /// A boolean flag that indicates whether result types have changed or not.
- bool changedResults;
};
class CreateOperationRewrite : public OperationRewrite {
@@ -941,6 +935,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;
+ /// A set of all unresolved materializations.
+ DenseSet<Operation *> unresolvedMaterializations;
+
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -1066,6 +1063,7 @@ void UnresolvedMaterializationRewrite::rollback() {
for (Value input : op->getOperands())
rewriterImpl.mapping.erase(input);
}
+ rewriterImpl.unresolvedMaterializations.erase(op);
op->erase();
}
@@ -1347,6 +1345,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+ unresolvedMaterializations.insert(convertOp);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
@@ -1379,22 +1378,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
assert(newValues.size() == op->getNumResults());
assert(!ignoredOps.contains(op) && "operation was already replaced");
- // Track if any of the results changed, e.g. erased and replaced with null.
- bool resultChanged = false;
-
// Create mappings for each of the new result values.
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
if (!newValue) {
- resultChanged = true;
- continue;
+ // This result was dropped and no replacement value was provided.
+ if (unresolvedMaterializations.contains(op)) {
+ // Do not create another materializations if we are erasing a
+ // materialization.
+ continue;
+ }
+
+ // Materialize a replacement value "out of thin air".
+ newValue = buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(result),
+ result.getLoc(), /*inputs=*/ValueRange(),
+ /*outputType=*/result.getType(), currentTypeConverter);
}
+
// Remap, and check for any result type changes.
mapping.map(result, newValue);
- resultChanged |= (newValue.getType() != result.getType());
}
- appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
- resultChanged);
+ appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
// Mark this operation and all nested ops as replaced.
op->walk([&](Operation *op) { replacedOps.insert(op); });
@@ -2359,11 +2364,6 @@ private:
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping);
- /// Legalize an operation result that was marked as "erased".
- LogicalResult
- legalizeErasedResult(Operation *op, OpResult result,
- ConversionPatternRewriterImpl &rewriterImpl);
-
/// Dialect conversion configuration.
ConversionConfig config;
@@ -2455,77 +2455,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
return failure();
}
-/// Erase all dead unrealized_conversion_cast ops. An op is dead if its results
-/// are not used (transitively) by any op that is not in the given list of
-/// cast ops.
-///
-/// In particular, this function erases cyclic casts that may be inserted
-/// during the dialect conversion process. E.g.:
-/// %0 = unrealized_conversion_cast(%1)
-/// %1 = unrealized_conversion_cast(%0)
-// Note: This step will become unnecessary when
-// https://github.com/llvm/llvm-project/pull/106760 has been merged.
-static void eraseDeadUnrealizedCasts(
- ArrayRef<UnrealizedConversionCastOp> castOps,
- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
- // Ops that have already been visited or are currently being visited.
- DenseSet<Operation *> visited;
- // Set of all cast ops for faster lookups.
- DenseSet<Operation *> castOpSet;
- // Set of all cast ops that have been determined to be alive.
- DenseSet<Operation *> live;
-
- for (UnrealizedConversionCastOp op : castOps)
- castOpSet.insert(op);
-
- // Visit a cast operation. Return "true" if the operation is live.
- std::function<bool(Operation *)> visit = [&](Operation *op) -> bool {
- // No need to traverse any IR if the op was already marked as live.
- if (live.contains(op))
- return true;
-
- // Do not visit ops multiple times. If we find a circle, no live user was
- // found on the current path.
- if (!visited.insert(op).second)
- return false;
-
- // Visit all users.
- for (Operation *user : op->getUsers()) {
- // If the user is not an unrealized_conversion_cast op, then the given op
- // is live.
- if (!castOpSet.contains(user)) {
- live.insert(op);
- return true;
- }
- // Otherwise, it is live if a live op can be reached from one of its
- // users (which must all be unrealized_conversion_cast ops).
- if (visit(user)) {
- live.insert(op);
- return true;
- }
- }
-
- return false;
- };
-
- // Visit all cast ops.
- for (UnrealizedConversionCastOp op : castOps) {
- visit(op);
- visited.clear();
- }
-
- // Erase all cast ops that are dead.
- for (UnrealizedConversionCastOp op : castOps) {
- if (live.contains(op)) {
- if (remainingCastOps)
- remainingCastOps->push_back(op);
- continue;
- }
- op->dropAllUses();
- op->erase();
- }
-}
-
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
@@ -2584,14 +2513,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Reconcile all UnrealizedConversionCastOps that were inserted by the
// dialect conversion frameworks. (Not the one that were inserted by
// patterns.)
- SmallVector<UnrealizedConversionCastOp> remainingCastOps1, remainingCastOps2;
- eraseDeadUnrealizedCasts(allCastOps, &remainingCastOps1);
- reconcileUnrealizedCasts(remainingCastOps1, &remainingCastOps2);
+ SmallVector<UnrealizedConversionCastOp> remainingCastOps;
+ reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
// Try to legalize all unresolved materializations.
if (config.buildMaterializations) {
IRRewriter rewriter(rewriterImpl.context, config.listener);
- for (UnrealizedConversionCastOp castOp : remainingCastOps2) {
+ for (UnrealizedConversionCastOp castOp : remainingCastOps) {
auto it = rewriteMap.find(castOp.getOperation());
assert(it != rewriteMap.end() && "inconsistent state");
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
@@ -2646,30 +2574,22 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
auto *opReplacement =
dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
- if (!opReplacement || !opReplacement->hasChangedResults())
+ if (!opReplacement)
continue;
Operation *op = opReplacement->getOperation();
for (OpResult result : op->getResults()) {
- Value newValue = rewriterImpl.mapping.lookupOrNull(result);
-
- // If the operation result was replaced with null, all of the uses of this
- // value should be replaced.
- if (!newValue) {
- if (failed(legalizeErasedResult(op, result, rewriterImpl)))
- return failure();
+ // If the type of this op result changed and the result is still live,
+ // we need to materialize a conversion.
+ if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
continue;
- }
-
- // Otherwise, check to see if the type of the result changed.
- if (result.getType() == newValue.getType())
- continue;
-
Operation *liveUser =
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
if (!liveUser)
continue;
// Legalize this result.
+ Value newValue = rewriterImpl.mapping.lookupOrNull(result);
+ assert(newValue && "replacement value not found");
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
/*inputs=*/newValue, /*outputType=*/result.getType(),
@@ -2727,25 +2647,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
return success();
}
-LogicalResult OperationConverter::legalizeErasedResult(
- Operation *op, OpResult result,
- ConversionPatternRewriterImpl &rewriterImpl) {
- // If the operation result was replaced with null, all of the uses of this
- // value should be replaced.
- auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
- return rewriterImpl.isOpIgnored(user);
- });
- if (liveUserIt != result.user_end()) {
- InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
- << op->getName() << "' marked as erased";
- diag.attachNote(liveUserIt->getLoc())
- << "found live user of result #" << result.getResultNumber() << ": "
- << *liveUserIt;
- return failure();
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//