summaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bytecode/Writer/BytecodeWriter.cpp')
-rw-r--r--mlir/lib/Bytecode/Writer/BytecodeWriter.cpp234
1 files changed, 225 insertions, 9 deletions
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 5f34d9bb7a08..ff53cec15d77 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -24,6 +24,29 @@ using namespace mlir;
using namespace mlir::bytecode::detail;
//===----------------------------------------------------------------------===//
+// BytecodeWriterConfig
+//===----------------------------------------------------------------------===//
+
+struct BytecodeWriterConfig::Impl {
+ Impl(StringRef producer) : producer(producer) {}
+
+ /// The producer of the bytecode.
+ StringRef producer;
+
+ /// A collection of non-dialect resource printers.
+ SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
+};
+
+BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer)
+ : impl(std::make_unique<Impl>(producer)) {}
+BytecodeWriterConfig::~BytecodeWriterConfig() = default;
+
+void BytecodeWriterConfig::attachResourcePrinter(
+ std::unique_ptr<AsmResourcePrinter> printer) {
+ impl->externalResourcePrinters.emplace_back(std::move(printer));
+}
+
+//===----------------------------------------------------------------------===//
// EncodingEmitter
//===----------------------------------------------------------------------===//
@@ -56,6 +79,48 @@ public:
currentResult[offset - prevResultSize] = value;
}
+ /// Emit the provided blob of data, which is owned by the caller and is
+ /// guaranteed to not die before the end of the bytecode process.
+ void emitOwnedBlob(ArrayRef<uint8_t> data) {
+ // Push the current buffer before adding the provided data.
+ appendResult(std::move(currentResult));
+ appendOwnedResult(data);
+ }
+
+ /// Emit the provided blob of data that has the given alignment, which is
+ /// owned by the caller and is guaranteed to not die before the end of the
+ /// bytecode process. The alignment value is also encoded, making it available
+ /// on load.
+ void emitOwnedBlobAndAlignment(ArrayRef<uint8_t> data, uint32_t alignment) {
+ emitVarInt(alignment);
+ emitVarInt(data.size());
+
+ alignTo(alignment);
+ emitOwnedBlob(data);
+ }
+ void emitOwnedBlobAndAlignment(ArrayRef<char> data, uint32_t alignment) {
+ ArrayRef<uint8_t> castedData(reinterpret_cast<const uint8_t *>(data.data()),
+ data.size());
+ emitOwnedBlobAndAlignment(castedData, alignment);
+ }
+
+ /// Align the emitter to the given alignment.
+ void alignTo(unsigned alignment) {
+ if (alignment < 2)
+ return;
+ assert(llvm::isPowerOf2_32(alignment) && "expected valid alignment");
+
+ // Check to see if we need to emit any padding bytes to meet the desired
+ // alignment.
+ size_t curOffset = size();
+ size_t paddingSize = llvm::alignTo(curOffset, alignment) - curOffset;
+ while (paddingSize--)
+ emitByte(bytecode::kAlignmentByte);
+
+ // Keep track of the maximum required alignment.
+ requiredAlignment = std::max(requiredAlignment, alignment);
+ }
+
//===--------------------------------------------------------------------===//
// Integer Emission
@@ -119,15 +184,37 @@ public:
/// Emit a nested section of the given code, whose contents are encoded in the
/// provided emitter.
void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) {
- // Emit the section code and length.
+ // Emit the section code and length. The high bit of the code is used to
+ // indicate whether the section alignment is present, so save an offset to
+ // it.
+ uint64_t codeOffset = currentResult.size();
emitByte(code);
emitVarInt(emitter.size());
+ // Integrate the alignment of the section into this emitter if necessary.
+ unsigned emitterAlign = emitter.requiredAlignment;
+ if (emitterAlign > 1) {
+ if (size() & (emitterAlign - 1)) {
+ emitVarInt(emitterAlign);
+ alignTo(emitterAlign);
+
+ // Indicate that we needed to align the section, the high bit of the
+ // code field is used for this.
+ currentResult[codeOffset] |= 0b10000000;
+ } else {
+ // Otherwise, if we happen to be at a compatible offset, we just
+ // remember that we need this alignment.
+ requiredAlignment = std::max(requiredAlignment, emitterAlign);
+ }
+ }
+
// Push our current buffer and then merge the provided section body into
// ours.
appendResult(std::move(currentResult));
for (std::vector<uint8_t> &result : emitter.prevResultStorage)
- appendResult(std::move(result));
+ prevResultStorage.push_back(std::move(result));
+ llvm::append_range(prevResultList, emitter.prevResultList);
+ prevResultSize += emitter.prevResultSize;
appendResult(std::move(emitter.currentResult));
}
@@ -140,9 +227,16 @@ private:
/// Append a new result buffer to the current contents.
void appendResult(std::vector<uint8_t> &&result) {
- prevResultSize += result.size();
+ if (result.empty())
+ return;
prevResultStorage.emplace_back(std::move(result));
- prevResultList.emplace_back(prevResultStorage.back());
+ appendOwnedResult(prevResultStorage.back());
+ }
+ void appendOwnedResult(ArrayRef<uint8_t> result) {
+ if (result.empty())
+ return;
+ prevResultSize += result.size();
+ prevResultList.emplace_back(result);
}
/// The result of the emitter currently being built. We refrain from building
@@ -157,6 +251,9 @@ private:
/// An up-to-date total size of all of the buffers within `prevResultList`.
/// This enables O(1) size checks of the current encoding.
size_t prevResultSize = 0;
+
+ /// The highest required alignment for the start of this section.
+ unsigned requiredAlignment = 1;
};
/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need
@@ -250,7 +347,8 @@ public:
BytecodeWriter(Operation *op) : numberingState(op) {}
/// Write the bytecode for the given root operation.
- void write(Operation *rootOp, raw_ostream &os, StringRef producer);
+ void write(Operation *rootOp, raw_ostream &os,
+ const BytecodeWriterConfig::Impl &config);
private:
//===--------------------------------------------------------------------===//
@@ -272,6 +370,12 @@ private:
void writeIRSection(EncodingEmitter &emitter, Operation *op);
//===--------------------------------------------------------------------===//
+ // Resources
+
+ void writeResourceSection(Operation *op, EncodingEmitter &emitter,
+ const BytecodeWriterConfig::Impl &config);
+
+ //===--------------------------------------------------------------------===//
// Strings
void writeStringSection(EncodingEmitter &emitter);
@@ -288,7 +392,7 @@ private:
} // namespace
void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
- StringRef producer) {
+ const BytecodeWriterConfig::Impl &config) {
EncodingEmitter emitter;
// Emit the bytecode file header. This is how we identify the output as a
@@ -299,7 +403,7 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
emitter.emitVarInt(bytecode::kVersion);
// Emit the producer.
- emitter.emitNulTerminatedString(producer);
+ emitter.emitNulTerminatedString(config.producer);
// Emit the dialect section.
writeDialectSection(emitter);
@@ -310,6 +414,9 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
// Emit the IR section.
writeIRSection(emitter, rootOp);
+ // Emit the resources section.
+ writeResourceSection(rootOp, emitter, config);
+
// Emit the string section.
writeStringSection(emitter);
@@ -386,6 +493,10 @@ public:
emitter.emitVarInt(numberingState.getNumber(type));
}
+ void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
+ emitter.emitVarInt(numberingState.getNumber(resource));
+ }
+
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
@@ -614,6 +725,111 @@ void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) {
}
//===----------------------------------------------------------------------===//
+// Resources
+
+namespace {
+/// This class represents a resource builder implementation for the MLIR
+/// bytecode format.
+class ResourceBuilder : public AsmResourceBuilder {
+public:
+ using PostProcessFn = function_ref<void(StringRef, AsmResourceEntryKind)>;
+
+ ResourceBuilder(EncodingEmitter &emitter, StringSectionBuilder &stringSection,
+ PostProcessFn postProcessFn)
+ : emitter(emitter), stringSection(stringSection),
+ postProcessFn(postProcessFn) {}
+ ~ResourceBuilder() override = default;
+
+ void buildBlob(StringRef key, ArrayRef<char> data,
+ uint32_t dataAlignment) final {
+ emitter.emitOwnedBlobAndAlignment(data, dataAlignment);
+ postProcessFn(key, AsmResourceEntryKind::Blob);
+ }
+ void buildBool(StringRef key, bool data) final {
+ emitter.emitByte(data);
+ postProcessFn(key, AsmResourceEntryKind::Bool);
+ }
+ void buildString(StringRef key, StringRef data) final {
+ emitter.emitVarInt(stringSection.insert(data));
+ postProcessFn(key, AsmResourceEntryKind::String);
+ }
+
+private:
+ EncodingEmitter &emitter;
+ StringSectionBuilder &stringSection;
+ PostProcessFn postProcessFn;
+};
+} // namespace
+
+void BytecodeWriter::writeResourceSection(
+ Operation *op, EncodingEmitter &emitter,
+ const BytecodeWriterConfig::Impl &config) {
+ EncodingEmitter resourceEmitter;
+ EncodingEmitter resourceOffsetEmitter;
+ uint64_t prevOffset = 0;
+ SmallVector<std::tuple<StringRef, AsmResourceEntryKind, uint64_t>>
+ curResourceEntries;
+
+ // Functor used to process the offset for a resource of `kind` defined by
+ // 'key'.
+ auto appendResourceOffset = [&](StringRef key, AsmResourceEntryKind kind) {
+ uint64_t curOffset = resourceEmitter.size();
+ curResourceEntries.emplace_back(key, kind, curOffset - prevOffset);
+ prevOffset = curOffset;
+ };
+
+ // Functor used to emit a resource group defined by 'key'.
+ auto emitResourceGroup = [&](uint64_t key) {
+ resourceOffsetEmitter.emitVarInt(key);
+ resourceOffsetEmitter.emitVarInt(curResourceEntries.size());
+ for (auto [key, kind, size] : curResourceEntries) {
+ resourceOffsetEmitter.emitVarInt(stringSection.insert(key));
+ resourceOffsetEmitter.emitVarInt(size);
+ resourceOffsetEmitter.emitByte(kind);
+ }
+ };
+
+ // Builder used to emit resources.
+ ResourceBuilder entryBuilder(resourceEmitter, stringSection,
+ appendResourceOffset);
+
+ // Emit the external resource entries.
+ resourceOffsetEmitter.emitVarInt(config.externalResourcePrinters.size());
+ for (const auto &printer : config.externalResourcePrinters) {
+ curResourceEntries.clear();
+ printer->buildResources(op, entryBuilder);
+ emitResourceGroup(stringSection.insert(printer->getName()));
+ }
+
+ // Emit the dialect resource entries.
+ for (DialectNumbering &dialect : numberingState.getDialects()) {
+ if (!dialect.asmInterface)
+ continue;
+ curResourceEntries.clear();
+ dialect.asmInterface->buildResources(op, dialect.resources, entryBuilder);
+
+ // Emit the declaration resources for this dialect, these didn't get emitted
+ // by the interface. These resources don't have data attached, so just use a
+ // "blob" kind as a placeholder.
+ for (const auto &resource : dialect.resourceMap)
+ if (resource.second->isDeclaration)
+ appendResourceOffset(resource.first, AsmResourceEntryKind::Blob);
+
+ // Emit the resource group for this dialect.
+ if (!curResourceEntries.empty())
+ emitResourceGroup(dialect.number);
+ }
+
+ // If we didn't emit any resource groups, elide the resource sections.
+ if (resourceOffsetEmitter.size() == 0)
+ return;
+
+ emitter.emitSection(bytecode::Section::kResourceOffset,
+ std::move(resourceOffsetEmitter));
+ emitter.emitSection(bytecode::Section::kResource, std::move(resourceEmitter));
+}
+
+//===----------------------------------------------------------------------===//
// Strings
void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
@@ -627,7 +843,7 @@ void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
//===----------------------------------------------------------------------===//
void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
- StringRef producer) {
+ const BytecodeWriterConfig &config) {
BytecodeWriter writer(op);
- writer.write(op, os, producer);
+ writer.write(op, os, config.getImpl());
}