diff options
| author | Michael Kruse <llvm-project@meinersbur.de> | 2025-01-03 10:22:51 +0100 |
|---|---|---|
| committer | Michael Kruse <llvm-project@meinersbur.de> | 2025-01-03 10:22:51 +0100 |
| commit | 38500d63e14ce340236840f60d356cdefb56a52c (patch) | |
| tree | 17edbec446ce9b50d2f215a483b83afb293a635d /flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp | |
| parent | 1a3d5daaef7a6a63448a497da3eff7fc9e23df26 (diff) | |
| parent | 27f30029741ecf023baece7b3dde1ff9011ffefc (diff) | |
Merge branch 'main' into users/meinersbur/flang_runtime_split-headersusers/meinersbur/flang_runtime_split-headers
Diffstat (limited to 'flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp')
| -rw-r--r-- | flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp | 169 |
1 files changed, 163 insertions, 6 deletions
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 4575c90e34ac..ad7b806ae262 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -24,10 +24,14 @@ /// indirectly via a parent object. //===----------------------------------------------------------------------===// +#include "flang/Lower/DirectivesCommon.h" #include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/Support/KindMapping.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/OpenMP/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/BuiltinDialect.h" @@ -411,10 +415,10 @@ class MapInfoFinalizationPass argIface ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs() : 0; - addOperands( - mapMutableOpRange, - llvm::dyn_cast_or_null<mlir::omp::TargetOp>(argIface.getOperation()), - blockArgInsertIndex); + addOperands(mapMutableOpRange, + llvm::dyn_cast_if_present<mlir::omp::TargetOp>( + argIface.getOperation()), + blockArgInsertIndex); } if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) { @@ -466,8 +470,7 @@ class MapInfoFinalizationPass // operation (usually function) containing the MapInfoOp because this pass // will mutate siblings of MapInfoOp. void runOnOperation() override { - mlir::ModuleOp module = - mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation()); + mlir::ModuleOp module = getOperation(); if (!module) module = getOperation()->getParentOfType<mlir::ModuleOp>(); fir::KindMapping kindMap = fir::getKindMapping(module); @@ -486,6 +489,160 @@ class MapInfoFinalizationPass // iterations from previous function scopes. localBoxAllocas.clear(); + // First, walk `omp.map.info` ops to see if any record members should be + // implicitly mapped. + func->walk([&](mlir::omp::MapInfoOp op) { + mlir::Type underlyingType = + fir::unwrapRefType(op.getVarPtr().getType()); + + // TODO Test with and support more complicated cases; like arrays for + // records, for example. + if (!fir::isRecordWithAllocatableMember(underlyingType)) + return mlir::WalkResult::advance(); + + // TODO For now, only consider `omp.target` ops. Other ops that support + // `map` clauses will follow later. + mlir::omp::TargetOp target = + mlir::dyn_cast_if_present<mlir::omp::TargetOp>( + getFirstTargetUser(op)); + + if (!target) + return mlir::WalkResult::advance(); + + auto mapClauseOwner = + llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target); + + int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op); + assert(mapVarIdx >= 0 && + mapVarIdx < + static_cast<int64_t>(mapClauseOwner.getMapVars().size())); + + auto argIface = + llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target); + // TODO How should `map` block argument that correspond to: `private`, + // `use_device_addr`, `use_device_ptr`, be handled? + mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx]; + llvm::SetVector<mlir::Operation *> mapVarForwardSlice; + mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice); + + mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) { + // TODO Support coordinate_of ops. + // + // TODO Support call ops by recursively examining the forward slice of + // the corresponding parameter to the field in the called function. + return !mlir::isa<hlfir::DesignateOp>(sliceOp); + }); + + auto recordType = mlir::cast<fir::RecordType>(underlyingType); + llvm::SmallVector<mlir::Value> newMapOpsForFields; + llvm::SmallVector<int64_t> fieldIndicies; + + for (auto fieldMemTyPair : recordType.getTypeList()) { + auto &field = fieldMemTyPair.first; + auto memTy = fieldMemTyPair.second; + + bool shouldMapField = + llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) { + if (!fir::isAllocatableType(memTy)) + return false; + + auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp); + if (!designateOp) + return false; + + return designateOp.getComponent() && + designateOp.getComponent()->strref() == field; + }) != mapVarForwardSlice.end(); + + // TODO Handle recursive record types. Adapting + // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR + // entities might be helpful here. + + if (!shouldMapField) + continue; + + int64_t fieldIdx = recordType.getFieldIndex(field); + bool alreadyMapped = [&]() { + if (op.getMembersIndexAttr()) + for (auto indexList : op.getMembersIndexAttr()) { + auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList); + if (indexListAttr.size() == 1 && + mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() == + fieldIdx) + return true; + } + + return false; + }(); + + if (alreadyMapped) + continue; + + builder.setInsertionPoint(op); + mlir::Value fieldIdxVal = builder.createIntegerConstant( + op.getLoc(), mlir::IndexType::get(builder.getContext()), + fieldIdx); + auto fieldCoord = builder.create<fir::CoordinateOp>( + op.getLoc(), builder.getRefType(memTy), op.getVarPtr(), + fieldIdxVal); + Fortran::lower::AddrAndBoundsInfo info = + Fortran::lower::getDataOperandBaseAddr( + builder, fieldCoord, /*isOptional=*/false, op.getLoc()); + llvm::SmallVector<mlir::Value> bounds = + Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp, + mlir::omp::MapBoundsType>( + builder, info, + hlfir::translateToExtendedValue(op.getLoc(), builder, + hlfir::Entity{fieldCoord}) + .first, + /*dataExvIsAssumedSize=*/false, op.getLoc()); + + mlir::omp::MapInfoOp fieldMapOp = + builder.create<mlir::omp::MapInfoOp>( + op.getLoc(), fieldCoord.getResult().getType(), + fieldCoord.getResult(), + mlir::TypeAttr::get( + fir::unwrapRefType(fieldCoord.getResult().getType())), + /*varPtrPtr=*/mlir::Value{}, + /*members=*/mlir::ValueRange{}, + /*members_index=*/mlir::ArrayAttr{}, + /*bounds=*/bounds, op.getMapTypeAttr(), + builder.getAttr<mlir::omp::VariableCaptureKindAttr>( + mlir::omp::VariableCaptureKind::ByRef), + builder.getStringAttr(op.getNameAttr().strref() + "." + + field + ".implicit_map"), + /*partial_map=*/builder.getBoolAttr(false)); + newMapOpsForFields.emplace_back(fieldMapOp); + fieldIndicies.emplace_back(fieldIdx); + } + + if (newMapOpsForFields.empty()) + return mlir::WalkResult::advance(); + + op.getMembersMutable().append(newMapOpsForFields); + llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices; + mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr(); + + if (oldMembersIdxAttr) + for (mlir::Attribute indexList : oldMembersIdxAttr) { + llvm::SmallVector<int64_t> listVec; + + for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList)) + listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt()); + + newMemberIndices.emplace_back(std::move(listVec)); + } + + for (int64_t newFieldIdx : fieldIndicies) + newMemberIndices.emplace_back( + llvm::SmallVector<int64_t>(1, newFieldIdx)); + + op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices)); + op.setPartialMap(true); + + return mlir::WalkResult::advance(); + }); + func->walk([&](mlir::omp::MapInfoOp op) { // TODO: Currently only supports a single user for the MapInfoOp. This // is fine for the moment, as the Fortran frontend will generate a |
