diff options
Diffstat (limited to 'mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp')
| -rw-r--r-- | mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp index 5f484294268a..d1227b045d4e 100644 --- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp @@ -32,6 +32,9 @@ constexpr static llvm::StringLiteral kGlobalKeyName = constexpr static llvm::StringLiteral kStackAlignmentKeyName = "dltest.stack_alignment"; +constexpr static llvm::StringLiteral kTargetSystemDescAttrName = + "dl_target_sys_desc_test.target_system_spec"; + /// Trivial array storage for the custom data layout spec attribute, just a list /// of entries. class DataLayoutSpecStorage : public AttributeStorage { @@ -91,6 +94,52 @@ struct CustomDataLayoutSpec } }; +class TargetSystemSpecStorage : public AttributeStorage { +public: + using KeyTy = ArrayRef<DeviceIDTargetDeviceSpecPair>; + + TargetSystemSpecStorage(ArrayRef<DeviceIDTargetDeviceSpecPair> entries) + : entries(entries) {} + + bool operator==(const KeyTy &key) const { return key == entries; } + + static TargetSystemSpecStorage * + construct(AttributeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate<TargetSystemSpecStorage>()) + TargetSystemSpecStorage(allocator.copyInto(key)); + } + + ArrayRef<DeviceIDTargetDeviceSpecPair> entries; +}; + +struct CustomTargetSystemSpec + : public Attribute::AttrBase<CustomTargetSystemSpec, Attribute, + TargetSystemSpecStorage, + TargetSystemSpecInterface::Trait> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec) + + using Base::Base; + + static constexpr StringLiteral name = "test.custom_target_system_spec"; + + static CustomTargetSystemSpec + get(MLIRContext *ctx, ArrayRef<DeviceIDTargetDeviceSpecPair> entries) { + return Base::get(ctx, entries); + } + DeviceIDTargetDeviceSpecPairListRef getEntries() const { + return getImpl()->entries; + } + LogicalResult verifySpec(Location loc) { return success(); } + std::optional<TargetDeviceSpecInterface> + getDeviceSpecForDeviceID(TargetSystemSpecInterface::DeviceID deviceID) { + for (const auto &entry : getEntries()) { + if (entry.first == deviceID) + return entry.second; + } + return std::nullopt; + } +}; + /// A type subject to data layout that exits the program if it is queried more /// than once. Handy to check if the cache works. struct SingleQueryType @@ -197,6 +246,11 @@ struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> { return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName); } + TargetSystemSpecInterface getTargetSystemSpec() { + return getOperation()->getAttrOfType<TargetSystemSpecInterface>( + kTargetSystemDescAttrName); + } + static llvm::TypeSize getTypeSizeInBits(Type type, const DataLayout &dataLayout, DataLayoutEntryListRef params) { @@ -244,6 +298,11 @@ struct OpWith7BitByte return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName); } + TargetSystemSpecInterface getTargetSystemSpec() { + return getOperation()->getAttrOfType<TargetSystemSpecInterface>( + kTargetSystemDescAttrName); + } + // Bytes are assumed to be 7-bit here. static llvm::TypeSize getTypeSize(Type type, const DataLayout &dataLayout, DataLayoutEntryListRef params) { @@ -308,6 +367,74 @@ struct DLTestDialect : Dialect { } }; +/// A dialect to test DLTI's target system spec and related attributes +struct DLTargetSystemDescTestDialect : public Dialect { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTargetSystemDescTestDialect) + + explicit DLTargetSystemDescTestDialect(MLIRContext *ctx) + : Dialect(getDialectNamespace(), ctx, + TypeID::get<DLTargetSystemDescTestDialect>()) { + ctx->getOrLoadDialect<DLTIDialect>(); + addAttributes<CustomTargetSystemSpec>(); + } + static StringRef getDialectNamespace() { return "dl_target_sys_desc_test"; } + + void printAttribute(Attribute attr, + DialectAsmPrinter &printer) const override { + printer << "target_system_spec<"; + llvm::interleaveComma( + cast<CustomTargetSystemSpec>(attr).getEntries(), printer, + [&](const auto &it) { printer << it.first << ":" << it.second; }); + printer << ">"; + } + + Attribute parseAttribute(DialectAsmParser &parser, Type type) const override { + bool ok = succeeded(parser.parseKeyword("target_system_spec")) && + succeeded(parser.parseLess()); + (void)ok; + assert(ok); + if (succeeded(parser.parseOptionalGreater())) + return CustomTargetSystemSpec::get(parser.getContext(), {}); + + auto parseDeviceIDTargetDeviceSpecPair = + [&](AsmParser &parser) -> FailureOr<DeviceIDTargetDeviceSpecPair> { + std::string deviceID; + if (failed(parser.parseString(&deviceID))) { + parser.emitError(parser.getCurrentLocation()) + << "DeviceID is missing, or is not of string type"; + return failure(); + } + if (failed(parser.parseColon())) { + parser.emitError(parser.getCurrentLocation()) << "Missing colon"; + return failure(); + } + + TargetDeviceSpecInterface targetDeviceSpec; + if (failed(parser.parseAttribute(targetDeviceSpec))) { + parser.emitError(parser.getCurrentLocation()) + << "Error in parsing target device spec"; + return failure(); + } + return std::make_pair(parser.getBuilder().getStringAttr(deviceID), + targetDeviceSpec); + }; + + SmallVector<DeviceIDTargetDeviceSpecPair> entries; + ok = succeeded(parser.parseCommaSeparatedList([&]() { + auto deviceIDAndTargetDeviceSpecPair = + parseDeviceIDTargetDeviceSpecPair(parser); + ok = succeeded(deviceIDAndTargetDeviceSpecPair); + assert(ok); + entries.push_back(*deviceIDAndTargetDeviceSpecPair); + return success(); + })); + assert(ok); + ok = succeeded(parser.parseGreater()); + assert(ok); + return CustomTargetSystemSpec::get(parser.getContext(), entries); + } +}; + } // namespace TEST(DataLayout, FallbackDefault) { @@ -367,6 +494,15 @@ TEST(DataLayout, NullSpec) { EXPECT_EQ(layout.getProgramMemorySpace(), Attribute()); EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute()); EXPECT_EQ(layout.getStackAlignment(), 0u); + + EXPECT_EQ(layout.getDevicePropertyValue( + Builder(&ctx).getStringAttr("CPU" /* device ID*/), + Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")), + std::nullopt); + EXPECT_EQ(layout.getDevicePropertyValue( + Builder(&ctx).getStringAttr("CPU" /* device ID*/), + Builder(&ctx).getStringAttr("max_vector_width")), + std::nullopt); } TEST(DataLayout, EmptySpec) { @@ -398,6 +534,15 @@ TEST(DataLayout, EmptySpec) { EXPECT_EQ(layout.getProgramMemorySpace(), Attribute()); EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute()); EXPECT_EQ(layout.getStackAlignment(), 0u); + + EXPECT_EQ(layout.getDevicePropertyValue( + Builder(&ctx).getStringAttr("CPU" /* device ID*/), + Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")), + std::nullopt); + EXPECT_EQ(layout.getDevicePropertyValue( + Builder(&ctx).getStringAttr("CPU" /* device ID*/), + Builder(&ctx).getStringAttr("max_vector_width")), + std::nullopt); } TEST(DataLayout, SpecWithEntries) { @@ -449,6 +594,32 @@ TEST(DataLayout, SpecWithEntries) { EXPECT_EQ(layout.getStackAlignment(), 128u); } +TEST(DataLayout, SpecWithTargetSystemDescEntries) { + const char *ir = R"MLIR( + module attributes { dl_target_sys_desc_test.target_system_spec = + #dl_target_sys_desc_test.target_system_spec< + "CPU": #dlti.target_device_spec< + #dlti.dl_entry<"L1_cache_size_in_bytes", "4096">, + #dlti.dl_entry<"max_vector_op_width", "128">> + > } {} + )MLIR"; + + DialectRegistry registry; + registry.insert<DLTIDialect, DLTargetSystemDescTestDialect>(); + MLIRContext ctx(registry); + + OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); + DataLayout layout(*module); + EXPECT_EQ(layout.getDevicePropertyValue( + Builder(&ctx).getStringAttr("CPU") /* device ID*/, + Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")), + std::optional<Attribute>(Builder(&ctx).getStringAttr("4096"))); + EXPECT_EQ(layout.getDevicePropertyValue( + Builder(&ctx).getStringAttr("CPU") /* device ID*/, + Builder(&ctx).getStringAttr("max_vector_op_width")), + std::optional<Attribute>(Builder(&ctx).getStringAttr("128"))); +} + TEST(DataLayout, Caching) { const char *ir = R"MLIR( "dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> () |
