summaryrefslogtreecommitdiff
path: root/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp')
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp75
1 files changed, 58 insertions, 17 deletions
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8de49dd397d2..f28454075f1d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -357,14 +357,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("priority");
};
auto checkPrivate = [&todo](auto op, LogicalResult &result) {
- if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
- // Privatization is supported only for included target tasks.
- if (!op.getPrivateVars().empty() && op.getNowait())
- result = todo("privatization for deferred target tasks");
- } else {
- if (!op.getPrivateVars().empty() || op.getPrivateSyms())
- result = todo("privatization");
- }
+ if (!op.getPrivateVars().empty() || op.getPrivateSyms())
+ result = todo("privatization");
};
auto checkReduction = [&todo](auto op, LogicalResult &result) {
if (isa<omp::TeamsOp>(op))
@@ -451,7 +445,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkDevice(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
- checkPrivate(op, result);
})
.Default([](Operation &) {
// Assume all clauses for an operation can be translated unless they are
@@ -3833,6 +3826,58 @@ static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
}
+// Convert the MLIR map flag set to the runtime map flag set for embedding
+// in LLVM-IR. This is important as the two bit-flag lists do not correspond
+// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
+// Certain flags are discarded here such as RefPtee and co.
+static llvm::omp::OpenMPOffloadMappingFlags
+convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
+ auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
+ return (mlirFlags & flag) == flag;
+ };
+
+ llvm::omp::OpenMPOffloadMappingFlags mapType =
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::to))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::from))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::always))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::del))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::return_param))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::priv))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::literal))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::implicit))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::close))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::present))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::attach))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
+
+ return mapType;
+}
+
static void collectMapDataFromMapOperands(
MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
@@ -3880,8 +3925,7 @@ static void collectMapDataFromMapOperands(
getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
mapData.BaseType.back(), builder, moduleTranslation));
mapData.MapClause.push_back(mapOp.getOperation());
- mapData.Types.push_back(
- llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
+ mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
mapData.Names.push_back(LLVM::createMappingInformation(
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
@@ -3950,8 +3994,7 @@ static void collectMapDataFromMapOperands(
Value offloadPtr =
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
- auto mapType =
- static_cast<llvm::omp::OpenMPOffloadMappingFlags>(mapOp.getMapType());
+ auto mapType = convertClauseMapFlags(mapOp.getMapType());
auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
mapData.OriginalValue.push_back(origValue);
@@ -4299,8 +4342,7 @@ static void processMapMembersWithParent(
// in part as we currently have substantially less information on the data
// being mapped at this stage.
if (checkIfPointerMap(memberClause)) {
- auto mapFlag =
- llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
+ auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
@@ -4319,8 +4361,7 @@ static void processMapMembersWithParent(
// Same MemberOfFlag to indicate its link with parent and other members
// of.
- auto mapFlag =
- llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
+ auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);