summaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2022-09-06 20:47:57 -0700
committerRiver Riddle <riddleriver@gmail.com>2022-09-13 11:39:19 -0700
commit6ab2bcffe45e660a68493e6a7cd04b6f05da51dc (patch)
tree9ed79ceafd349129076de2d997bd2ccd66314325 /mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
parente166d2e00bc0624cd83d9c790bc4b8c80446d126 (diff)
[mlir:Bytecode] Add support for encoding resources
Resources are encoded in two separate sections similarly to attributes/types, one for the actual data and one for the data offsets. Unlike other sections, the resource sections are optional given that in many cases they won't be present. For testing, bytecode serialization is added for DenseResourceElementsAttr. Differential Revision: https://reviews.llvm.org/D132729
Diffstat (limited to 'mlir/lib/Bytecode/Writer/BytecodeWriter.cpp')
-rw-r--r--mlir/lib/Bytecode/Writer/BytecodeWriter.cpp234
1 files changed, 225 insertions, 9 deletions
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 5f34d9bb7a08..ff53cec15d77 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -24,6 +24,29 @@ using namespace mlir;
using namespace mlir::bytecode::detail;
//===----------------------------------------------------------------------===//
+// BytecodeWriterConfig
+//===----------------------------------------------------------------------===//
+
+struct BytecodeWriterConfig::Impl {
+ Impl(StringRef producer) : producer(producer) {}
+
+ /// The producer of the bytecode.
+ StringRef producer;
+
+ /// A collection of non-dialect resource printers.
+ SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
+};
+
+BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer)
+ : impl(std::make_unique<Impl>(producer)) {}
+BytecodeWriterConfig::~BytecodeWriterConfig() = default;
+
+void BytecodeWriterConfig::attachResourcePrinter(
+ std::unique_ptr<AsmResourcePrinter> printer) {
+ impl->externalResourcePrinters.emplace_back(std::move(printer));
+}
+
+//===----------------------------------------------------------------------===//
// EncodingEmitter
//===----------------------------------------------------------------------===//
@@ -56,6 +79,48 @@ public:
currentResult[offset - prevResultSize] = value;
}
+ /// Emit the provided blob of data, which is owned by the caller and is
+ /// guaranteed to not die before the end of the bytecode process.
+ void emitOwnedBlob(ArrayRef<uint8_t> data) {
+ // Push the current buffer before adding the provided data.
+ appendResult(std::move(currentResult));
+ appendOwnedResult(data);
+ }
+
+ /// Emit the provided blob of data that has the given alignment, which is
+ /// owned by the caller and is guaranteed to not die before the end of the
+ /// bytecode process. The alignment value is also encoded, making it available
+ /// on load.
+ void emitOwnedBlobAndAlignment(ArrayRef<uint8_t> data, uint32_t alignment) {
+ emitVarInt(alignment);
+ emitVarInt(data.size());
+
+ alignTo(alignment);
+ emitOwnedBlob(data);
+ }
+ void emitOwnedBlobAndAlignment(ArrayRef<char> data, uint32_t alignment) {
+ ArrayRef<uint8_t> castedData(reinterpret_cast<const uint8_t *>(data.data()),
+ data.size());
+ emitOwnedBlobAndAlignment(castedData, alignment);
+ }
+
+ /// Align the emitter to the given alignment.
+ void alignTo(unsigned alignment) {
+ if (alignment < 2)
+ return;
+ assert(llvm::isPowerOf2_32(alignment) && "expected valid alignment");
+
+ // Check to see if we need to emit any padding bytes to meet the desired
+ // alignment.
+ size_t curOffset = size();
+ size_t paddingSize = llvm::alignTo(curOffset, alignment) - curOffset;
+ while (paddingSize--)
+ emitByte(bytecode::kAlignmentByte);
+
+ // Keep track of the maximum required alignment.
+ requiredAlignment = std::max(requiredAlignment, alignment);
+ }
+
//===--------------------------------------------------------------------===//
// Integer Emission
@@ -119,15 +184,37 @@ public:
/// Emit a nested section of the given code, whose contents are encoded in the
/// provided emitter.
void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) {
- // Emit the section code and length.
+ // Emit the section code and length. The high bit of the code is used to
+ // indicate whether the section alignment is present, so save an offset to
+ // it.
+ uint64_t codeOffset = currentResult.size();
emitByte(code);
emitVarInt(emitter.size());
+ // Integrate the alignment of the section into this emitter if necessary.
+ unsigned emitterAlign = emitter.requiredAlignment;
+ if (emitterAlign > 1) {
+ if (size() & (emitterAlign - 1)) {
+ emitVarInt(emitterAlign);
+ alignTo(emitterAlign);
+
+ // Indicate that we needed to align the section, the high bit of the
+ // code field is used for this.
+ currentResult[codeOffset] |= 0b10000000;
+ } else {
+ // Otherwise, if we happen to be at a compatible offset, we just
+ // remember that we need this alignment.
+ requiredAlignment = std::max(requiredAlignment, emitterAlign);
+ }
+ }
+
// Push our current buffer and then merge the provided section body into
// ours.
appendResult(std::move(currentResult));
for (std::vector<uint8_t> &result : emitter.prevResultStorage)
- appendResult(std::move(result));
+ prevResultStorage.push_back(std::move(result));
+ llvm::append_range(prevResultList, emitter.prevResultList);
+ prevResultSize += emitter.prevResultSize;
appendResult(std::move(emitter.currentResult));
}
@@ -140,9 +227,16 @@ private:
/// Append a new result buffer to the current contents.
void appendResult(std::vector<uint8_t> &&result) {
- prevResultSize += result.size();
+ if (result.empty())
+ return;
prevResultStorage.emplace_back(std::move(result));
- prevResultList.emplace_back(prevResultStorage.back());
+ appendOwnedResult(prevResultStorage.back());
+ }
+ void appendOwnedResult(ArrayRef<uint8_t> result) {
+ if (result.empty())
+ return;
+ prevResultSize += result.size();
+ prevResultList.emplace_back(result);
}
/// The result of the emitter currently being built. We refrain from building
@@ -157,6 +251,9 @@ private:
/// An up-to-date total size of all of the buffers within `prevResultList`.
/// This enables O(1) size checks of the current encoding.
size_t prevResultSize = 0;
+
+ /// The highest required alignment for the start of this section.
+ unsigned requiredAlignment = 1;
};
/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need
@@ -250,7 +347,8 @@ public:
BytecodeWriter(Operation *op) : numberingState(op) {}
/// Write the bytecode for the given root operation.
- void write(Operation *rootOp, raw_ostream &os, StringRef producer);
+ void write(Operation *rootOp, raw_ostream &os,
+ const BytecodeWriterConfig::Impl &config);
private:
//===--------------------------------------------------------------------===//
@@ -272,6 +370,12 @@ private:
void writeIRSection(EncodingEmitter &emitter, Operation *op);
//===--------------------------------------------------------------------===//
+ // Resources
+
+ void writeResourceSection(Operation *op, EncodingEmitter &emitter,
+ const BytecodeWriterConfig::Impl &config);
+
+ //===--------------------------------------------------------------------===//
// Strings
void writeStringSection(EncodingEmitter &emitter);
@@ -288,7 +392,7 @@ private:
} // namespace
void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
- StringRef producer) {
+ const BytecodeWriterConfig::Impl &config) {
EncodingEmitter emitter;
// Emit the bytecode file header. This is how we identify the output as a
@@ -299,7 +403,7 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
emitter.emitVarInt(bytecode::kVersion);
// Emit the producer.
- emitter.emitNulTerminatedString(producer);
+ emitter.emitNulTerminatedString(config.producer);
// Emit the dialect section.
writeDialectSection(emitter);
@@ -310,6 +414,9 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
// Emit the IR section.
writeIRSection(emitter, rootOp);
+ // Emit the resources section.
+ writeResourceSection(rootOp, emitter, config);
+
// Emit the string section.
writeStringSection(emitter);
@@ -386,6 +493,10 @@ public:
emitter.emitVarInt(numberingState.getNumber(type));
}
+ void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
+ emitter.emitVarInt(numberingState.getNumber(resource));
+ }
+
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
@@ -614,6 +725,111 @@ void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) {
}
//===----------------------------------------------------------------------===//
+// Resources
+
+namespace {
+/// This class represents a resource builder implementation for the MLIR
+/// bytecode format.
+class ResourceBuilder : public AsmResourceBuilder {
+public:
+ using PostProcessFn = function_ref<void(StringRef, AsmResourceEntryKind)>;
+
+ ResourceBuilder(EncodingEmitter &emitter, StringSectionBuilder &stringSection,
+ PostProcessFn postProcessFn)
+ : emitter(emitter), stringSection(stringSection),
+ postProcessFn(postProcessFn) {}
+ ~ResourceBuilder() override = default;
+
+ void buildBlob(StringRef key, ArrayRef<char> data,
+ uint32_t dataAlignment) final {
+ emitter.emitOwnedBlobAndAlignment(data, dataAlignment);
+ postProcessFn(key, AsmResourceEntryKind::Blob);
+ }
+ void buildBool(StringRef key, bool data) final {
+ emitter.emitByte(data);
+ postProcessFn(key, AsmResourceEntryKind::Bool);
+ }
+ void buildString(StringRef key, StringRef data) final {
+ emitter.emitVarInt(stringSection.insert(data));
+ postProcessFn(key, AsmResourceEntryKind::String);
+ }
+
+private:
+ EncodingEmitter &emitter;
+ StringSectionBuilder &stringSection;
+ PostProcessFn postProcessFn;
+};
+} // namespace
+
+void BytecodeWriter::writeResourceSection(
+ Operation *op, EncodingEmitter &emitter,
+ const BytecodeWriterConfig::Impl &config) {
+ EncodingEmitter resourceEmitter;
+ EncodingEmitter resourceOffsetEmitter;
+ uint64_t prevOffset = 0;
+ SmallVector<std::tuple<StringRef, AsmResourceEntryKind, uint64_t>>
+ curResourceEntries;
+
+ // Functor used to process the offset for a resource of `kind` defined by
+ // 'key'.
+ auto appendResourceOffset = [&](StringRef key, AsmResourceEntryKind kind) {
+ uint64_t curOffset = resourceEmitter.size();
+ curResourceEntries.emplace_back(key, kind, curOffset - prevOffset);
+ prevOffset = curOffset;
+ };
+
+ // Functor used to emit a resource group defined by 'key'.
+ auto emitResourceGroup = [&](uint64_t key) {
+ resourceOffsetEmitter.emitVarInt(key);
+ resourceOffsetEmitter.emitVarInt(curResourceEntries.size());
+ for (auto [key, kind, size] : curResourceEntries) {
+ resourceOffsetEmitter.emitVarInt(stringSection.insert(key));
+ resourceOffsetEmitter.emitVarInt(size);
+ resourceOffsetEmitter.emitByte(kind);
+ }
+ };
+
+ // Builder used to emit resources.
+ ResourceBuilder entryBuilder(resourceEmitter, stringSection,
+ appendResourceOffset);
+
+ // Emit the external resource entries.
+ resourceOffsetEmitter.emitVarInt(config.externalResourcePrinters.size());
+ for (const auto &printer : config.externalResourcePrinters) {
+ curResourceEntries.clear();
+ printer->buildResources(op, entryBuilder);
+ emitResourceGroup(stringSection.insert(printer->getName()));
+ }
+
+ // Emit the dialect resource entries.
+ for (DialectNumbering &dialect : numberingState.getDialects()) {
+ if (!dialect.asmInterface)
+ continue;
+ curResourceEntries.clear();
+ dialect.asmInterface->buildResources(op, dialect.resources, entryBuilder);
+
+ // Emit the declaration resources for this dialect, these didn't get emitted
+ // by the interface. These resources don't have data attached, so just use a
+ // "blob" kind as a placeholder.
+ for (const auto &resource : dialect.resourceMap)
+ if (resource.second->isDeclaration)
+ appendResourceOffset(resource.first, AsmResourceEntryKind::Blob);
+
+ // Emit the resource group for this dialect.
+ if (!curResourceEntries.empty())
+ emitResourceGroup(dialect.number);
+ }
+
+ // If we didn't emit any resource groups, elide the resource sections.
+ if (resourceOffsetEmitter.size() == 0)
+ return;
+
+ emitter.emitSection(bytecode::Section::kResourceOffset,
+ std::move(resourceOffsetEmitter));
+ emitter.emitSection(bytecode::Section::kResource, std::move(resourceEmitter));
+}
+
+//===----------------------------------------------------------------------===//
// Strings
void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
@@ -627,7 +843,7 @@ void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
//===----------------------------------------------------------------------===//
void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
- StringRef producer) {
+ const BytecodeWriterConfig &config) {
BytecodeWriter writer(op);
- writer.write(op, os, producer);
+ writer.write(op, os, config.getImpl());
}