diff options
Diffstat (limited to 'mlir/lib/Bytecode/Writer/BytecodeWriter.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 91 |
1 files changed, 65 insertions, 26 deletions
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp index d8f2cb106510..75315b5ec75e 100644 --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -18,15 +18,10 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/Endian.h" -#include <cstddef> -#include <cstdint> -#include <cstring> +#include "llvm/Support/raw_ostream.h" #include <optional> -#include <sys/types.h> #define DEBUG_TYPE "mlir-bytecode-writer" @@ -47,6 +42,12 @@ struct BytecodeWriterConfig::Impl { /// The producer of the bytecode. StringRef producer; + /// Printer callbacks used to emit custom type and attribute encodings. + llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> + attributeWriterCallbacks; + llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> + typeWriterCallbacks; + /// A collection of non-dialect resource printers. SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; }; @@ -60,6 +61,26 @@ BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map, } BytecodeWriterConfig::~BytecodeWriterConfig() = default; +ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> +BytecodeWriterConfig::getAttributeWriterCallbacks() const { + return impl->attributeWriterCallbacks; +} + +ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> +BytecodeWriterConfig::getTypeWriterCallbacks() const { + return impl->typeWriterCallbacks; +} + +void BytecodeWriterConfig::attachAttributeCallback( + std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback) { + impl->attributeWriterCallbacks.emplace_back(std::move(callback)); +} + +void BytecodeWriterConfig::attachTypeCallback( + std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback) { + impl->typeWriterCallbacks.emplace_back(std::move(callback)); +} + void BytecodeWriterConfig::attachResourcePrinter( std::unique_ptr<AsmResourcePrinter> printer) { impl->externalResourcePrinters.emplace_back(std::move(printer)); @@ -774,32 +795,50 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { auto emitAttrOrType = [&](auto &entry) { auto entryValue = entry.getValue(); - // First, try to emit this entry using the dialect bytecode interface. - bool hasCustomEncoding = false; - if (const BytecodeDialectInterface *interface = entry.dialect->interface) { - // The writer used when emitting using a custom bytecode encoding. + auto emitAttrOrTypeRawImpl = [&]() -> void { + RawEmitterOstream(attrTypeEmitter) << entryValue; + attrTypeEmitter.emitByte(0); + }; + auto emitAttrOrTypeImpl = [&]() -> bool { + // TODO: We don't currently support custom encoded mutable types and + // attributes. + if (entryValue.template hasTrait<TypeTrait::IsMutable>() || + entryValue.template hasTrait<AttributeTrait::IsMutable>()) { + emitAttrOrTypeRawImpl(); + return false; + } + DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, numberingState, stringSection); - if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) { - // TODO: We don't currently support custom encoded mutable types. - hasCustomEncoding = - !entryValue.template hasTrait<TypeTrait::IsMutable>() && - succeeded(interface->writeType(entryValue, dialectWriter)); + for (const auto &callback : config.typeWriterCallbacks) { + if (succeeded(callback->write(entryValue, dialectWriter))) + return true; + } + if (const BytecodeDialectInterface *interface = + entry.dialect->interface) { + if (succeeded(interface->writeType(entryValue, dialectWriter))) + return true; + } } else { - // TODO: We don't currently support custom encoded mutable attributes. - hasCustomEncoding = - !entryValue.template hasTrait<AttributeTrait::IsMutable>() && - succeeded(interface->writeAttribute(entryValue, dialectWriter)); + for (const auto &callback : config.attributeWriterCallbacks) { + if (succeeded(callback->write(entryValue, dialectWriter))) + return true; + } + if (const BytecodeDialectInterface *interface = + entry.dialect->interface) { + if (succeeded(interface->writeAttribute(entryValue, dialectWriter))) + return true; + } } - } - // If the entry was not emitted using the dialect interface, emit it using - // the textual format. - if (!hasCustomEncoding) { - RawEmitterOstream(attrTypeEmitter) << entryValue; - attrTypeEmitter.emitByte(0); - } + // If the entry was not emitted using a callback or a dialect interface, + // emit it using the textual format. + emitAttrOrTypeRawImpl(); + return false; + }; + + bool hasCustomEncoding = emitAttrOrTypeImpl(); // Record the offset of this entry. uint64_t curOffset = attrTypeEmitter.size(); |
