summaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode
diff options
context:
space:
mode:
authorMehdi Amini <joker.eph@gmail.com>2023-05-25 21:04:35 -0700
committerMehdi Amini <joker.eph@gmail.com>2023-05-26 17:45:01 -0700
commit660f714e26999d266232a1fbb02712bb879bd34e (patch)
treea3a473f8ac64651140d855c2d6521cada262fd65 /mlir/lib/Bytecode
parentf354e971b09c244147ff59eb65b34487755598c0 (diff)
[MLIR] Add native Bytecode support for properties
This is adding a new interface (`BytecodeOpInterface`) to allow operations to opt-in skipping conversion to attribute and serializing properties to native bytecode. The scheme relies on a new section where properties are stored in sequence { size, serialize_properties }, ... The operations are storing the index of a properties, a table of offset is built when loading the properties section the first time. This is a re-commit of 837d1ce0dc which conflicted with another patch upgrading the bytecode and the collision wasn't properly resolved before. Differential Revision: https://reviews.llvm.org/D151065
Diffstat (limited to 'mlir/lib/Bytecode')
-rw-r--r--mlir/lib/Bytecode/BytecodeOpInterface.cpp17
-rw-r--r--mlir/lib/Bytecode/CMakeLists.txt11
-rw-r--r--mlir/lib/Bytecode/Reader/BytecodeReader.cpp198
-rw-r--r--mlir/lib/Bytecode/Writer/BytecodeWriter.cpp236
-rw-r--r--mlir/lib/Bytecode/Writer/CMakeLists.txt1
-rw-r--r--mlir/lib/Bytecode/Writer/IRNumbering.cpp31
-rw-r--r--mlir/lib/Bytecode/Writer/IRNumbering.h7
7 files changed, 457 insertions, 44 deletions
diff --git a/mlir/lib/Bytecode/BytecodeOpInterface.cpp b/mlir/lib/Bytecode/BytecodeOpInterface.cpp
new file mode 100644
index 000000000000..e767f57b9bb4
--- /dev/null
+++ b/mlir/lib/Bytecode/BytecodeOpInterface.cpp
@@ -0,0 +1,17 @@
+//===- BytecodeOpInterface.cpp - Bytecode Op Interfaces -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// BytecodeOpInterface
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Bytecode/BytecodeOpInterface.cpp.inc"
diff --git a/mlir/lib/Bytecode/CMakeLists.txt b/mlir/lib/Bytecode/CMakeLists.txt
index ff7e290cad1b..c89415f60d12 100644
--- a/mlir/lib/Bytecode/CMakeLists.txt
+++ b/mlir/lib/Bytecode/CMakeLists.txt
@@ -1,2 +1,13 @@
add_subdirectory(Reader)
add_subdirectory(Writer)
+
+add_mlir_library(MLIRBytecodeOpInterface
+ BytecodeOpInterface.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSupport
+ )
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index ca05eac1e3e1..b4fe53e33279 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -11,6 +11,7 @@
#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Bytecode/Encoding.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
@@ -20,6 +21,7 @@
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallString.h"
@@ -28,6 +30,7 @@
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/SourceMgr.h"
+#include <cstddef>
#include <list>
#include <memory>
#include <numeric>
@@ -56,13 +59,15 @@ static std::string toString(bytecode::Section::ID sectionID) {
return "ResourceOffset (6)";
case bytecode::Section::kDialectVersions:
return "DialectVersions (7)";
+ case bytecode::Section::kProperties:
+ return "Properties (8)";
default:
return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
}
}
/// Returns true if the given top-level section ID is optional.
-static bool isSectionOptional(bytecode::Section::ID sectionID) {
+static bool isSectionOptional(bytecode::Section::ID sectionID, int version) {
switch (sectionID) {
case bytecode::Section::kString:
case bytecode::Section::kDialect:
@@ -74,6 +79,8 @@ static bool isSectionOptional(bytecode::Section::ID sectionID) {
case bytecode::Section::kResourceOffset:
case bytecode::Section::kDialectVersions:
return true;
+ case bytecode::Section::kProperties:
+ return version < 5;
default:
llvm_unreachable("unknown section ID");
}
@@ -364,6 +371,17 @@ public:
/// Parse a shared string from the string section. The shared string is
/// encoded using an index to a corresponding string in the string section.
+ /// This variant parses a flag compressed with the index.
+ LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result,
+ bool &flag) {
+ uint64_t entryIdx;
+ if (failed(reader.parseVarIntWithFlag(entryIdx, flag)))
+ return failure();
+ return parseStringAtIndex(reader, entryIdx, result);
+ }
+
+ /// Parse a shared string from the string section. The shared string is
+ /// encoded using an index to a corresponding string in the string section.
LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
StringRef &result) {
return resolveEntry(reader, strings, index, result, "string");
@@ -459,8 +477,9 @@ struct BytecodeDialect {
/// This struct represents an operation name entry within the bytecode.
struct BytecodeOperationName {
- BytecodeOperationName(BytecodeDialect *dialect, StringRef name)
- : dialect(dialect), name(name) {}
+ BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
+ std::optional<bool> wasRegistered)
+ : dialect(dialect), name(name), wasRegistered(wasRegistered) {}
/// The loaded operation name, or std::nullopt if it hasn't been processed
/// yet.
@@ -471,6 +490,10 @@ struct BytecodeOperationName {
/// The name of the operation, without the dialect prefix.
StringRef name;
+
+ /// Whether this operation was registered when the bytecode was produced.
+ /// This flag is populated when bytecode version >=5.
+ std::optional<bool> wasRegistered;
};
} // namespace
@@ -791,6 +814,18 @@ public:
result = resolveAttribute(attrIdx);
return success(!!result);
}
+ LogicalResult parseOptionalAttribute(EncodingReader &reader,
+ Attribute &result) {
+ uint64_t attrIdx;
+ bool flag;
+ if (failed(reader.parseVarIntWithFlag(attrIdx, flag)))
+ return failure();
+ if (!flag)
+ return success();
+ result = resolveAttribute(attrIdx);
+ return success(!!result);
+ }
+
LogicalResult parseType(EncodingReader &reader, Type &result) {
uint64_t typeIdx;
if (failed(reader.parseVarInt(typeIdx)))
@@ -870,7 +905,9 @@ public:
LogicalResult readAttribute(Attribute &result) override {
return attrTypeReader.parseAttribute(reader, result);
}
-
+ LogicalResult readOptionalAttribute(Attribute &result) override {
+ return attrTypeReader.parseOptionalAttribute(reader, result);
+ }
LogicalResult readType(Type &result) override {
return attrTypeReader.parseType(reader, result);
}
@@ -957,6 +994,87 @@ private:
ResourceSectionReader &resourceReader;
EncodingReader &reader;
};
+
+/// Wraps the properties section and handles reading properties out of it.
+class PropertiesSectionReader {
+public:
+ /// Initialize the properties section reader with the given section data.
+ LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
+ if (sectionData.empty())
+ return success();
+ EncodingReader propReader(sectionData, fileLoc);
+ size_t count;
+ if (failed(propReader.parseVarInt(count)))
+ return failure();
+ // Parse the raw properties buffer.
+ if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers)))
+ return failure();
+
+ EncodingReader offsetsReader(propertiesBuffers, fileLoc);
+ offsetTable.reserve(count);
+ for (auto idx : llvm::seq<int64_t>(0, count)) {
+ (void)idx;
+ offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size());
+ ArrayRef<uint8_t> rawProperties;
+ size_t dataSize;
+ if (failed(offsetsReader.parseVarInt(dataSize)) ||
+ failed(offsetsReader.parseBytes(dataSize, rawProperties)))
+ return failure();
+ }
+ if (!offsetsReader.empty())
+ return offsetsReader.emitError()
+ << "Broken properties section: didn't exhaust the offsets table";
+ return success();
+ }
+
+ LogicalResult read(Location fileLoc, DialectReader &dialectReader,
+ OperationName *opName, OperationState &opState) {
+ uint64_t propertiesIdx;
+ if (failed(dialectReader.readVarInt(propertiesIdx)))
+ return failure();
+ if (propertiesIdx >= offsetTable.size())
+ return dialectReader.emitError("Properties idx out-of-bound for ")
+ << opName->getStringRef();
+ size_t propertiesOffset = offsetTable[propertiesIdx];
+ if (propertiesIdx >= propertiesBuffers.size())
+ return dialectReader.emitError("Properties offset out-of-bound for ")
+ << opName->getStringRef();
+
+ // Acquire the sub-buffer that represent the requested properties.
+ ArrayRef<char> rawProperties;
+ {
+ // "Seek" to the requested offset by getting a new reader with the right
+ // sub-buffer.
+ EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset),
+ fileLoc);
+ // Properties are stored as a sequence of {size + raw_data}.
+ if (failed(
+ dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
+ return failure();
+ }
+ // Setup a new reader to read from the `rawProperties` sub-buffer.
+ EncodingReader reader(
+ StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
+ DialectReader propReader = dialectReader.withEncodingReader(reader);
+
+ auto *iface = opName->getInterface<BytecodeOpInterface>();
+ if (iface)
+ return iface->readProperties(propReader, opState);
+ if (opName->isRegistered())
+ return propReader.emitError(
+ "has properties but missing BytecodeOpInterface for ")
+ << opName->getStringRef();
+ // Unregistered op are storing properties as an attribute.
+ return propReader.readAttribute(opState.propertiesAttr);
+ }
+
+private:
+ /// The properties buffer referenced within the bytecode file.
+ ArrayRef<uint8_t> propertiesBuffers;
+
+ /// Table of offset in the buffer above.
+ SmallVector<int64_t> offsetTable;
+};
} // namespace
LogicalResult
@@ -1194,7 +1312,9 @@ private:
lazyLoadableOps.erase(it->getSecond());
lazyLoadableOpsMap.erase(it);
auto result = parseRegions(regionStack, regionStack.back());
- assert(regionStack.empty());
+ assert((regionStack.empty() || failed(result)) &&
+ "broken invariant: regionStack should be empty when parseRegions "
+ "succeeds");
return result;
}
@@ -1209,8 +1329,11 @@ private:
LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
- /// Parse an operation name reference using the given reader.
- FailureOr<OperationName> parseOpName(EncodingReader &reader);
+ /// Parse an operation name reference using the given reader, and set the
+ /// `wasRegistered` flag that indicates if the bytecode was produced by a
+ /// context where opName was registered.
+ FailureOr<OperationName> parseOpName(EncodingReader &reader,
+ std::optional<bool> &wasRegistered);
//===--------------------------------------------------------------------===//
// Attribute/Type Section
@@ -1398,6 +1521,9 @@ private:
/// The table of strings referenced within the bytecode file.
StringSectionReader stringReader;
+ /// The table of properties referenced by the operation in the bytecode file.
+ PropertiesSectionReader propertiesReader;
+
/// The current set of available IR value scopes.
std::vector<ValueScope> valueScopes;
@@ -1466,7 +1592,7 @@ LogicalResult BytecodeReader::Impl::read(
// Check that all of the required sections were found.
for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
- if (!sectionDatas[i] && !isSectionOptional(sectionID)) {
+ if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) {
return reader.emitError("missing data for top-level section: ",
::toString(sectionID));
}
@@ -1477,6 +1603,12 @@ LogicalResult BytecodeReader::Impl::read(
fileLoc, *sectionDatas[bytecode::Section::kString])))
return failure();
+ // Process the properties section.
+ if (sectionDatas[bytecode::Section::kProperties] &&
+ failed(propertiesReader.initialize(
+ fileLoc, *sectionDatas[bytecode::Section::kProperties])))
+ return failure();
+
// Process the dialect section.
if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
return failure();
@@ -1598,9 +1730,20 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
// Parse the operation names, which are grouped by dialect.
auto parseOpName = [&](BytecodeDialect *dialect) {
StringRef opName;
- if (failed(stringReader.parseString(sectionReader, opName)))
- return failure();
- opNames.emplace_back(dialect, opName);
+ std::optional<bool> wasRegistered;
+ // Prior to version 5, the information about wheter an op was registered or
+ // not wasn't encoded.
+ if (version < 5) {
+ if (failed(stringReader.parseString(sectionReader, opName)))
+ return failure();
+ } else {
+ bool wasRegisteredFlag;
+ if (failed(stringReader.parseStringWithFlag(sectionReader, opName,
+ wasRegisteredFlag)))
+ return failure();
+ wasRegistered = wasRegisteredFlag;
+ }
+ opNames.emplace_back(dialect, opName, wasRegistered);
return success();
};
// Avoid re-allocation in bytecode version > 3 where the number of ops are
@@ -1618,11 +1761,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
}
FailureOr<OperationName>
-BytecodeReader::Impl::parseOpName(EncodingReader &reader) {
+BytecodeReader::Impl::parseOpName(EncodingReader &reader,
+ std::optional<bool> &wasRegistered) {
BytecodeOperationName *opName = nullptr;
if (failed(parseEntry(reader, opNames, opName, "operation name")))
return failure();
-
+ wasRegistered = opName->wasRegistered;
// Check to see if this operation name has already been resolved. If we
// haven't, load the dialect and build the operation name.
if (!opName->opName) {
@@ -1994,7 +2138,8 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove) {
// Parse the name of the operation.
- FailureOr<OperationName> opName = parseOpName(reader);
+ std::optional<bool> wasRegistered;
+ FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
if (failed(opName))
return failure();
@@ -2021,6 +2166,31 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
opState.attributes = dictAttr;
}
+ if (opMask & bytecode::OpEncodingMask::kHasProperties) {
+ // kHasProperties wasn't emitted in older bytecode, we should never get
+ // there without also having the `wasRegistered` flag available.
+ if (!wasRegistered)
+ return emitError(fileLoc,
+ "Unexpected missing `wasRegistered` opname flag at "
+ "bytecode version ")
+ << version << " with properties.";
+ // When an operation is emitted without being registered, the properties are
+ // stored as an attribute. Otherwise the op must implement the bytecode
+ // interface and control the serialization.
+ if (wasRegistered) {
+ DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
+ reader);
+ if (failed(
+ propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
+ return failure();
+ } else {
+ // If the operation wasn't registered when it was emitted, the properties
+ // was serialized as an attribute.
+ if (failed(parseAttribute(reader, opState.propertiesAttr)))
+ return failure();
+ }
+ }
+
/// Parse the results of the operation.
if (opMask & bytecode::OpEncodingMask::kHasResults) {
uint64_t numResults;
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 93484913548a..515391d5634c 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -9,11 +9,23 @@
#include "mlir/Bytecode/BytecodeWriter.h"
#include "IRNumbering.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Bytecode/Encoding.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/CachedHashString.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <optional>
+#include <sys/types.h>
#define DEBUG_TYPE "mlir-bytecode-writer"
@@ -58,6 +70,10 @@ void BytecodeWriterConfig::setDesiredBytecodeVersion(int64_t bytecodeVersion) {
std::min<int64_t>(bytecodeVersion, bytecode::kVersion);
}
+int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const {
+ return impl->bytecodeVersion;
+}
+
//===----------------------------------------------------------------------===//
// EncodingEmitter
//===----------------------------------------------------------------------===//
@@ -318,6 +334,14 @@ public:
void writeAttribute(Attribute attr) override {
emitter.emitVarInt(numberingState.getNumber(attr));
}
+ void writeOptionalAttribute(Attribute attr) override {
+ if (!attr) {
+ emitter.emitVarInt(0);
+ return;
+ }
+ emitter.emitVarIntWithFlag(numberingState.getNumber(attr), true);
+ }
+
void writeType(Type type) override {
emitter.emitVarInt(numberingState.getNumber(type));
}
@@ -382,6 +406,105 @@ private:
StringSectionBuilder &stringSection;
};
+namespace {
+class PropertiesSectionBuilder {
+public:
+ PropertiesSectionBuilder(IRNumberingState &numberingState,
+ StringSectionBuilder &stringSection,
+ const BytecodeWriterConfig::Impl &config)
+ : numberingState(numberingState), stringSection(stringSection),
+ config(config) {}
+
+ /// Emit the op properties in the properties section and return the index of
+ /// the properties within the section. Return -1 if no properties was emitted.
+ std::optional<ssize_t> emit(Operation *op) {
+ EncodingEmitter propertiesEmitter;
+ if (!op->getPropertiesStorageSize())
+ return std::nullopt;
+ if (!op->isRegistered()) {
+ // Unregistered op are storing properties as an optional attribute.
+ Attribute prop = *op->getPropertiesStorage().as<Attribute *>();
+ if (!prop)
+ return std::nullopt;
+ EncodingEmitter sizeEmitter;
+ sizeEmitter.emitVarInt(numberingState.getNumber(prop));
+ scratch.clear();
+ llvm::raw_svector_ostream os(scratch);
+ sizeEmitter.writeTo(os);
+ return emit(scratch);
+ }
+
+ EncodingEmitter emitter;
+ DialectWriter propertiesWriter(config.bytecodeVersion, emitter,
+ numberingState, stringSection);
+ auto iface = cast<BytecodeOpInterface>(op);
+ iface.writeProperties(propertiesWriter);
+ scratch.clear();
+ llvm::raw_svector_ostream os(scratch);
+ emitter.writeTo(os);
+ return emit(scratch);
+ }
+
+ /// Write the current set of properties to the given emitter.
+ void write(EncodingEmitter &emitter) {
+ emitter.emitVarInt(propertiesStorage.size());
+ if (propertiesStorage.empty())
+ return;
+ for (const auto &storage : propertiesStorage) {
+ if (storage.empty()) {
+ emitter.emitBytes(ArrayRef<uint8_t>());
+ continue;
+ }
+ emitter.emitBytes(ArrayRef(reinterpret_cast<const uint8_t *>(&storage[0]),
+ storage.size()));
+ }
+ }
+
+ /// Returns true if the section is empty.
+ bool empty() { return propertiesStorage.empty(); }
+
+private:
+ /// Emit raw data and returns the offset in the internal buffer.
+ /// Data are deduplicated and will be copied in the internal buffer only if
+ /// they don't exist there already.
+ ssize_t emit(ArrayRef<char> rawProperties) {
+ // Populate a scratch buffer with the properties size.
+ SmallVector<char> sizeScratch;
+ {
+ EncodingEmitter sizeEmitter;
+ sizeEmitter.emitVarInt(rawProperties.size());
+ llvm::raw_svector_ostream os(sizeScratch);
+ sizeEmitter.writeTo(os);
+ }
+ // Append a new storage to the table now.
+ size_t index = propertiesStorage.size();
+ propertiesStorage.emplace_back();
+ std::vector<char> &newStorage = propertiesStorage.back();
+ size_t propertiesSize = sizeScratch.size() + rawProperties.size();
+ newStorage.reserve(propertiesSize);
+ newStorage.insert(newStorage.end(), sizeScratch.begin(), sizeScratch.end());
+ newStorage.insert(newStorage.end(), rawProperties.begin(),
+ rawProperties.end());
+
+ // Try to de-duplicate the new serialized properties.
+ // If the properties is a duplicate, pop it back from the storage.
+ auto inserted = propertiesUniquing.insert(
+ std::make_pair(ArrayRef<char>(newStorage), index));
+ if (!inserted.second)
+ propertiesStorage.pop_back();
+ return inserted.first->getSecond();
+ }
+
+ /// Storage for properties.
+ std::vector<std::vector<char>> propertiesStorage;
+ SmallVector<char> scratch;
+ DenseMap<ArrayRef<char>, int64_t> propertiesUniquing;
+ IRNumberingState &numberingState;
+ StringSectionBuilder &stringSection;
+ const BytecodeWriterConfig::Impl &config;
+};
+} // namespace
+
/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need
/// to go through an intermediate buffer when interacting with code that wants a
/// raw_ostream.
@@ -435,11 +558,12 @@ void EncodingEmitter::emitMultiByteVarInt(uint64_t value) {
namespace {
class BytecodeWriter {
public:
- BytecodeWriter(Operation *op, const BytecodeWriterConfig::Impl &config)
- : numberingState(op), config(config) {}
+ BytecodeWriter(Operation *op, const BytecodeWriterConfig &config)
+ : numberingState(op, config), config(config.getImpl()),
+ propertiesSection(numberingState, stringSection, config.getImpl()) {}
/// Write the bytecode for the given root operation.
- void write(Operation *rootOp, raw_ostream &os);
+ LogicalResult write(Operation *rootOp, raw_ostream &os);
private:
//===--------------------------------------------------------------------===//
@@ -455,10 +579,10 @@ private:
//===--------------------------------------------------------------------===//
// Operations
- void writeBlock(EncodingEmitter &emitter, Block *block);
- void writeOp(EncodingEmitter &emitter, Operation *op);
- void writeRegion(EncodingEmitter &emitter, Region *region);
- void writeIRSection(EncodingEmitter &emitter, Operation *op);
+ LogicalResult writeBlock(EncodingEmitter &emitter, Block *block);
+ LogicalResult writeOp(EncodingEmitter &emitter, Operation *op);
+ LogicalResult writeRegion(EncodingEmitter &emitter, Region *region);
+ LogicalResult writeIRSection(EncodingEmitter &emitter, Operation *op);
//===--------------------------------------------------------------------===//
// Resources
@@ -471,6 +595,11 @@ private:
void writeStringSection(EncodingEmitter &emitter);
//===--------------------------------------------------------------------===//
+ // Properties
+
+ void writePropertiesSection(EncodingEmitter &emitter);
+
+ //===--------------------------------------------------------------------===//
// Helpers
void writeUseListOrders(EncodingEmitter &emitter, uint8_t &opEncodingMask,
@@ -487,10 +616,13 @@ private:
/// Configuration dictating bytecode emission.
const BytecodeWriterConfig::Impl &config;
+
+ /// Storage for the properties section
+ PropertiesSectionBuilder propertiesSection;
};
} // namespace
-void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
+LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
EncodingEmitter emitter;
// Emit the bytecode file header. This is how we identify the output as a
@@ -510,7 +642,8 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
writeAttrTypeSection(emitter);
// Emit the IR section.
- writeIRSection(emitter, rootOp);
+ if (failed(writeIRSection(emitter, rootOp)))
+ return failure();
// Emit the resources section.
writeResourceSection(rootOp, emitter);
@@ -518,8 +651,17 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
// Emit the string section.
writeStringSection(emitter);
+ // Emit the properties section.
+ if (config.bytecodeVersion >= 5)
+ writePropertiesSection(emitter);
+ else if (!propertiesSection.empty())
+ return rootOp->emitError(
+ "unexpected properties emitted incompatible with bytecode <5");
+
// Write the generated bytecode to the provided output stream.
emitter.writeTo(os);
+
+ return success();
}
//===----------------------------------------------------------------------===//
@@ -590,7 +732,11 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
// Emit the referenced operation names grouped by dialect.
auto emitOpName = [&](OpNameNumbering &name) {
- dialectEmitter.emitVarInt(stringSection.insert(name.name.stripDialect()));
+ size_t stringId = stringSection.insert(name.name.stripDialect());
+ if (config.bytecodeVersion < 5)
+ dialectEmitter.emitVarInt(stringId);
+ else
+ dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered());
};
writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName);
@@ -659,7 +805,8 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
//===----------------------------------------------------------------------===//
// Operations
-void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) {
+LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter,
+ Block *block) {
ArrayRef<BlockArgument> args = block->getArguments();
bool hasArgs = !args.empty();
@@ -696,10 +843,12 @@ void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) {
// Emit the operations within the block.
for (Operation &op : *block)
- writeOp(emitter, &op);
+ if (failed(writeOp(emitter, &op)))
+ return failure();
+ return success();
}
-void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
+LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
emitter.emitVarInt(numberingState.getNumber(op->getName()));
// Emit a mask for the operation components. We need to fill this in later
@@ -713,10 +862,24 @@ void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
emitter.emitVarInt(numberingState.getNumber(op->getLoc()));
// Emit the attributes of this operation.
- DictionaryAttr attrs = op->getAttrDictionary();
+ DictionaryAttr attrs = op->getDiscardableAttrDictionary();
+ // Allow deployment to version <5 by merging inherent attribute with the
+ // discardable ones. We should fail if there are any conflicts.
+ if (config.bytecodeVersion < 5)
+ attrs = op->getAttrDictionary();
if (!attrs.empty()) {
opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs;
- emitter.emitVarInt(numberingState.getNumber(op->getAttrDictionary()));
+ emitter.emitVarInt(numberingState.getNumber(attrs));
+ }
+
+ // Emit the properties of this operation, for now we still support deployment
+ // to version <5.
+ if (config.bytecodeVersion >= 5) {
+ std::optional<ssize_t> propertiesId = propertiesSection.emit(op);
+ if (propertiesId.has_value()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasProperties;
+ emitter.emitVarInt(*propertiesId);
+ }
}
// Emit the result types of the operation.
@@ -768,15 +931,18 @@ void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
// If the region is not isolated from above, or we are emitting bytecode
// targeting version <2, we don't use a section.
if (!isIsolatedFromAbove || config.bytecodeVersion < 2) {
- writeRegion(emitter, &region);
+ if (failed(writeRegion(emitter, &region)))
+ return failure();
continue;
}
EncodingEmitter regionEmitter;
- writeRegion(regionEmitter, &region);
+ if (failed(writeRegion(regionEmitter, &region)))
+ return failure();
emitter.emitSection(bytecode::Section::kIR, std::move(regionEmitter));
}
}
+ return success();
}
void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter,
@@ -867,11 +1033,14 @@ void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter,
}
}
-void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) {
+LogicalResult BytecodeWriter::writeRegion(EncodingEmitter &emitter,
+ Region *region) {
// If the region is empty, we only need to emit the number of blocks (which is
// zero).
- if (region->empty())
- return emitter.emitVarInt(/*numBlocks*/ 0);
+ if (region->empty()) {
+ emitter.emitVarInt(/*numBlocks*/ 0);
+ return success();
+ }
// Emit the number of blocks and values within the region.
unsigned numBlocks, numValues;
@@ -881,10 +1050,13 @@ void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) {
// Emit the blocks within the region.
for (Block &block : *region)
- writeBlock(emitter, &block);
+ if (failed(writeBlock(emitter, &block)))
+ return failure();
+ return success();
}
-void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) {
+LogicalResult BytecodeWriter::writeIRSection(EncodingEmitter &emitter,
+ Operation *op) {
EncodingEmitter irEmitter;
// Write the IR section the same way as a block with no arguments. Note that
@@ -893,9 +1065,11 @@ void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) {
irEmitter.emitVarIntWithFlag(/*numOps*/ 1, /*hasArgs*/ false);
// Emit the operations.
- writeOp(irEmitter, op);
+ if (failed(writeOp(irEmitter, op)))
+ return failure();
emitter.emitSection(bytecode::Section::kIR, std::move(irEmitter));
+ return success();
}
//===----------------------------------------------------------------------===//
@@ -1012,13 +1186,21 @@ void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
}
//===----------------------------------------------------------------------===//
+// Properties
+
+void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) {
+ EncodingEmitter propertiesEmitter;
+ propertiesSection.write(propertiesEmitter);
+ emitter.emitSection(bytecode::Section::kProperties,
+ std::move(propertiesEmitter));
+}
+
+//===----------------------------------------------------------------------===//
// Entry Points
//===----------------------------------------------------------------------===//
LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config) {
- BytecodeWriter writer(op, config.getImpl());
- writer.write(op, os);
- // Currently there is no failure case.
- return success();
+ BytecodeWriter writer(op, config);
+ return writer.write(op, os);
}
diff --git a/mlir/lib/Bytecode/Writer/CMakeLists.txt b/mlir/lib/Bytecode/Writer/CMakeLists.txt
index 7d260568249a..45d7f2158097 100644
--- a/mlir/lib/Bytecode/Writer/CMakeLists.txt
+++ b/mlir/lib/Bytecode/Writer/CMakeLists.txt
@@ -8,4 +8,5 @@ add_mlir_library(MLIRBytecodeWriter
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
+ MLIRBytecodeOpInterface
)
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 129437cf0245..36f7a268a6a1 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -8,6 +8,7 @@
#include "IRNumbering.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -24,6 +25,10 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
NumberingDialectWriter(IRNumberingState &state) : state(state) {}
void writeAttribute(Attribute attr) override { state.number(attr); }
+ void writeOptionalAttribute(Attribute attr) override {
+ if (attr)
+ state.number(attr);
+ }
void writeType(Type type) override { state.number(type); }
void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
state.number(resource.getDialect(), resource);
@@ -106,7 +111,9 @@ static void groupByDialectPerByte(T range) {
value->number = idx;
}
-IRNumberingState::IRNumberingState(Operation *op) {
+IRNumberingState::IRNumberingState(Operation *op,
+ const BytecodeWriterConfig &config)
+ : config(config) {
// Compute a global operation ID numbering according to the pre-order walk of
// the IR. This is used as reference to construct use-list orders.
unsigned operationID = 0;
@@ -276,10 +283,30 @@ void IRNumberingState::number(Operation &op) {
}
// Only number the operation's dictionary if it isn't empty.
- DictionaryAttr dictAttr = op.getAttrDictionary();
+ DictionaryAttr dictAttr = op.getDiscardableAttrDictionary();
+ // Prior to version 5 we need to number also the merged dictionnary
+ // containing both the inherent and discardable attribute.
+ if (config.getDesiredBytecodeVersion() < 5)
+ dictAttr = op.getAttrDictionary();
if (!dictAttr.empty())
number(dictAttr);
+ // Visit the operation properties (if any) to make sure referenced attributes
+ // are numbered.
+ if (config.getDesiredBytecodeVersion() >= 5 &&
+ op.getPropertiesStorageSize()) {
+ if (op.isRegistered()) {
+ // Operation that have properties *must* implement this interface.
+ auto iface = cast<BytecodeOpInterface>(op);
+ NumberingDialectWriter writer(*this);
+ iface.writeProperties(writer);
+ } else {
+ // Unregistered op are storing properties as an optional attribute.
+ if (Attribute prop = *op.getPropertiesStorage().as<Attribute *>())
+ number(prop);
+ }
+ }
+
number(op.getLoc());
}
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index 91f0be05b36d..329ca2db8a80 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -18,6 +18,8 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringMap.h"
+#include "llvm/CodeGen/NonRelocatableStringpool.h"
+#include <cstdint>
namespace mlir {
class BytecodeDialectInterface;
@@ -133,7 +135,7 @@ struct DialectNumbering {
/// emission.
class IRNumberingState {
public:
- IRNumberingState(Operation *op);
+ IRNumberingState(Operation *op, const BytecodeWriterConfig &config);
/// Return the numbered dialects.
auto getDialects() {
@@ -241,6 +243,9 @@ private:
/// The next value ID to assign when numbering.
unsigned nextValueID = 0;
+
+ // Configuration: useful to query the required version to emit.
+ const BytecodeWriterConfig &config;
};
} // namespace detail
} // namespace bytecode