diff options
| author | Razvan Lupusoru <razvan.lupusoru@gmail.com> | 2025-11-18 16:04:11 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-11-18 16:04:11 -0800 |
| commit | 0a96b240fcb715c082ab9b4cab6fddae02065602 (patch) | |
| tree | 7ec8f7793f8656253f506f2142bd395fec5be4e6 /mlir | |
| parent | 1262acf4ecc9f55d0699705c7810bbf84d3da09e (diff) | |
[mlir][acc][flang] Introduce OpenACC interfaces for globals (#168614)
Introduce two new OpenACC operation interfaces for identifying global
variables and their address computations:
- `GlobalVariableOpInterface`: Identifies operations that define global
variables. Provides an `isConstant()` method to query whether the global
is constant.
- `AddressOfGlobalOpInterface`: Identifies operations that compute the
address of a global variable. Provides a `getSymbol()` method to
retrieve the symbol reference.
This is being done in preparation for `ACCImplicitDeclare` pass which
will automatically ensure that `acc declare` is applied to globals when
needed.
The following operations now implement these interfaces:
- `memref::GlobalOp` implements `GlobalVariableOpInterface`
- `memref::GetGlobalOp` implements `AddressOfGlobalOpInterface`
- `fir::GlobalOp` implements `GlobalVariableOpInterface`
- `fir::AddrOfOp` implements `AddressOfGlobalOpInterface`
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td | 31 | ||||
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 23 | ||||
| -rw-r--r-- | mlir/unittests/Dialect/OpenACC/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp | 95 |
4 files changed, 150 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td index 054c13a88a55..6b0c84d31d1b 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td @@ -44,4 +44,35 @@ def PartialEntityAccessOpInterface : OpInterface<"PartialEntityAccessOpInterface ]; } +def AddressOfGlobalOpInterface : OpInterface<"AddressOfGlobalOpInterface"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface for operations that compute the address of a global variable + or symbol. + }]; + + let methods = [ + InterfaceMethod<"Get the symbol reference to the global", "::mlir::SymbolRefAttr", + "getSymbol", (ins)>, + ]; +} + +def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface for operations that define global variables. This interface + provides a uniform way to query properties of global variables across + different dialects. + }]; + + let methods = [ + InterfaceMethod<"Check if the global variable is constant", "bool", + "isConstant", (ins), [{ + return false; + }]>, + ]; +} + #endif // OPENACC_OPS_INTERFACES diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 8c9c137b8aeb..5749e6ded73b 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -211,6 +211,24 @@ struct LLVMPointerPointerLikeModel Type getElementType(Type pointer) const { return Type(); } }; +struct MemrefAddressOfGlobalModel + : public AddressOfGlobalOpInterface::ExternalModel< + MemrefAddressOfGlobalModel, memref::GetGlobalOp> { + SymbolRefAttr getSymbol(Operation *op) const { + auto getGlobalOp = cast<memref::GetGlobalOp>(op); + return getGlobalOp.getNameAttr(); + } +}; + +struct MemrefGlobalVariableModel + : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel, + memref::GlobalOp> { + bool isConstant(Operation *op) const { + auto globalOp = cast<memref::GlobalOp>(op); + return globalOp.getConstant(); + } +}; + /// Helper function for any of the times we need to modify an ArrayAttr based on /// a device type list. Returns a new ArrayAttr with all of the /// existingDeviceTypes, plus the effective new ones(or an added none if hte new @@ -302,6 +320,11 @@ void OpenACCDialect::initialize() { MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext()); LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( *getContext()); + + // Attach operation interfaces + memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>( + *getContext()); + memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt index 177c8680b004..c8c2bb96b053 100644 --- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIROpenACCTests OpenACCOpsTest.cpp + OpenACCOpsInterfacesTest.cpp OpenACCUtilsTest.cpp ) mlir_target_link_libraries(MLIROpenACCTests diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp new file mode 100644 index 000000000000..261f5c513ea2 --- /dev/null +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp @@ -0,0 +1,95 @@ +//===- OpenACCOpsInterfacesTest.cpp - Unit tests for OpenACC interfaces --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::acc; + +//===----------------------------------------------------------------------===// +// Test Fixture +//===----------------------------------------------------------------------===// + +class OpenACCOpsInterfacesTest : public ::testing::Test { +protected: + OpenACCOpsInterfacesTest() + : context(), builder(&context), loc(UnknownLoc::get(&context)) { + context.loadDialect<acc::OpenACCDialect, memref::MemRefDialect>(); + } + + MLIRContext context; + OpBuilder builder; + Location loc; +}; + +//===----------------------------------------------------------------------===// +// GlobalVariableOpInterface Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceNonConstant) { + // Test that a non-constant global returns false for isConstant() + + auto memrefType = MemRefType::get({10}, builder.getF32Type()); + OwningOpRef<memref::GlobalOp> globalOp = memref::GlobalOp::create( + builder, loc, + /*sym_name=*/builder.getStringAttr("mutable_global"), + /*sym_visibility=*/builder.getStringAttr("private"), + /*type=*/TypeAttr::get(memrefType), + /*initial_value=*/Attribute(), + /*constant=*/UnitAttr(), + /*alignment=*/IntegerAttr()); + + auto globalVarIface = + dyn_cast<GlobalVariableOpInterface>(globalOp->getOperation()); + ASSERT_TRUE(globalVarIface != nullptr); + EXPECT_FALSE(globalVarIface.isConstant()); +} + +TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceConstant) { + // Test that a constant global returns true for isConstant() + + auto memrefType = MemRefType::get({5}, builder.getI32Type()); + OwningOpRef<memref::GlobalOp> constantGlobalOp = memref::GlobalOp::create( + builder, loc, + /*sym_name=*/builder.getStringAttr("constant_global"), + /*sym_visibility=*/builder.getStringAttr("public"), + /*type=*/TypeAttr::get(memrefType), + /*initial_value=*/Attribute(), + /*constant=*/builder.getUnitAttr(), + /*alignment=*/IntegerAttr()); + + auto globalVarIface = + dyn_cast<GlobalVariableOpInterface>(constantGlobalOp->getOperation()); + ASSERT_TRUE(globalVarIface != nullptr); + EXPECT_TRUE(globalVarIface.isConstant()); +} + +//===----------------------------------------------------------------------===// +// AddressOfGlobalOpInterface Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCOpsInterfacesTest, AddressOfGlobalOpInterfaceGetSymbol) { + // Test that getSymbol() returns the correct symbol reference + + auto memrefType = MemRefType::get({5}, builder.getI32Type()); + const auto *symbolName = "test_global_symbol"; + + OwningOpRef<memref::GetGlobalOp> getGlobalOp = memref::GetGlobalOp::create( + builder, loc, memrefType, FlatSymbolRefAttr::get(&context, symbolName)); + + auto addrOfGlobalIface = + dyn_cast<AddressOfGlobalOpInterface>(getGlobalOp->getOperation()); + ASSERT_TRUE(addrOfGlobalIface != nullptr); + EXPECT_EQ(addrOfGlobalIface.getSymbol().getLeafReference(), symbolName); +} |
