summaryrefslogtreecommitdiff
path: root/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
diff options
context:
space:
mode:
authordonald chen <chenxunyu1993@gmail.com>2024-10-11 21:59:05 +0800
committerGitHub <noreply@github.com>2024-10-11 21:59:05 +0800
commit4b3f251bada55cfc20a2c72321fa0bbfd7a759d5 (patch)
tree94a8dac0f9cc347af9467ca29a0969fa9a47d0a9 /mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
parent5dac691b66accd2f80c4291280efd5368986d7af (diff)
[mlir] [dataflow] unify semantics of program point (#110344)
The concept of a 'program point' in the original data flow framework is ambiguous. It can refer to either an operation or a block itself. This representation has different interpretations in forward and backward data-flow analysis. In forward data-flow analysis, the program point of an operation represents the state after the operation, while in backward data flow analysis, it represents the state before the operation. When using forward or backward data-flow analysis, it is crucial to carefully handle this distinction to ensure correctness. This patch refactors the definition of program point, unifying the interpretation of program points in both forward and backward data-flow analysis. How to integrate this patch? For dense forward data-flow analysis and other analysis (except dense backward data-flow analysis), the program point corresponding to the original operation can be obtained by `getProgramPointAfter(op)`, and the program point corresponding to the original block can be obtained by `getProgramPointBefore(block)`. For dense backward data-flow analysis, the program point corresponding to the original operation can be obtained by `getProgramPointBefore(op)`, and the program point corresponding to the original block can be obtained by `getProgramPointAfter(block)`. NOTE: If you need to get the lattice of other data-flow analyses in dense backward data-flow analysis, you should still use the dense forward data-flow approach. For example, to get the Executable state of a block in dense backward data-flow analysis and add the dependency of the current operation, you should write: ``getOrCreateFor<Executable>(getProgramPointBefore(op), getProgramPointBefore(block))`` In case above, we use getProgramPointBefore(op) because the analysis we rely on is dense backward data-flow, and we use getProgramPointBefore(block) because the lattice we query is the result of a non-dense backward data flow computation. related dsscussion: https://discourse.llvm.org/t/rfc-unify-the-semantics-of-program-points/80671/8 corresponding PSA: https://discourse.llvm.org/t/psa-program-point-semantics-change/81479
Diffstat (limited to 'mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp')
-rw-r--r--mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp76
1 files changed, 46 insertions, 30 deletions
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index beb68018a3b1..3c190d4e9919 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -46,22 +46,23 @@ void Executable::print(raw_ostream &os) const {
void Executable::onUpdate(DataFlowSolver *solver) const {
AnalysisState::onUpdate(solver);
- if (ProgramPoint pp = llvm::dyn_cast_if_present<ProgramPoint>(anchor)) {
- if (Block *block = llvm::dyn_cast_if_present<Block *>(pp)) {
+ if (ProgramPoint *pp = llvm::dyn_cast_if_present<ProgramPoint *>(anchor)) {
+ if (pp->isBlockStart()) {
// Re-invoke the analyses on the block itself.
for (DataFlowAnalysis *analysis : subscribers)
- solver->enqueue({block, analysis});
+ solver->enqueue({pp, analysis});
// Re-invoke the analyses on all operations in the block.
for (DataFlowAnalysis *analysis : subscribers)
- for (Operation &op : *block)
- solver->enqueue({&op, analysis});
+ for (Operation &op : *pp->getBlock())
+ solver->enqueue({solver->getProgramPointAfter(&op), analysis});
}
} else if (auto *latticeAnchor =
llvm::dyn_cast_if_present<GenericLatticeAnchor *>(anchor)) {
// Re-invoke the analysis on the successor block.
if (auto *edge = dyn_cast<CFGEdge>(latticeAnchor)) {
for (DataFlowAnalysis *analysis : subscribers)
- solver->enqueue({edge->getTo(), analysis});
+ solver->enqueue(
+ {solver->getProgramPointBefore(edge->getTo()), analysis});
}
}
}
@@ -125,7 +126,8 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
for (Region &region : top->getRegions()) {
if (region.empty())
continue;
- auto *state = getOrCreate<Executable>(&region.front());
+ auto *state =
+ getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
}
@@ -154,7 +156,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
// Public symbol callables or those for which we can't see all uses have
// potentially unknown callsites.
if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
- auto *state = getOrCreate<PredecessorState>(callable);
+ auto *state =
+ getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
}
foundSymbolCallable = true;
@@ -171,7 +174,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
// If we couldn't gather the symbol uses, conservatively assume that
// we can't track information for any nested symbols.
return top->walk([&](CallableOpInterface callable) {
- auto *state = getOrCreate<PredecessorState>(callable);
+ auto *state =
+ getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
});
}
@@ -182,7 +186,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
// If a callable symbol has a non-call use, then we can't be guaranteed to
// know all callsites.
Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
- auto *state = getOrCreate<PredecessorState>(symbol);
+ auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol));
propagateIfChanged(state, state->setHasUnknownPredecessors());
}
};
@@ -193,7 +197,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
/// Returns true if the operation is a returning terminator in region
/// control-flow or the terminator of a callable region.
static bool isRegionOrCallableReturn(Operation *op) {
- return !op->getNumSuccessors() &&
+ return op->getBlock() != nullptr && !op->getNumSuccessors() &&
isa<RegionBranchOpInterface, CallableOpInterface>(op->getParentOp()) &&
op->getBlock()->getTerminator() == op;
}
@@ -205,9 +209,10 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
// When the liveness of the parent block changes, make sure to re-invoke the
// analysis on the op.
if (op->getBlock())
- getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this);
+ getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
+ ->blockContentSubscribe(this);
// Visit the op.
- if (failed(visit(op)))
+ if (failed(visit(getProgramPointAfter(op))))
return failure();
}
// Recurse on nested operations.
@@ -219,7 +224,7 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
}
void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
- auto *state = getOrCreate<Executable>(to);
+ auto *state = getOrCreate<Executable>(getProgramPointBefore(to));
propagateIfChanged(state, state->setToLive());
auto *edgeState =
getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(from, to));
@@ -230,18 +235,20 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
for (Region &region : op->getRegions()) {
if (region.empty())
continue;
- auto *state = getOrCreate<Executable>(&region.front());
+ auto *state =
+ getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
}
}
-LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
- if (point.is<Block *>())
+LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
+ if (point->isBlockStart())
return success();
- auto *op = point.get<Operation *>();
+ Operation *op = point->getPrevOp();
// If the parent block is not executable, there is nothing to do.
- if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ if (op->getBlock() != nullptr &&
+ !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
return success();
// We have a live call op. Add this as a live predecessor of the callee.
@@ -256,7 +263,8 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
// Check if this is a callable operation.
} else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
- const auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
+ const auto *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(callable));
// If the callsites could not be resolved or are known to be non-empty,
// mark the callable as executable.
@@ -316,11 +324,13 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
!isExternalCallable(callableOp)) {
// Add the live callsite.
- auto *callsites = getOrCreate<PredecessorState>(callableOp);
+ auto *callsites =
+ getOrCreate<PredecessorState>(getProgramPointAfter(callableOp));
propagateIfChanged(callsites, callsites->join(call));
} else {
// Mark this call op's predecessors as overdefined.
- auto *predecessors = getOrCreate<PredecessorState>(call);
+ auto *predecessors =
+ getOrCreate<PredecessorState>(getProgramPointAfter(call));
propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
}
}
@@ -378,9 +388,10 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
branch.getEntrySuccessorRegions(*operands, successors);
for (const RegionSuccessor &successor : successors) {
// The successor can be either an entry block or the parent operation.
- ProgramPoint point = successor.getSuccessor()
- ? &successor.getSuccessor()->front()
- : ProgramPoint(branch);
+ ProgramPoint *point =
+ successor.getSuccessor()
+ ? getProgramPointBefore(&successor.getSuccessor()->front())
+ : getProgramPointAfter(branch);
// Mark the entry block as executable.
auto *state = getOrCreate<Executable>(point);
propagateIfChanged(state, state->setToLive());
@@ -409,12 +420,15 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
for (const RegionSuccessor &successor : successors) {
PredecessorState *predecessors;
if (Region *region = successor.getSuccessor()) {
- auto *state = getOrCreate<Executable>(&region->front());
+ auto *state =
+ getOrCreate<Executable>(getProgramPointBefore(&region->front()));
propagateIfChanged(state, state->setToLive());
- predecessors = getOrCreate<PredecessorState>(&region->front());
+ predecessors = getOrCreate<PredecessorState>(
+ getProgramPointBefore(&region->front()));
} else {
// Add this terminator as a predecessor to the parent op.
- predecessors = getOrCreate<PredecessorState>(branch);
+ predecessors =
+ getOrCreate<PredecessorState>(getProgramPointAfter(branch));
}
propagateIfChanged(predecessors,
predecessors->join(op, successor.getSuccessorInputs()));
@@ -424,11 +438,13 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
CallableOpInterface callable) {
// Add as predecessors to all callsites this return op.
- auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
+ auto *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(callable));
bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
for (Operation *predecessor : callsites->getKnownPredecessors()) {
assert(isa<CallOpInterface>(predecessor));
- auto *predecessors = getOrCreate<PredecessorState>(predecessor);
+ auto *predecessors =
+ getOrCreate<PredecessorState>(getProgramPointAfter(predecessor));
if (canResolve) {
propagateIfChanged(predecessors, predecessors->join(op));
} else {