diff options
| author | Mehdi Amini <joker.eph@gmail.com> | 2023-07-28 10:43:51 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2023-07-28 10:44:02 -0700 |
| commit | b299ec16661f653df66cdaf161cdc5441bc9803c (patch) | |
| tree | 871d39e32caf5b3cacdc546905d4ee7731b8053e /mlir/lib/Bytecode | |
| parent | bb65caf90ae1ade0ab1896c8e781cff34b34a846 (diff) | |
Expose callbacks for encoding of types/attributes
[mlir] Expose a mechanism to provide a callback for encoding types and attributes in MLIR bytecode.
Two callbacks are exposed, respectively, to the BytecodeWriterConfig and to the ParserConfig. At bytecode parsing/printing, clients have the ability to specify a callback to be used to optionally read/write the encoding. On failure, fallback path will execute the default parsers and printers for the dialect.
Testing shows how to leverage this functionality to support back-deployment and backward-compatibility usecases when roundtripping to bytecode a client dialect with type/attributes dependencies on upstream.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D153383
Diffstat (limited to 'mlir/lib/Bytecode')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 181 | ||||
| -rw-r--r-- | mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 91 | ||||
| -rw-r--r-- | mlir/lib/Bytecode/Writer/IRNumbering.cpp | 40 |
3 files changed, 219 insertions, 93 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 0639baf10b0b..91e47c4c0e47 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -451,7 +451,7 @@ struct BytecodeDialect { /// Returns failure if the dialect couldn't be loaded *and* the provided /// context does not allow unregistered dialects. The provided reader is used /// for error emission if necessary. - LogicalResult load(DialectReader &reader, MLIRContext *ctx); + LogicalResult load(const DialectReader &reader, MLIRContext *ctx); /// Return the loaded dialect, or nullptr if the dialect is unknown. This can /// only be called after `load`. @@ -505,10 +505,11 @@ struct BytecodeOperationName { /// Parse a single dialect group encoded in the byte stream. static LogicalResult parseDialectGrouping( - EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, + EncodingReader &reader, + MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { // Parse the dialect and the number of entries in the group. - BytecodeDialect *dialect; + std::unique_ptr<BytecodeDialect> *dialect; if (failed(parseEntry(reader, dialects, dialect, "dialect"))) return failure(); uint64_t numEntries; @@ -516,7 +517,7 @@ static LogicalResult parseDialectGrouping( return failure(); for (uint64_t i = 0; i < numEntries; ++i) - if (failed(entryCallback(dialect))) + if (failed(entryCallback(dialect->get()))) return failure(); return success(); } @@ -532,7 +533,7 @@ public: /// Initialize the resource section reader with the given section data. LogicalResult initialize(Location fileLoc, const ParserConfig &config, - MutableArrayRef<BytecodeDialect> dialects, + MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); @@ -682,7 +683,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty, LogicalResult ResourceSectionReader::initialize( Location fileLoc, const ParserConfig &config, - MutableArrayRef<BytecodeDialect> dialects, + MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { @@ -731,19 +732,19 @@ LogicalResult ResourceSectionReader::initialize( // Read the dialect resources from the bytecode. MLIRContext *ctx = fileLoc->getContext(); while (!offsetReader.empty()) { - BytecodeDialect *dialect; + std::unique_ptr<BytecodeDialect> *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed(dialect->load(dialectReader, ctx))) + failed((*dialect)->load(dialectReader, ctx))) return failure(); - Dialect *loadedDialect = dialect->getLoadedDialect(); + Dialect *loadedDialect = (*dialect)->getLoadedDialect(); if (!loadedDialect) { return resourceReader.emitError() - << "dialect '" << dialect->name << "' is unknown"; + << "dialect '" << (*dialect)->name << "' is unknown"; } const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); if (!handler) { return resourceReader.emitError() - << "unexpected resources for dialect '" << dialect->name << "'"; + << "unexpected resources for dialect '" << (*dialect)->name << "'"; } // Ensure that each resource is declared before being processed. @@ -753,7 +754,7 @@ LogicalResult ResourceSectionReader::initialize( if (failed(handle)) { return resourceReader.emitError() << "unknown 'resource' key '" << key << "' for dialect '" - << dialect->name << "'"; + << (*dialect)->name << "'"; } dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); dialectResources.push_back(*handle); @@ -796,15 +797,19 @@ class AttrTypeReader { public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, Location fileLoc, - uint64_t &bytecodeVersion) + ResourceSectionReader &resourceReader, + const llvm::StringMap<BytecodeDialect *> &dialectsMap, + uint64_t &bytecodeVersion, Location fileLoc, + const ParserConfig &config) : stringReader(stringReader), resourceReader(resourceReader), - fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {} + dialectsMap(dialectsMap), fileLoc(fileLoc), + bytecodeVersion(bytecodeVersion), parserConfig(config) {} /// Initialize the attribute and type information within the reader. - LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, - ArrayRef<uint8_t> sectionData, - ArrayRef<uint8_t> offsetSectionData); + LogicalResult + initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, + ArrayRef<uint8_t> sectionData, + ArrayRef<uint8_t> offsetSectionData); /// Resolve the attribute or type at the given index. Returns nullptr on /// failure. @@ -878,6 +883,10 @@ private: /// parsing custom encoded attribute/type entries. ResourceSectionReader &resourceReader; + /// The map of the loaded dialects used to retrieve dialect information, such + /// as the dialect version. + const llvm::StringMap<BytecodeDialect *> &dialectsMap; + /// The set of attribute and type entries. SmallVector<AttrEntry> attributes; SmallVector<TypeEntry> types; @@ -887,27 +896,48 @@ private: /// Current bytecode version being used. uint64_t &bytecodeVersion; + + /// Reference to the parser configuration. + const ParserConfig &parserConfig; }; class DialectReader : public DialectBytecodeReader { public: DialectReader(AttrTypeReader &attrTypeReader, StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, EncodingReader &reader, - uint64_t &bytecodeVersion) + ResourceSectionReader &resourceReader, + const llvm::StringMap<BytecodeDialect *> &dialectsMap, + EncodingReader &reader, uint64_t &bytecodeVersion) : attrTypeReader(attrTypeReader), stringReader(stringReader), - resourceReader(resourceReader), reader(reader), - bytecodeVersion(bytecodeVersion) {} + resourceReader(resourceReader), dialectsMap(dialectsMap), + reader(reader), bytecodeVersion(bytecodeVersion) {} - InFlightDiagnostic emitError(const Twine &msg) override { + InFlightDiagnostic emitError(const Twine &msg) const override { return reader.emitError(msg); } + FailureOr<const DialectVersion *> + getDialectVersion(StringRef dialectName) const override { + // First check if the dialect is available in the map. + auto dialectEntry = dialectsMap.find(dialectName); + if (dialectEntry == dialectsMap.end()) + return failure(); + // If the dialect was found, try to load it. This will trigger reading the + // bytecode version from the version buffer if it wasn't already processed. + // Return failure if either of those two actions could not be completed. + if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) || + dialectEntry->getValue()->loadedVersion.get() == nullptr) + return failure(); + return dialectEntry->getValue()->loadedVersion.get(); + } + + MLIRContext *getContext() const override { return getLoc().getContext(); } + uint64_t getBytecodeVersion() const override { return bytecodeVersion; } - DialectReader withEncodingReader(EncodingReader &encReader) { + DialectReader withEncodingReader(EncodingReader &encReader) const { return DialectReader(attrTypeReader, stringReader, resourceReader, - encReader, bytecodeVersion); + dialectsMap, encReader, bytecodeVersion); } Location getLoc() const { return reader.getLoc(); } @@ -1010,6 +1040,7 @@ private: AttrTypeReader &attrTypeReader; StringSectionReader &stringReader; ResourceSectionReader &resourceReader; + const llvm::StringMap<BytecodeDialect *> &dialectsMap; EncodingReader &reader; uint64_t &bytecodeVersion; }; @@ -1096,10 +1127,9 @@ private: }; } // namespace -LogicalResult -AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, - ArrayRef<uint8_t> sectionData, - ArrayRef<uint8_t> offsetSectionData) { +LogicalResult AttrTypeReader::initialize( + MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, + ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) { EncodingReader offsetReader(offsetSectionData, fileLoc); // Parse the number of attribute and type entries. @@ -1151,6 +1181,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, return offsetReader.emitError( "unexpected trailing data in the Attribute/Type offset section"); } + return success(); } @@ -1216,32 +1247,54 @@ template <typename T> LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader, StringRef entryType) { - DialectReader dialectReader(*this, stringReader, resourceReader, reader, - bytecodeVersion); + DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, + reader, bytecodeVersion); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); + + if constexpr (std::is_same_v<T, Type>) { + // Try parsing with callbacks first if available. + for (const auto &callback : + parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) { + if (failed( + callback->read(dialectReader, entry.dialect->name, entry.entry))) + return failure(); + // Early return if parsing was successful. + if (!!entry.entry) + return success(); + + // Reset the reader if we failed to parse, so we can fall through the + // other parsing functions. + reader = EncodingReader(entry.data, reader.getLoc()); + } + } else { + // Try parsing with callbacks first if available. + for (const auto &callback : + parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) { + if (failed( + callback->read(dialectReader, entry.dialect->name, entry.entry))) + return failure(); + // Early return if parsing was successful. + if (!!entry.entry) + return success(); + + // Reset the reader if we failed to parse, so we can fall through the + // other parsing functions. + reader = EncodingReader(entry.data, reader.getLoc()); + } + } + // Ensure that the dialect implements the bytecode interface. if (!entry.dialect->interface) { return reader.emitError("dialect '", entry.dialect->name, "' does not implement the bytecode interface"); } - // Ask the dialect to parse the entry. If the dialect is versioned, parse - // using the versioned encoding readers. - if (entry.dialect->loadedVersion.get()) { - if constexpr (std::is_same_v<T, Type>) - entry.entry = entry.dialect->interface->readType( - dialectReader, *entry.dialect->loadedVersion); - else - entry.entry = entry.dialect->interface->readAttribute( - dialectReader, *entry.dialect->loadedVersion); + if constexpr (std::is_same_v<T, Type>) + entry.entry = entry.dialect->interface->readType(dialectReader); + else + entry.entry = entry.dialect->interface->readAttribute(dialectReader); - } else { - if constexpr (std::is_same_v<T, Type>) - entry.entry = entry.dialect->interface->readType(dialectReader); - else - entry.entry = entry.dialect->interface->readAttribute(dialectReader); - } return success(!!entry.entry); } @@ -1262,7 +1315,8 @@ public: llvm::MemoryBufferRef buffer, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, fileLoc, version), + attrTypeReader(stringReader, resourceReader, dialectsMap, version, + fileLoc, config), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1528,7 +1582,8 @@ private: StringRef producer; /// The table of IR units referenced within the bytecode file. - SmallVector<BytecodeDialect> dialects; + SmallVector<std::unique_ptr<BytecodeDialect>> dialects; + llvm::StringMap<BytecodeDialect *> dialectsMap; SmallVector<BytecodeOperationName> opNames; /// The reader used to process resources within the bytecode. @@ -1675,7 +1730,8 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { //===----------------------------------------------------------------------===// // Dialect Section -LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { +LogicalResult BytecodeDialect::load(const DialectReader &reader, + MLIRContext *ctx) { if (dialect) return success(); Dialect *loadedDialect = ctx->getOrLoadDialect(name); @@ -1719,13 +1775,15 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { // Parse each of the dialects. for (uint64_t i = 0; i < numDialects; ++i) { + dialects[i] = std::make_unique<BytecodeDialect>(); /// Before version kDialectVersioning, there wasn't any versioning available /// for dialects, and the entryIdx represent the string itself. if (version < bytecode::kDialectVersioning) { - if (failed(stringReader.parseString(sectionReader, dialects[i].name))) + if (failed(stringReader.parseString(sectionReader, dialects[i]->name))) return failure(); continue; } + // Parse ID representing dialect and version. uint64_t dialectNameIdx; bool versionAvailable; @@ -1733,18 +1791,19 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { versionAvailable))) return failure(); if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, - dialects[i].name))) + dialects[i]->name))) return failure(); if (versionAvailable) { bytecode::Section::ID sectionID; - if (failed( - sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) + if (failed(sectionReader.parseSection(sectionID, + dialects[i]->versionBuffer))) return failure(); if (sectionID != bytecode::Section::kDialectVersions) { emitError(fileLoc, "expected dialect version section"); return failure(); } } + dialectsMap[dialects[i]->name] = dialects[i].get(); } // Parse the operation names, which are grouped by dialect. @@ -1792,7 +1851,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader, if (!opName->opName) { // Load the dialect and its version. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader, version); + dialectsMap, reader, version); if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); // If the opName is empty, this is because we use to accept names such as @@ -1835,7 +1894,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection( // Initialize the resource reader with the resource sections. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader, version); + dialectsMap, reader, version); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, dialectReader, bufferOwnerRef); @@ -2036,14 +2095,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, "parsed use-list orders were invalid and could not be applied"); // Resolve dialect version. - for (const BytecodeDialect &byteCodeDialect : dialects) { + for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) { // Parsing is complete, give an opportunity to each dialect to visit the // IR and perform upgrades. - if (!byteCodeDialect.loadedVersion) + if (!byteCodeDialect->loadedVersion) continue; - if (byteCodeDialect.interface && - failed(byteCodeDialect.interface->upgradeFromVersion( - *moduleOp, *byteCodeDialect.loadedVersion))) + if (byteCodeDialect->interface && + failed(byteCodeDialect->interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect->loadedVersion))) return failure(); } @@ -2196,7 +2255,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, // interface and control the serialization. if (wasRegistered) { DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader, version); + dialectsMap, reader, version); if (failed( propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) return failure(); diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp index d8f2cb106510..75315b5ec75e 100644 --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -18,15 +18,10 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/Endian.h" -#include <cstddef> -#include <cstdint> -#include <cstring> +#include "llvm/Support/raw_ostream.h" #include <optional> -#include <sys/types.h> #define DEBUG_TYPE "mlir-bytecode-writer" @@ -47,6 +42,12 @@ struct BytecodeWriterConfig::Impl { /// The producer of the bytecode. StringRef producer; + /// Printer callbacks used to emit custom type and attribute encodings. + llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> + attributeWriterCallbacks; + llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> + typeWriterCallbacks; + /// A collection of non-dialect resource printers. SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; }; @@ -60,6 +61,26 @@ BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map, } BytecodeWriterConfig::~BytecodeWriterConfig() = default; +ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> +BytecodeWriterConfig::getAttributeWriterCallbacks() const { + return impl->attributeWriterCallbacks; +} + +ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> +BytecodeWriterConfig::getTypeWriterCallbacks() const { + return impl->typeWriterCallbacks; +} + +void BytecodeWriterConfig::attachAttributeCallback( + std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback) { + impl->attributeWriterCallbacks.emplace_back(std::move(callback)); +} + +void BytecodeWriterConfig::attachTypeCallback( + std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback) { + impl->typeWriterCallbacks.emplace_back(std::move(callback)); +} + void BytecodeWriterConfig::attachResourcePrinter( std::unique_ptr<AsmResourcePrinter> printer) { impl->externalResourcePrinters.emplace_back(std::move(printer)); @@ -774,32 +795,50 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { auto emitAttrOrType = [&](auto &entry) { auto entryValue = entry.getValue(); - // First, try to emit this entry using the dialect bytecode interface. - bool hasCustomEncoding = false; - if (const BytecodeDialectInterface *interface = entry.dialect->interface) { - // The writer used when emitting using a custom bytecode encoding. + auto emitAttrOrTypeRawImpl = [&]() -> void { + RawEmitterOstream(attrTypeEmitter) << entryValue; + attrTypeEmitter.emitByte(0); + }; + auto emitAttrOrTypeImpl = [&]() -> bool { + // TODO: We don't currently support custom encoded mutable types and + // attributes. + if (entryValue.template hasTrait<TypeTrait::IsMutable>() || + entryValue.template hasTrait<AttributeTrait::IsMutable>()) { + emitAttrOrTypeRawImpl(); + return false; + } + DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, numberingState, stringSection); - if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) { - // TODO: We don't currently support custom encoded mutable types. - hasCustomEncoding = - !entryValue.template hasTrait<TypeTrait::IsMutable>() && - succeeded(interface->writeType(entryValue, dialectWriter)); + for (const auto &callback : config.typeWriterCallbacks) { + if (succeeded(callback->write(entryValue, dialectWriter))) + return true; + } + if (const BytecodeDialectInterface *interface = + entry.dialect->interface) { + if (succeeded(interface->writeType(entryValue, dialectWriter))) + return true; + } } else { - // TODO: We don't currently support custom encoded mutable attributes. - hasCustomEncoding = - !entryValue.template hasTrait<AttributeTrait::IsMutable>() && - succeeded(interface->writeAttribute(entryValue, dialectWriter)); + for (const auto &callback : config.attributeWriterCallbacks) { + if (succeeded(callback->write(entryValue, dialectWriter))) + return true; + } + if (const BytecodeDialectInterface *interface = + entry.dialect->interface) { + if (succeeded(interface->writeAttribute(entryValue, dialectWriter))) + return true; + } } - } - // If the entry was not emitted using the dialect interface, emit it using - // the textual format. - if (!hasCustomEncoding) { - RawEmitterOstream(attrTypeEmitter) << entryValue; - attrTypeEmitter.emitByte(0); - } + // If the entry was not emitted using a callback or a dialect interface, + // emit it using the textual format. + emitAttrOrTypeRawImpl(); + return false; + }; + + bool hasCustomEncoding = emitAttrOrTypeImpl(); // Record the offset of this entry. uint64_t curOffset = attrTypeEmitter.size(); diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp index ef643ca6d74c..67f929059e47 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -314,9 +314,22 @@ void IRNumberingState::number(Attribute attr) { // If this attribute will be emitted using the bytecode format, perform a // dummy writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable attributes right now. - if (!attr.hasTrait<AttributeTrait::IsMutable>()) { + // TODO: We don't allow custom encodings for mutable attributes right now. + if (!attr.hasTrait<AttributeTrait::IsMutable>()) { + // Try overriding emission with callbacks. + for (const auto &callback : config.getAttributeWriterCallbacks()) { + NumberingDialectWriter writer(*this); + // The client has the ability to override the group name through the + // callback. + std::optional<StringRef> groupNameOverride; + if (succeeded(callback->write(attr, groupNameOverride, writer))) { + if (groupNameOverride.has_value()) + numbering->dialect = &numberDialect(*groupNameOverride); + return; + } + } + + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeAttribute(attr, writer))) return; @@ -464,9 +477,24 @@ void IRNumberingState::number(Type type) { // If this type will be emitted using the bytecode format, perform a dummy // writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable types right now. - if (!type.hasTrait<TypeTrait::IsMutable>()) { + // TODO: We don't allow custom encodings for mutable types right now. + if (!type.hasTrait<TypeTrait::IsMutable>()) { + // Try overriding emission with callbacks. + for (const auto &callback : config.getTypeWriterCallbacks()) { + NumberingDialectWriter writer(*this); + // The client has the ability to override the group name through the + // callback. + std::optional<StringRef> groupNameOverride; + if (succeeded(callback->write(type, groupNameOverride, writer))) { + if (groupNameOverride.has_value()) + numbering->dialect = &numberDialect(*groupNameOverride); + return; + } + } + + // If this attribute will be emitted using the bytecode format, perform a + // dummy writing to number any nested components. + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeType(type, writer))) return; |
