summaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2023-07-24 19:58:59 -0700
committerRiver Riddle <riddleriver@gmail.com>2023-07-25 15:55:34 -0700
commit4af01bf95628f8ec674277fd1610eac172598cea (patch)
tree05502d4271b95ce88fee36e8a4eb13b7d2485204 /mlir/lib/Bytecode
parent5ab6589551c1a4b81458c766266374a5d4d2aaca (diff)
[mlir:bytecode] Support lazy loading dynamically isolated regions
We currently only support lazy loading for regions that statically implement the IsolatedFromAbove trait, but that limits the amount of operations that can be lazily loaded. This review lifts that restriction by computing which operations have isolated regions when numbering, allowing any operation to be lazily loaded as long as it doesn't use values defined above. Differential Revision: https://reviews.llvm.org/D156199
Diffstat (limited to 'mlir/lib/Bytecode')
-rw-r--r--mlir/lib/Bytecode/Writer/BytecodeWriter.cpp2
-rw-r--r--mlir/lib/Bytecode/Writer/IRNumbering.cpp146
-rw-r--r--mlir/lib/Bytecode/Writer/IRNumbering.h36
3 files changed, 162 insertions, 22 deletions
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 401629c73965..d8f2cb106510 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -942,7 +942,7 @@ LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
// emitting the regions first (e.g. if the regions are huge, backpatching the
// op encoding mask is more annoying).
if (numRegions) {
- bool isIsolatedFromAbove = op->hasTrait<OpTrait::IsIsolatedFromAbove>();
+ bool isIsolatedFromAbove = numberingState.isIsolatedFromAbove(op);
emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove);
// If the region is not isolated from above, or we are emitting bytecode
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 284b3c02f1f2..788cf5b201f0 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -115,19 +115,29 @@ static void groupByDialectPerByte(T range) {
IRNumberingState::IRNumberingState(Operation *op,
const BytecodeWriterConfig &config)
: config(config) {
- // Compute a global operation ID numbering according to the pre-order walk of
- // the IR. This is used as reference to construct use-list orders.
- unsigned operationID = 0;
- op->walk<WalkOrder::PreOrder>(
- [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
+ computeGlobalNumberingState(op);
// Number the root operation.
number(*op);
- // Push all of the regions of the root operation onto the worklist.
+ // A worklist of region contexts to number and the next value id before that
+ // region.
SmallVector<std::pair<Region *, unsigned>, 8> numberContext;
- for (Region &region : op->getRegions())
- numberContext.emplace_back(&region, nextValueID);
+
+ // Functor to push the regions of the given operation onto the numbering
+ // context.
+ auto addOpRegionsToNumber = [&](Operation *op) {
+ MutableArrayRef<Region> regions = op->getRegions();
+ if (regions.empty())
+ return;
+
+ // Isolated regions don't share value numbers with their parent, so we can
+ // start numbering these regions at zero.
+ unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID;
+ for (Region &region : regions)
+ numberContext.emplace_back(&region, opFirstValueID);
+ };
+ addOpRegionsToNumber(op);
// Iteratively process each of the nested regions.
while (!numberContext.empty()) {
@@ -136,14 +146,8 @@ IRNumberingState::IRNumberingState(Operation *op,
number(*region);
// Traverse into nested regions.
- for (Operation &op : region->getOps()) {
- // Isolated regions don't share value numbers with their parent, so we can
- // start numbering these regions at zero.
- unsigned opFirstValueID =
- op.hasTrait<OpTrait::IsIsolatedFromAbove>() ? 0 : nextValueID;
- for (Region &region : op.getRegions())
- numberContext.emplace_back(&region, opFirstValueID);
- }
+ for (Operation &op : region->getOps())
+ addOpRegionsToNumber(&op);
}
// Number each of the dialects. For now this is just in the order they were
@@ -178,6 +182,116 @@ IRNumberingState::IRNumberingState(Operation *op,
finalizeDialectResourceNumberings(op);
}
+void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) {
+ // A simple state struct tracking data used when walking operations.
+ struct StackState {
+ /// The operation currently being walked.
+ Operation *op;
+
+ /// The numbering of the operation.
+ OperationNumbering *numbering;
+
+ /// A flag indicating if the current state or one of its parents has
+ /// unresolved isolation status. This is tracked separately from the
+ /// isIsolatedFromAbove bit on `numbering` because we need to be able to
+ /// handle the given case:
+ /// top.op {
+ /// %value = ...
+ /// middle.op {
+ /// %value2 = ...
+ /// inner.op {
+ /// // Here we mark `inner.op` as not isolated. Note `middle.op`
+ /// // isn't known not isolated yet.
+ /// use.op %value2
+ ///
+ /// // Here inner.op is already known to be non-isolated, but
+ /// // `middle.op` is now also discovered to be non-isolated.
+ /// use.op %value
+ /// }
+ /// }
+ /// }
+ bool hasUnresolvedIsolation;
+ };
+
+ // Compute a global operation ID numbering according to the pre-order walk of
+ // the IR. This is used as reference to construct use-list orders.
+ unsigned operationID = 0;
+
+ // Walk each of the operations within the IR, tracking a stack of operations
+ // as we recurse into nested regions. This walk method hooks in at two stages
+ // during the walk:
+ //
+ // BeforeAllRegions:
+ // Here we generate a numbering for the operation and push it onto the
+ // stack if it has regions. We also compute the isolation status of parent
+ // regions at this stage. This is done by checking the parent regions of
+ // operands used by the operation, and marking each region between the
+ // the operand region and the current as not isolated. See
+ // StackState::hasUnresolvedIsolation above for an example.
+ //
+ // AfterAllRegions:
+ // Here we pop the operation from the stack, and if it hasn't been marked
+ // as non-isolated, we mark it as so. A non-isolated use would have been
+ // found while walking the regions, so it is safe to mark the operation at
+ // this point.
+ //
+ SmallVector<StackState> opStack;
+ rootOp->walk([&](Operation *op, const WalkStage &stage) {
+ // After visiting all nested regions, we pop the operation from the stack.
+ if (stage.isAfterAllRegions()) {
+ // If no non-isolated uses were found, we can safely mark this operation
+ // as isolated from above.
+ OperationNumbering *numbering = opStack.pop_back_val().numbering;
+ if (!numbering->isIsolatedFromAbove.has_value())
+ numbering->isIsolatedFromAbove = true;
+ return;
+ }
+
+ // When visiting before nested regions, we process "IsolatedFromAbove"
+ // checks and compute the number for this operation.
+ if (!stage.isBeforeAllRegions())
+ return;
+ // Update the isolation status of parent regions if any have yet to be
+ // resolved.
+ if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
+ Region *parentRegion = op->getParentRegion();
+ for (Value operand : op->getOperands()) {
+ Region *operandRegion = operand.getParentRegion();
+ if (operandRegion == parentRegion)
+ continue;
+ // We've found a use of an operand outside of the current region,
+ // walk the operation stack searching for the parent operation,
+ // marking every region on the way as not isolated.
+ Operation *operandContainerOp = operandRegion->getParentOp();
+ auto it = std::find_if(
+ opStack.rbegin(), opStack.rend(), [=](const StackState &it) {
+ // We only need to mark up to the container region, or the first
+ // that has an unresolved status.
+ return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
+ });
+ assert(it != opStack.rend() && "expected to find the container");
+ for (auto &state : llvm::make_range(opStack.rbegin(), it)) {
+ // If we stopped at a region that knows its isolation status, we can
+ // stop updating the isolation status for the parent regions.
+ state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
+ state.numbering->isIsolatedFromAbove = false;
+ }
+ }
+ }
+
+ // Compute the number for this op and push it onto the stack.
+ auto *numbering =
+ new (opAllocator.Allocate()) OperationNumbering(operationID++);
+ if (op->hasTrait<OpTrait::IsIsolatedFromAbove>())
+ numbering->isIsolatedFromAbove = true;
+ operations.try_emplace(op, numbering);
+ if (op->getNumRegions()) {
+ opStack.emplace_back(StackState{
+ op, numbering, !numbering->isIsolatedFromAbove.has_value()});
+ }
+ });
+}
+
void IRNumberingState::number(Attribute attr) {
auto it = attrs.insert({attr, nullptr});
if (!it.second) {
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index ca30078f3468..eab75f50d2ee 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -127,6 +127,22 @@ struct DialectNumbering {
};
//===----------------------------------------------------------------------===//
+// Operation Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents the numbering entry of an operation.
+struct OperationNumbering {
+ OperationNumbering(unsigned number) : number(number) {}
+
+ /// The number assigned to this operation.
+ unsigned number;
+
+ /// A flag indicating if this operation's regions are isolated. If unset, the
+ /// operation isn't yet known to be isolated.
+ std::optional<bool> isIsolatedFromAbove;
+};
+
+//===----------------------------------------------------------------------===//
// IRNumberingState
//===----------------------------------------------------------------------===//
@@ -154,8 +170,8 @@ public:
return blockIDs[block];
}
unsigned getNumber(Operation *op) {
- assert(operationIDs.count(op) && "operation not numbered");
- return operationIDs[op];
+ assert(operations.count(op) && "operation not numbered");
+ return operations[op]->number;
}
unsigned getNumber(OperationName opName) {
assert(opNames.count(opName) && "opName not numbered");
@@ -186,14 +202,23 @@ public:
return blockOperationCounts[block];
}
+ /// Return if the given operation is isolated from above.
+ bool isIsolatedFromAbove(Operation *op) {
+ assert(operations.count(op) && "operation not numbered");
+ return operations[op]->isIsolatedFromAbove.value_or(false);
+ }
+
/// Get the set desired bytecode version to emit.
int64_t getDesiredBytecodeVersion() const;
-
+
private:
/// This class is used to provide a fake dialect writer for numbering nested
/// attributes and types.
struct NumberingDialectWriter;
+ /// Compute the global numbering state for the given root operation.
+ void computeGlobalNumberingState(Operation *rootOp);
+
/// Number the given IR unit for bytecode emission.
void number(Attribute attr);
void number(Block &block);
@@ -212,6 +237,7 @@ private:
/// Mapping from IR to the respective numbering entries.
DenseMap<Attribute, AttributeNumbering *> attrs;
+ DenseMap<Operation *, OperationNumbering *> operations;
DenseMap<OperationName, OpNameNumbering *> opNames;
DenseMap<Type, TypeNumbering *> types;
DenseMap<Dialect *, DialectNumbering *> registeredDialects;
@@ -228,12 +254,12 @@ private:
/// Allocators used for the various numbering entries.
llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator;
llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator;
+ llvm::SpecificBumpPtrAllocator<OperationNumbering> opAllocator;
llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator;
llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator;
llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator;
- /// The value ID for each Operation, Block and Value.
- DenseMap<Operation *, unsigned> operationIDs;
+ /// The value ID for each Block and Value.
DenseMap<Block *, unsigned> blockIDs;
DenseMap<Value, unsigned> valueIDs;