summaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode
diff options
context:
space:
mode:
authorMatteo Franciolini <m_franciolini@apple.com>2023-10-31 15:41:29 -0700
committerGitHub <noreply@github.com>2023-10-31 15:41:29 -0700
commit7ad9e9dcf518431a8ecedcc06b09df6c799658ef (patch)
tree21e31fda80c0f77c744a3503d2f2873d03697e4d /mlir/lib/Bytecode
parent5888dee7d04748744743a35d3aef030018bdc275 (diff)
[mlir][bytecode] Implements back deployment capability for MLIR dialects (#70724)
When emitting bytecode, clients can specify a target dialect version to emit in `BytecodeWriterConfig`. This exposes a target dialect version to the DialectBytecodeWriter, which can be queried by name and used to back-deploy attributes, types, and properties.
Diffstat (limited to 'mlir/lib/Bytecode')
-rw-r--r--mlir/lib/Bytecode/Writer/BytecodeWriter.cpp42
-rw-r--r--mlir/lib/Bytecode/Writer/IRNumbering.cpp27
2 files changed, 57 insertions, 12 deletions
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 5628ff6c54af..01dcea1ca384 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -39,6 +39,9 @@ struct BytecodeWriterConfig::Impl {
/// Note: This only differs from kVersion if a specific version is set.
int64_t bytecodeVersion = bytecode::kVersion;
+ /// A map containing dialect version information for each dialect to emit.
+ llvm::StringMap<std::unique_ptr<DialectVersion>> dialectVersionMap;
+
/// The producer of the bytecode.
StringRef producer;
@@ -94,6 +97,19 @@ int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const {
return impl->bytecodeVersion;
}
+llvm::StringMap<std::unique_ptr<DialectVersion>> &
+BytecodeWriterConfig::getDialectVersionMap() const {
+ return impl->dialectVersionMap;
+}
+
+void BytecodeWriterConfig::setDialectVersion(
+ llvm::StringRef dialectName,
+ std::unique_ptr<DialectVersion> dialectVersion) const {
+ assert(!impl->dialectVersionMap.contains(dialectName) &&
+ "cannot override a previously set dialect version");
+ impl->dialectVersionMap.insert({dialectName, std::move(dialectVersion)});
+}
+
//===----------------------------------------------------------------------===//
// EncodingEmitter
//===----------------------------------------------------------------------===//
@@ -340,12 +356,16 @@ private:
} // namespace
class DialectWriter : public DialectBytecodeWriter {
+ using DialectVersionMapT = llvm::StringMap<std::unique_ptr<DialectVersion>>;
+
public:
DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter,
IRNumberingState &numberingState,
- StringSectionBuilder &stringSection)
+ StringSectionBuilder &stringSection,
+ const DialectVersionMapT &dialectVersionMap)
: bytecodeVersion(bytecodeVersion), emitter(emitter),
- numberingState(numberingState), stringSection(stringSection) {}
+ numberingState(numberingState), stringSection(stringSection),
+ dialectVersionMap(dialectVersionMap) {}
//===--------------------------------------------------------------------===//
// IR
@@ -421,11 +441,20 @@ public:
int64_t getBytecodeVersion() const override { return bytecodeVersion; }
+ FailureOr<const DialectVersion *>
+ getDialectVersion(StringRef dialectName) const override {
+ auto dialectEntry = dialectVersionMap.find(dialectName);
+ if (dialectEntry == dialectVersionMap.end())
+ return failure();
+ return dialectEntry->getValue().get();
+ }
+
private:
int64_t bytecodeVersion;
EncodingEmitter &emitter;
IRNumberingState &numberingState;
StringSectionBuilder &stringSection;
+ const DialectVersionMapT &dialectVersionMap;
};
namespace {
@@ -458,7 +487,8 @@ public:
EncodingEmitter emitter;
DialectWriter propertiesWriter(config.bytecodeVersion, emitter,
- numberingState, stringSection);
+ numberingState, stringSection,
+ config.dialectVersionMap);
auto iface = cast<BytecodeOpInterface>(op);
iface.writeProperties(propertiesWriter);
scratch.clear();
@@ -751,7 +781,8 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
if (dialect.interface) {
// The writer used when emitting using a custom bytecode encoding.
DialectWriter versionWriter(config.bytecodeVersion, versionEmitter,
- numberingState, stringSection);
+ numberingState, stringSection,
+ config.dialectVersionMap);
dialect.interface->writeVersion(versionWriter);
}
@@ -809,7 +840,8 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
}
DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter,
- numberingState, stringSection);
+ numberingState, stringSection,
+ config.dialectVersionMap);
if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
for (const auto &callback : config.typeWriterCallbacks) {
if (succeeded(callback->write(entryValue, dialectWriter)))
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 74c45723c222..036a9477cce6 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -12,7 +12,6 @@
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
-#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
using namespace mlir::bytecode::detail;
@@ -22,7 +21,10 @@ using namespace mlir::bytecode::detail;
//===----------------------------------------------------------------------===//
struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
- NumberingDialectWriter(IRNumberingState &state) : state(state) {}
+ NumberingDialectWriter(
+ IRNumberingState &state,
+ llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap)
+ : state(state), dialectVersionMap(dialectVersionMap) {}
void writeAttribute(Attribute attr) override { state.number(attr); }
void writeOptionalAttribute(Attribute attr) override {
@@ -51,8 +53,19 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
return state.getDesiredBytecodeVersion();
}
+ FailureOr<const DialectVersion *>
+ getDialectVersion(StringRef dialectName) const override {
+ auto dialectEntry = dialectVersionMap.find(dialectName);
+ if (dialectEntry == dialectVersionMap.end())
+ return failure();
+ return dialectEntry->getValue().get();
+ }
+
/// The parent numbering state that is populated by this writer.
IRNumberingState &state;
+
+ /// A map containing dialect version information for each dialect to emit.
+ llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap;
};
//===----------------------------------------------------------------------===//
@@ -318,7 +331,7 @@ void IRNumberingState::number(Attribute attr) {
if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
// Try overriding emission with callbacks.
for (const auto &callback : config.getAttributeWriterCallbacks()) {
- NumberingDialectWriter writer(*this);
+ NumberingDialectWriter writer(*this, config.getDialectVersionMap());
// The client has the ability to override the group name through the
// callback.
std::optional<StringRef> groupNameOverride;
@@ -330,7 +343,7 @@ void IRNumberingState::number(Attribute attr) {
}
if (const auto *interface = numbering->dialect->interface) {
- NumberingDialectWriter writer(*this);
+ NumberingDialectWriter writer(*this, config.getDialectVersionMap());
if (succeeded(interface->writeAttribute(attr, writer)))
return;
}
@@ -426,7 +439,7 @@ void IRNumberingState::number(Operation &op) {
if (op.isRegistered()) {
// Operation that have properties *must* implement this interface.
auto iface = cast<BytecodeOpInterface>(op);
- NumberingDialectWriter writer(*this);
+ NumberingDialectWriter writer(*this, config.getDialectVersionMap());
iface.writeProperties(writer);
} else {
// Unregistered op are storing properties as an optional attribute.
@@ -481,7 +494,7 @@ void IRNumberingState::number(Type type) {
if (!type.hasTrait<TypeTrait::IsMutable>()) {
// Try overriding emission with callbacks.
for (const auto &callback : config.getTypeWriterCallbacks()) {
- NumberingDialectWriter writer(*this);
+ NumberingDialectWriter writer(*this, config.getDialectVersionMap());
// The client has the ability to override the group name through the
// callback.
std::optional<StringRef> groupNameOverride;
@@ -495,7 +508,7 @@ void IRNumberingState::number(Type type) {
// If this attribute will be emitted using the bytecode format, perform a
// dummy writing to number any nested components.
if (const auto *interface = numbering->dialect->interface) {
- NumberingDialectWriter writer(*this);
+ NumberingDialectWriter writer(*this, config.getDialectVersionMap());
if (succeeded(interface->writeType(type, writer)))
return;
}