summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorRazvan Lupusoru <razvan.lupusoru@gmail.com>2025-11-18 16:04:11 -0800
committerGitHub <noreply@github.com>2025-11-18 16:04:11 -0800
commit0a96b240fcb715c082ab9b4cab6fddae02065602 (patch)
tree7ec8f7793f8656253f506f2142bd395fec5be4e6 /mlir
parent1262acf4ecc9f55d0699705c7810bbf84d3da09e (diff)
[mlir][acc][flang] Introduce OpenACC interfaces for globals (#168614)
Introduce two new OpenACC operation interfaces for identifying global variables and their address computations: - `GlobalVariableOpInterface`: Identifies operations that define global variables. Provides an `isConstant()` method to query whether the global is constant. - `AddressOfGlobalOpInterface`: Identifies operations that compute the address of a global variable. Provides a `getSymbol()` method to retrieve the symbol reference. This is being done in preparation for `ACCImplicitDeclare` pass which will automatically ensure that `acc declare` is applied to globals when needed. The following operations now implement these interfaces: - `memref::GlobalOp` implements `GlobalVariableOpInterface` - `memref::GetGlobalOp` implements `AddressOfGlobalOpInterface` - `fir::GlobalOp` implements `GlobalVariableOpInterface` - `fir::AddrOfOp` implements `AddressOfGlobalOpInterface`
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td31
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp23
-rw-r--r--mlir/unittests/Dialect/OpenACC/CMakeLists.txt1
-rw-r--r--mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp95
4 files changed, 150 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
index 054c13a88a55..6b0c84d31d1b 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
@@ -44,4 +44,35 @@ def PartialEntityAccessOpInterface : OpInterface<"PartialEntityAccessOpInterface
];
}
+def AddressOfGlobalOpInterface : OpInterface<"AddressOfGlobalOpInterface"> {
+ let cppNamespace = "::mlir::acc";
+
+ let description = [{
+ An interface for operations that compute the address of a global variable
+ or symbol.
+ }];
+
+ let methods = [
+ InterfaceMethod<"Get the symbol reference to the global", "::mlir::SymbolRefAttr",
+ "getSymbol", (ins)>,
+ ];
+}
+
+def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
+ let cppNamespace = "::mlir::acc";
+
+ let description = [{
+ An interface for operations that define global variables. This interface
+ provides a uniform way to query properties of global variables across
+ different dialects.
+ }];
+
+ let methods = [
+ InterfaceMethod<"Check if the global variable is constant", "bool",
+ "isConstant", (ins), [{
+ return false;
+ }]>,
+ ];
+}
+
#endif // OPENACC_OPS_INTERFACES
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 8c9c137b8aeb..5749e6ded73b 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -211,6 +211,24 @@ struct LLVMPointerPointerLikeModel
Type getElementType(Type pointer) const { return Type(); }
};
+struct MemrefAddressOfGlobalModel
+ : public AddressOfGlobalOpInterface::ExternalModel<
+ MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
+ SymbolRefAttr getSymbol(Operation *op) const {
+ auto getGlobalOp = cast<memref::GetGlobalOp>(op);
+ return getGlobalOp.getNameAttr();
+ }
+};
+
+struct MemrefGlobalVariableModel
+ : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
+ memref::GlobalOp> {
+ bool isConstant(Operation *op) const {
+ auto globalOp = cast<memref::GlobalOp>(op);
+ return globalOp.getConstant();
+ }
+};
+
/// Helper function for any of the times we need to modify an ArrayAttr based on
/// a device type list. Returns a new ArrayAttr with all of the
/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
@@ -302,6 +320,11 @@ void OpenACCDialect::initialize() {
MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
*getContext());
+
+ // Attach operation interfaces
+ memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
+ *getContext());
+ memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
index 177c8680b004..c8c2bb96b053 100644
--- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIROpenACCTests
OpenACCOpsTest.cpp
+ OpenACCOpsInterfacesTest.cpp
OpenACCUtilsTest.cpp
)
mlir_target_link_libraries(MLIROpenACCTests
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp
new file mode 100644
index 000000000000..261f5c513ea2
--- /dev/null
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp
@@ -0,0 +1,95 @@
+//===- OpenACCOpsInterfacesTest.cpp - Unit tests for OpenACC interfaces --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+//===----------------------------------------------------------------------===//
+// Test Fixture
+//===----------------------------------------------------------------------===//
+
+class OpenACCOpsInterfacesTest : public ::testing::Test {
+protected:
+ OpenACCOpsInterfacesTest()
+ : context(), builder(&context), loc(UnknownLoc::get(&context)) {
+ context.loadDialect<acc::OpenACCDialect, memref::MemRefDialect>();
+ }
+
+ MLIRContext context;
+ OpBuilder builder;
+ Location loc;
+};
+
+//===----------------------------------------------------------------------===//
+// GlobalVariableOpInterface Tests
+//===----------------------------------------------------------------------===//
+
+TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceNonConstant) {
+ // Test that a non-constant global returns false for isConstant()
+
+ auto memrefType = MemRefType::get({10}, builder.getF32Type());
+ OwningOpRef<memref::GlobalOp> globalOp = memref::GlobalOp::create(
+ builder, loc,
+ /*sym_name=*/builder.getStringAttr("mutable_global"),
+ /*sym_visibility=*/builder.getStringAttr("private"),
+ /*type=*/TypeAttr::get(memrefType),
+ /*initial_value=*/Attribute(),
+ /*constant=*/UnitAttr(),
+ /*alignment=*/IntegerAttr());
+
+ auto globalVarIface =
+ dyn_cast<GlobalVariableOpInterface>(globalOp->getOperation());
+ ASSERT_TRUE(globalVarIface != nullptr);
+ EXPECT_FALSE(globalVarIface.isConstant());
+}
+
+TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceConstant) {
+ // Test that a constant global returns true for isConstant()
+
+ auto memrefType = MemRefType::get({5}, builder.getI32Type());
+ OwningOpRef<memref::GlobalOp> constantGlobalOp = memref::GlobalOp::create(
+ builder, loc,
+ /*sym_name=*/builder.getStringAttr("constant_global"),
+ /*sym_visibility=*/builder.getStringAttr("public"),
+ /*type=*/TypeAttr::get(memrefType),
+ /*initial_value=*/Attribute(),
+ /*constant=*/builder.getUnitAttr(),
+ /*alignment=*/IntegerAttr());
+
+ auto globalVarIface =
+ dyn_cast<GlobalVariableOpInterface>(constantGlobalOp->getOperation());
+ ASSERT_TRUE(globalVarIface != nullptr);
+ EXPECT_TRUE(globalVarIface.isConstant());
+}
+
+//===----------------------------------------------------------------------===//
+// AddressOfGlobalOpInterface Tests
+//===----------------------------------------------------------------------===//
+
+TEST_F(OpenACCOpsInterfacesTest, AddressOfGlobalOpInterfaceGetSymbol) {
+ // Test that getSymbol() returns the correct symbol reference
+
+ auto memrefType = MemRefType::get({5}, builder.getI32Type());
+ const auto *symbolName = "test_global_symbol";
+
+ OwningOpRef<memref::GetGlobalOp> getGlobalOp = memref::GetGlobalOp::create(
+ builder, loc, memrefType, FlatSymbolRefAttr::get(&context, symbolName));
+
+ auto addrOfGlobalIface =
+ dyn_cast<AddressOfGlobalOpInterface>(getGlobalOp->getOperation());
+ ASSERT_TRUE(addrOfGlobalIface != nullptr);
+ EXPECT_EQ(addrOfGlobalIface.getSymbol().getLeafReference(), symbolName);
+}