diff options
| author | Matteo Franciolini <m_franciolini@apple.com> | 2023-10-31 15:41:29 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-10-31 15:41:29 -0700 |
| commit | 7ad9e9dcf518431a8ecedcc06b09df6c799658ef (patch) | |
| tree | 21e31fda80c0f77c744a3503d2f2873d03697e4d /mlir/lib/Bytecode | |
| parent | 5888dee7d04748744743a35d3aef030018bdc275 (diff) | |
[mlir][bytecode] Implements back deployment capability for MLIR dialects (#70724)
When emitting bytecode, clients can specify a target dialect version to
emit in `BytecodeWriterConfig`. This exposes a target dialect version to
the DialectBytecodeWriter, which can be queried by name and used to
back-deploy attributes, types, and properties.
Diffstat (limited to 'mlir/lib/Bytecode')
| -rw-r--r-- | mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 42 | ||||
| -rw-r--r-- | mlir/lib/Bytecode/Writer/IRNumbering.cpp | 27 |
2 files changed, 57 insertions, 12 deletions
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp index 5628ff6c54af..01dcea1ca384 100644 --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -39,6 +39,9 @@ struct BytecodeWriterConfig::Impl { /// Note: This only differs from kVersion if a specific version is set. int64_t bytecodeVersion = bytecode::kVersion; + /// A map containing dialect version information for each dialect to emit. + llvm::StringMap<std::unique_ptr<DialectVersion>> dialectVersionMap; + /// The producer of the bytecode. StringRef producer; @@ -94,6 +97,19 @@ int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const { return impl->bytecodeVersion; } +llvm::StringMap<std::unique_ptr<DialectVersion>> & +BytecodeWriterConfig::getDialectVersionMap() const { + return impl->dialectVersionMap; +} + +void BytecodeWriterConfig::setDialectVersion( + llvm::StringRef dialectName, + std::unique_ptr<DialectVersion> dialectVersion) const { + assert(!impl->dialectVersionMap.contains(dialectName) && + "cannot override a previously set dialect version"); + impl->dialectVersionMap.insert({dialectName, std::move(dialectVersion)}); +} + //===----------------------------------------------------------------------===// // EncodingEmitter //===----------------------------------------------------------------------===// @@ -340,12 +356,16 @@ private: } // namespace class DialectWriter : public DialectBytecodeWriter { + using DialectVersionMapT = llvm::StringMap<std::unique_ptr<DialectVersion>>; + public: DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter, IRNumberingState &numberingState, - StringSectionBuilder &stringSection) + StringSectionBuilder &stringSection, + const DialectVersionMapT &dialectVersionMap) : bytecodeVersion(bytecodeVersion), emitter(emitter), - numberingState(numberingState), stringSection(stringSection) {} + numberingState(numberingState), stringSection(stringSection), + dialectVersionMap(dialectVersionMap) {} //===--------------------------------------------------------------------===// // IR @@ -421,11 +441,20 @@ public: int64_t getBytecodeVersion() const override { return bytecodeVersion; } + FailureOr<const DialectVersion *> + getDialectVersion(StringRef dialectName) const override { + auto dialectEntry = dialectVersionMap.find(dialectName); + if (dialectEntry == dialectVersionMap.end()) + return failure(); + return dialectEntry->getValue().get(); + } + private: int64_t bytecodeVersion; EncodingEmitter &emitter; IRNumberingState &numberingState; StringSectionBuilder &stringSection; + const DialectVersionMapT &dialectVersionMap; }; namespace { @@ -458,7 +487,8 @@ public: EncodingEmitter emitter; DialectWriter propertiesWriter(config.bytecodeVersion, emitter, - numberingState, stringSection); + numberingState, stringSection, + config.dialectVersionMap); auto iface = cast<BytecodeOpInterface>(op); iface.writeProperties(propertiesWriter); scratch.clear(); @@ -751,7 +781,8 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) { if (dialect.interface) { // The writer used when emitting using a custom bytecode encoding. DialectWriter versionWriter(config.bytecodeVersion, versionEmitter, - numberingState, stringSection); + numberingState, stringSection, + config.dialectVersionMap); dialect.interface->writeVersion(versionWriter); } @@ -809,7 +840,8 @@ void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { } DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, - numberingState, stringSection); + numberingState, stringSection, + config.dialectVersionMap); if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) { for (const auto &callback : config.typeWriterCallbacks) { if (succeeded(callback->write(entryValue, dialectWriter))) diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp index 74c45723c222..036a9477cce6 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -12,7 +12,6 @@ #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" -#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::bytecode::detail; @@ -22,7 +21,10 @@ using namespace mlir::bytecode::detail; //===----------------------------------------------------------------------===// struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter { - NumberingDialectWriter(IRNumberingState &state) : state(state) {} + NumberingDialectWriter( + IRNumberingState &state, + llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap) + : state(state), dialectVersionMap(dialectVersionMap) {} void writeAttribute(Attribute attr) override { state.number(attr); } void writeOptionalAttribute(Attribute attr) override { @@ -51,8 +53,19 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter { return state.getDesiredBytecodeVersion(); } + FailureOr<const DialectVersion *> + getDialectVersion(StringRef dialectName) const override { + auto dialectEntry = dialectVersionMap.find(dialectName); + if (dialectEntry == dialectVersionMap.end()) + return failure(); + return dialectEntry->getValue().get(); + } + /// The parent numbering state that is populated by this writer. IRNumberingState &state; + + /// A map containing dialect version information for each dialect to emit. + llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap; }; //===----------------------------------------------------------------------===// @@ -318,7 +331,7 @@ void IRNumberingState::number(Attribute attr) { if (!attr.hasTrait<AttributeTrait::IsMutable>()) { // Try overriding emission with callbacks. for (const auto &callback : config.getAttributeWriterCallbacks()) { - NumberingDialectWriter writer(*this); + NumberingDialectWriter writer(*this, config.getDialectVersionMap()); // The client has the ability to override the group name through the // callback. std::optional<StringRef> groupNameOverride; @@ -330,7 +343,7 @@ void IRNumberingState::number(Attribute attr) { } if (const auto *interface = numbering->dialect->interface) { - NumberingDialectWriter writer(*this); + NumberingDialectWriter writer(*this, config.getDialectVersionMap()); if (succeeded(interface->writeAttribute(attr, writer))) return; } @@ -426,7 +439,7 @@ void IRNumberingState::number(Operation &op) { if (op.isRegistered()) { // Operation that have properties *must* implement this interface. auto iface = cast<BytecodeOpInterface>(op); - NumberingDialectWriter writer(*this); + NumberingDialectWriter writer(*this, config.getDialectVersionMap()); iface.writeProperties(writer); } else { // Unregistered op are storing properties as an optional attribute. @@ -481,7 +494,7 @@ void IRNumberingState::number(Type type) { if (!type.hasTrait<TypeTrait::IsMutable>()) { // Try overriding emission with callbacks. for (const auto &callback : config.getTypeWriterCallbacks()) { - NumberingDialectWriter writer(*this); + NumberingDialectWriter writer(*this, config.getDialectVersionMap()); // The client has the ability to override the group name through the // callback. std::optional<StringRef> groupNameOverride; @@ -495,7 +508,7 @@ void IRNumberingState::number(Type type) { // If this attribute will be emitted using the bytecode format, perform a // dummy writing to number any nested components. if (const auto *interface = numbering->dialect->interface) { - NumberingDialectWriter writer(*this); + NumberingDialectWriter writer(*this, config.getDialectVersionMap()); if (succeeded(interface->writeType(type, writer))) return; } |
