summaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp')
-rw-r--r--flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp93
1 files changed, 39 insertions, 54 deletions
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 2bbd8034fa52..bd07d7fe01b8 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -43,7 +43,6 @@
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringSet.h"
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstddef>
@@ -350,7 +349,7 @@ class MapInfoFinalizationPass
/// the descriptor map onto the base address map.
mlir::omp::MapInfoOp genBaseAddrMap(mlir::Value descriptor,
mlir::OperandRange bounds,
- int64_t mapType,
+ mlir::omp::ClauseMapFlags mapType,
fir::FirOpBuilder &builder) {
mlir::Location loc = descriptor.getLoc();
mlir::Value baseAddrAddr = fir::BoxOffsetOp::create(
@@ -368,7 +367,7 @@ class MapInfoFinalizationPass
return mlir::omp::MapInfoOp::create(
builder, loc, baseAddrAddr.getType(), descriptor,
mlir::TypeAttr::get(underlyingVarType),
- builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
+ builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(mapType),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
baseAddrAddr, /*members=*/mlir::SmallVector<mlir::Value>{},
@@ -428,22 +427,22 @@ class MapInfoFinalizationPass
/// allowing `to` mappings, and `target update` not allowing both `to` and
/// `from` simultaneously. We currently try to maintain the `implicit` flag
/// where necessary, although it does not seem strictly required.
- unsigned long getDescriptorMapType(unsigned long mapTypeFlag,
- mlir::Operation *target) {
- using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
+ mlir::omp::ClauseMapFlags
+ getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag,
+ mlir::Operation *target) {
+ using mapFlags = mlir::omp::ClauseMapFlags;
if (llvm::isa_and_nonnull<mlir::omp::TargetExitDataOp,
mlir::omp::TargetUpdateOp>(target))
return mapTypeFlag;
- mapFlags flags = mapFlags::OMP_MAP_TO |
- (mapFlags(mapTypeFlag) &
- (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS));
+ mapFlags flags =
+ mapFlags::to | (mapTypeFlag & (mapFlags::implicit | mapFlags::always));
// For unified_shared_memory, we additionally add `CLOSE` on the descriptor
// to ensure device-local placement where required by tests relying on USM +
// close semantics.
if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>()))
- flags |= mapFlags::OMP_MAP_CLOSE;
- return llvm::to_underlying(flags);
+ flags |= mapFlags::close;
+ return flags;
}
/// Check if the mapOp is present in the HasDeviceAddr clause on
@@ -493,11 +492,6 @@ class MapInfoFinalizationPass
mlir::Value boxAddr = fir::BoxOffsetOp::create(
builder, loc, op.getVarPtr(), fir::BoxFieldAttr::base_addr);
- uint64_t mapTypeToImplicit = static_cast<
- std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
-
mlir::ArrayAttr newMembersAttr;
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}};
newMembersAttr = builder.create2DI64ArrayAttr(memberIdx);
@@ -506,8 +500,9 @@ class MapInfoFinalizationPass
mlir::omp::MapInfoOp memberMapInfoOp = mlir::omp::MapInfoOp::create(
builder, op.getLoc(), varPtr.getType(), varPtr,
mlir::TypeAttr::get(boxCharType.getEleTy()),
- builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
- mapTypeToImplicit),
+ builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
+ mlir::omp::ClauseMapFlags::to |
+ mlir::omp::ClauseMapFlags::implicit),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
/*varPtrPtr=*/boxAddr,
@@ -568,12 +563,9 @@ class MapInfoFinalizationPass
mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx);
// Force CLOSE in USM paths so the pointer gets device-local placement
// when required by tests relying on USM + close semantics.
- uint64_t mapTypeVal =
- op.getMapType() |
- llvm::to_underlying(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
- mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr(
- builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal);
+ mlir::omp::ClauseMapFlagsAttr mapTypeAttr =
+ builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
+ op.getMapType() | mlir::omp::ClauseMapFlags::close);
mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create(
builder, loc, coord.getType(), coord,
@@ -683,17 +675,16 @@ class MapInfoFinalizationPass
// one place in the code may differ from that address in another place.
// The contents of the descriptor (the base address in particular) will
// remain unchanged though.
- uint64_t mapType = op.getMapType();
+ mlir::omp::ClauseMapFlags mapType = op.getMapType();
if (isHasDeviceAddrFlag) {
- mapType |= llvm::to_underlying(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
+ mapType |= mlir::omp::ClauseMapFlags::always;
}
mlir::omp::MapInfoOp newDescParentMapOp = mlir::omp::MapInfoOp::create(
builder, op->getLoc(), op.getResult().getType(), descriptor,
mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
- builder.getIntegerAttr(builder.getIntegerType(64, false),
- getDescriptorMapType(mapType, target)),
+ builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
+ getDescriptorMapType(mapType, target)),
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
@@ -892,20 +883,16 @@ class MapInfoFinalizationPass
if (explicitMappingPresent(op, targetDataOp))
return;
- mlir::omp::MapInfoOp newDescParentMapOp =
- builder.create<mlir::omp::MapInfoOp>(
- op->getLoc(), op.getResult().getType(), op.getVarPtr(),
- op.getVarTypeAttr(),
- builder.getIntegerAttr(
- builder.getIntegerType(64, false),
- llvm::to_underlying(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)),
- op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{},
- mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
- /*bounds=*/mlir::SmallVector<mlir::Value>{},
- /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
- /*partial_map=*/builder.getBoolAttr(false));
+ mlir::omp::MapInfoOp newDescParentMapOp = mlir::omp::MapInfoOp::create(
+ builder, op->getLoc(), op.getResult().getType(), op.getVarPtr(),
+ op.getVarTypeAttr(),
+ builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
+ mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::always),
+ op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{},
+ mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
+ /*bounds=*/mlir::SmallVector<mlir::Value>{},
+ /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
+ /*partial_map=*/builder.getBoolAttr(false));
targetDataOp.getMapVarsMutable().append({newDescParentMapOp});
}
@@ -957,14 +944,13 @@ class MapInfoFinalizationPass
// need to see how well this alteration works.
auto loadBaseAddr =
builder.loadIfRef(op->getLoc(), baseAddr.getVarPtrPtr());
- mlir::omp::MapInfoOp newBaseAddrMapOp =
- builder.create<mlir::omp::MapInfoOp>(
- op->getLoc(), loadBaseAddr.getType(), loadBaseAddr,
- baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(),
- baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members,
- membersAttr, baseAddr.getBounds(),
- /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
- /*partial_map=*/builder.getBoolAttr(false));
+ mlir::omp::MapInfoOp newBaseAddrMapOp = mlir::omp::MapInfoOp::create(
+ builder, op->getLoc(), loadBaseAddr.getType(), loadBaseAddr,
+ baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(),
+ baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members, membersAttr,
+ baseAddr.getBounds(),
+ /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
+ /*partial_map=*/builder.getBoolAttr(false));
op.replaceAllUsesWith(newBaseAddrMapOp.getResult());
op->erase();
baseAddr.erase();
@@ -1240,9 +1226,8 @@ class MapInfoFinalizationPass
// we need to change this check for early return OR live with
// over-mapping.
bool hasImplicitMap =
- (llvm::omp::OpenMPOffloadMappingFlags(op.getMapType()) &
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT) ==
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+ (op.getMapType() & mlir::omp::ClauseMapFlags::implicit) ==
+ mlir::omp::ClauseMapFlags::implicit;
if (hasImplicitMap)
return;