summaryrefslogtreecommitdiff
path: root/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp')
-rw-r--r--mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp171
1 files changed, 171 insertions, 0 deletions
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 5f484294268a..d1227b045d4e 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -32,6 +32,9 @@ constexpr static llvm::StringLiteral kGlobalKeyName =
constexpr static llvm::StringLiteral kStackAlignmentKeyName =
"dltest.stack_alignment";
+constexpr static llvm::StringLiteral kTargetSystemDescAttrName =
+ "dl_target_sys_desc_test.target_system_spec";
+
/// Trivial array storage for the custom data layout spec attribute, just a list
/// of entries.
class DataLayoutSpecStorage : public AttributeStorage {
@@ -91,6 +94,52 @@ struct CustomDataLayoutSpec
}
};
+class TargetSystemSpecStorage : public AttributeStorage {
+public:
+ using KeyTy = ArrayRef<DeviceIDTargetDeviceSpecPair>;
+
+ TargetSystemSpecStorage(ArrayRef<DeviceIDTargetDeviceSpecPair> entries)
+ : entries(entries) {}
+
+ bool operator==(const KeyTy &key) const { return key == entries; }
+
+ static TargetSystemSpecStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<TargetSystemSpecStorage>())
+ TargetSystemSpecStorage(allocator.copyInto(key));
+ }
+
+ ArrayRef<DeviceIDTargetDeviceSpecPair> entries;
+};
+
+struct CustomTargetSystemSpec
+ : public Attribute::AttrBase<CustomTargetSystemSpec, Attribute,
+ TargetSystemSpecStorage,
+ TargetSystemSpecInterface::Trait> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
+
+ using Base::Base;
+
+ static constexpr StringLiteral name = "test.custom_target_system_spec";
+
+ static CustomTargetSystemSpec
+ get(MLIRContext *ctx, ArrayRef<DeviceIDTargetDeviceSpecPair> entries) {
+ return Base::get(ctx, entries);
+ }
+ DeviceIDTargetDeviceSpecPairListRef getEntries() const {
+ return getImpl()->entries;
+ }
+ LogicalResult verifySpec(Location loc) { return success(); }
+ std::optional<TargetDeviceSpecInterface>
+ getDeviceSpecForDeviceID(TargetSystemSpecInterface::DeviceID deviceID) {
+ for (const auto &entry : getEntries()) {
+ if (entry.first == deviceID)
+ return entry.second;
+ }
+ return std::nullopt;
+ }
+};
+
/// A type subject to data layout that exits the program if it is queried more
/// than once. Handy to check if the cache works.
struct SingleQueryType
@@ -197,6 +246,11 @@ struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
}
+ TargetSystemSpecInterface getTargetSystemSpec() {
+ return getOperation()->getAttrOfType<TargetSystemSpecInterface>(
+ kTargetSystemDescAttrName);
+ }
+
static llvm::TypeSize getTypeSizeInBits(Type type,
const DataLayout &dataLayout,
DataLayoutEntryListRef params) {
@@ -244,6 +298,11 @@ struct OpWith7BitByte
return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
}
+ TargetSystemSpecInterface getTargetSystemSpec() {
+ return getOperation()->getAttrOfType<TargetSystemSpecInterface>(
+ kTargetSystemDescAttrName);
+ }
+
// Bytes are assumed to be 7-bit here.
static llvm::TypeSize getTypeSize(Type type, const DataLayout &dataLayout,
DataLayoutEntryListRef params) {
@@ -308,6 +367,74 @@ struct DLTestDialect : Dialect {
}
};
+/// A dialect to test DLTI's target system spec and related attributes
+struct DLTargetSystemDescTestDialect : public Dialect {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTargetSystemDescTestDialect)
+
+ explicit DLTargetSystemDescTestDialect(MLIRContext *ctx)
+ : Dialect(getDialectNamespace(), ctx,
+ TypeID::get<DLTargetSystemDescTestDialect>()) {
+ ctx->getOrLoadDialect<DLTIDialect>();
+ addAttributes<CustomTargetSystemSpec>();
+ }
+ static StringRef getDialectNamespace() { return "dl_target_sys_desc_test"; }
+
+ void printAttribute(Attribute attr,
+ DialectAsmPrinter &printer) const override {
+ printer << "target_system_spec<";
+ llvm::interleaveComma(
+ cast<CustomTargetSystemSpec>(attr).getEntries(), printer,
+ [&](const auto &it) { printer << it.first << ":" << it.second; });
+ printer << ">";
+ }
+
+ Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
+ bool ok = succeeded(parser.parseKeyword("target_system_spec")) &&
+ succeeded(parser.parseLess());
+ (void)ok;
+ assert(ok);
+ if (succeeded(parser.parseOptionalGreater()))
+ return CustomTargetSystemSpec::get(parser.getContext(), {});
+
+ auto parseDeviceIDTargetDeviceSpecPair =
+ [&](AsmParser &parser) -> FailureOr<DeviceIDTargetDeviceSpecPair> {
+ std::string deviceID;
+ if (failed(parser.parseString(&deviceID))) {
+ parser.emitError(parser.getCurrentLocation())
+ << "DeviceID is missing, or is not of string type";
+ return failure();
+ }
+ if (failed(parser.parseColon())) {
+ parser.emitError(parser.getCurrentLocation()) << "Missing colon";
+ return failure();
+ }
+
+ TargetDeviceSpecInterface targetDeviceSpec;
+ if (failed(parser.parseAttribute(targetDeviceSpec))) {
+ parser.emitError(parser.getCurrentLocation())
+ << "Error in parsing target device spec";
+ return failure();
+ }
+ return std::make_pair(parser.getBuilder().getStringAttr(deviceID),
+ targetDeviceSpec);
+ };
+
+ SmallVector<DeviceIDTargetDeviceSpecPair> entries;
+ ok = succeeded(parser.parseCommaSeparatedList([&]() {
+ auto deviceIDAndTargetDeviceSpecPair =
+ parseDeviceIDTargetDeviceSpecPair(parser);
+ ok = succeeded(deviceIDAndTargetDeviceSpecPair);
+ assert(ok);
+ entries.push_back(*deviceIDAndTargetDeviceSpecPair);
+ return success();
+ }));
+ assert(ok);
+ ok = succeeded(parser.parseGreater());
+ assert(ok);
+ return CustomTargetSystemSpec::get(parser.getContext(), entries);
+ }
+};
+
} // namespace
TEST(DataLayout, FallbackDefault) {
@@ -367,6 +494,15 @@ TEST(DataLayout, NullSpec) {
EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
EXPECT_EQ(layout.getStackAlignment(), 0u);
+
+ EXPECT_EQ(layout.getDevicePropertyValue(
+ Builder(&ctx).getStringAttr("CPU" /* device ID*/),
+ Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")),
+ std::nullopt);
+ EXPECT_EQ(layout.getDevicePropertyValue(
+ Builder(&ctx).getStringAttr("CPU" /* device ID*/),
+ Builder(&ctx).getStringAttr("max_vector_width")),
+ std::nullopt);
}
TEST(DataLayout, EmptySpec) {
@@ -398,6 +534,15 @@ TEST(DataLayout, EmptySpec) {
EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
EXPECT_EQ(layout.getStackAlignment(), 0u);
+
+ EXPECT_EQ(layout.getDevicePropertyValue(
+ Builder(&ctx).getStringAttr("CPU" /* device ID*/),
+ Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")),
+ std::nullopt);
+ EXPECT_EQ(layout.getDevicePropertyValue(
+ Builder(&ctx).getStringAttr("CPU" /* device ID*/),
+ Builder(&ctx).getStringAttr("max_vector_width")),
+ std::nullopt);
}
TEST(DataLayout, SpecWithEntries) {
@@ -449,6 +594,32 @@ TEST(DataLayout, SpecWithEntries) {
EXPECT_EQ(layout.getStackAlignment(), 128u);
}
+TEST(DataLayout, SpecWithTargetSystemDescEntries) {
+ const char *ir = R"MLIR(
+ module attributes { dl_target_sys_desc_test.target_system_spec =
+ #dl_target_sys_desc_test.target_system_spec<
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"L1_cache_size_in_bytes", "4096">,
+ #dlti.dl_entry<"max_vector_op_width", "128">>
+ > } {}
+ )MLIR";
+
+ DialectRegistry registry;
+ registry.insert<DLTIDialect, DLTargetSystemDescTestDialect>();
+ MLIRContext ctx(registry);
+
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ DataLayout layout(*module);
+ EXPECT_EQ(layout.getDevicePropertyValue(
+ Builder(&ctx).getStringAttr("CPU") /* device ID*/,
+ Builder(&ctx).getStringAttr("L1_cache_size_in_bytes")),
+ std::optional<Attribute>(Builder(&ctx).getStringAttr("4096")));
+ EXPECT_EQ(layout.getDevicePropertyValue(
+ Builder(&ctx).getStringAttr("CPU") /* device ID*/,
+ Builder(&ctx).getStringAttr("max_vector_op_width")),
+ std::optional<Attribute>(Builder(&ctx).getStringAttr("128")));
+}
+
TEST(DataLayout, Caching) {
const char *ir = R"MLIR(
"dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()