summaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bytecode/Writer/BytecodeWriter.cpp')
-rw-r--r--mlir/lib/Bytecode/Writer/BytecodeWriter.cpp91
1 files changed, 65 insertions, 26 deletions
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index d8f2cb106510..75315b5ec75e 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -18,15 +18,10 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/CachedHashString.h"
#include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Endian.h"
-#include <cstddef>
-#include <cstdint>
-#include <cstring>
+#include "llvm/Support/raw_ostream.h"
#include <optional>
-#include <sys/types.h>
#define DEBUG_TYPE "mlir-bytecode-writer"
@@ -47,6 +42,12 @@ struct BytecodeWriterConfig::Impl {
/// The producer of the bytecode.
StringRef producer;
+ /// Printer callbacks used to emit custom type and attribute encodings.
+ llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
+ attributeWriterCallbacks;
+ llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
+ typeWriterCallbacks;
+
/// A collection of non-dialect resource printers.
SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
};
@@ -60,6 +61,26 @@ BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map,
}
BytecodeWriterConfig::~BytecodeWriterConfig() = default;
+ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
+BytecodeWriterConfig::getAttributeWriterCallbacks() const {
+ return impl->attributeWriterCallbacks;
+}
+
+ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
+BytecodeWriterConfig::getTypeWriterCallbacks() const {
+ return impl->typeWriterCallbacks;
+}
+
+void BytecodeWriterConfig::attachAttributeCallback(
+ std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback) {
+ impl->attributeWriterCallbacks.emplace_back(std::move(callback));
+}
+
+void BytecodeWriterConfig::attachTypeCallback(
+ std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback) {
+ impl->typeWriterCallbacks.emplace_back(std::move(callback));
+}
+
void BytecodeWriterConfig::attachResourcePrinter(
std::unique_ptr<AsmResourcePrinter> printer) {
impl->externalResourcePrinters.emplace_back(std::move(printer));
@@ -774,32 +795,50 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
auto emitAttrOrType = [&](auto &entry) {
auto entryValue = entry.getValue();
- // First, try to emit this entry using the dialect bytecode interface.
- bool hasCustomEncoding = false;
- if (const BytecodeDialectInterface *interface = entry.dialect->interface) {
- // The writer used when emitting using a custom bytecode encoding.
+ auto emitAttrOrTypeRawImpl = [&]() -> void {
+ RawEmitterOstream(attrTypeEmitter) << entryValue;
+ attrTypeEmitter.emitByte(0);
+ };
+ auto emitAttrOrTypeImpl = [&]() -> bool {
+ // TODO: We don't currently support custom encoded mutable types and
+ // attributes.
+ if (entryValue.template hasTrait<TypeTrait::IsMutable>() ||
+ entryValue.template hasTrait<AttributeTrait::IsMutable>()) {
+ emitAttrOrTypeRawImpl();
+ return false;
+ }
+
DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
numberingState, stringSection);
-
if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
- // TODO: We don't currently support custom encoded mutable types.
- hasCustomEncoding =
- !entryValue.template hasTrait<TypeTrait::IsMutable>() &&
- succeeded(interface->writeType(entryValue, dialectWriter));
+ for (const auto &callback : config.typeWriterCallbacks) {
+ if (succeeded(callback->write(entryValue, dialectWriter)))
+ return true;
+ }
+ if (const BytecodeDialectInterface *interface =
+ entry.dialect->interface) {
+ if (succeeded(interface->writeType(entryValue, dialectWriter)))
+ return true;
+ }
} else {
- // TODO: We don't currently support custom encoded mutable attributes.
- hasCustomEncoding =
- !entryValue.template hasTrait<AttributeTrait::IsMutable>() &&
- succeeded(interface->writeAttribute(entryValue, dialectWriter));
+ for (const auto &callback : config.attributeWriterCallbacks) {
+ if (succeeded(callback->write(entryValue, dialectWriter)))
+ return true;
+ }
+ if (const BytecodeDialectInterface *interface =
+ entry.dialect->interface) {
+ if (succeeded(interface->writeAttribute(entryValue, dialectWriter)))
+ return true;
+ }
}
- }
- // If the entry was not emitted using the dialect interface, emit it using
- // the textual format.
- if (!hasCustomEncoding) {
- RawEmitterOstream(attrTypeEmitter) << entryValue;
- attrTypeEmitter.emitByte(0);
- }
+ // If the entry was not emitted using a callback or a dialect interface,
+ // emit it using the textual format.
+ emitAttrOrTypeRawImpl();
+ return false;
+ };
+
+ bool hasCustomEncoding = emitAttrOrTypeImpl();
// Record the offset of this entry.
uint64_t curOffset = attrTypeEmitter.size();