summaryrefslogtreecommitdiff
path: root/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp')
-rw-r--r--mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp87
1 files changed, 50 insertions, 37 deletions
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 1bd6defef90b..67cf8c9c5b81 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -36,7 +36,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
// Push all users of the value to the queue.
for (Operation *user : anchor.get<Value>().getUsers())
for (DataFlowAnalysis *analysis : useDefSubscribers)
- solver->enqueue({user, analysis});
+ solver->enqueue({solver->getProgramPointAfter(user), analysis});
}
//===----------------------------------------------------------------------===//
@@ -72,7 +72,8 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
for (Region &region : op->getRegions()) {
for (Block &block : region) {
- getOrCreate<Executable>(&block)->blockContentSubscribe(this);
+ getOrCreate<Executable>(getProgramPointBefore(&block))
+ ->blockContentSubscribe(this);
visitBlock(&block);
for (Operation &op : block)
if (failed(initializeRecursively(&op)))
@@ -83,10 +84,11 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
return success();
}
-LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
- return visitOperation(op);
- visitBlock(point.get<Block *>());
+LogicalResult
+AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
+ if (!point->isBlockStart())
+ return visitOperation(point->getPrevOp());
+ visitBlock(point->getBlock());
return success();
}
@@ -97,7 +99,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
return success();
// If the containing block is not executable, bail out.
- if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ if (op->getBlock() != nullptr &&
+ !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
return success();
// Get the result lattices.
@@ -110,7 +113,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- visitRegionSuccessors({branch}, branch,
+ visitRegionSuccessors(getProgramPointAfter(branch), branch,
/*successor=*/RegionBranchPoint::parent(),
resultLattices);
return success();
@@ -138,7 +141,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// Otherwise, the results of a call operation are determined by the
// callgraph.
- const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
+ const auto *predecessors = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(call));
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
if (!predecessors->allPredecessorsKnown()) {
@@ -148,7 +152,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
for (Operation *predecessor : predecessors->getKnownPredecessors())
for (auto &&[operand, resLattice] :
llvm::zip(predecessor->getOperands(), resultLattices))
- join(resLattice, *getLatticeElementFor(op, operand));
+ join(resLattice,
+ *getLatticeElementFor(getProgramPointAfter(op), operand));
return success();
}
@@ -162,7 +167,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
return;
// If the block is not executable, bail out.
- if (!getOrCreate<Executable>(block)->isLive())
+ if (!getOrCreate<Executable>(getProgramPointBefore(block))->isLive())
return;
// Get the argument lattices.
@@ -179,7 +184,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Check if this block is the entry block of a callable region.
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
if (callable && callable.getCallableRegion() == block->getParent()) {
- const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
+ const auto *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointBefore(block), getProgramPointAfter(callable));
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
if (!callsites->allPredecessorsKnown() ||
@@ -189,15 +195,17 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
for (Operation *callsite : callsites->getKnownPredecessors()) {
auto call = cast<CallOpInterface>(callsite);
for (auto it : llvm::zip(call.getArgOperands(), argLattices))
- join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
+ join(std::get<1>(it),
+ *getLatticeElementFor(getProgramPointBefore(block),
+ std::get<0>(it)));
}
return;
}
// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
- return visitRegionSuccessors(block, branch, block->getParent(),
- argLattices);
+ return visitRegionSuccessors(getProgramPointBefore(block), branch,
+ block->getParent(), argLattices);
}
// Otherwise, we can't reason about the data-flow.
@@ -226,7 +234,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
branch.getSuccessorOperands(it.getSuccessorIndex());
for (auto [idx, lattice] : llvm::enumerate(argLattices)) {
if (Value operand = operands[idx]) {
- join(lattice, *getLatticeElementFor(block, operand));
+ join(lattice,
+ *getLatticeElementFor(getProgramPointBefore(block), operand));
} else {
// Conservatively consider internally produced arguments as entry
// points.
@@ -240,7 +249,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
}
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
- ProgramPoint point, RegionBranchOpInterface branch,
+ ProgramPoint *point, RegionBranchOpInterface branch,
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
@@ -270,7 +279,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
unsigned firstIndex = 0;
if (inputs.size() != lattices.size()) {
- if (llvm::dyn_cast_if_present<Operation *>(point)) {
+ if (!point->isBlockStart()) {
if (!inputs.empty())
firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
visitNonControlFlowArgumentsImpl(
@@ -281,7 +290,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
} else {
if (!inputs.empty())
firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
- Region *region = point.get<Block *>()->getParent();
+ Region *region = point->getBlock()->getParent();
visitNonControlFlowArgumentsImpl(
branch,
RegionSuccessor(region, region->getArguments().slice(
@@ -296,7 +305,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
}
const AbstractSparseLattice *
-AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
+AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint *point,
Value value) {
AbstractSparseLattice *state = getLatticeElement(value);
addDependency(state, point);
@@ -336,7 +345,8 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
for (Region &region : op->getRegions()) {
for (Block &block : region) {
- getOrCreate<Executable>(&block)->blockContentSubscribe(this);
+ getOrCreate<Executable>(getProgramPointBefore(&block))
+ ->blockContentSubscribe(this);
// Initialize ops in reverse order, so we can do as much initial
// propagation as possible without having to go through the
// solver queue.
@@ -349,14 +359,14 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
}
LogicalResult
-AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
- return visitOperation(op);
+AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
// For backward dataflow, we don't have to do any work for the blocks
// themselves. CFG edges between blocks are processed by the BranchOp
// logic in `visitOperation`, and entry blocks for functions are tied
// to the CallOp arguments by visitOperation.
- return success();
+ if (point->isBlockStart())
+ return success();
+ return visitOperation(point->getPrevOp());
}
SmallVector<AbstractSparseLattice *>
@@ -372,7 +382,7 @@ AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
SmallVector<const AbstractSparseLattice *>
AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
- ProgramPoint point, ValueRange values) {
+ ProgramPoint *point, ValueRange values) {
SmallVector<const AbstractSparseLattice *> resultLattices;
resultLattices.reserve(values.size());
for (Value result : values) {
@@ -390,13 +400,14 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// If we're in a dead block, bail out.
- if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ if (op->getBlock() != nullptr &&
+ !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
return success();
SmallVector<AbstractSparseLattice *> operandLattices =
getLatticeElements(op->getOperands());
SmallVector<const AbstractSparseLattice *> resultLattices =
- getLatticeElementsFor(op, op->getResults());
+ getLatticeElementsFor(getProgramPointAfter(op), op->getResults());
// Block arguments of region branch operations flow back into the operands
// of the parent op
@@ -425,7 +436,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
detail::getBranchSuccessorArgument(
successorOperands, operand.getOperandNumber(), block)) {
meet(getLatticeElement(operand.get()),
- *getLatticeElementFor(op, *blockArg));
+ *getLatticeElementFor(getProgramPointAfter(op), *blockArg));
}
}
}
@@ -467,7 +478,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
for (auto [blockArg, argOpOperand] :
llvm::zip(block.getArguments(), argOpOperands)) {
meet(getLatticeElement(argOpOperand.get()),
- *getLatticeElementFor(op, blockArg));
+ *getLatticeElementFor(getProgramPointAfter(op), blockArg));
unaccounted.reset(argOpOperand.getOperandNumber());
}
@@ -502,12 +513,13 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
- const PredecessorState *callsites =
- getOrCreateFor<PredecessorState>(op, callable);
+ const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(callable));
if (callsites->allPredecessorsKnown()) {
for (Operation *call : callsites->getKnownPredecessors()) {
SmallVector<const AbstractSparseLattice *> callResultLattices =
- getLatticeElementsFor(op, call->getResults());
+ getLatticeElementsFor(getProgramPointAfter(op),
+ call->getResults());
for (auto [op, result] :
llvm::zip(operandLattices, callResultLattices))
meet(op, *result);
@@ -542,7 +554,8 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
ValueRange inputs = successor.getSuccessorInputs();
for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
- meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input));
+ meet(getLatticeElement(operand.get()),
+ *getLatticeElementFor(getProgramPointAfter(op), input));
unaccounted.reset(operand.getOperandNumber());
}
}
@@ -576,7 +589,7 @@ void AbstractSparseBackwardDataFlowAnalysis::
MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
meet(getLatticeElement(opOperand.get()),
- *getLatticeElementFor(terminator, input));
+ *getLatticeElementFor(getProgramPointAfter(terminator), input));
unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
}
}
@@ -588,8 +601,8 @@ void AbstractSparseBackwardDataFlowAnalysis::
}
const AbstractSparseLattice *
-AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
- Value value) {
+AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
+ ProgramPoint *point, Value value) {
AbstractSparseLattice *state = getLatticeElement(value);
addDependency(state, point);
return state;