diff options
| author | Fangrui Song <i@maskray.me> | 2024-10-11 21:39:06 -0700 |
|---|---|---|
| committer | Amir Ayupov <aaupov@fb.com> | 2024-10-11 21:39:06 -0700 |
| commit | 436701d88c1384d3f72c44dd152cd55e47ef2de3 (patch) | |
| tree | c9825a370f1ba14e5fff19cea1279a0e7a7e9b54 /mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | |
| parent | efa1900174cb940f3750ce9e8cb6f06e69b4f3f0 (diff) | |
| parent | dd326b122506421aba2368053103767f4c56e2ba (diff) | |
[𝘀𝗽𝗿] changes introduced through rebaseusers/aaupov/spr/main.boltnfc-speedup-batwritemaps
Created using spr 1.3.4
[skip ci]
Diffstat (limited to 'mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp')
| -rw-r--r-- | mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 87 |
1 files changed, 50 insertions, 37 deletions
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 1bd6defef90b..67cf8c9c5b81 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -36,7 +36,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { // Push all users of the value to the queue. for (Operation *user : anchor.get<Value>().getUsers()) for (DataFlowAnalysis *analysis : useDefSubscribers) - solver->enqueue({user, analysis}); + solver->enqueue({solver->getProgramPointAfter(user), analysis}); } //===----------------------------------------------------------------------===// @@ -72,7 +72,8 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { for (Region ®ion : op->getRegions()) { for (Block &block : region) { - getOrCreate<Executable>(&block)->blockContentSubscribe(this); + getOrCreate<Executable>(getProgramPointBefore(&block)) + ->blockContentSubscribe(this); visitBlock(&block); for (Operation &op : block) if (failed(initializeRecursively(&op))) @@ -83,10 +84,11 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { return success(); } -LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) { - if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point)) - return visitOperation(op); - visitBlock(point.get<Block *>()); +LogicalResult +AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) { + if (!point->isBlockStart()) + return visitOperation(point->getPrevOp()); + visitBlock(point->getBlock()); return success(); } @@ -97,7 +99,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { return success(); // If the containing block is not executable, bail out. - if (!getOrCreate<Executable>(op->getBlock())->isLive()) + if (op->getBlock() != nullptr && + !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) return success(); // Get the result lattices. @@ -110,7 +113,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { // The results of a region branch operation are determined by control-flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { - visitRegionSuccessors({branch}, branch, + visitRegionSuccessors(getProgramPointAfter(branch), branch, /*successor=*/RegionBranchPoint::parent(), resultLattices); return success(); @@ -138,7 +141,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { // Otherwise, the results of a call operation are determined by the // callgraph. - const auto *predecessors = getOrCreateFor<PredecessorState>(op, call); + const auto *predecessors = getOrCreateFor<PredecessorState>( + getProgramPointAfter(op), getProgramPointAfter(call)); // If not all return sites are known, then conservatively assume we can't // reason about the data-flow. if (!predecessors->allPredecessorsKnown()) { @@ -148,7 +152,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { for (Operation *predecessor : predecessors->getKnownPredecessors()) for (auto &&[operand, resLattice] : llvm::zip(predecessor->getOperands(), resultLattices)) - join(resLattice, *getLatticeElementFor(op, operand)); + join(resLattice, + *getLatticeElementFor(getProgramPointAfter(op), operand)); return success(); } @@ -162,7 +167,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { return; // If the block is not executable, bail out. - if (!getOrCreate<Executable>(block)->isLive()) + if (!getOrCreate<Executable>(getProgramPointBefore(block))->isLive()) return; // Get the argument lattices. @@ -179,7 +184,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { // Check if this block is the entry block of a callable region. auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); if (callable && callable.getCallableRegion() == block->getParent()) { - const auto *callsites = getOrCreateFor<PredecessorState>(block, callable); + const auto *callsites = getOrCreateFor<PredecessorState>( + getProgramPointBefore(block), getProgramPointAfter(callable)); // If not all callsites are known, conservatively mark all lattices as // having reached their pessimistic fixpoints. if (!callsites->allPredecessorsKnown() || @@ -189,15 +195,17 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { for (Operation *callsite : callsites->getKnownPredecessors()) { auto call = cast<CallOpInterface>(callsite); for (auto it : llvm::zip(call.getArgOperands(), argLattices)) - join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it))); + join(std::get<1>(it), + *getLatticeElementFor(getProgramPointBefore(block), + std::get<0>(it))); } return; } // Check if the lattices can be determined from region control flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { - return visitRegionSuccessors(block, branch, block->getParent(), - argLattices); + return visitRegionSuccessors(getProgramPointBefore(block), branch, + block->getParent(), argLattices); } // Otherwise, we can't reason about the data-flow. @@ -226,7 +234,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { branch.getSuccessorOperands(it.getSuccessorIndex()); for (auto [idx, lattice] : llvm::enumerate(argLattices)) { if (Value operand = operands[idx]) { - join(lattice, *getLatticeElementFor(block, operand)); + join(lattice, + *getLatticeElementFor(getProgramPointBefore(block), operand)); } else { // Conservatively consider internally produced arguments as entry // points. @@ -240,7 +249,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { } void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( - ProgramPoint point, RegionBranchOpInterface branch, + ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) { const auto *predecessors = getOrCreateFor<PredecessorState>(point, point); assert(predecessors->allPredecessorsKnown() && @@ -270,7 +279,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( unsigned firstIndex = 0; if (inputs.size() != lattices.size()) { - if (llvm::dyn_cast_if_present<Operation *>(point)) { + if (!point->isBlockStart()) { if (!inputs.empty()) firstIndex = cast<OpResult>(inputs.front()).getResultNumber(); visitNonControlFlowArgumentsImpl( @@ -281,7 +290,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( } else { if (!inputs.empty()) firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber(); - Region *region = point.get<Block *>()->getParent(); + Region *region = point->getBlock()->getParent(); visitNonControlFlowArgumentsImpl( branch, RegionSuccessor(region, region->getArguments().slice( @@ -296,7 +305,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( } const AbstractSparseLattice * -AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, +AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint *point, Value value) { AbstractSparseLattice *state = getLatticeElement(value); addDependency(state, point); @@ -336,7 +345,8 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) { for (Region ®ion : op->getRegions()) { for (Block &block : region) { - getOrCreate<Executable>(&block)->blockContentSubscribe(this); + getOrCreate<Executable>(getProgramPointBefore(&block)) + ->blockContentSubscribe(this); // Initialize ops in reverse order, so we can do as much initial // propagation as possible without having to go through the // solver queue. @@ -349,14 +359,14 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) { } LogicalResult -AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) { - if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point)) - return visitOperation(op); +AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint *point) { // For backward dataflow, we don't have to do any work for the blocks // themselves. CFG edges between blocks are processed by the BranchOp // logic in `visitOperation`, and entry blocks for functions are tied // to the CallOp arguments by visitOperation. - return success(); + if (point->isBlockStart()) + return success(); + return visitOperation(point->getPrevOp()); } SmallVector<AbstractSparseLattice *> @@ -372,7 +382,7 @@ AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) { SmallVector<const AbstractSparseLattice *> AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor( - ProgramPoint point, ValueRange values) { + ProgramPoint *point, ValueRange values) { SmallVector<const AbstractSparseLattice *> resultLattices; resultLattices.reserve(values.size()); for (Value result : values) { @@ -390,13 +400,14 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) { LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // If we're in a dead block, bail out. - if (!getOrCreate<Executable>(op->getBlock())->isLive()) + if (op->getBlock() != nullptr && + !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) return success(); SmallVector<AbstractSparseLattice *> operandLattices = getLatticeElements(op->getOperands()); SmallVector<const AbstractSparseLattice *> resultLattices = - getLatticeElementsFor(op, op->getResults()); + getLatticeElementsFor(getProgramPointAfter(op), op->getResults()); // Block arguments of region branch operations flow back into the operands // of the parent op @@ -425,7 +436,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { detail::getBranchSuccessorArgument( successorOperands, operand.getOperandNumber(), block)) { meet(getLatticeElement(operand.get()), - *getLatticeElementFor(op, *blockArg)); + *getLatticeElementFor(getProgramPointAfter(op), *blockArg)); } } } @@ -467,7 +478,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { for (auto [blockArg, argOpOperand] : llvm::zip(block.getArguments(), argOpOperands)) { meet(getLatticeElement(argOpOperand.get()), - *getLatticeElementFor(op, blockArg)); + *getLatticeElementFor(getProgramPointAfter(op), blockArg)); unaccounted.reset(argOpOperand.getOperandNumber()); } @@ -502,12 +513,13 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // Going backwards, the operands of the return are derived from the // results of all CallOps calling this CallableOp. if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { - const PredecessorState *callsites = - getOrCreateFor<PredecessorState>(op, callable); + const PredecessorState *callsites = getOrCreateFor<PredecessorState>( + getProgramPointAfter(op), getProgramPointAfter(callable)); if (callsites->allPredecessorsKnown()) { for (Operation *call : callsites->getKnownPredecessors()) { SmallVector<const AbstractSparseLattice *> callResultLattices = - getLatticeElementsFor(op, call->getResults()); + getLatticeElementsFor(getProgramPointAfter(op), + call->getResults()); for (auto [op, result] : llvm::zip(operandLattices, callResultLattices)) meet(op, *result); @@ -542,7 +554,8 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors( MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands); ValueRange inputs = successor.getSuccessorInputs(); for (auto [operand, input] : llvm::zip(opoperands, inputs)) { - meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input)); + meet(getLatticeElement(operand.get()), + *getLatticeElementFor(getProgramPointAfter(op), input)); unaccounted.reset(operand.getOperandNumber()); } } @@ -576,7 +589,7 @@ void AbstractSparseBackwardDataFlowAnalysis:: MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands); for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) { meet(getLatticeElement(opOperand.get()), - *getLatticeElementFor(terminator, input)); + *getLatticeElementFor(getProgramPointAfter(terminator), input)); unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber()); } } @@ -588,8 +601,8 @@ void AbstractSparseBackwardDataFlowAnalysis:: } const AbstractSparseLattice * -AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, - Value value) { +AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor( + ProgramPoint *point, Value value) { AbstractSparseLattice *state = getLatticeElement(value); addDependency(state, point); return state; |
