diff options
| author | donald chen <chenxunyu1993@gmail.com> | 2024-10-11 21:59:05 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-11 21:59:05 +0800 |
| commit | 4b3f251bada55cfc20a2c72321fa0bbfd7a759d5 (patch) | |
| tree | 94a8dac0f9cc347af9467ca29a0969fa9a47d0a9 /mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp | |
| parent | 5dac691b66accd2f80c4291280efd5368986d7af (diff) | |
[mlir] [dataflow] unify semantics of program point (#110344)
The concept of a 'program point' in the original data flow framework is
ambiguous. It can refer to either an operation or a block itself. This
representation has different interpretations in forward and backward
data-flow analysis. In forward data-flow analysis, the program point of
an operation represents the state after the operation, while in backward
data flow analysis, it represents the state before the operation. When
using forward or backward data-flow analysis, it is crucial to carefully
handle this distinction to ensure correctness.
This patch refactors the definition of program point, unifying the
interpretation of program points in both forward and backward data-flow
analysis.
How to integrate this patch?
For dense forward data-flow analysis and other analysis (except dense
backward data-flow analysis), the program point corresponding to the
original operation can be obtained by `getProgramPointAfter(op)`, and
the program point corresponding to the original block can be obtained by
`getProgramPointBefore(block)`.
For dense backward data-flow analysis, the program point corresponding
to the original operation can be obtained by
`getProgramPointBefore(op)`, and the program point corresponding to the
original block can be obtained by `getProgramPointAfter(block)`.
NOTE: If you need to get the lattice of other data-flow analyses in
dense backward data-flow analysis, you should still use the dense
forward data-flow approach. For example, to get the Executable state of
a block in dense backward data-flow analysis and add the dependency of
the current operation, you should write:
``getOrCreateFor<Executable>(getProgramPointBefore(op),
getProgramPointBefore(block))``
In case above, we use getProgramPointBefore(op) because the analysis we
rely on is dense backward data-flow, and we use
getProgramPointBefore(block) because the lattice we query is the result
of a non-dense backward data flow computation.
related dsscussion:
https://discourse.llvm.org/t/rfc-unify-the-semantics-of-program-points/80671/8
corresponding PSA:
https://discourse.llvm.org/t/psa-program-point-semantics-change/81479
Diffstat (limited to 'mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp')
| -rw-r--r-- | mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp | 76 |
1 files changed, 46 insertions, 30 deletions
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index beb68018a3b1..3c190d4e9919 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -46,22 +46,23 @@ void Executable::print(raw_ostream &os) const { void Executable::onUpdate(DataFlowSolver *solver) const { AnalysisState::onUpdate(solver); - if (ProgramPoint pp = llvm::dyn_cast_if_present<ProgramPoint>(anchor)) { - if (Block *block = llvm::dyn_cast_if_present<Block *>(pp)) { + if (ProgramPoint *pp = llvm::dyn_cast_if_present<ProgramPoint *>(anchor)) { + if (pp->isBlockStart()) { // Re-invoke the analyses on the block itself. for (DataFlowAnalysis *analysis : subscribers) - solver->enqueue({block, analysis}); + solver->enqueue({pp, analysis}); // Re-invoke the analyses on all operations in the block. for (DataFlowAnalysis *analysis : subscribers) - for (Operation &op : *block) - solver->enqueue({&op, analysis}); + for (Operation &op : *pp->getBlock()) + solver->enqueue({solver->getProgramPointAfter(&op), analysis}); } } else if (auto *latticeAnchor = llvm::dyn_cast_if_present<GenericLatticeAnchor *>(anchor)) { // Re-invoke the analysis on the successor block. if (auto *edge = dyn_cast<CFGEdge>(latticeAnchor)) { for (DataFlowAnalysis *analysis : subscribers) - solver->enqueue({edge->getTo(), analysis}); + solver->enqueue( + {solver->getProgramPointBefore(edge->getTo()), analysis}); } } } @@ -125,7 +126,8 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) { for (Region ®ion : top->getRegions()) { if (region.empty()) continue; - auto *state = getOrCreate<Executable>(®ion.front()); + auto *state = + getOrCreate<Executable>(getProgramPointBefore(®ion.front())); propagateIfChanged(state, state->setToLive()); } @@ -154,7 +156,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { // Public symbol callables or those for which we can't see all uses have // potentially unknown callsites. if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { - auto *state = getOrCreate<PredecessorState>(callable); + auto *state = + getOrCreate<PredecessorState>(getProgramPointAfter(callable)); propagateIfChanged(state, state->setHasUnknownPredecessors()); } foundSymbolCallable = true; @@ -171,7 +174,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { // If we couldn't gather the symbol uses, conservatively assume that // we can't track information for any nested symbols. return top->walk([&](CallableOpInterface callable) { - auto *state = getOrCreate<PredecessorState>(callable); + auto *state = + getOrCreate<PredecessorState>(getProgramPointAfter(callable)); propagateIfChanged(state, state->setHasUnknownPredecessors()); }); } @@ -182,7 +186,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { // If a callable symbol has a non-call use, then we can't be guaranteed to // know all callsites. Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef()); - auto *state = getOrCreate<PredecessorState>(symbol); + auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol)); propagateIfChanged(state, state->setHasUnknownPredecessors()); } }; @@ -193,7 +197,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { /// Returns true if the operation is a returning terminator in region /// control-flow or the terminator of a callable region. static bool isRegionOrCallableReturn(Operation *op) { - return !op->getNumSuccessors() && + return op->getBlock() != nullptr && !op->getNumSuccessors() && isa<RegionBranchOpInterface, CallableOpInterface>(op->getParentOp()) && op->getBlock()->getTerminator() == op; } @@ -205,9 +209,10 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { // When the liveness of the parent block changes, make sure to re-invoke the // analysis on the op. if (op->getBlock()) - getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this); + getOrCreate<Executable>(getProgramPointBefore(op->getBlock())) + ->blockContentSubscribe(this); // Visit the op. - if (failed(visit(op))) + if (failed(visit(getProgramPointAfter(op)))) return failure(); } // Recurse on nested operations. @@ -219,7 +224,7 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { } void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { - auto *state = getOrCreate<Executable>(to); + auto *state = getOrCreate<Executable>(getProgramPointBefore(to)); propagateIfChanged(state, state->setToLive()); auto *edgeState = getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(from, to)); @@ -230,18 +235,20 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) { for (Region ®ion : op->getRegions()) { if (region.empty()) continue; - auto *state = getOrCreate<Executable>(®ion.front()); + auto *state = + getOrCreate<Executable>(getProgramPointBefore(®ion.front())); propagateIfChanged(state, state->setToLive()); } } -LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) { - if (point.is<Block *>()) +LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { + if (point->isBlockStart()) return success(); - auto *op = point.get<Operation *>(); + Operation *op = point->getPrevOp(); // If the parent block is not executable, there is nothing to do. - if (!getOrCreate<Executable>(op->getBlock())->isLive()) + if (op->getBlock() != nullptr && + !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) return success(); // We have a live call op. Add this as a live predecessor of the callee. @@ -256,7 +263,8 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) { // Check if this is a callable operation. } else if (auto callable = dyn_cast<CallableOpInterface>(op)) { - const auto *callsites = getOrCreateFor<PredecessorState>(op, callable); + const auto *callsites = getOrCreateFor<PredecessorState>( + getProgramPointAfter(op), getProgramPointAfter(callable)); // If the callsites could not be resolved or are known to be non-empty, // mark the callable as executable. @@ -316,11 +324,13 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { if (isa_and_nonnull<SymbolOpInterface>(callableOp) && !isExternalCallable(callableOp)) { // Add the live callsite. - auto *callsites = getOrCreate<PredecessorState>(callableOp); + auto *callsites = + getOrCreate<PredecessorState>(getProgramPointAfter(callableOp)); propagateIfChanged(callsites, callsites->join(call)); } else { // Mark this call op's predecessors as overdefined. - auto *predecessors = getOrCreate<PredecessorState>(call); + auto *predecessors = + getOrCreate<PredecessorState>(getProgramPointAfter(call)); propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); } } @@ -378,9 +388,10 @@ void DeadCodeAnalysis::visitRegionBranchOperation( branch.getEntrySuccessorRegions(*operands, successors); for (const RegionSuccessor &successor : successors) { // The successor can be either an entry block or the parent operation. - ProgramPoint point = successor.getSuccessor() - ? &successor.getSuccessor()->front() - : ProgramPoint(branch); + ProgramPoint *point = + successor.getSuccessor() + ? getProgramPointBefore(&successor.getSuccessor()->front()) + : getProgramPointAfter(branch); // Mark the entry block as executable. auto *state = getOrCreate<Executable>(point); propagateIfChanged(state, state->setToLive()); @@ -409,12 +420,15 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, for (const RegionSuccessor &successor : successors) { PredecessorState *predecessors; if (Region *region = successor.getSuccessor()) { - auto *state = getOrCreate<Executable>(®ion->front()); + auto *state = + getOrCreate<Executable>(getProgramPointBefore(®ion->front())); propagateIfChanged(state, state->setToLive()); - predecessors = getOrCreate<PredecessorState>(®ion->front()); + predecessors = getOrCreate<PredecessorState>( + getProgramPointBefore(®ion->front())); } else { // Add this terminator as a predecessor to the parent op. - predecessors = getOrCreate<PredecessorState>(branch); + predecessors = + getOrCreate<PredecessorState>(getProgramPointAfter(branch)); } propagateIfChanged(predecessors, predecessors->join(op, successor.getSuccessorInputs())); @@ -424,11 +438,13 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, void DeadCodeAnalysis::visitCallableTerminator(Operation *op, CallableOpInterface callable) { // Add as predecessors to all callsites this return op. - auto *callsites = getOrCreateFor<PredecessorState>(op, callable); + auto *callsites = getOrCreateFor<PredecessorState>( + getProgramPointAfter(op), getProgramPointAfter(callable)); bool canResolve = op->hasTrait<OpTrait::ReturnLike>(); for (Operation *predecessor : callsites->getKnownPredecessors()) { assert(isa<CallOpInterface>(predecessor)); - auto *predecessors = getOrCreate<PredecessorState>(predecessor); + auto *predecessors = + getOrCreate<PredecessorState>(getProgramPointAfter(predecessor)); if (canResolve) { propagateIfChanged(predecessors, predecessors->join(op)); } else { |
