diff options
| author | Mehdi Amini <joker.eph@gmail.com> | 2023-05-25 21:04:35 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2023-05-26 17:45:01 -0700 |
| commit | 660f714e26999d266232a1fbb02712bb879bd34e (patch) | |
| tree | a3a473f8ac64651140d855c2d6521cada262fd65 /mlir/lib/Bytecode | |
| parent | f354e971b09c244147ff59eb65b34487755598c0 (diff) | |
[MLIR] Add native Bytecode support for properties
This is adding a new interface (`BytecodeOpInterface`) to allow operations to
opt-in skipping conversion to attribute and serializing properties to native
bytecode.
The scheme relies on a new section where properties are stored in sequence
{ size, serialize_properties }, ...
The operations are storing the index of a properties, a table of offset is
built when loading the properties section the first time.
This is a re-commit of 837d1ce0dc which conflicted with another patch upgrading
the bytecode and the collision wasn't properly resolved before.
Differential Revision: https://reviews.llvm.org/D151065
Diffstat (limited to 'mlir/lib/Bytecode')
| -rw-r--r-- | mlir/lib/Bytecode/BytecodeOpInterface.cpp | 17 | ||||
| -rw-r--r-- | mlir/lib/Bytecode/CMakeLists.txt | 11 | ||||
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 198 | ||||
| -rw-r--r-- | mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 236 | ||||
| -rw-r--r-- | mlir/lib/Bytecode/Writer/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/lib/Bytecode/Writer/IRNumbering.cpp | 31 | ||||
| -rw-r--r-- | mlir/lib/Bytecode/Writer/IRNumbering.h | 7 |
7 files changed, 457 insertions, 44 deletions
diff --git a/mlir/lib/Bytecode/BytecodeOpInterface.cpp b/mlir/lib/Bytecode/BytecodeOpInterface.cpp new file mode 100644 index 000000000000..e767f57b9bb4 --- /dev/null +++ b/mlir/lib/Bytecode/BytecodeOpInterface.cpp @@ -0,0 +1,17 @@ +//===- BytecodeOpInterface.cpp - Bytecode Op Interfaces -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeOpInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// BytecodeOpInterface +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeOpInterface.cpp.inc" diff --git a/mlir/lib/Bytecode/CMakeLists.txt b/mlir/lib/Bytecode/CMakeLists.txt index ff7e290cad1b..c89415f60d12 100644 --- a/mlir/lib/Bytecode/CMakeLists.txt +++ b/mlir/lib/Bytecode/CMakeLists.txt @@ -1,2 +1,13 @@ add_subdirectory(Reader) add_subdirectory(Writer) + +add_mlir_library(MLIRBytecodeOpInterface + BytecodeOpInterface.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + ) diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index ca05eac1e3e1..b4fe53e33279 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -11,6 +11,7 @@ #include "mlir/Bytecode/BytecodeReader.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Bytecode/Encoding.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" @@ -20,6 +21,7 @@ #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" @@ -28,6 +30,7 @@ #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" +#include <cstddef> #include <list> #include <memory> #include <numeric> @@ -56,13 +59,15 @@ static std::string toString(bytecode::Section::ID sectionID) { return "ResourceOffset (6)"; case bytecode::Section::kDialectVersions: return "DialectVersions (7)"; + case bytecode::Section::kProperties: + return "Properties (8)"; default: return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); } } /// Returns true if the given top-level section ID is optional. -static bool isSectionOptional(bytecode::Section::ID sectionID) { +static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { switch (sectionID) { case bytecode::Section::kString: case bytecode::Section::kDialect: @@ -74,6 +79,8 @@ static bool isSectionOptional(bytecode::Section::ID sectionID) { case bytecode::Section::kResourceOffset: case bytecode::Section::kDialectVersions: return true; + case bytecode::Section::kProperties: + return version < 5; default: llvm_unreachable("unknown section ID"); } @@ -364,6 +371,17 @@ public: /// Parse a shared string from the string section. The shared string is /// encoded using an index to a corresponding string in the string section. + /// This variant parses a flag compressed with the index. + LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, + bool &flag) { + uint64_t entryIdx; + if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) + return failure(); + return parseStringAtIndex(reader, entryIdx, result); + } + + /// Parse a shared string from the string section. The shared string is + /// encoded using an index to a corresponding string in the string section. LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, StringRef &result) { return resolveEntry(reader, strings, index, result, "string"); @@ -459,8 +477,9 @@ struct BytecodeDialect { /// This struct represents an operation name entry within the bytecode. struct BytecodeOperationName { - BytecodeOperationName(BytecodeDialect *dialect, StringRef name) - : dialect(dialect), name(name) {} + BytecodeOperationName(BytecodeDialect *dialect, StringRef name, + std::optional<bool> wasRegistered) + : dialect(dialect), name(name), wasRegistered(wasRegistered) {} /// The loaded operation name, or std::nullopt if it hasn't been processed /// yet. @@ -471,6 +490,10 @@ struct BytecodeOperationName { /// The name of the operation, without the dialect prefix. StringRef name; + + /// Whether this operation was registered when the bytecode was produced. + /// This flag is populated when bytecode version >=5. + std::optional<bool> wasRegistered; }; } // namespace @@ -791,6 +814,18 @@ public: result = resolveAttribute(attrIdx); return success(!!result); } + LogicalResult parseOptionalAttribute(EncodingReader &reader, + Attribute &result) { + uint64_t attrIdx; + bool flag; + if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) + return failure(); + if (!flag) + return success(); + result = resolveAttribute(attrIdx); + return success(!!result); + } + LogicalResult parseType(EncodingReader &reader, Type &result) { uint64_t typeIdx; if (failed(reader.parseVarInt(typeIdx))) @@ -870,7 +905,9 @@ public: LogicalResult readAttribute(Attribute &result) override { return attrTypeReader.parseAttribute(reader, result); } - + LogicalResult readOptionalAttribute(Attribute &result) override { + return attrTypeReader.parseOptionalAttribute(reader, result); + } LogicalResult readType(Type &result) override { return attrTypeReader.parseType(reader, result); } @@ -957,6 +994,87 @@ private: ResourceSectionReader &resourceReader; EncodingReader &reader; }; + +/// Wraps the properties section and handles reading properties out of it. +class PropertiesSectionReader { +public: + /// Initialize the properties section reader with the given section data. + LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) { + if (sectionData.empty()) + return success(); + EncodingReader propReader(sectionData, fileLoc); + size_t count; + if (failed(propReader.parseVarInt(count))) + return failure(); + // Parse the raw properties buffer. + if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) + return failure(); + + EncodingReader offsetsReader(propertiesBuffers, fileLoc); + offsetTable.reserve(count); + for (auto idx : llvm::seq<int64_t>(0, count)) { + (void)idx; + offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); + ArrayRef<uint8_t> rawProperties; + size_t dataSize; + if (failed(offsetsReader.parseVarInt(dataSize)) || + failed(offsetsReader.parseBytes(dataSize, rawProperties))) + return failure(); + } + if (!offsetsReader.empty()) + return offsetsReader.emitError() + << "Broken properties section: didn't exhaust the offsets table"; + return success(); + } + + LogicalResult read(Location fileLoc, DialectReader &dialectReader, + OperationName *opName, OperationState &opState) { + uint64_t propertiesIdx; + if (failed(dialectReader.readVarInt(propertiesIdx))) + return failure(); + if (propertiesIdx >= offsetTable.size()) + return dialectReader.emitError("Properties idx out-of-bound for ") + << opName->getStringRef(); + size_t propertiesOffset = offsetTable[propertiesIdx]; + if (propertiesIdx >= propertiesBuffers.size()) + return dialectReader.emitError("Properties offset out-of-bound for ") + << opName->getStringRef(); + + // Acquire the sub-buffer that represent the requested properties. + ArrayRef<char> rawProperties; + { + // "Seek" to the requested offset by getting a new reader with the right + // sub-buffer. + EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), + fileLoc); + // Properties are stored as a sequence of {size + raw_data}. + if (failed( + dialectReader.withEncodingReader(reader).readBlob(rawProperties))) + return failure(); + } + // Setup a new reader to read from the `rawProperties` sub-buffer. + EncodingReader reader( + StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); + DialectReader propReader = dialectReader.withEncodingReader(reader); + + auto *iface = opName->getInterface<BytecodeOpInterface>(); + if (iface) + return iface->readProperties(propReader, opState); + if (opName->isRegistered()) + return propReader.emitError( + "has properties but missing BytecodeOpInterface for ") + << opName->getStringRef(); + // Unregistered op are storing properties as an attribute. + return propReader.readAttribute(opState.propertiesAttr); + } + +private: + /// The properties buffer referenced within the bytecode file. + ArrayRef<uint8_t> propertiesBuffers; + + /// Table of offset in the buffer above. + SmallVector<int64_t> offsetTable; +}; } // namespace LogicalResult @@ -1194,7 +1312,9 @@ private: lazyLoadableOps.erase(it->getSecond()); lazyLoadableOpsMap.erase(it); auto result = parseRegions(regionStack, regionStack.back()); - assert(regionStack.empty()); + assert((regionStack.empty() || failed(result)) && + "broken invariant: regionStack should be empty when parseRegions " + "succeeds"); return result; } @@ -1209,8 +1329,11 @@ private: LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); - /// Parse an operation name reference using the given reader. - FailureOr<OperationName> parseOpName(EncodingReader &reader); + /// Parse an operation name reference using the given reader, and set the + /// `wasRegistered` flag that indicates if the bytecode was produced by a + /// context where opName was registered. + FailureOr<OperationName> parseOpName(EncodingReader &reader, + std::optional<bool> &wasRegistered); //===--------------------------------------------------------------------===// // Attribute/Type Section @@ -1398,6 +1521,9 @@ private: /// The table of strings referenced within the bytecode file. StringSectionReader stringReader; + /// The table of properties referenced by the operation in the bytecode file. + PropertiesSectionReader propertiesReader; + /// The current set of available IR value scopes. std::vector<ValueScope> valueScopes; @@ -1466,7 +1592,7 @@ LogicalResult BytecodeReader::Impl::read( // Check that all of the required sections were found. for (int i = 0; i < bytecode::Section::kNumSections; ++i) { bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); - if (!sectionDatas[i] && !isSectionOptional(sectionID)) { + if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { return reader.emitError("missing data for top-level section: ", ::toString(sectionID)); } @@ -1477,6 +1603,12 @@ LogicalResult BytecodeReader::Impl::read( fileLoc, *sectionDatas[bytecode::Section::kString]))) return failure(); + // Process the properties section. + if (sectionDatas[bytecode::Section::kProperties] && + failed(propertiesReader.initialize( + fileLoc, *sectionDatas[bytecode::Section::kProperties]))) + return failure(); + // Process the dialect section. if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); @@ -1598,9 +1730,20 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { StringRef opName; - if (failed(stringReader.parseString(sectionReader, opName))) - return failure(); - opNames.emplace_back(dialect, opName); + std::optional<bool> wasRegistered; + // Prior to version 5, the information about wheter an op was registered or + // not wasn't encoded. + if (version < 5) { + if (failed(stringReader.parseString(sectionReader, opName))) + return failure(); + } else { + bool wasRegisteredFlag; + if (failed(stringReader.parseStringWithFlag(sectionReader, opName, + wasRegisteredFlag))) + return failure(); + wasRegistered = wasRegisteredFlag; + } + opNames.emplace_back(dialect, opName, wasRegistered); return success(); }; // Avoid re-allocation in bytecode version > 3 where the number of ops are @@ -1618,11 +1761,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { } FailureOr<OperationName> -BytecodeReader::Impl::parseOpName(EncodingReader &reader) { +BytecodeReader::Impl::parseOpName(EncodingReader &reader, + std::optional<bool> &wasRegistered) { BytecodeOperationName *opName = nullptr; if (failed(parseEntry(reader, opNames, opName, "operation name"))) return failure(); - + wasRegistered = opName->wasRegistered; // Check to see if this operation name has already been resolved. If we // haven't, load the dialect and build the operation name. if (!opName->opName) { @@ -1994,7 +2138,8 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, RegionReadState &readState, bool &isIsolatedFromAbove) { // Parse the name of the operation. - FailureOr<OperationName> opName = parseOpName(reader); + std::optional<bool> wasRegistered; + FailureOr<OperationName> opName = parseOpName(reader, wasRegistered); if (failed(opName)) return failure(); @@ -2021,6 +2166,31 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, opState.attributes = dictAttr; } + if (opMask & bytecode::OpEncodingMask::kHasProperties) { + // kHasProperties wasn't emitted in older bytecode, we should never get + // there without also having the `wasRegistered` flag available. + if (!wasRegistered) + return emitError(fileLoc, + "Unexpected missing `wasRegistered` opname flag at " + "bytecode version ") + << version << " with properties."; + // When an operation is emitted without being registered, the properties are + // stored as an attribute. Otherwise the op must implement the bytecode + // interface and control the serialization. + if (wasRegistered) { + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); + if (failed( + propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) + return failure(); + } else { + // If the operation wasn't registered when it was emitted, the properties + // was serialized as an attribute. + if (failed(parseAttribute(reader, opState.propertiesAttr))) + return failure(); + } + } + /// Parse the results of the operation. if (opMask & bytecode::OpEncodingMask::kHasResults) { uint64_t numResults; diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp index 93484913548a..515391d5634c 100644 --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -9,11 +9,23 @@ #include "mlir/Bytecode/BytecodeWriter.h" #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Bytecode/Encoding.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <optional> +#include <sys/types.h> #define DEBUG_TYPE "mlir-bytecode-writer" @@ -58,6 +70,10 @@ void BytecodeWriterConfig::setDesiredBytecodeVersion(int64_t bytecodeVersion) { std::min<int64_t>(bytecodeVersion, bytecode::kVersion); } +int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const { + return impl->bytecodeVersion; +} + //===----------------------------------------------------------------------===// // EncodingEmitter //===----------------------------------------------------------------------===// @@ -318,6 +334,14 @@ public: void writeAttribute(Attribute attr) override { emitter.emitVarInt(numberingState.getNumber(attr)); } + void writeOptionalAttribute(Attribute attr) override { + if (!attr) { + emitter.emitVarInt(0); + return; + } + emitter.emitVarIntWithFlag(numberingState.getNumber(attr), true); + } + void writeType(Type type) override { emitter.emitVarInt(numberingState.getNumber(type)); } @@ -382,6 +406,105 @@ private: StringSectionBuilder &stringSection; }; +namespace { +class PropertiesSectionBuilder { +public: + PropertiesSectionBuilder(IRNumberingState &numberingState, + StringSectionBuilder &stringSection, + const BytecodeWriterConfig::Impl &config) + : numberingState(numberingState), stringSection(stringSection), + config(config) {} + + /// Emit the op properties in the properties section and return the index of + /// the properties within the section. Return -1 if no properties was emitted. + std::optional<ssize_t> emit(Operation *op) { + EncodingEmitter propertiesEmitter; + if (!op->getPropertiesStorageSize()) + return std::nullopt; + if (!op->isRegistered()) { + // Unregistered op are storing properties as an optional attribute. + Attribute prop = *op->getPropertiesStorage().as<Attribute *>(); + if (!prop) + return std::nullopt; + EncodingEmitter sizeEmitter; + sizeEmitter.emitVarInt(numberingState.getNumber(prop)); + scratch.clear(); + llvm::raw_svector_ostream os(scratch); + sizeEmitter.writeTo(os); + return emit(scratch); + } + + EncodingEmitter emitter; + DialectWriter propertiesWriter(config.bytecodeVersion, emitter, + numberingState, stringSection); + auto iface = cast<BytecodeOpInterface>(op); + iface.writeProperties(propertiesWriter); + scratch.clear(); + llvm::raw_svector_ostream os(scratch); + emitter.writeTo(os); + return emit(scratch); + } + + /// Write the current set of properties to the given emitter. + void write(EncodingEmitter &emitter) { + emitter.emitVarInt(propertiesStorage.size()); + if (propertiesStorage.empty()) + return; + for (const auto &storage : propertiesStorage) { + if (storage.empty()) { + emitter.emitBytes(ArrayRef<uint8_t>()); + continue; + } + emitter.emitBytes(ArrayRef(reinterpret_cast<const uint8_t *>(&storage[0]), + storage.size())); + } + } + + /// Returns true if the section is empty. + bool empty() { return propertiesStorage.empty(); } + +private: + /// Emit raw data and returns the offset in the internal buffer. + /// Data are deduplicated and will be copied in the internal buffer only if + /// they don't exist there already. + ssize_t emit(ArrayRef<char> rawProperties) { + // Populate a scratch buffer with the properties size. + SmallVector<char> sizeScratch; + { + EncodingEmitter sizeEmitter; + sizeEmitter.emitVarInt(rawProperties.size()); + llvm::raw_svector_ostream os(sizeScratch); + sizeEmitter.writeTo(os); + } + // Append a new storage to the table now. + size_t index = propertiesStorage.size(); + propertiesStorage.emplace_back(); + std::vector<char> &newStorage = propertiesStorage.back(); + size_t propertiesSize = sizeScratch.size() + rawProperties.size(); + newStorage.reserve(propertiesSize); + newStorage.insert(newStorage.end(), sizeScratch.begin(), sizeScratch.end()); + newStorage.insert(newStorage.end(), rawProperties.begin(), + rawProperties.end()); + + // Try to de-duplicate the new serialized properties. + // If the properties is a duplicate, pop it back from the storage. + auto inserted = propertiesUniquing.insert( + std::make_pair(ArrayRef<char>(newStorage), index)); + if (!inserted.second) + propertiesStorage.pop_back(); + return inserted.first->getSecond(); + } + + /// Storage for properties. + std::vector<std::vector<char>> propertiesStorage; + SmallVector<char> scratch; + DenseMap<ArrayRef<char>, int64_t> propertiesUniquing; + IRNumberingState &numberingState; + StringSectionBuilder &stringSection; + const BytecodeWriterConfig::Impl &config; +}; +} // namespace + /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need /// to go through an intermediate buffer when interacting with code that wants a /// raw_ostream. @@ -435,11 +558,12 @@ void EncodingEmitter::emitMultiByteVarInt(uint64_t value) { namespace { class BytecodeWriter { public: - BytecodeWriter(Operation *op, const BytecodeWriterConfig::Impl &config) - : numberingState(op), config(config) {} + BytecodeWriter(Operation *op, const BytecodeWriterConfig &config) + : numberingState(op, config), config(config.getImpl()), + propertiesSection(numberingState, stringSection, config.getImpl()) {} /// Write the bytecode for the given root operation. - void write(Operation *rootOp, raw_ostream &os); + LogicalResult write(Operation *rootOp, raw_ostream &os); private: //===--------------------------------------------------------------------===// @@ -455,10 +579,10 @@ private: //===--------------------------------------------------------------------===// // Operations - void writeBlock(EncodingEmitter &emitter, Block *block); - void writeOp(EncodingEmitter &emitter, Operation *op); - void writeRegion(EncodingEmitter &emitter, Region *region); - void writeIRSection(EncodingEmitter &emitter, Operation *op); + LogicalResult writeBlock(EncodingEmitter &emitter, Block *block); + LogicalResult writeOp(EncodingEmitter &emitter, Operation *op); + LogicalResult writeRegion(EncodingEmitter &emitter, Region *region); + LogicalResult writeIRSection(EncodingEmitter &emitter, Operation *op); //===--------------------------------------------------------------------===// // Resources @@ -471,6 +595,11 @@ private: void writeStringSection(EncodingEmitter &emitter); //===--------------------------------------------------------------------===// + // Properties + + void writePropertiesSection(EncodingEmitter &emitter); + + //===--------------------------------------------------------------------===// // Helpers void writeUseListOrders(EncodingEmitter &emitter, uint8_t &opEncodingMask, @@ -487,10 +616,13 @@ private: /// Configuration dictating bytecode emission. const BytecodeWriterConfig::Impl &config; + + /// Storage for the properties section + PropertiesSectionBuilder propertiesSection; }; } // namespace -void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { +LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { EncodingEmitter emitter; // Emit the bytecode file header. This is how we identify the output as a @@ -510,7 +642,8 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { writeAttrTypeSection(emitter); // Emit the IR section. - writeIRSection(emitter, rootOp); + if (failed(writeIRSection(emitter, rootOp))) + return failure(); // Emit the resources section. writeResourceSection(rootOp, emitter); @@ -518,8 +651,17 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { // Emit the string section. writeStringSection(emitter); + // Emit the properties section. + if (config.bytecodeVersion >= 5) + writePropertiesSection(emitter); + else if (!propertiesSection.empty()) + return rootOp->emitError( + "unexpected properties emitted incompatible with bytecode <5"); + // Write the generated bytecode to the provided output stream. emitter.writeTo(os); + + return success(); } //===----------------------------------------------------------------------===// @@ -590,7 +732,11 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) { // Emit the referenced operation names grouped by dialect. auto emitOpName = [&](OpNameNumbering &name) { - dialectEmitter.emitVarInt(stringSection.insert(name.name.stripDialect())); + size_t stringId = stringSection.insert(name.name.stripDialect()); + if (config.bytecodeVersion < 5) + dialectEmitter.emitVarInt(stringId); + else + dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered()); }; writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName); @@ -659,7 +805,8 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { //===----------------------------------------------------------------------===// // Operations -void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) { +LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter, + Block *block) { ArrayRef<BlockArgument> args = block->getArguments(); bool hasArgs = !args.empty(); @@ -696,10 +843,12 @@ void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) { // Emit the operations within the block. for (Operation &op : *block) - writeOp(emitter, &op); + if (failed(writeOp(emitter, &op))) + return failure(); + return success(); } -void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { +LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { emitter.emitVarInt(numberingState.getNumber(op->getName())); // Emit a mask for the operation components. We need to fill this in later @@ -713,10 +862,24 @@ void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { emitter.emitVarInt(numberingState.getNumber(op->getLoc())); // Emit the attributes of this operation. - DictionaryAttr attrs = op->getAttrDictionary(); + DictionaryAttr attrs = op->getDiscardableAttrDictionary(); + // Allow deployment to version <5 by merging inherent attribute with the + // discardable ones. We should fail if there are any conflicts. + if (config.bytecodeVersion < 5) + attrs = op->getAttrDictionary(); if (!attrs.empty()) { opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs; - emitter.emitVarInt(numberingState.getNumber(op->getAttrDictionary())); + emitter.emitVarInt(numberingState.getNumber(attrs)); + } + + // Emit the properties of this operation, for now we still support deployment + // to version <5. + if (config.bytecodeVersion >= 5) { + std::optional<ssize_t> propertiesId = propertiesSection.emit(op); + if (propertiesId.has_value()) { + opEncodingMask |= bytecode::OpEncodingMask::kHasProperties; + emitter.emitVarInt(*propertiesId); + } } // Emit the result types of the operation. @@ -768,15 +931,18 @@ void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { // If the region is not isolated from above, or we are emitting bytecode // targeting version <2, we don't use a section. if (!isIsolatedFromAbove || config.bytecodeVersion < 2) { - writeRegion(emitter, ®ion); + if (failed(writeRegion(emitter, ®ion))) + return failure(); continue; } EncodingEmitter regionEmitter; - writeRegion(regionEmitter, ®ion); + if (failed(writeRegion(regionEmitter, ®ion))) + return failure(); emitter.emitSection(bytecode::Section::kIR, std::move(regionEmitter)); } } + return success(); } void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter, @@ -867,11 +1033,14 @@ void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter, } } -void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) { +LogicalResult BytecodeWriter::writeRegion(EncodingEmitter &emitter, + Region *region) { // If the region is empty, we only need to emit the number of blocks (which is // zero). - if (region->empty()) - return emitter.emitVarInt(/*numBlocks*/ 0); + if (region->empty()) { + emitter.emitVarInt(/*numBlocks*/ 0); + return success(); + } // Emit the number of blocks and values within the region. unsigned numBlocks, numValues; @@ -881,10 +1050,13 @@ void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) { // Emit the blocks within the region. for (Block &block : *region) - writeBlock(emitter, &block); + if (failed(writeBlock(emitter, &block))) + return failure(); + return success(); } -void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) { +LogicalResult BytecodeWriter::writeIRSection(EncodingEmitter &emitter, + Operation *op) { EncodingEmitter irEmitter; // Write the IR section the same way as a block with no arguments. Note that @@ -893,9 +1065,11 @@ void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) { irEmitter.emitVarIntWithFlag(/*numOps*/ 1, /*hasArgs*/ false); // Emit the operations. - writeOp(irEmitter, op); + if (failed(writeOp(irEmitter, op))) + return failure(); emitter.emitSection(bytecode::Section::kIR, std::move(irEmitter)); + return success(); } //===----------------------------------------------------------------------===// @@ -1012,13 +1186,21 @@ void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) { } //===----------------------------------------------------------------------===// +// Properties + +void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) { + EncodingEmitter propertiesEmitter; + propertiesSection.write(propertiesEmitter); + emitter.emitSection(bytecode::Section::kProperties, + std::move(propertiesEmitter)); +} + +//===----------------------------------------------------------------------===// // Entry Points //===----------------------------------------------------------------------===// LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config) { - BytecodeWriter writer(op, config.getImpl()); - writer.write(op, os); - // Currently there is no failure case. - return success(); + BytecodeWriter writer(op, config); + return writer.write(op, os); } diff --git a/mlir/lib/Bytecode/Writer/CMakeLists.txt b/mlir/lib/Bytecode/Writer/CMakeLists.txt index 7d260568249a..45d7f2158097 100644 --- a/mlir/lib/Bytecode/Writer/CMakeLists.txt +++ b/mlir/lib/Bytecode/Writer/CMakeLists.txt @@ -8,4 +8,5 @@ add_mlir_library(MLIRBytecodeWriter LINK_LIBS PUBLIC MLIRIR MLIRSupport + MLIRBytecodeOpInterface ) diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp index 129437cf0245..36f7a268a6a1 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -8,6 +8,7 @@ #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" @@ -24,6 +25,10 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter { NumberingDialectWriter(IRNumberingState &state) : state(state) {} void writeAttribute(Attribute attr) override { state.number(attr); } + void writeOptionalAttribute(Attribute attr) override { + if (attr) + state.number(attr); + } void writeType(Type type) override { state.number(type); } void writeResourceHandle(const AsmDialectResourceHandle &resource) override { state.number(resource.getDialect(), resource); @@ -106,7 +111,9 @@ static void groupByDialectPerByte(T range) { value->number = idx; } -IRNumberingState::IRNumberingState(Operation *op) { +IRNumberingState::IRNumberingState(Operation *op, + const BytecodeWriterConfig &config) + : config(config) { // Compute a global operation ID numbering according to the pre-order walk of // the IR. This is used as reference to construct use-list orders. unsigned operationID = 0; @@ -276,10 +283,30 @@ void IRNumberingState::number(Operation &op) { } // Only number the operation's dictionary if it isn't empty. - DictionaryAttr dictAttr = op.getAttrDictionary(); + DictionaryAttr dictAttr = op.getDiscardableAttrDictionary(); + // Prior to version 5 we need to number also the merged dictionnary + // containing both the inherent and discardable attribute. + if (config.getDesiredBytecodeVersion() < 5) + dictAttr = op.getAttrDictionary(); if (!dictAttr.empty()) number(dictAttr); + // Visit the operation properties (if any) to make sure referenced attributes + // are numbered. + if (config.getDesiredBytecodeVersion() >= 5 && + op.getPropertiesStorageSize()) { + if (op.isRegistered()) { + // Operation that have properties *must* implement this interface. + auto iface = cast<BytecodeOpInterface>(op); + NumberingDialectWriter writer(*this); + iface.writeProperties(writer); + } else { + // Unregistered op are storing properties as an optional attribute. + if (Attribute prop = *op.getPropertiesStorage().as<Attribute *>()) + number(prop); + } + } + number(op.getLoc()); } diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h index 91f0be05b36d..329ca2db8a80 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -18,6 +18,8 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringMap.h" +#include "llvm/CodeGen/NonRelocatableStringpool.h" +#include <cstdint> namespace mlir { class BytecodeDialectInterface; @@ -133,7 +135,7 @@ struct DialectNumbering { /// emission. class IRNumberingState { public: - IRNumberingState(Operation *op); + IRNumberingState(Operation *op, const BytecodeWriterConfig &config); /// Return the numbered dialects. auto getDialects() { @@ -241,6 +243,9 @@ private: /// The next value ID to assign when numbering. unsigned nextValueID = 0; + + // Configuration: useful to query the required version to emit. + const BytecodeWriterConfig &config; }; } // namespace detail } // namespace bytecode |
