diff options
Diffstat (limited to 'mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp')
| -rw-r--r-- | mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 146 |
1 files changed, 80 insertions, 66 deletions
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 0b39d1404249..016e59dcb744 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { operandLattices.push_back(operandLattice); } - if (auto call = dyn_cast<CallOpInterface>(op)) { - // If the call operation is to an external function, attempt to infer the - // results from the call arguments. - auto callable = - dyn_cast_if_present<CallableOpInterface>(call.resolveCallable()); - if (!getSolverConfig().isInterprocedural() || - (callable && !callable.getCallableRegion())) { - visitExternalCallImpl(call, operandLattices, resultLattices); - return success(); - } - - // Otherwise, the results of a call operation are determined by the - // callgraph. - const auto *predecessors = getOrCreateFor<PredecessorState>( - getProgramPointAfter(op), getProgramPointAfter(call)); - // If not all return sites are known, then conservatively assume we can't - // reason about the data-flow. - if (!predecessors->allPredecessorsKnown()) { - setAllToEntryStates(resultLattices); - return success(); - } - for (Operation *predecessor : predecessors->getKnownPredecessors()) - for (auto &&[operand, resLattice] : - llvm::zip(predecessor->getOperands(), resultLattices)) - join(resLattice, - *getLatticeElementFor(getProgramPointAfter(op), operand)); - return success(); - } + if (auto call = dyn_cast<CallOpInterface>(op)) + return visitCallOperation(call, operandLattices, resultLattices); // Invoke the operation transfer function. return visitOperationImpl(op, operandLattices, resultLattices); @@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { if (block->isEntryBlock()) { // Check if this block is the entry block of a callable region. auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); - if (callable && callable.getCallableRegion() == block->getParent()) { - const auto *callsites = getOrCreateFor<PredecessorState>( - getProgramPointBefore(block), getProgramPointAfter(callable)); - // If not all callsites are known, conservatively mark all lattices as - // having reached their pessimistic fixpoints. - if (!callsites->allPredecessorsKnown() || - !getSolverConfig().isInterprocedural()) { - return setAllToEntryStates(argLattices); - } - for (Operation *callsite : callsites->getKnownPredecessors()) { - auto call = cast<CallOpInterface>(callsite); - for (auto it : llvm::zip(call.getArgOperands(), argLattices)) - join(std::get<1>(it), - *getLatticeElementFor(getProgramPointBefore(block), - std::get<0>(it))); - } - return; - } + if (callable && callable.getCallableRegion() == block->getParent()) + return visitCallableOperation(callable, argLattices); // Check if the lattices can be determined from region control flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { @@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { } } +LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation( + CallOpInterface call, + ArrayRef<const AbstractSparseLattice *> operandLattices, + ArrayRef<AbstractSparseLattice *> resultLattices) { + // If the call operation is to an external function, attempt to infer the + // results from the call arguments. + auto callable = + dyn_cast_if_present<CallableOpInterface>(call.resolveCallable()); + if (!getSolverConfig().isInterprocedural() || + (callable && !callable.getCallableRegion())) { + visitExternalCallImpl(call, operandLattices, resultLattices); + return success(); + } + + // Otherwise, the results of a call operation are determined by the + // callgraph. + const auto *predecessors = getOrCreateFor<PredecessorState>( + getProgramPointAfter(call), getProgramPointAfter(call)); + // If not all return sites are known, then conservatively assume we can't + // reason about the data-flow. + if (!predecessors->allPredecessorsKnown()) { + setAllToEntryStates(resultLattices); + return success(); + } + for (Operation *predecessor : predecessors->getKnownPredecessors()) + for (auto &&[operand, resLattice] : + llvm::zip(predecessor->getOperands(), resultLattices)) + join(resLattice, + *getLatticeElementFor(getProgramPointAfter(call), operand)); + return success(); +} + +void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation( + CallableOpInterface callable, + ArrayRef<AbstractSparseLattice *> argLattices) { + Block *entryBlock = &callable.getCallableRegion()->front(); + const auto *callsites = getOrCreateFor<PredecessorState>( + getProgramPointBefore(entryBlock), getProgramPointAfter(callable)); + // If not all callsites are known, conservatively mark all lattices as + // having reached their pessimistic fixpoints. + if (!callsites->allPredecessorsKnown() || + !getSolverConfig().isInterprocedural()) { + return setAllToEntryStates(argLattices); + } + for (Operation *callsite : callsites->getKnownPredecessors()) { + auto call = cast<CallOpInterface>(callsite); + for (auto it : llvm::zip(call.getArgOperands(), argLattices)) + join(std::get<1>(it), + *getLatticeElementFor(getProgramPointBefore(entryBlock), + std::get<0>(it))); + } +} + void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) { @@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { if (op->hasTrait<OpTrait::ReturnLike>()) { // Going backwards, the operands of the return are derived from the // results of all CallOps calling this CallableOp. - if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { - const PredecessorState *callsites = getOrCreateFor<PredecessorState>( - getProgramPointAfter(op), getProgramPointAfter(callable)); - if (callsites->allPredecessorsKnown()) { - for (Operation *call : callsites->getKnownPredecessors()) { - SmallVector<const AbstractSparseLattice *> callResultLattices = - getLatticeElementsFor(getProgramPointAfter(op), - call->getResults()); - for (auto [op, result] : - llvm::zip(operandLattices, callResultLattices)) - meet(op, *result); - } - } else { - // If we don't know all the callers, we can't know where the - // returned values go. Note that, in particular, this will trigger - // for the return ops of any public functions. - setAllToExitStates(operandLattices); - } - return success(); - } + if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) + return visitCallableOperation(op, callable, operandLattices); } return visitOperationImpl(op, operandLattices, resultLattices); } +LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation( + Operation *op, CallableOpInterface callable, + ArrayRef<AbstractSparseLattice *> operandLattices) { + const PredecessorState *callsites = getOrCreateFor<PredecessorState>( + getProgramPointAfter(op), getProgramPointAfter(callable)); + if (callsites->allPredecessorsKnown()) { + for (Operation *call : callsites->getKnownPredecessors()) { + SmallVector<const AbstractSparseLattice *> callResultLattices = + getLatticeElementsFor(getProgramPointAfter(op), call->getResults()); + for (auto [op, result] : llvm::zip(operandLattices, callResultLattices)) + meet(op, *result); + } + } else { + // If we don't know all the callers, we can't know where the + // returned values go. Note that, in particular, this will trigger + // for the return ops of any public functions. + setAllToExitStates(operandLattices); + } + return success(); +} + void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors( RegionBranchOpInterface branch, ArrayRef<AbstractSparseLattice *> operandLattices) { |
