diff options
Diffstat (limited to 'mlir/lib/Bytecode/Writer/IRNumbering.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Writer/IRNumbering.cpp | 40 |
1 files changed, 34 insertions, 6 deletions
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp index ef643ca6d74c..67f929059e47 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -314,9 +314,22 @@ void IRNumberingState::number(Attribute attr) { // If this attribute will be emitted using the bytecode format, perform a // dummy writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable attributes right now. - if (!attr.hasTrait<AttributeTrait::IsMutable>()) { + // TODO: We don't allow custom encodings for mutable attributes right now. + if (!attr.hasTrait<AttributeTrait::IsMutable>()) { + // Try overriding emission with callbacks. + for (const auto &callback : config.getAttributeWriterCallbacks()) { + NumberingDialectWriter writer(*this); + // The client has the ability to override the group name through the + // callback. + std::optional<StringRef> groupNameOverride; + if (succeeded(callback->write(attr, groupNameOverride, writer))) { + if (groupNameOverride.has_value()) + numbering->dialect = &numberDialect(*groupNameOverride); + return; + } + } + + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeAttribute(attr, writer))) return; @@ -464,9 +477,24 @@ void IRNumberingState::number(Type type) { // If this type will be emitted using the bytecode format, perform a dummy // writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable types right now. - if (!type.hasTrait<TypeTrait::IsMutable>()) { + // TODO: We don't allow custom encodings for mutable types right now. + if (!type.hasTrait<TypeTrait::IsMutable>()) { + // Try overriding emission with callbacks. + for (const auto &callback : config.getTypeWriterCallbacks()) { + NumberingDialectWriter writer(*this); + // The client has the ability to override the group name through the + // callback. + std::optional<StringRef> groupNameOverride; + if (succeeded(callback->write(type, groupNameOverride, writer))) { + if (groupNameOverride.has_value()) + numbering->dialect = &numberDialect(*groupNameOverride); + return; + } + } + + // If this attribute will be emitted using the bytecode format, perform a + // dummy writing to number any nested components. + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeType(type, writer))) return; |
