summaryrefslogtreecommitdiff
path: root/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp')
-rw-r--r--mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp146
1 files changed, 80 insertions, 66 deletions
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0b39d1404249..016e59dcb744 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
operandLattices.push_back(operandLattice);
}
- if (auto call = dyn_cast<CallOpInterface>(op)) {
- // If the call operation is to an external function, attempt to infer the
- // results from the call arguments.
- auto callable =
- dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
- if (!getSolverConfig().isInterprocedural() ||
- (callable && !callable.getCallableRegion())) {
- visitExternalCallImpl(call, operandLattices, resultLattices);
- return success();
- }
-
- // Otherwise, the results of a call operation are determined by the
- // callgraph.
- const auto *predecessors = getOrCreateFor<PredecessorState>(
- getProgramPointAfter(op), getProgramPointAfter(call));
- // If not all return sites are known, then conservatively assume we can't
- // reason about the data-flow.
- if (!predecessors->allPredecessorsKnown()) {
- setAllToEntryStates(resultLattices);
- return success();
- }
- for (Operation *predecessor : predecessors->getKnownPredecessors())
- for (auto &&[operand, resLattice] :
- llvm::zip(predecessor->getOperands(), resultLattices))
- join(resLattice,
- *getLatticeElementFor(getProgramPointAfter(op), operand));
- return success();
- }
+ if (auto call = dyn_cast<CallOpInterface>(op))
+ return visitCallOperation(call, operandLattices, resultLattices);
// Invoke the operation transfer function.
return visitOperationImpl(op, operandLattices, resultLattices);
@@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
if (block->isEntryBlock()) {
// Check if this block is the entry block of a callable region.
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
- if (callable && callable.getCallableRegion() == block->getParent()) {
- const auto *callsites = getOrCreateFor<PredecessorState>(
- getProgramPointBefore(block), getProgramPointAfter(callable));
- // If not all callsites are known, conservatively mark all lattices as
- // having reached their pessimistic fixpoints.
- if (!callsites->allPredecessorsKnown() ||
- !getSolverConfig().isInterprocedural()) {
- return setAllToEntryStates(argLattices);
- }
- for (Operation *callsite : callsites->getKnownPredecessors()) {
- auto call = cast<CallOpInterface>(callsite);
- for (auto it : llvm::zip(call.getArgOperands(), argLattices))
- join(std::get<1>(it),
- *getLatticeElementFor(getProgramPointBefore(block),
- std::get<0>(it)));
- }
- return;
- }
+ if (callable && callable.getCallableRegion() == block->getParent())
+ return visitCallableOperation(callable, argLattices);
// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
@@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
}
}
+LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
+ CallOpInterface call,
+ ArrayRef<const AbstractSparseLattice *> operandLattices,
+ ArrayRef<AbstractSparseLattice *> resultLattices) {
+ // If the call operation is to an external function, attempt to infer the
+ // results from the call arguments.
+ auto callable =
+ dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+ if (!getSolverConfig().isInterprocedural() ||
+ (callable && !callable.getCallableRegion())) {
+ visitExternalCallImpl(call, operandLattices, resultLattices);
+ return success();
+ }
+
+ // Otherwise, the results of a call operation are determined by the
+ // callgraph.
+ const auto *predecessors = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(call), getProgramPointAfter(call));
+ // If not all return sites are known, then conservatively assume we can't
+ // reason about the data-flow.
+ if (!predecessors->allPredecessorsKnown()) {
+ setAllToEntryStates(resultLattices);
+ return success();
+ }
+ for (Operation *predecessor : predecessors->getKnownPredecessors())
+ for (auto &&[operand, resLattice] :
+ llvm::zip(predecessor->getOperands(), resultLattices))
+ join(resLattice,
+ *getLatticeElementFor(getProgramPointAfter(call), operand));
+ return success();
+}
+
+void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
+ CallableOpInterface callable,
+ ArrayRef<AbstractSparseLattice *> argLattices) {
+ Block *entryBlock = &callable.getCallableRegion()->front();
+ const auto *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointBefore(entryBlock), getProgramPointAfter(callable));
+ // If not all callsites are known, conservatively mark all lattices as
+ // having reached their pessimistic fixpoints.
+ if (!callsites->allPredecessorsKnown() ||
+ !getSolverConfig().isInterprocedural()) {
+ return setAllToEntryStates(argLattices);
+ }
+ for (Operation *callsite : callsites->getKnownPredecessors()) {
+ auto call = cast<CallOpInterface>(callsite);
+ for (auto it : llvm::zip(call.getArgOperands(), argLattices))
+ join(std::get<1>(it),
+ *getLatticeElementFor(getProgramPointBefore(entryBlock),
+ std::get<0>(it)));
+ }
+}
+
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint *point, RegionBranchOpInterface branch,
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
@@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
if (op->hasTrait<OpTrait::ReturnLike>()) {
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
- if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
- const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
- getProgramPointAfter(op), getProgramPointAfter(callable));
- if (callsites->allPredecessorsKnown()) {
- for (Operation *call : callsites->getKnownPredecessors()) {
- SmallVector<const AbstractSparseLattice *> callResultLattices =
- getLatticeElementsFor(getProgramPointAfter(op),
- call->getResults());
- for (auto [op, result] :
- llvm::zip(operandLattices, callResultLattices))
- meet(op, *result);
- }
- } else {
- // If we don't know all the callers, we can't know where the
- // returned values go. Note that, in particular, this will trigger
- // for the return ops of any public functions.
- setAllToExitStates(operandLattices);
- }
- return success();
- }
+ if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+ return visitCallableOperation(op, callable, operandLattices);
}
return visitOperationImpl(op, operandLattices, resultLattices);
}
+LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
+ Operation *op, CallableOpInterface callable,
+ ArrayRef<AbstractSparseLattice *> operandLattices) {
+ const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(callable));
+ if (callsites->allPredecessorsKnown()) {
+ for (Operation *call : callsites->getKnownPredecessors()) {
+ SmallVector<const AbstractSparseLattice *> callResultLattices =
+ getLatticeElementsFor(getProgramPointAfter(op), call->getResults());
+ for (auto [op, result] : llvm::zip(operandLattices, callResultLattices))
+ meet(op, *result);
+ }
+ } else {
+ // If we don't know all the callers, we can't know where the
+ // returned values go. Note that, in particular, this will trigger
+ // for the return ops of any public functions.
+ setAllToExitStates(operandLattices);
+ }
+ return success();
+}
+
void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
RegionBranchOpInterface branch,
ArrayRef<AbstractSparseLattice *> operandLattices) {