diff options
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 381 |
1 files changed, 370 insertions, 11 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 26510a8d58c4..5e10dfad355a 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" @@ -40,11 +41,32 @@ static std::string toString(bytecode::Section::ID sectionID) { return "AttrTypeOffset (3)"; case bytecode::Section::kIR: return "IR (4)"; + case bytecode::Section::kResource: + return "Resource (5)"; + case bytecode::Section::kResourceOffset: + return "ResourceOffset (6)"; default: return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); } } +/// Returns true if the given top-level section ID is optional. +static bool isSectionOptional(bytecode::Section::ID sectionID) { + switch (sectionID) { + case bytecode::Section::kString: + case bytecode::Section::kDialect: + case bytecode::Section::kAttrType: + case bytecode::Section::kAttrTypeOffset: + case bytecode::Section::kIR: + return false; + case bytecode::Section::kResource: + case bytecode::Section::kResourceOffset: + return true; + default: + llvm_unreachable("unknown section ID"); + } +} + //===----------------------------------------------------------------------===// // EncodingReader //===----------------------------------------------------------------------===// @@ -65,11 +87,34 @@ public: /// Returns the remaining size of the bytecode. size_t size() const { return dataEnd - dataIt; } + /// Align the current reader position to the specified alignment. + LogicalResult alignTo(unsigned alignment) { + if (!llvm::isPowerOf2_32(alignment)) + return emitError("expected alignment to be a power-of-two"); + + // Shift the reader position to the next alignment boundary. + while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) { + uint8_t padding; + if (failed(parseByte(padding))) + return failure(); + if (padding != bytecode::kAlignmentByte) { + return emitError("expected alignment byte (0xCB), but got: '0x" + + llvm::utohexstr(padding) + "'"); + } + } + + // TODO: Check that the current data pointer is actually at the expected + // alignment. + + return success(); + } + /// Emit an error using the given arguments. template <typename... Args> InFlightDiagnostic emitError(Args &&...args) const { return ::emitError(fileLoc).append(std::forward<Args>(args)...); } + InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } /// Parse a single byte from the stream. template <typename T> @@ -101,6 +146,17 @@ public: return success(); } + /// Parse an aligned blob of data, where the alignment was encoded alongside + /// the data. + LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, + uint64_t &alignment) { + uint64_t dataSize; + if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || + failed(alignTo(alignment))) + return failure(); + return parseBytes(dataSize, data); + } + /// Parse a variable length encoded integer from the byte stream. The first /// encoded byte contains a prefix in the low bits indicating the encoded /// length of the value. This length prefix is a bit sequence of '0's followed @@ -177,13 +233,31 @@ public: /// contents of the section in `sectionData`. LogicalResult parseSection(bytecode::Section::ID §ionID, ArrayRef<uint8_t> §ionData) { + uint8_t sectionIDAndHasAlignment; uint64_t length; - if (failed(parseByte(sectionID)) || failed(parseVarInt(length))) + if (failed(parseByte(sectionIDAndHasAlignment)) || + failed(parseVarInt(length))) return failure(); + + // Extract the section ID and whether the section is aligned. The high bit + // of the ID is the alignment flag. + sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & + 0b01111111); + bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; + + // Check that the section is actually valid before trying to process its + // data. if (sectionID >= bytecode::Section::kNumSections) return emitError("invalid section ID: ", unsigned(sectionID)); - // Parse the actua section data now that we have its length. + // Process the section alignment if present. + if (hasAlignment) { + uint64_t alignment; + if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) + return failure(); + } + + // Parse the actual section data. return parseBytes(static_cast<size_t>(length), sectionData); } @@ -346,6 +420,14 @@ struct BytecodeDialect { return success(); } + /// Return the loaded dialect, or nullptr if the dialect is unknown. This can + /// only be called after `load`. + Dialect *getLoadedDialect() const { + assert(dialect && + "expected `load` to be invoked before `getLoadedDialect`"); + return *dialect; + } + /// The loaded dialect entry. This field is None if we haven't attempted to /// load, nullptr if we failed to load, otherwise the loaded dialect. Optional<Dialect *> dialect; @@ -394,6 +476,225 @@ static LogicalResult parseDialectGrouping( } //===----------------------------------------------------------------------===// +// ResourceSectionReader +//===----------------------------------------------------------------------===// + +namespace { +/// This class is used to read the resource section from the bytecode. +class ResourceSectionReader { +public: + /// Initialize the resource section reader with the given section data. + LogicalResult initialize(Location fileLoc, const ParserConfig &config, + MutableArrayRef<BytecodeDialect> dialects, + StringSectionReader &stringReader, + ArrayRef<uint8_t> sectionData, + ArrayRef<uint8_t> offsetSectionData); + + /// Parse a dialect resource handle from the resource section. + LogicalResult parseResourceHandle(EncodingReader &reader, + AsmDialectResourceHandle &result) { + return parseEntry(reader, dialectResources, result, "resource handle"); + } + +private: + /// The table of dialect resources within the bytecode file. + SmallVector<AsmDialectResourceHandle> dialectResources; +}; + +class ParsedResourceEntry : public AsmParsedResourceEntry { +public: + ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, + EncodingReader &reader, StringSectionReader &stringReader) + : key(key), kind(kind), reader(reader), stringReader(stringReader) {} + ~ParsedResourceEntry() override = default; + + StringRef getKey() const final { return key; } + + InFlightDiagnostic emitError() const final { return reader.emitError(); } + + AsmResourceEntryKind getKind() const final { return kind; } + + FailureOr<bool> parseAsBool() const final { + if (kind != AsmResourceEntryKind::Bool) + return emitError() << "expected a bool resource entry, but found a " + << toString(kind) << " entry instead"; + + bool value; + if (failed(reader.parseByte(value))) + return failure(); + return value; + } + FailureOr<std::string> parseAsString() const final { + if (kind != AsmResourceEntryKind::String) + return emitError() << "expected a string resource entry, but found a " + << toString(kind) << " entry instead"; + + StringRef string; + if (failed(stringReader.parseString(reader, string))) + return failure(); + return string.str(); + } + + FailureOr<AsmResourceBlob> + parseAsBlob(BlobAllocatorFn allocator) const final { + if (kind != AsmResourceEntryKind::Blob) + return emitError() << "expected a blob resource entry, but found a " + << toString(kind) << " entry instead"; + + ArrayRef<uint8_t> data; + uint64_t alignment; + if (failed(reader.parseBlobAndAlignment(data, alignment))) + return failure(); + + // Allocate memory for the blob using the provided allocator and copy the + // data into it. + // FIXME: If the current holder of the bytecode can ensure its lifetime + // (e.g. when mmap'd), we should not copy the data. We should use the data + // from the bytecode directly. + AsmResourceBlob blob = allocator(data.size(), alignment); + assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && + blob.isMutable() && + "blob allocator did not return a properly aligned address"); + memcpy(blob.getMutableData().data(), data.data(), data.size()); + return blob; + } + +private: + StringRef key; + AsmResourceEntryKind kind; + EncodingReader &reader; + StringSectionReader &stringReader; +}; +} // namespace + +template <typename T> +static LogicalResult +parseResourceGroup(Location fileLoc, bool allowEmpty, + EncodingReader &offsetReader, EncodingReader &resourceReader, + StringSectionReader &stringReader, T *handler, + function_ref<LogicalResult(StringRef)> processKeyFn = {}) { + uint64_t numResources; + if (failed(offsetReader.parseVarInt(numResources))) + return failure(); + + for (uint64_t i = 0; i < numResources; ++i) { + StringRef key; + AsmResourceEntryKind kind; + uint64_t resourceOffset; + ArrayRef<uint8_t> data; + if (failed(stringReader.parseString(offsetReader, key)) || + failed(offsetReader.parseVarInt(resourceOffset)) || + failed(offsetReader.parseByte(kind)) || + failed(resourceReader.parseBytes(resourceOffset, data))) + return failure(); + + // Process the resource key. + if ((processKeyFn && failed(processKeyFn(key)))) + return failure(); + + // If the resource data is empty and we allow it, don't error out when + // parsing below, just skip it. + if (allowEmpty && data.empty()) + continue; + + // Ignore the entry if we don't have a valid handler. + if (!handler) + continue; + + // Otherwise, parse the resource value. + EncodingReader entryReader(data, fileLoc); + ParsedResourceEntry entry(key, kind, entryReader, stringReader); + if (failed(handler->parseResource(entry))) + return failure(); + if (!entryReader.empty()) { + return entryReader.emitError( + "unexpected trailing bytes in resource entry '", key, "'"); + } + } + return success(); +} + +LogicalResult +ResourceSectionReader::initialize(Location fileLoc, const ParserConfig &config, + MutableArrayRef<BytecodeDialect> dialects, + StringSectionReader &stringReader, + ArrayRef<uint8_t> sectionData, + ArrayRef<uint8_t> offsetSectionData) { + EncodingReader resourceReader(sectionData, fileLoc); + EncodingReader offsetReader(offsetSectionData, fileLoc); + + // Read the number of external resource providers. + uint64_t numExternalResourceGroups; + if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) + return failure(); + + // Utility functor that dispatches to `parseResourceGroup`, but implicitly + // provides most of the arguments. + auto parseGroup = [&](auto *handler, bool allowEmpty = false, + function_ref<LogicalResult(StringRef)> keyFn = {}) { + return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, + stringReader, handler, keyFn); + }; + + // Read the external resources from the bytecode. + for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { + StringRef key; + if (failed(stringReader.parseString(offsetReader, key))) + return failure(); + + // Get the handler for these resources. + // TODO: Should we require handling external resources in some scenarios? + AsmResourceParser *handler = config.getResourceParser(key); + if (!handler) { + emitWarning(fileLoc) << "ignoring unknown external resources for '" << key + << "'"; + } + + if (failed(parseGroup(handler))) + return failure(); + } + + // Read the dialect resources from the bytecode. + MLIRContext *ctx = fileLoc->getContext(); + while (!offsetReader.empty()) { + BytecodeDialect *dialect; + if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || + failed(dialect->load(resourceReader, ctx))) + return failure(); + Dialect *loadedDialect = dialect->getLoadedDialect(); + if (!loadedDialect) { + return resourceReader.emitError() + << "dialect '" << dialect->name << "' is unknown"; + } + const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); + if (!handler) { + return resourceReader.emitError() + << "unexpected resources for dialect '" << dialect->name << "'"; + } + + // Ensure that each resource is declared before being processed. + auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { + FailureOr<AsmDialectResourceHandle> handle = + handler->declareResource(key); + if (failed(handle)) { + return resourceReader.emitError() + << "unknown 'resource' key '" << key << "' for dialect '" + << dialect->name << "'"; + } + dialectResources.push_back(*handle); + return success(); + }; + + // Parse the resources for this dialect. We allow empty resources because we + // just treat these as declarations. + if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) + return failure(); + } + + return success(); +} + +//===----------------------------------------------------------------------===// // Attribute/Type Reader //===----------------------------------------------------------------------===// @@ -419,8 +720,10 @@ class AttrTypeReader { using TypeEntry = Entry<Type>; public: - AttrTypeReader(StringSectionReader &stringReader, Location fileLoc) - : stringReader(stringReader), fileLoc(fileLoc) {} + AttrTypeReader(StringSectionReader &stringReader, + ResourceSectionReader &resourceReader, Location fileLoc) + : stringReader(stringReader), resourceReader(resourceReader), + fileLoc(fileLoc) {} /// Initialize the attribute and type information within the reader. LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, @@ -483,6 +786,10 @@ private: /// custom encoded attribute/type entries. StringSectionReader &stringReader; + /// The resource section reader used to resolve resource references when + /// parsing custom encoded attribute/type entries. + ResourceSectionReader &resourceReader; + /// The set of attribute and type entries. SmallVector<AttrEntry> attributes; SmallVector<TypeEntry> types; @@ -494,9 +801,10 @@ private: class DialectReader : public DialectBytecodeReader { public: DialectReader(AttrTypeReader &attrTypeReader, - StringSectionReader &stringReader, EncodingReader &reader) + StringSectionReader &stringReader, + ResourceSectionReader &resourceReader, EncodingReader &reader) : attrTypeReader(attrTypeReader), stringReader(stringReader), - reader(reader) {} + resourceReader(resourceReader), reader(reader) {} InFlightDiagnostic emitError(const Twine &msg) override { return reader.emitError(msg); @@ -514,6 +822,13 @@ public: return attrTypeReader.parseType(reader, result); } + FailureOr<AsmDialectResourceHandle> readResourceHandle() override { + AsmDialectResourceHandle handle; + if (failed(resourceReader.parseResourceHandle(reader, handle))) + return failure(); + return handle; + } + //===--------------------------------------------------------------------===// // Primitives //===--------------------------------------------------------------------===// @@ -575,6 +890,7 @@ public: private: AttrTypeReader &attrTypeReader; StringSectionReader &stringReader; + ResourceSectionReader &resourceReader; EncodingReader &reader; }; } // namespace @@ -707,7 +1023,7 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, } // Ask the dialect to parse the entry. - DialectReader dialectReader(*this, stringReader, reader); + DialectReader dialectReader(*this, stringReader, resourceReader, reader); if constexpr (std::is_same_v<T, Type>) entry.entry = entry.dialect->interface->readType(dialectReader); else @@ -724,7 +1040,8 @@ namespace { class BytecodeReader { public: BytecodeReader(Location fileLoc, const ParserConfig &config) - : config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc), + : config(config), fileLoc(fileLoc), + attrTypeReader(stringReader, resourceReader, fileLoc), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -762,6 +1079,13 @@ private: } //===--------------------------------------------------------------------===// + // Resource Section + + LogicalResult + parseResourceSection(Optional<ArrayRef<uint8_t>> resourceData, + Optional<ArrayRef<uint8_t>> resourceOffsetData); + + //===--------------------------------------------------------------------===// // IR Section /// This struct represents the current read state of a range of regions. This @@ -863,6 +1187,9 @@ private: SmallVector<BytecodeDialect> dialects; SmallVector<BytecodeOperationName> opNames; + /// The reader used to process resources within the bytecode. + ResourceSectionReader resourceReader; + /// The table of strings referenced within the bytecode file. StringSectionReader stringReader; @@ -914,11 +1241,12 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { } sectionDatas[sectionID] = sectionData; } - // Check that all of the sections were found. + // Check that all of the required sections were found. for (int i = 0; i < bytecode::Section::kNumSections; ++i) { - if (!sectionDatas[i]) { + bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); + if (!sectionDatas[i] && !isSectionOptional(sectionID)) { return reader.emitError("missing data for top-level section: ", - toString(bytecode::Section::ID(i))); + toString(sectionID)); } } @@ -931,6 +1259,12 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); + // Process the resource section if present. + if (failed(parseResourceSection( + sectionDatas[bytecode::Section::kResource], + sectionDatas[bytecode::Section::kResourceOffset]))) + return failure(); + // Process the attribute and type section. if (failed(attrTypeReader.initialize( dialects, *sectionDatas[bytecode::Section::kAttrType], @@ -1009,6 +1343,31 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) { } //===----------------------------------------------------------------------===// +// Resource Section + +LogicalResult BytecodeReader::parseResourceSection( + Optional<ArrayRef<uint8_t>> resourceData, + Optional<ArrayRef<uint8_t>> resourceOffsetData) { + // Ensure both sections are either present or not. + if (resourceData.has_value() != resourceOffsetData.has_value()) { + if (resourceOffsetData) + return emitError(fileLoc, "unexpected resource offset section when " + "resource section is not present"); + return emitError( + fileLoc, + "expected resource offset section when resource section is present"); + } + + // If the resource sections are absent, there is nothing to do. + if (!resourceData) + return success(); + + // Initialize the resource reader with the resource sections. + return resourceReader.initialize(fileLoc, config, dialects, stringReader, + *resourceData, *resourceOffsetData); +} + +//===----------------------------------------------------------------------===// // IR Section LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData, |
