diff options
| author | Florian Mayer <fmayer@google.com> | 2025-10-22 10:55:10 -0700 |
|---|---|---|
| committer | Florian Mayer <fmayer@google.com> | 2025-10-22 10:55:10 -0700 |
| commit | f5f8398d7fe18a968f5873518e87d5fdd8269359 (patch) | |
| tree | 347dff286c3b48b2336fb7a425adfceebd478116 /flang/lib | |
| parent | 73edaec4a6cd1212f9ae819c413d2cf58216d3b1 (diff) | |
| parent | a0abc0af0a0a90878822f8107d70dad6f7cdfc26 (diff) | |
Created using spr 1.3.7
Diffstat (limited to 'flang/lib')
38 files changed, 1082 insertions, 608 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 3b711ccbe786..a516a44204ca 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1766,7 +1766,7 @@ private: // to a crash due to a block with no terminator. See issue #126452. mlir::FunctionType funcType = builder->getFunction().getFunctionType(); mlir::Type resultType = funcType.getResult(0); - mlir::Value undefResult = builder->create<fir::UndefOp>(loc, resultType); + mlir::Value undefResult = fir::UndefOp::create(*builder, loc, resultType); genExitRoutine(false, undefResult); return; } @@ -4010,8 +4010,8 @@ private: // parameters and dynamic type. The selector cannot be a // POINTER/ALLOCATBLE as per F'2023 C1160. fir::ExtendedValue newExv; - llvm::SmallVector assumeSizeExtents{ - builder->createMinusOneInteger(loc, builder->getIndexType())}; + llvm::SmallVector<mlir::Value> assumeSizeExtents{ + fir::AssumedSizeExtentOp::create(*builder, loc)}; mlir::Value baseAddr = hlfir::genVariableRawAddress(loc, *builder, selector); const bool isVolatile = fir::isa_volatile_type(selector.getType()); @@ -4733,11 +4733,21 @@ private: return fir::factory::createUnallocatedBox(*builder, loc, lhsBoxType, {}); hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR( loc, *this, assign.rhs, localSymbols, rhsContext); + auto rhsBoxType = rhs.getBoxType(); // Create pointer descriptor value from the RHS. if (rhs.isMutableBox()) rhs = hlfir::Entity{fir::LoadOp::create(*builder, loc, rhs)}; - mlir::Value rhsBox = hlfir::genVariableBox( - loc, *builder, rhs, lhsBoxType.getBoxTypeWithNewShape(rhs.getRank())); + + // Use LHS type if LHS is not polymorphic. + fir::BaseBoxType targetBoxType; + if (assign.lhs.GetType()->IsPolymorphic()) + targetBoxType = rhsBoxType.getBoxTypeWithNewAttr( + fir::BaseBoxType::Attribute::Pointer); + else + targetBoxType = lhsBoxType.getBoxTypeWithNewShape(rhs.getRank()); + mlir::Value rhsBox = + hlfir::genVariableBox(loc, *builder, rhs, targetBoxType); + // Apply lower bounds or reshaping if any. if (const auto *lbExprs = std::get_if<Fortran::evaluate::Assignment::BoundsSpec>(&assign.u); diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp index 00ec1b51e540..2517ab35d4ff 100644 --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -1711,7 +1711,7 @@ static void lowerExplicitLowerBounds( /// CFI_desc_t requirements in 18.5.3 point 5.). static mlir::Value getAssumedSizeExtent(mlir::Location loc, fir::FirOpBuilder &builder) { - return builder.createMinusOneInteger(loc, builder.getIndexType()); + return fir::AssumedSizeExtentOp::create(builder, loc); } /// Lower explicit extents into \p result if this is an explicit-shape or diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index b3e8b697df1e..1fc59c702fd8 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -718,6 +718,84 @@ static void genDataOperandOperations( } } +template <typename GlobalCtorOrDtorOp, typename EntryOp, typename DeclareOp, + typename ExitOp> +static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder, + fir::FirOpBuilder &builder, + mlir::Location loc, fir::GlobalOp globalOp, + mlir::acc::DataClause clause, + const std::string &declareGlobalName, + bool implicit, std::stringstream &asFortran) { + GlobalCtorOrDtorOp declareGlobalOp = + GlobalCtorOrDtorOp::create(modBuilder, loc, declareGlobalName); + builder.createBlock(&declareGlobalOp.getRegion(), + declareGlobalOp.getRegion().end(), {}, {}); + builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back()); + + fir::AddrOfOp addrOp = fir::AddrOfOp::create( + builder, loc, fir::ReferenceType::get(globalOp.getType()), + globalOp.getSymbol()); + addDeclareAttr(builder, addrOp, clause); + + llvm::SmallVector<mlir::Value> bounds; + EntryOp entryOp = createDataEntryOp<EntryOp>( + builder, loc, addrOp.getResTy(), asFortran, bounds, + /*structured=*/false, implicit, clause, addrOp.getResTy().getType(), + /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); + if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>) + DeclareOp::create(builder, loc, + mlir::acc::DeclareTokenType::get(entryOp.getContext()), + mlir::ValueRange(entryOp.getAccVar())); + else + DeclareOp::create(builder, loc, mlir::Value{}, + mlir::ValueRange(entryOp.getAccVar())); + if constexpr (std::is_same_v<GlobalCtorOrDtorOp, + mlir::acc::GlobalDestructorOp>) { + if constexpr (std::is_same_v<ExitOp, mlir::acc::DeclareLinkOp>) { + // No destructor emission for declare link in this path to avoid + // complex var/varType/varPtrPtr signatures. The ctor registers the link. + } else if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> || + std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>) { + ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(), + entryOp.getVar(), entryOp.getVarType(), + entryOp.getBounds(), entryOp.getAsyncOperands(), + entryOp.getAsyncOperandsDeviceTypeAttr(), + entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), + /*structured=*/false, /*implicit=*/false, + builder.getStringAttr(*entryOp.getName())); + } else { + ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(), + entryOp.getBounds(), entryOp.getAsyncOperands(), + entryOp.getAsyncOperandsDeviceTypeAttr(), + entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), + /*structured=*/false, /*implicit=*/false, + builder.getStringAttr(*entryOp.getName())); + } + } + mlir::acc::TerminatorOp::create(builder, loc); + modBuilder.setInsertionPointAfter(declareGlobalOp); +} + +template <typename EntryOp, typename ExitOp> +static void +emitCtorDtorPair(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, + mlir::Location operandLocation, fir::GlobalOp globalOp, + mlir::acc::DataClause clause, std::stringstream &asFortran, + const std::string &ctorName) { + createDeclareGlobalOp<mlir::acc::GlobalConstructorOp, EntryOp, + mlir::acc::DeclareEnterOp, ExitOp>( + modBuilder, builder, operandLocation, globalOp, clause, ctorName, + /*implicit=*/false, asFortran); + + std::stringstream dtorName; + dtorName << globalOp.getSymName().str() << "_acc_dtor"; + createDeclareGlobalOp<mlir::acc::GlobalDestructorOp, + mlir::acc::GetDevicePtrOp, mlir::acc::DeclareExitOp, + ExitOp>(modBuilder, builder, operandLocation, globalOp, + clause, dtorName.str(), + /*implicit=*/false, asFortran); +} + template <typename EntryOp, typename ExitOp> static void genDeclareDataOperandOperations( const Fortran::parser::AccObjectList &objectList, @@ -733,6 +811,37 @@ static void genDeclareDataOperandOperations( std::stringstream asFortran; mlir::Location operandLocation = genOperandLocation(converter, accObject); Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); + // Handle COMMON/global symbols via module-level ctor/dtor path. + if (symbol.detailsIf<Fortran::semantics::CommonBlockDetails>() || + Fortran::semantics::FindCommonBlockContaining(symbol)) { + emitCommonGlobal( + converter, builder, accObject, dataClause, + [&](mlir::OpBuilder &modBuilder, mlir::Location loc, + fir::GlobalOp globalOp, mlir::acc::DataClause clause, + std::stringstream &asFortranStr, const std::string &ctorName) { + if constexpr (std::is_same_v<EntryOp, mlir::acc::DeclareLinkOp>) { + createDeclareGlobalOp< + mlir::acc::GlobalConstructorOp, mlir::acc::DeclareLinkOp, + mlir::acc::DeclareEnterOp, mlir::acc::DeclareLinkOp>( + modBuilder, builder, loc, globalOp, clause, ctorName, + /*implicit=*/false, asFortranStr); + } else if constexpr (std::is_same_v<EntryOp, mlir::acc::CreateOp> || + std::is_same_v<EntryOp, mlir::acc::CopyinOp> || + std::is_same_v< + EntryOp, + mlir::acc::DeclareDeviceResidentOp> || + std::is_same_v<ExitOp, mlir::acc::CopyoutOp>) { + emitCtorDtorPair<EntryOp, ExitOp>(modBuilder, builder, loc, + globalOp, clause, asFortranStr, + ctorName); + } else { + // No module-level ctor/dtor for this clause (e.g., deviceptr, + // present). Handled via structured declare region only. + return; + } + }); + continue; + } Fortran::semantics::MaybeExpr designator = Fortran::common::visit( [&](auto &&s) { return ea.Analyze(s); }, accObject.u); fir::factory::AddrAndBoundsInfo info = @@ -2257,6 +2366,23 @@ static void processDoLoopBounds( } } +static void remapCommonBlockMember( + Fortran::lower::AbstractConverter &converter, mlir::Location loc, + const Fortran::semantics::Symbol &member, + mlir::Value newCommonBlockBaseAddress, + const Fortran::semantics::Symbol &commonBlockSymbol, + llvm::SmallPtrSetImpl<const Fortran::semantics::Symbol *> &seenSymbols) { + if (seenSymbols.contains(&member)) + return; + mlir::Value accMemberValue = Fortran::lower::genCommonBlockMember( + converter, loc, member, newCommonBlockBaseAddress, + commonBlockSymbol.size()); + fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(member); + fir::ExtendedValue accExv = fir::substBase(hostExv, accMemberValue); + converter.bindSymbol(member, accExv); + seenSymbols.insert(&member); +} + /// Remap symbols that appeared in OpenACC data clauses to use the results of /// the corresponding data operations. This allows isolating symbol accesses /// inside the OpenACC region from accesses in the host and other regions while @@ -2282,14 +2408,39 @@ static void remapDataOperandSymbols( builder.setInsertionPointToStart(®ionOp.getRegion().front()); llvm::SmallPtrSet<const Fortran::semantics::Symbol *, 8> seenSymbols; mlir::IRMapping mapper; + mlir::Location loc = regionOp.getLoc(); for (auto [value, symbol] : dataOperandSymbolPairs) { - - // If A symbol appears on several data clause, just map it to the first + // If a symbol appears on several data clause, just map it to the first // result (all data operations results for a symbol are pointing same // memory, so it does not matter which one is used). if (seenSymbols.contains(&symbol.get())) continue; seenSymbols.insert(&symbol.get()); + // When a common block appears in a directive, remap its members. + // Note: this will instantiate all common block members even if they are not + // used inside the region. If hlfir.declare DCE is not made possible, this + // could be improved to reduce IR noise. + if (const auto *commonBlock = symbol->template detailsIf< + Fortran::semantics::CommonBlockDetails>()) { + const Fortran::semantics::Scope &commonScope = symbol->owner(); + if (commonScope.equivalenceSets().empty()) { + for (auto member : commonBlock->objects()) + remapCommonBlockMember(converter, loc, *member, value, *symbol, + seenSymbols); + } else { + // Objects equivalenced with common block members still belong to the + // common block storage even if they are not part of the common block + // declaration. The easiest and most robust way to find all symbols + // belonging to the common block is to loop through the scope symbols + // and check if they belong to the common. + for (const auto &scopeSymbol : commonScope) + if (Fortran::semantics::FindCommonBlockContaining( + *scopeSymbol.second) == &symbol.get()) + remapCommonBlockMember(converter, loc, *scopeSymbol.second, value, + *symbol, seenSymbols); + } + continue; + } std::optional<fir::FortranVariableOpInterface> hostDef = symbolMap.lookupVariableDefinition(symbol); assert(hostDef.has_value() && llvm::isa<hlfir::DeclareOp>(*hostDef) && @@ -2306,10 +2457,8 @@ static void remapDataOperandSymbols( "box type mismatch between compute region variable and " "hlfir.declare input unexpected"); if (Fortran::semantics::IsOptional(symbol)) - TODO(regionOp.getLoc(), - "remapping OPTIONAL symbol in OpenACC compute region"); - auto rawValue = - fir::BoxAddrOp::create(builder, regionOp.getLoc(), hostType, value); + TODO(loc, "remapping OPTIONAL symbol in OpenACC compute region"); + auto rawValue = fir::BoxAddrOp::create(builder, loc, hostType, value); mapper.map(hostInput, rawValue); } else { assert(!llvm::isa<fir::BaseBoxType>(hostType) && @@ -2321,8 +2470,7 @@ static void remapDataOperandSymbols( assert(fir::isa_ref_type(hostType) && fir::isa_ref_type(computeType) && "compute region variable and host variable should both be raw " "addresses"); - mlir::Value cast = - builder.createConvert(regionOp.getLoc(), hostType, value); + mlir::Value cast = builder.createConvert(loc, hostType, value); mapper.map(hostInput, cast); } if (mlir::Value dummyScope = hostDeclare.getDummyScope()) { @@ -4098,49 +4246,6 @@ static void genACC(Fortran::lower::AbstractConverter &converter, waitOp.setAsyncAttr(firOpBuilder.getUnitAttr()); } -template <typename GlobalOp, typename EntryOp, typename DeclareOp, - typename ExitOp> -static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder, - fir::FirOpBuilder &builder, - mlir::Location loc, fir::GlobalOp globalOp, - mlir::acc::DataClause clause, - const std::string &declareGlobalName, - bool implicit, std::stringstream &asFortran) { - GlobalOp declareGlobalOp = - GlobalOp::create(modBuilder, loc, declareGlobalName); - builder.createBlock(&declareGlobalOp.getRegion(), - declareGlobalOp.getRegion().end(), {}, {}); - builder.setInsertionPointToEnd(&declareGlobalOp.getRegion().back()); - - fir::AddrOfOp addrOp = fir::AddrOfOp::create( - builder, loc, fir::ReferenceType::get(globalOp.getType()), - globalOp.getSymbol()); - addDeclareAttr(builder, addrOp, clause); - - llvm::SmallVector<mlir::Value> bounds; - EntryOp entryOp = createDataEntryOp<EntryOp>( - builder, loc, addrOp.getResTy(), asFortran, bounds, - /*structured=*/false, implicit, clause, addrOp.getResTy().getType(), - /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); - if constexpr (std::is_same_v<DeclareOp, mlir::acc::DeclareEnterOp>) - DeclareOp::create(builder, loc, - mlir::acc::DeclareTokenType::get(entryOp.getContext()), - mlir::ValueRange(entryOp.getAccVar())); - else - DeclareOp::create(builder, loc, mlir::Value{}, - mlir::ValueRange(entryOp.getAccVar())); - if constexpr (std::is_same_v<GlobalOp, mlir::acc::GlobalDestructorOp>) { - ExitOp::create(builder, entryOp.getLoc(), entryOp.getAccVar(), - entryOp.getBounds(), entryOp.getAsyncOperands(), - entryOp.getAsyncOperandsDeviceTypeAttr(), - entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), - /*structured=*/false, /*implicit=*/false, - builder.getStringAttr(*entryOp.getName())); - } - mlir::acc::TerminatorOp::create(builder, loc); - modBuilder.setInsertionPointAfter(declareGlobalOp); -} - template <typename EntryOp> static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, @@ -4317,6 +4422,66 @@ genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter, dataClause); } +static fir::GlobalOp +lookupGlobalBySymbolOrEquivalence(Fortran::lower::AbstractConverter &converter, + fir::FirOpBuilder &builder, + const Fortran::semantics::Symbol &sym) { + const Fortran::semantics::Symbol *commonBlock = + Fortran::semantics::FindCommonBlockContaining(sym); + std::string globalName = commonBlock ? converter.mangleName(*commonBlock) + : converter.mangleName(sym); + if (fir::GlobalOp g = builder.getNamedGlobal(globalName)) { + return g; + } + // Not found: if not a COMMON member, try equivalence members + if (!commonBlock) { + if (const Fortran::semantics::EquivalenceSet *eqSet = + Fortran::semantics::FindEquivalenceSet(sym)) { + for (const Fortran::semantics::EquivalenceObject &eqObj : *eqSet) { + std::string eqName = converter.mangleName(eqObj.symbol); + if (fir::GlobalOp g = builder.getNamedGlobal(eqName)) + return g; + } + } + } + return {}; +} + +template <typename EmitterFn> +static void emitCommonGlobal(Fortran::lower::AbstractConverter &converter, + fir::FirOpBuilder &builder, + const Fortran::parser::AccObject &obj, + mlir::acc::DataClause clause, + EmitterFn &&emitCtorDtor) { + Fortran::semantics::Symbol &sym = getSymbolFromAccObject(obj); + if (!(sym.detailsIf<Fortran::semantics::CommonBlockDetails>() || + Fortran::semantics::FindCommonBlockContaining(sym))) + return; + + fir::GlobalOp globalOp = + lookupGlobalBySymbolOrEquivalence(converter, builder, sym); + if (!globalOp) + llvm::report_fatal_error("could not retrieve global symbol"); + + std::stringstream ctorName; + ctorName << globalOp.getSymName().str() << "_acc_ctor"; + if (builder.getModule().lookupSymbol<mlir::acc::GlobalConstructorOp>( + ctorName.str())) + return; + + mlir::Location operandLocation = genOperandLocation(converter, obj); + addDeclareAttr(builder, globalOp.getOperation(), clause); + mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion()); + modBuilder.setInsertionPointAfter(globalOp); + std::stringstream asFortran; + asFortran << sym.name().ToString(); + + auto savedIP = builder.saveInsertionPoint(); + emitCtorDtor(modBuilder, operandLocation, globalOp, clause, asFortran, + ctorName.str()); + builder.restoreInsertionPoint(savedIP); +} + static void genDeclareInFunction(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, @@ -4342,11 +4507,9 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter, dataClauseOperands.end()); } else if (const auto *createClause = std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { - const Fortran::parser::AccObjectListWithModifier &listWithModifier = - createClause->v; - const auto &accObjectList = - std::get<Fortran::parser::AccObjectList>(listWithModifier.t); auto crtDataStart = dataClauseOperands.size(); + const auto &accObjectList = + std::get<Fortran::parser::AccObjectList>(createClause->v.t); genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>( accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_create, @@ -4378,11 +4541,9 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter, } else if (const auto *copyoutClause = std::get_if<Fortran::parser::AccClause::Copyout>( &clause.u)) { - const Fortran::parser::AccObjectListWithModifier &listWithModifier = - copyoutClause->v; - const auto &accObjectList = - std::get<Fortran::parser::AccObjectList>(listWithModifier.t); auto crtDataStart = dataClauseOperands.size(); + const auto &accObjectList = + std::get<Fortran::parser::AccObjectList>(copyoutClause->v.t); genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>( accObjectList, converter, semanticsContext, stmtCtx, @@ -4423,6 +4584,11 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter, } } + // If no structured operands were generated (all objects were COMMON), + // do not create a declare region. + if (dataClauseOperands.empty()) + return; + mlir::func::FuncOp funcOp = builder.getFunction(); auto ops = funcOp.getOps<mlir::acc::DeclareEnterOp>(); mlir::Value declareToken; diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 85398be77838..1c163e6de7e5 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1080,9 +1080,8 @@ bool ClauseProcessor::processHasDeviceAddr( [&](const omp::clause::HasDeviceAddr &clause, const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mlir::omp::ClauseMapFlags mapTypeBits = + mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::implicit; omp::ObjectList baseObjects; llvm::transform(clause.v, std::back_inserter(baseObjects), [&](const omp::Object &object) { @@ -1217,8 +1216,7 @@ bool ClauseProcessor::processLink( void ClauseProcessor::processMapObjects( lower::StatementContext &stmtCtx, mlir::Location clauseLocation, - const omp::ObjectList &objects, - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, + const omp::ObjectList &objects, mlir::omp::ClauseMapFlags mapTypeBits, std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices, llvm::SmallVectorImpl<mlir::Value> &mapVars, llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, @@ -1310,10 +1308,7 @@ void ClauseProcessor::processMapObjects( mlir::omp::MapInfoOp mapOp = utils::openmp::createMapInfoOp( firOpBuilder, location, baseOp, /*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds, - /*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{}, - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - mapTypeBits), + /*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{}, mapTypeBits, mlir::omp::VariableCaptureKind::ByRef, baseOp.getType(), /*partialMap=*/false, mapperId); @@ -1347,8 +1342,7 @@ bool ClauseProcessor::processMap( objects] = clause.t; if (attachMod) TODO(currentLocation, "ATTACH modifier is not implemented yet"); - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + mlir::omp::ClauseMapFlags mapTypeBits = mlir::omp::ClauseMapFlags::none; std::string mapperIdName = "__implicit_mapper"; // If the map type is specified, then process it else set the appropriate // default value @@ -1364,36 +1358,32 @@ bool ClauseProcessor::processMap( switch (type) { case Map::MapType::To: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + mapTypeBits |= mlir::omp::ClauseMapFlags::to; break; case Map::MapType::From: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapTypeBits |= mlir::omp::ClauseMapFlags::from; break; case Map::MapType::Tofrom: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapTypeBits |= + mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::from; break; case Map::MapType::Storage: - // alloc and release is the default map_type for the Target Data - // Ops, i.e. if no bits for map_type is supplied then alloc/release - // (aka storage in 6.0+) is implicitly assumed based on the target - // directive. Default value for Target Data and Enter Data is alloc - // and for Exit Data it is release. + mapTypeBits |= mlir::omp::ClauseMapFlags::storage; break; } if (typeMods) { // TODO: Still requires "self" modifier, an OpenMP 6.0+ feature if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Always)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + mapTypeBits |= mlir::omp::ClauseMapFlags::always; if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Present)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; + mapTypeBits |= mlir::omp::ClauseMapFlags::present; if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Close)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; + mapTypeBits |= mlir::omp::ClauseMapFlags::close; if (llvm::is_contained(*typeMods, Map::MapTypeModifier::Delete)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + mapTypeBits |= mlir::omp::ClauseMapFlags::del; if (llvm::is_contained(*typeMods, Map::MapTypeModifier::OmpxHold)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD; + mapTypeBits |= mlir::omp::ClauseMapFlags::ompx_hold; } if (iterator) { @@ -1437,12 +1427,12 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx, TODO(clauseLocation, "Iterator modifier is not supported yet"); } - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = + mlir::omp::ClauseMapFlags mapTypeBits = std::is_same_v<llvm::remove_cvref_t<decltype(clause)>, omp::clause::To> - ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO - : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + ? mlir::omp::ClauseMapFlags::to + : mlir::omp::ClauseMapFlags::from; if (expectation && *expectation == omp::clause::To::Expectation::Present) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; + mapTypeBits |= mlir::omp::ClauseMapFlags::present; processMapObjects(stmtCtx, clauseLocation, objects, mapTypeBits, parentMemberIndices, result.mapVars, mapSymbols); }; @@ -1568,8 +1558,8 @@ bool ClauseProcessor::processUseDeviceAddr( [&](const omp::clause::UseDeviceAddr &clause, const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + mlir::omp::ClauseMapFlags mapTypeBits = + mlir::omp::ClauseMapFlags::return_param; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDeviceAddrVars, useDeviceSyms); @@ -1589,8 +1579,8 @@ bool ClauseProcessor::processUseDevicePtr( [&](const omp::clause::UseDevicePtr &clause, const parser::CharBlock &source) { mlir::Location location = converter.genLocation(source); - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + mlir::omp::ClauseMapFlags mapTypeBits = + mlir::omp::ClauseMapFlags::return_param; processMapObjects(stmtCtx, location, clause.v, mapTypeBits, parentMemberIndices, result.useDevicePtrVars, useDeviceSyms); diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 9e352fa574a9..6452e39b9755 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -194,8 +194,7 @@ private: void processMapObjects( lower::StatementContext &stmtCtx, mlir::Location clauseLocation, - const omp::ObjectList &objects, - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, + const omp::ObjectList &objects, mlir::omp::ClauseMapFlags mapTypeBits, std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices, llvm::SmallVectorImpl<mlir::Value> &mapVars, llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms, diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index 2a4ebf10bcaf..d39f9dda92a2 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -16,8 +16,6 @@ #include "flang/Semantics/openmp-modifiers.h" #include "flang/Semantics/symbol.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" - #include <list> #include <optional> #include <tuple> diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 9495ea61058c..71067283d13f 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -45,7 +45,6 @@ #include "mlir/Support/StateStack.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" using namespace Fortran::lower::omp; using namespace Fortran::common::openmp; @@ -945,8 +944,7 @@ getDefaultmapIfPresent(const DefaultMapsTy &defaultMaps, mlir::Type varType) { return DefMap::ImplicitBehavior::Default; } -static std::pair<llvm::omp::OpenMPOffloadMappingFlags, - mlir::omp::VariableCaptureKind> +static std::pair<mlir::omp::ClauseMapFlags, mlir::omp::VariableCaptureKind> getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, lower::AbstractConverter &converter, const DefaultMapsTy &defaultMaps, mlir::Type varType, @@ -967,8 +965,7 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, return size <= ptrSize && align <= ptrAlign; }; - llvm::omp::OpenMPOffloadMappingFlags mapFlag = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::implicit; auto implicitBehaviour = getDefaultmapIfPresent(defaultMaps, varType); if (implicitBehaviour == DefMap::ImplicitBehavior::Default) { @@ -986,8 +983,8 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, mlir::omp::DeclareTargetCaptureClause::link && declareTargetOp.getDeclareTargetDeviceType() != mlir::omp::DeclareTargetDeviceType::nohost) { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapFlag |= mlir::omp::ClauseMapFlags::to; + mapFlag |= mlir::omp::ClauseMapFlags::from; } } else if (fir::isa_trivial(varType) || fir::isa_char(varType)) { // Scalars behave as if they were "firstprivate". @@ -996,18 +993,18 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, if (isLiteralType(varType)) { captureKind = mlir::omp::VariableCaptureKind::ByCopy; } else { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + mapFlag |= mlir::omp::ClauseMapFlags::to; } } else if (!fir::isa_builtin_cptr_type(varType)) { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapFlag |= mlir::omp::ClauseMapFlags::to; + mapFlag |= mlir::omp::ClauseMapFlags::from; } return std::make_pair(mapFlag, captureKind); } switch (implicitBehaviour) { case DefMap::ImplicitBehavior::Alloc: - return std::make_pair(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE, + return std::make_pair(mlir::omp::ClauseMapFlags::storage, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Firstprivate: @@ -1016,26 +1013,22 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, "behaviour"); break; case DefMap::ImplicitBehavior::From: - return std::make_pair(mapFlag |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM, + return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::from, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Present: - return std::make_pair(mapFlag |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT, + return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::present, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::To: - return std::make_pair(mapFlag |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO, + return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::to, (fir::isa_trivial(varType) || fir::isa_char(varType)) ? mlir::omp::VariableCaptureKind::ByCopy : mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Tofrom: - return std::make_pair(mapFlag |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO, + return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::from | + mlir::omp::ClauseMapFlags::to, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Default: @@ -1044,9 +1037,8 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, break; } - return std::make_pair(mapFlag |= - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO, + return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::from | + mlir::omp::ClauseMapFlags::to, mlir::omp::VariableCaptureKind::ByRef); } @@ -2067,37 +2059,38 @@ static void genCanonicalLoopNest( // Start lowering mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0); mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1); - mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>( - loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero); + mlir::Value isDownwards = mlir::arith::CmpIOp::create( + firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero); // Ensure we are counting upwards. If not, negate step and swap lb and ub. mlir::Value negStep = - firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar); - mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>( - loc, isDownwards, negStep, loopStepVar); - mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>( - loc, isDownwards, loopUBVar, loopLBVar); - mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>( - loc, isDownwards, loopLBVar, loopUBVar); + mlir::arith::SubIOp::create(firOpBuilder, loc, zero, loopStepVar); + mlir::Value incr = mlir::arith::SelectOp::create( + firOpBuilder, loc, isDownwards, negStep, loopStepVar); + mlir::Value lb = mlir::arith::SelectOp::create( + firOpBuilder, loc, isDownwards, loopUBVar, loopLBVar); + mlir::Value ub = mlir::arith::SelectOp::create( + firOpBuilder, loc, isDownwards, loopLBVar, loopUBVar); // Compute the trip count assuming lb <= ub. This guarantees that the result // is non-negative and we can use unsigned arithmetic. - mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>( - loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw); + mlir::Value span = mlir::arith::SubIOp::create( + firOpBuilder, loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw); mlir::Value tcMinusOne = - firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr); - mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>( - loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw); + mlir::arith::DivUIOp::create(firOpBuilder, loc, span, incr); + mlir::Value tcIfLooping = + mlir::arith::AddIOp::create(firOpBuilder, loc, tcMinusOne, one, + ::mlir::arith::IntegerOverflowFlags::nuw); // Fall back to 0 if lb > ub - mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>( - loc, mlir::arith::CmpIPredicate::slt, ub, lb); - mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>( - loc, isZeroTC, zero, tcIfLooping); + mlir::Value isZeroTC = mlir::arith::CmpIOp::create( + firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, ub, lb); + mlir::Value tripcount = mlir::arith::SelectOp::create( + firOpBuilder, loc, isZeroTC, zero, tcIfLooping); tripcounts.push_back(tripcount); // Create the CLI handle. - auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + auto newcli = mlir::omp::NewCliOp::create(firOpBuilder, loc); mlir::Value cli = newcli.getResult(); clis.push_back(cli); @@ -2130,10 +2123,10 @@ static void genCanonicalLoopNest( "Expecting all block args to have been collected by now"); for (auto j : llvm::seq<size_t>(numLoops)) { mlir::Value natIterNum = fir::getBase(blockArgs[j]); - mlir::Value scaled = firOpBuilder.create<mlir::arith::MulIOp>( - loc, natIterNum, loopStepVars[j]); - mlir::Value userVal = firOpBuilder.create<mlir::arith::AddIOp>( - loc, loopLBVars[j], scaled); + mlir::Value scaled = mlir::arith::MulIOp::create( + firOpBuilder, loc, natIterNum, loopStepVars[j]); + mlir::Value userVal = mlir::arith::AddIOp::create( + firOpBuilder, loc, loopLBVars[j], scaled); mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); @@ -2206,9 +2199,9 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter, gridGeneratees.reserve(numLoops); intratileGeneratees.reserve(numLoops); for ([[maybe_unused]] auto i : llvm::seq<int>(0, sizesClause.sizes.size())) { - auto gridCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + auto gridCLI = mlir::omp::NewCliOp::create(firOpBuilder, loc); gridGeneratees.push_back(gridCLI.getResult()); - auto intratileCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + auto intratileCLI = mlir::omp::NewCliOp::create(firOpBuilder, loc); intratileGeneratees.push_back(intratileCLI.getResult()); } @@ -2217,8 +2210,8 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter, generatees.append(gridGeneratees); generatees.append(intratileGeneratees); - firOpBuilder.create<mlir::omp::TileOp>(loc, generatees, applyees, - sizesClause.sizes); + mlir::omp::TileOp::create(firOpBuilder, loc, generatees, applyees, + sizesClause.sizes); } static void genUnrollOp(Fortran::lower::AbstractConverter &converter, @@ -2612,18 +2605,14 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType())) eleType = refType.getElementType(); - std::pair<llvm::omp::OpenMPOffloadMappingFlags, - mlir::omp::VariableCaptureKind> + std::pair<mlir::omp::ClauseMapFlags, mlir::omp::VariableCaptureKind> mapFlagAndKind = getImplicitMapTypeAndKind( firOpBuilder, converter, defaultMaps, eleType, loc, sym); mlir::Value mapOp = createMapInfoOp( firOpBuilder, converter.getCurrentLocation(), baseOp, /*varPtrPtr=*/mlir::Value{}, name.str(), bounds, /*members=*/{}, - /*membersIndex=*/mlir::ArrayAttr{}, - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - std::get<0>(mapFlagAndKind)), + /*membersIndex=*/mlir::ArrayAttr{}, std::get<0>(mapFlagAndKind), std::get<1>(mapFlagAndKind), baseOp.getType(), /*partialMap=*/false, mapperId); diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 37b926e2f23f..6487f599df72 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -273,7 +273,7 @@ mlir::Value createParentSymAndGenIntermediateMaps( semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, omp::ObjectList &objectList, llvm::SmallVectorImpl<int64_t> &indices, OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran, - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits) { + mlir::omp::ClauseMapFlags mapTypeBits) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); /// Checks if an omp::Object is an array expression with a subscript, e.g. @@ -414,11 +414,10 @@ mlir::Value createParentSymAndGenIntermediateMaps( // be safer to just pass OMP_MAP_NONE as the map type, but we may still // need some of the other map types the mapped member utilises, so for // now it's good to keep an eye on this. - llvm::omp::OpenMPOffloadMappingFlags interimMapType = mapTypeBits; - interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - interimMapType &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - interimMapType &= - ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + mlir::omp::ClauseMapFlags interimMapType = mapTypeBits; + interimMapType &= ~mlir::omp::ClauseMapFlags::to; + interimMapType &= ~mlir::omp::ClauseMapFlags::from; + interimMapType &= ~mlir::omp::ClauseMapFlags::return_param; // Create a map for the intermediate member and insert it and it's // indices into the parentMemberIndices list to track it. @@ -427,10 +426,7 @@ mlir::Value createParentSymAndGenIntermediateMaps( /*varPtrPtr=*/mlir::Value{}, asFortran, /*bounds=*/interimBounds, /*members=*/{}, - /*membersIndex=*/mlir::ArrayAttr{}, - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - interimMapType), + /*membersIndex=*/mlir::ArrayAttr{}, interimMapType, mlir::omp::VariableCaptureKind::ByRef, curValue.getType()); parentMemberIndices.memberPlacementIndices.push_back(interimIndices); @@ -563,7 +559,8 @@ void insertChildMapInfoIntoParent( // it allows this to work with enter and exit without causing MLIR // verification issues. The more appropriate thing may be to take // the "main" map type clause from the directive being used. - uint64_t mapType = indices.second.memberMap[0].getMapType(); + mlir::omp::ClauseMapFlags mapType = + indices.second.memberMap[0].getMapType(); llvm::SmallVector<mlir::Value> members; members.reserve(indices.second.memberMap.size()); diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 69499f9c7b62..ef1f37ac2552 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -134,7 +134,7 @@ mlir::Value createParentSymAndGenIntermediateMaps( semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, omp::ObjectList &objectList, llvm::SmallVectorImpl<int64_t> &indices, OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran, - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits); + mlir::omp::ClauseMapFlags mapTypeBits); omp::ObjectList gatherObjectsOf(omp::Object derivedTypeMember, semantics::SemanticsContext &semaCtx); diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 70bb43a2510b..478ab151b96d 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -749,6 +749,44 @@ struct VolatileCastOpConversion } }; +/// Lower `fir.assumed_size_extent` to constant -1 of index type. +struct AssumedSizeExtentOpConversion + : public fir::FIROpConversion<fir::AssumedSizeExtentOp> { + using FIROpConversion::FIROpConversion; + + llvm::LogicalResult + matchAndRewrite(fir::AssumedSizeExtentOp op, OpAdaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + mlir::Type ity = lowerTy().indexType(); + auto cst = fir::genConstantIndex(loc, ity, rewriter, -1); + rewriter.replaceOp(op, cst.getResult()); + return mlir::success(); + } +}; + +/// Lower `fir.is_assumed_size_extent` to integer equality with -1. +struct IsAssumedSizeExtentOpConversion + : public fir::FIROpConversion<fir::IsAssumedSizeExtentOp> { + using FIROpConversion::FIROpConversion; + + llvm::LogicalResult + matchAndRewrite(fir::IsAssumedSizeExtentOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + mlir::Value val = adaptor.getVal(); + mlir::Type valTy = val.getType(); + // Create constant -1 of the operand type. + auto negOneAttr = rewriter.getIntegerAttr(valTy, -1); + auto negOne = + mlir::LLVM::ConstantOp::create(rewriter, loc, valTy, negOneAttr); + auto cmp = mlir::LLVM::ICmpOp::create( + rewriter, loc, mlir::LLVM::ICmpPredicate::eq, val, negOne); + rewriter.replaceOp(op, cmp.getResult()); + return mlir::success(); + } +}; + /// convert value of from-type to value of to-type struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> { using FIROpConversion::FIROpConversion; @@ -1113,7 +1151,7 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> { mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); if (auto scaleSize = fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter)) - size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); + size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize); for (mlir::Value opnd : adaptor.getOperands()) size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, integerCast(loc, rewriter, ity, opnd)); @@ -4360,6 +4398,7 @@ void fir::populateFIRToLLVMConversionPatterns( AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion, BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion, BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion, + AssumedSizeExtentOpConversion, IsAssumedSizeExtentOpConversion, BoxOffsetOpConversion, BoxProcHostOpConversion, BoxRankOpConversion, BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion, CmpcOpConversion, VolatileCastOpConversion, ConvertOpConversion, diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 381b2a29c517..f74d635d50a7 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -242,10 +242,11 @@ struct TargetAllocMemOpConversion loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout()); if (auto scaleSize = fir::genAllocationScaleSize( loc, allocmemOp.getInType(), ity, rewriter)) - size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); + size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize); for (mlir::Value opnd : adaptor.getOperands().drop_front()) - size = rewriter.create<mlir::LLVM::MulOp>( - loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd)); + size = mlir::LLVM::MulOp::create( + rewriter, loc, ity, size, + integerCast(lowerTy(), loc, rewriter, ity, opnd)); auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); auto mallocTy = mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index ac285b5d403d..0776346870c7 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -872,6 +872,14 @@ public: } } + // Count the number of arguments that have to stay in place at the end of + // the argument list. + unsigned trailingArgs = 0; + if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) { + trailingArgs = + func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions(); + } + // Convert return value(s) for (auto ty : funcTy.getResults()) llvm::TypeSwitch<mlir::Type>(ty) @@ -981,6 +989,16 @@ public: } } + // Add the argument at the end if the number of trailing arguments is 0, + // otherwise insert the argument at the appropriate index. + auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) { + unsigned inputIndex = func.front().getArguments().size() - trailingArgs; + auto newArg = trailingArgs == 0 + ? func.front().addArgument(ty, loc) + : func.front().insertArgument(inputIndex, ty, loc); + return newArg; + }; + if (!func.empty()) { // If the function has a body, then apply the fixups to the arguments and // return ops as required. These fixups are done in place. @@ -1117,8 +1135,7 @@ public: // original arguments. (Boxchar arguments.) auto newBufArg = func.front().insertArgument(fixup.index, fixupType, loc); - auto newLenArg = - func.front().addArgument(trailingTys[fixup.second], loc); + auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc); auto boxTy = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg, @@ -1133,8 +1150,7 @@ public: // appended after all the original arguments. auto newProcPointerArg = func.front().insertArgument(fixup.index, fixupType, loc); - auto newLenArg = - func.front().addArgument(trailingTys[fixup.second], loc); + auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc); auto tupleType = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); fir::FirOpBuilder builder(*rewriter, getModule()); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 1712af1d1eba..d0164f32d9b6 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -5143,6 +5143,34 @@ void fir::BoxTotalElementsOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// +// IsAssumedSizeExtentOp and AssumedSizeExtentOp +//===----------------------------------------------------------------------===// + +namespace { +struct FoldIsAssumedSizeExtentOnCtor + : public mlir::OpRewritePattern<fir::IsAssumedSizeExtentOp> { + using mlir::OpRewritePattern<fir::IsAssumedSizeExtentOp>::OpRewritePattern; + mlir::LogicalResult + matchAndRewrite(fir::IsAssumedSizeExtentOp op, + mlir::PatternRewriter &rewriter) const override { + if (llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>( + op.getVal().getDefiningOp())) { + mlir::Type i1 = rewriter.getI1Type(); + rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( + op, i1, rewriter.getIntegerAttr(i1, 1)); + return mlir::success(); + } + return mlir::failure(); + } +}; +} // namespace + +void fir::IsAssumedSizeExtentOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add<FoldIsAssumedSizeExtentOnCtor>(context); +} + +//===----------------------------------------------------------------------===// // LocalitySpecifierOp //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt b/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt index d52ab097ddbf..ed8463e9b033 100644 --- a/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt +++ b/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt @@ -3,18 +3,22 @@ add_flang_library(MIFDialect MIFOps.cpp DEPENDS - MLIRIR MIFOpsIncGen LINK_LIBS FIRDialect FIRDialectSupport FIRSupport - MLIRIR - MLIRTargetLLVMIRExport LINK_COMPONENTS AsmParser AsmPrinter Remarks + + MLIR_DEPS + MLIRIR + + MLIR_LIBS + MLIRIR + MLIRTargetLLVMIRExport ) diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp index 4840a999ecd2..0d135a94588e 100644 --- a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp +++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp @@ -39,13 +39,13 @@ public: static mlir::Operation *load(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value value) { - return builder.create<fir::LoadOp>(loc, value); + return fir::LoadOp::create(builder, loc, value); } static mlir::Value placeInMemory(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value value) { - auto alloca = builder.create<fir::AllocaOp>(loc, value.getType()); - builder.create<fir::StoreOp>(loc, value, alloca); + auto alloca = fir::AllocaOp::create(builder, loc, value.getType()); + fir::StoreOp::create(builder, loc, value, alloca); return alloca; } }; diff --git a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp index 8b9991301aae..5793d46a192a 100644 --- a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp +++ b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp @@ -20,8 +20,6 @@ #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" - namespace flangomp { #define GEN_PASS_DEF_AUTOMAPTOTARGETDATAPASS #include "flang/Optimizer/OpenMP/Passes.h.inc" @@ -120,12 +118,9 @@ class AutomapToTargetDataPass builder, memOp.getLoc(), memOp.getMemref().getType(), memOp.getMemref(), TypeAttr::get(fir::unwrapRefType(memOp.getMemref().getType())), - builder.getIntegerAttr( - builder.getIntegerType(64, false), - static_cast<unsigned>( - isa<fir::StoreOp>(memOp) - ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO - : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)), + builder.getAttr<omp::ClauseMapFlagsAttr>( + isa<fir::StoreOp>(memOp) ? omp::ClauseMapFlags::to + : omp::ClauseMapFlags::del), builder.getAttr<omp::VariableCaptureKindAttr>( omp::VariableCaptureKind::ByCopy), /*var_ptr_ptr=*/mlir::Value{}, @@ -135,8 +130,8 @@ class AutomapToTargetDataPass builder.getBoolAttr(false)); clauses.mapVars.push_back(mapInfo); isa<fir::StoreOp>(memOp) - ? builder.create<omp::TargetEnterDataOp>(memOp.getLoc(), clauses) - : builder.create<omp::TargetExitDataOp>(memOp.getLoc(), clauses); + ? omp::TargetEnterDataOp::create(builder, memOp.getLoc(), clauses) + : omp::TargetExitDataOp::create(builder, memOp.getLoc(), clauses); }; for (fir::GlobalOp globalOp : automapGlobals) { diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 03ff16366a9d..1229018bd9b3 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -22,7 +22,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" namespace flangomp { #define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS @@ -568,16 +567,15 @@ private: if (auto refType = mlir::dyn_cast<fir::ReferenceType>(liveInType)) eleType = refType.getElementType(); - llvm::omp::OpenMPOffloadMappingFlags mapFlag = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::implicit; mlir::omp::VariableCaptureKind captureKind = mlir::omp::VariableCaptureKind::ByRef; if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) { captureKind = mlir::omp::VariableCaptureKind::ByCopy; } else if (!fir::isa_builtin_cptr_type(eleType)) { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapFlag |= mlir::omp::ClauseMapFlags::to; + mapFlag |= mlir::omp::ClauseMapFlags::from; } llvm::SmallVector<mlir::Value> boundsOps; @@ -587,11 +585,8 @@ private: builder, liveIn.getLoc(), rawAddr, /*varPtrPtr=*/{}, name.str(), boundsOps, /*members=*/{}, - /*membersIndex=*/mlir::ArrayAttr{}, - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - mapFlag), - captureKind, rawAddr.getType()); + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, + rawAddr.getType()); } mlir::omp::TargetOp @@ -600,7 +595,7 @@ private: mlir::omp::TargetOperands &clauseOps, mlir::omp::LoopNestOperands &loopNestClauseOps, const LiveInShapeInfoMap &liveInShapeInfoMap) const { - auto targetOp = rewriter.create<mlir::omp::TargetOp>(loc, clauseOps); + auto targetOp = mlir::omp::TargetOp::create(rewriter, loc, clauseOps); auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp); mlir::Region ®ion = targetOp.getRegion(); @@ -677,7 +672,7 @@ private: // temporary. Fortran::utils::openmp::cloneOrMapRegionOutsiders(builder, targetOp); rewriter.setInsertionPoint( - rewriter.create<mlir::omp::TerminatorOp>(targetOp.getLoc())); + mlir::omp::TerminatorOp::create(rewriter, targetOp.getLoc())); return targetOp; } @@ -720,8 +715,8 @@ private: auto shapeShiftType = fir::ShapeShiftType::get( builder.getContext(), shapeShiftOperands.size() / 2); - return builder.create<fir::ShapeShiftOp>( - liveInArg.getLoc(), shapeShiftType, shapeShiftOperands); + return fir::ShapeShiftOp::create(builder, liveInArg.getLoc(), + shapeShiftType, shapeShiftOperands); } llvm::SmallVector<mlir::Value> shapeOperands; @@ -733,11 +728,11 @@ private: ++shapeIdx; } - return builder.create<fir::ShapeOp>(liveInArg.getLoc(), shapeOperands); + return fir::ShapeOp::create(builder, liveInArg.getLoc(), shapeOperands); }(); - return builder.create<hlfir::DeclareOp>(liveInArg.getLoc(), liveInArg, - liveInName, shape); + return hlfir::DeclareOp::create(builder, liveInArg.getLoc(), liveInArg, + liveInName, shape); } mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter, @@ -747,13 +742,13 @@ private: genReductions(rewriter, mapper, loop, teamsOps); mlir::Location loc = loop.getLoc(); - auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps); + auto teamsOp = mlir::omp::TeamsOp::create(rewriter, loc, teamsOps); Fortran::common::openmp::EntryBlockArgs teamsArgs; teamsArgs.reduction.vars = teamsOps.reductionVars; Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs, teamsOp.getRegion()); - rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc)); + rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc)); for (auto [loopVar, teamsArg] : llvm::zip_equal( loop.getReduceVars(), teamsOp.getRegion().getArguments())) { @@ -766,8 +761,8 @@ private: mlir::omp::DistributeOp genDistributeOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter) const { - auto distOp = rewriter.create<mlir::omp::DistributeOp>( - loc, /*clauses=*/mlir::omp::DistributeOperands{}); + auto distOp = mlir::omp::DistributeOp::create( + rewriter, loc, /*clauses=*/mlir::omp::DistributeOperands{}); rewriter.createBlock(&distOp.getRegion()); return distOp; diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 9278e17e74d1..7b6153998423 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -282,14 +282,14 @@ fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); for (auto arg : teamsBlock->getArguments()) newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); - auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc); - rewriter.create<omp::TerminatorOp>(loc); + auto newWorkdistribute = omp::WorkdistributeOp::create(rewriter, loc); + omp::TerminatorOp::create(rewriter, loc); rewriter.createBlock(&newWorkdistribute.getRegion(), newWorkdistribute.getRegion().begin(), {}, {}); auto *cloned = rewriter.clone(*parallelize); parallelize->replaceAllUsesWith(cloned); parallelize->erase(); - rewriter.create<omp::TerminatorOp>(loc); + omp::TerminatorOp::create(rewriter, loc); changed = true; } } @@ -298,10 +298,10 @@ fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { /// Generate omp.parallel operation with an empty region. static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { - auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc); + auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc); parallelOp.setComposite(composite); rewriter.createBlock(¶llelOp.getRegion()); - rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc)); + rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc)); return; } @@ -309,7 +309,7 @@ static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { mlir::omp::DistributeOperands distributeClauseOps; auto distributeOp = - rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps); + mlir::omp::DistributeOp::create(rewriter, loc, distributeClauseOps); distributeOp.setComposite(composite); auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion()); rewriter.setInsertionPointToStart(distributeBlock); @@ -334,12 +334,12 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, const mlir::omp::LoopNestOperands &clauseOps, bool composite) { - auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc()); + auto wsloopOp = mlir::omp::WsloopOp::create(rewriter, doLoop.getLoc()); wsloopOp.setComposite(composite); rewriter.createBlock(&wsloopOp.getRegion()); auto loopNestOp = - rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps); + mlir::omp::LoopNestOp::create(rewriter, doLoop.getLoc(), clauseOps); // Clone the loop's body inside the loop nest construct using the // mapped values. @@ -351,7 +351,7 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, // Erase fir.result op of do loop and create yield op. if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) { rewriter.setInsertionPoint(terminatorOp); - rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc()); + mlir::omp::YieldOp::create(rewriter, doLoop->getLoc()); terminatorOp->erase(); } } @@ -494,15 +494,15 @@ static SmallVector<Value> convertFlatToMultiDim(OpBuilder &builder, // Convert flat index to multi-dimensional indices SmallVector<Value> indices(rank); Value temp = flatIdx; - auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1); + auto c1 = arith::ConstantIndexOp::create(builder, loc, 1); // Work backwards through dimensions (row-major order) for (int i = rank - 1; i >= 0; --i) { - Value zeroBasedIdx = builder.create<arith::RemSIOp>(loc, temp, extents[i]); + Value zeroBasedIdx = arith::RemSIOp::create(builder, loc, temp, extents[i]); // Convert to one-based index - indices[i] = builder.create<arith::AddIOp>(loc, zeroBasedIdx, c1); + indices[i] = arith::AddIOp::create(builder, loc, zeroBasedIdx, c1); if (i > 0) { - temp = builder.create<arith::DivSIOp>(loc, temp, extents[i]); + temp = arith::DivSIOp::create(builder, loc, temp, extents[i]); } } @@ -525,7 +525,7 @@ static Value CalculateTotalElements(OpBuilder &builder, Location loc, if (i == 0) { totalElems = extent; } else { - totalElems = builder.create<arith::MulIOp>(loc, totalElems, extent); + totalElems = arith::MulIOp::create(builder, loc, totalElems, extent); } } return totalElems; @@ -562,14 +562,14 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, // Load destination array box (if it's a reference) Value arrayBox = destBox; if (isa<fir::ReferenceType>(destBox.getType())) - arrayBox = builder.create<fir::LoadOp>(loc, destBox); + arrayBox = fir::LoadOp::create(builder, loc, destBox); - auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox); - Value scalar = builder.create<fir::LoadOp>(loc, scalarValue); + auto scalarValue = fir::BoxAddrOp::create(builder, loc, srcBox); + Value scalar = fir::LoadOp::create(builder, loc, scalarValue); // Calculate total number of elements (flattened) - auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0); - auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1); + auto c0 = arith::ConstantIndexOp::create(builder, loc, 0); + auto c1 = arith::ConstantIndexOp::create(builder, loc, 1); Value totalElems = CalculateTotalElements(builder, loc, arrayBox); auto *workdistributeBlock = &workdistribute.getRegion().front(); @@ -587,7 +587,7 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox, nullptr, nullptr, ValueRange{indices}, ValueRange{}); - builder.create<fir::StoreOp>(loc, scalar, elemPtr); + fir::StoreOp::create(builder, loc, scalar, elemPtr); } /// workdistributeRuntimeCallLower method finds the runtime calls @@ -719,10 +719,9 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, SmallVector<Value> outerMapInfos; // Create new mapinfo ops for the inner target region for (auto mapInfo : mapInfos) { - auto originalMapType = - (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); + mlir::omp::ClauseMapFlags originalMapType = mapInfo.getMapType(); auto originalCaptureType = mapInfo.getMapCaptureType(); - llvm::omp::OpenMPOffloadMappingFlags newMapType; + mlir::omp::ClauseMapFlags newMapType; mlir::omp::VariableCaptureKind newCaptureType; // For bycopy, we keep the same map type and capture type // For byref, we change the map type to none and keep the capture type @@ -730,7 +729,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, newMapType = originalMapType; newCaptureType = originalCaptureType; } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) { - newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + newMapType = mlir::omp::ClauseMapFlags::storage; newCaptureType = originalCaptureType; outerMapInfos.push_back(mapInfo); } else { @@ -738,11 +737,8 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, return failure(); } auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo)); - innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( - rewriter.getIntegerType(64, false), - static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - newMapType))); + innerMapInfo.setMapTypeAttr( + rewriter.getAttr<omp::ClauseMapFlagsAttr>(newMapType)); innerMapInfo.setMapCaptureType(newCaptureType); innerMapInfos.push_back(innerMapInfo.getResult()); } @@ -753,14 +749,15 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); auto devicePtrVars = targetOp.getIsDevicePtrVars(); // Create the target data op - auto targetDataOp = rewriter.create<omp::TargetDataOp>( - loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); + auto targetDataOp = + omp::TargetDataOp::create(rewriter, loc, device, ifExpr, outerMapInfos, + deviceAddrVars, devicePtrVars); auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); - rewriter.create<mlir::omp::TerminatorOp>(loc); + mlir::omp::TerminatorOp::create(rewriter, loc); rewriter.setInsertionPointToStart(taregtDataBlock); // Create the inner target op - auto newTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + auto newTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), @@ -825,20 +822,20 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, // Get the appropriate type for allocation if (isPtr(ty)) { Type intTy = rewriter.getI32Type(); - auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1); + auto one = LLVM::ConstantOp::create(rewriter, loc, intTy, 1); allocType = llvmPtrTy; - alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one); + alloc = LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, allocType, one); allocType = intTy; } else { allocType = ty; - alloc = rewriter.create<fir::AllocaOp>(loc, allocType); + alloc = fir::AllocaOp::create(rewriter, loc, allocType); } // Lambda to create mapinfo ops - auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { - return rewriter.create<omp::MapInfoOp>( - loc, alloc.getType(), alloc, TypeAttr::get(allocType), - rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), - mappingFlags), + auto getMapInfo = [&](mlir::omp::ClauseMapFlags mappingFlags, + const char *name) { + return omp::MapInfoOp::create( + rewriter, loc, alloc.getType(), alloc, TypeAttr::get(allocType), + rewriter.getAttr<omp::ClauseMapFlagsAttr>(mappingFlags), rewriter.getAttr<omp::VariableCaptureKindAttr>( omp::VariableCaptureKind::ByRef), /*varPtrPtr=*/Value{}, @@ -849,14 +846,10 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false)); }; // Create mapinfo ops. - uint64_t mapFrom = - static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - uint64_t mapTo = - static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); - auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); + auto mapInfoFrom = getMapInfo(mlir::omp::ClauseMapFlags::from, + "__flang_workdistribute_from"); + auto mapInfoTo = + getMapInfo(mlir::omp::ClauseMapFlags::to, "__flang_workdistribute_to"); return TempOmpVar{mapInfoFrom, mapInfoTo}; } @@ -987,12 +980,12 @@ static void reloadCacheAndRecompute( // If the original value is a pointer or reference, load and convert if // necessary. if (isPtr(original.getType())) { - restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg); + restored = LLVM::LoadOp::create(rewriter, loc, llvmPtrTy, newArg); if (!isa<LLVM::LLVMPointerType>(original.getType())) restored = - rewriter.create<fir::ConvertOp>(loc, original.getType(), restored); + fir::ConvertOp::create(rewriter, loc, original.getType(), restored); } else { - restored = rewriter.create<fir::LoadOp>(loc, newArg); + restored = fir::LoadOp::create(rewriter, loc, newArg); } irMapping.map(original, restored); } @@ -1061,7 +1054,7 @@ static mlir::LLVM::ConstantOp genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { mlir::Type i32Ty = rewriter.getI32Type(); mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); - return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr); + return mlir::LLVM::ConstantOp::create(rewriter, loc, i32Ty, attr); } /// Given a box descriptor, extract the base address of the data it describes. @@ -1238,8 +1231,8 @@ static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module); Value srcPtr = genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module); - Value zero = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), - builder.getI64IntegerAttr(0)); + Value zero = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getI64IntegerAttr(0)); // Generate the call to omp_target_memcpy to perform the data copy on the // device. @@ -1356,23 +1349,24 @@ static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, for (Operation *op : opsToReplace) { if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) { rewriter.setInsertionPoint(allocOp); - auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>( - allocOp.getLoc(), rewriter.getI64Type(), device, + auto ompAllocmemOp = omp::TargetAllocMemOp::create( + rewriter, allocOp.getLoc(), rewriter.getI64Type(), device, allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(), allocOp.getBindcNameAttr(), allocOp.getTypeparams(), allocOp.getShape()); - auto firConvertOp = rewriter.create<fir::ConvertOp>( - allocOp.getLoc(), allocOp.getResult().getType(), - ompAllocmemOp.getResult()); + auto firConvertOp = fir::ConvertOp::create(rewriter, allocOp.getLoc(), + allocOp.getResult().getType(), + ompAllocmemOp.getResult()); rewriter.replaceOp(allocOp, firConvertOp.getResult()); } // Replace fir.freemem with omp.target_freemem. else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) { rewriter.setInsertionPoint(freeOp); - auto firConvertOp = rewriter.create<fir::ConvertOp>( - freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref()); - rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device, - firConvertOp.getResult()); + auto firConvertOp = + fir::ConvertOp::create(rewriter, freeOp.getLoc(), + rewriter.getI64Type(), freeOp.getHeapref()); + omp::TargetFreeMemOp::create(rewriter, freeOp.getLoc(), device, + firConvertOp.getResult()); rewriter.eraseOp(freeOp); } // fir.declare changes its type when hoisting it out of omp.target to @@ -1384,8 +1378,9 @@ static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, dyn_cast<fir::ReferenceType>(clonedInType); Type clonedEleTy = clonedRefType.getElementType(); rewriter.setInsertionPoint(op); - Value loadedValue = rewriter.create<fir::LoadOp>( - clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref()); + Value loadedValue = + fir::LoadOp::create(rewriter, clonedDeclareOp.getLoc(), clonedEleTy, + clonedDeclareOp.getMemref()); clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue); } // Replace runtime calls with omp versions. @@ -1481,8 +1476,8 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, auto *targetBlock = &targetOp.getRegion().front(); SmallVector<Value> preHostEvalVars{targetOp.getHostEvalVars()}; // update the hostEvalVars of preTargetOp - omp::TargetOp preTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + omp::TargetOp preTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars, @@ -1521,13 +1516,13 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, // Create the store operation. if (isPtr(originalResult.getType())) { if (!isa<LLVM::LLVMPointerType>(toStore.getType())) - toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore); - rewriter.create<LLVM::StoreOp>(loc, toStore, newArg); + toStore = fir::ConvertOp::create(rewriter, loc, llvmPtrTy, toStore); + LLVM::StoreOp::create(rewriter, loc, toStore, newArg); } else { - rewriter.create<fir::StoreOp>(loc, toStore, newArg); + fir::StoreOp::create(rewriter, loc, toStore, newArg); } } - rewriter.create<omp::TerminatorOp>(loc); + omp::TerminatorOp::create(rewriter, loc); // Update hostEvalVars with the mapped values for the loop bounds if we have // a loopNestOp and we are not generating code for the target device. @@ -1571,8 +1566,8 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, hostEvalVars.steps.end()); } // Create the isolated target op - omp::TargetOp isolatedTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + omp::TargetOp isolatedTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), @@ -1598,7 +1593,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, // Clone the original operations. rewriter.clone(*splitBeforeOp, isolatedMapping); - rewriter.create<omp::TerminatorOp>(loc); + omp::TerminatorOp::create(rewriter, loc); // update the loop bounds in the isolatedTargetOp if we have host_eval vars // and we are not generating code for the target device. @@ -1651,8 +1646,8 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, auto *targetBlock = &targetOp.getRegion().front(); SmallVector<Value> postHostEvalVars{targetOp.getHostEvalVars()}; // Create the post target op - omp::TargetOp postTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + omp::TargetOp postTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars, diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 2bbd8034fa52..bd07d7fe01b8 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -43,7 +43,6 @@ #include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringSet.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cstddef> @@ -350,7 +349,7 @@ class MapInfoFinalizationPass /// the descriptor map onto the base address map. mlir::omp::MapInfoOp genBaseAddrMap(mlir::Value descriptor, mlir::OperandRange bounds, - int64_t mapType, + mlir::omp::ClauseMapFlags mapType, fir::FirOpBuilder &builder) { mlir::Location loc = descriptor.getLoc(); mlir::Value baseAddrAddr = fir::BoxOffsetOp::create( @@ -368,7 +367,7 @@ class MapInfoFinalizationPass return mlir::omp::MapInfoOp::create( builder, loc, baseAddrAddr.getType(), descriptor, mlir::TypeAttr::get(underlyingVarType), - builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(mapType), builder.getAttr<mlir::omp::VariableCaptureKindAttr>( mlir::omp::VariableCaptureKind::ByRef), baseAddrAddr, /*members=*/mlir::SmallVector<mlir::Value>{}, @@ -428,22 +427,22 @@ class MapInfoFinalizationPass /// allowing `to` mappings, and `target update` not allowing both `to` and /// `from` simultaneously. We currently try to maintain the `implicit` flag /// where necessary, although it does not seem strictly required. - unsigned long getDescriptorMapType(unsigned long mapTypeFlag, - mlir::Operation *target) { - using mapFlags = llvm::omp::OpenMPOffloadMappingFlags; + mlir::omp::ClauseMapFlags + getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag, + mlir::Operation *target) { + using mapFlags = mlir::omp::ClauseMapFlags; if (llvm::isa_and_nonnull<mlir::omp::TargetExitDataOp, mlir::omp::TargetUpdateOp>(target)) return mapTypeFlag; - mapFlags flags = mapFlags::OMP_MAP_TO | - (mapFlags(mapTypeFlag) & - (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS)); + mapFlags flags = + mapFlags::to | (mapTypeFlag & (mapFlags::implicit | mapFlags::always)); // For unified_shared_memory, we additionally add `CLOSE` on the descriptor // to ensure device-local placement where required by tests relying on USM + // close semantics. if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>())) - flags |= mapFlags::OMP_MAP_CLOSE; - return llvm::to_underlying(flags); + flags |= mapFlags::close; + return flags; } /// Check if the mapOp is present in the HasDeviceAddr clause on @@ -493,11 +492,6 @@ class MapInfoFinalizationPass mlir::Value boxAddr = fir::BoxOffsetOp::create( builder, loc, op.getVarPtr(), fir::BoxFieldAttr::base_addr); - uint64_t mapTypeToImplicit = static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT); - mlir::ArrayAttr newMembersAttr; llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}}; newMembersAttr = builder.create2DI64ArrayAttr(memberIdx); @@ -506,8 +500,9 @@ class MapInfoFinalizationPass mlir::omp::MapInfoOp memberMapInfoOp = mlir::omp::MapInfoOp::create( builder, op.getLoc(), varPtr.getType(), varPtr, mlir::TypeAttr::get(boxCharType.getEleTy()), - builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false), - mapTypeToImplicit), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + mlir::omp::ClauseMapFlags::to | + mlir::omp::ClauseMapFlags::implicit), builder.getAttr<mlir::omp::VariableCaptureKindAttr>( mlir::omp::VariableCaptureKind::ByRef), /*varPtrPtr=*/boxAddr, @@ -568,12 +563,9 @@ class MapInfoFinalizationPass mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx); // Force CLOSE in USM paths so the pointer gets device-local placement // when required by tests relying on USM + close semantics. - uint64_t mapTypeVal = - op.getMapType() | - llvm::to_underlying( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); - mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr( - builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal); + mlir::omp::ClauseMapFlagsAttr mapTypeAttr = + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + op.getMapType() | mlir::omp::ClauseMapFlags::close); mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create( builder, loc, coord.getType(), coord, @@ -683,17 +675,16 @@ class MapInfoFinalizationPass // one place in the code may differ from that address in another place. // The contents of the descriptor (the base address in particular) will // remain unchanged though. - uint64_t mapType = op.getMapType(); + mlir::omp::ClauseMapFlags mapType = op.getMapType(); if (isHasDeviceAddrFlag) { - mapType |= llvm::to_underlying( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); + mapType |= mlir::omp::ClauseMapFlags::always; } mlir::omp::MapInfoOp newDescParentMapOp = mlir::omp::MapInfoOp::create( builder, op->getLoc(), op.getResult().getType(), descriptor, mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())), - builder.getIntegerAttr(builder.getIntegerType(64, false), - getDescriptorMapType(mapType, target)), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + getDescriptorMapType(mapType, target)), op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers, newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{}, /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), @@ -892,20 +883,16 @@ class MapInfoFinalizationPass if (explicitMappingPresent(op, targetDataOp)) return; - mlir::omp::MapInfoOp newDescParentMapOp = - builder.create<mlir::omp::MapInfoOp>( - op->getLoc(), op.getResult().getType(), op.getVarPtr(), - op.getVarTypeAttr(), - builder.getIntegerAttr( - builder.getIntegerType(64, false), - llvm::to_underlying( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)), - op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, - mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{}, - /*bounds=*/mlir::SmallVector<mlir::Value>{}, - /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), - /*partial_map=*/builder.getBoolAttr(false)); + mlir::omp::MapInfoOp newDescParentMapOp = mlir::omp::MapInfoOp::create( + builder, op->getLoc(), op.getResult().getType(), op.getVarPtr(), + op.getVarTypeAttr(), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::always), + op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, + mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{}, + /*bounds=*/mlir::SmallVector<mlir::Value>{}, + /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), + /*partial_map=*/builder.getBoolAttr(false)); targetDataOp.getMapVarsMutable().append({newDescParentMapOp}); } @@ -957,14 +944,13 @@ class MapInfoFinalizationPass // need to see how well this alteration works. auto loadBaseAddr = builder.loadIfRef(op->getLoc(), baseAddr.getVarPtrPtr()); - mlir::omp::MapInfoOp newBaseAddrMapOp = - builder.create<mlir::omp::MapInfoOp>( - op->getLoc(), loadBaseAddr.getType(), loadBaseAddr, - baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(), - baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members, - membersAttr, baseAddr.getBounds(), - /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), - /*partial_map=*/builder.getBoolAttr(false)); + mlir::omp::MapInfoOp newBaseAddrMapOp = mlir::omp::MapInfoOp::create( + builder, op->getLoc(), loadBaseAddr.getType(), loadBaseAddr, + baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(), + baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members, membersAttr, + baseAddr.getBounds(), + /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), + /*partial_map=*/builder.getBoolAttr(false)); op.replaceAllUsesWith(newBaseAddrMapOp.getResult()); op->erase(); baseAddr.erase(); @@ -1240,9 +1226,8 @@ class MapInfoFinalizationPass // we need to change this check for early return OR live with // over-mapping. bool hasImplicitMap = - (llvm::omp::OpenMPOffloadMappingFlags(op.getMapType()) & - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT) == - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + (op.getMapType() & mlir::omp::ClauseMapFlags::implicit) == + mlir::omp::ClauseMapFlags::implicit; if (hasImplicitMap) return; diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp index 30328573b74f..0972861b8450 100644 --- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp +++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp @@ -35,7 +35,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/Debug.h" #include <type_traits> @@ -70,9 +69,6 @@ class MapsForPrivatizedSymbolsPass return size <= ptrSize && align <= ptrAlign; }; - uint64_t mapTypeTo = static_cast< - std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); Operation *definingOp = var.getDefiningOp(); Value varPtr = var; @@ -122,8 +118,7 @@ class MapsForPrivatizedSymbolsPass builder, loc, varPtr.getType(), varPtr, TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType()) .getElementType()), - builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false), - mapTypeTo), + builder.getAttr<omp::ClauseMapFlagsAttr>(omp::ClauseMapFlags::to), builder.getAttr<omp::VariableCaptureKindAttr>(captureKind), /*varPtrPtr=*/Value{}, /*members=*/SmallVector<Value>{}, diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 6dae39b26976..103e736accca 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -426,6 +426,12 @@ void createMLIRToLLVMPassPipeline(mlir::PassManager &pm, // Add codegen pass pipeline. fir::createDefaultFIRCodeGenPassPipeline(pm, config, inputFilename); + + // Run a pass to prepare for translation of delayed privatization in the + // context of deferred target tasks. + addPassConditionally(pm, disableFirToLlvmIr, [&]() { + return mlir::omp::createPrepareForOMPOffloadPrivatizationPass(); + }); } } // namespace fir diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp index 92390e4a3a23..2f33d8956479 100644 --- a/flang/lib/Optimizer/Support/Utils.cpp +++ b/flang/lib/Optimizer/Support/Utils.cpp @@ -66,7 +66,7 @@ fir::genConstantIndex(mlir::Location loc, mlir::Type ity, mlir::ConversionPatternRewriter &rewriter, std::int64_t offset) { auto cattr = rewriter.getI64IntegerAttr(offset); - return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr); + return mlir::LLVM::ConstantOp::create(rewriter, loc, ity, cattr); } mlir::Value @@ -125,9 +125,9 @@ mlir::Value fir::integerCast(const fir::LLVMTypeConverter &converter, return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val); } else { if (toSize < fromSize) - return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val); + return mlir::LLVM::TruncOp::create(rewriter, loc, ty, val); if (toSize > fromSize) - return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val); + return mlir::LLVM::SExtOp::create(rewriter, loc, ty, val); } return val; } diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp index ed9a2ae11f0d..5bf783db92bf 100644 --- a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp +++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp @@ -832,8 +832,8 @@ static mlir::Type getEleTy(mlir::Type ty) { static bool isAssumedSize(llvm::SmallVectorImpl<mlir::Value> &extents) { if (extents.empty()) return false; - auto cstLen = fir::getIntIfConstant(extents.back()); - return cstLen.has_value() && *cstLen == -1; + return llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>( + extents.back().getDefiningOp()); } // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 759e3a65dd24..8d00272b09f4 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -454,6 +454,8 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> { mlir::LogicalResult matchAndRewrite(fir::DeclareOp op, mlir::PatternRewriter &rewriter) const override { + if (op.getResult().getUsers().empty()) + return success(); if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) { if (auto global = symTab.lookup<fir::GlobalOp>( addrOfOp.getSymbol().getRootReference().getValue())) { @@ -963,6 +965,8 @@ public: } target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) { + if (op.getResult().getUsers().empty()) + return true; if (inDeviceContext(op)) return true; if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) { diff --git a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp index 9dfe26cbf589..fea511cc63a3 100644 --- a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp +++ b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp @@ -99,10 +99,6 @@ void FunctionAttrPass::runOnOperation() { func->setAttr( mlir::LLVM::LLVMFuncOp::getNoSignedZerosFpMathAttrName(llvmFuncOpName), mlir::BoolAttr::get(context, true)); - if (unsafeFPMath) - func->setAttr( - mlir::LLVM::LLVMFuncOp::getUnsafeFpMathAttrName(llvmFuncOpName), - mlir::BoolAttr::get(context, true)); if (!reciprocals.empty()) func->setAttr( mlir::LLVM::LLVMFuncOp::getReciprocalEstimatesAttrName(llvmFuncOpName), diff --git a/flang/lib/Parser/openacc-parsers.cpp b/flang/lib/Parser/openacc-parsers.cpp index ad035e6ade32..0dec56521f75 100644 --- a/flang/lib/Parser/openacc-parsers.cpp +++ b/flang/lib/Parser/openacc-parsers.cpp @@ -75,21 +75,21 @@ TYPE_PARSER( // tile size is one of: // * (represented as an empty std::optional<ScalarIntExpr>) // constant-int-expr -TYPE_PARSER(construct<AccTileExpr>(scalarIntConstantExpr) || +TYPE_PARSER(sourced(construct<AccTileExpr>(scalarIntConstantExpr) || construct<AccTileExpr>( - "*" >> construct<std::optional<ScalarIntConstantExpr>>())) + "*" >> construct<std::optional<ScalarIntConstantExpr>>()))) TYPE_PARSER(construct<AccTileExprList>(nonemptyList(Parser<AccTileExpr>{}))) // 2.9 (1979-1982) gang-arg is one of : // [num:]int-expr // dim:int-expr // static:size-expr -TYPE_PARSER(construct<AccGangArg>(construct<AccGangArg::Static>( - "STATIC: " >> Parser<AccSizeExpr>{})) || +TYPE_PARSER(sourced(construct<AccGangArg>(construct<AccGangArg::Static>( + "STATIC: " >> Parser<AccSizeExpr>{})) || construct<AccGangArg>( construct<AccGangArg::Dim>("DIM: " >> scalarIntExpr)) || construct<AccGangArg>( - construct<AccGangArg::Num>(maybe("NUM: "_tok) >> scalarIntExpr))) + construct<AccGangArg::Num>(maybe("NUM: "_tok) >> scalarIntExpr)))) // 2.9 gang-arg-list TYPE_PARSER( @@ -101,7 +101,7 @@ TYPE_PARSER(construct<AccCollapseArg>( // 2.5.15 Reduction, F'2023 R1131, and CUF reduction-op // Operator for reduction -TYPE_PARSER(sourced(construct<ReductionOperator>( +TYPE_PARSER(construct<ReductionOperator>( first("+" >> pure(ReductionOperator::Operator::Plus), "*" >> pure(ReductionOperator::Operator::Multiply), "MAX" >> pure(ReductionOperator::Operator::Max), @@ -112,32 +112,32 @@ TYPE_PARSER(sourced(construct<ReductionOperator>( ".AND." >> pure(ReductionOperator::Operator::And), ".OR." >> pure(ReductionOperator::Operator::Or), ".EQV." >> pure(ReductionOperator::Operator::Eqv), - ".NEQV." >> pure(ReductionOperator::Operator::Neqv))))) + ".NEQV." >> pure(ReductionOperator::Operator::Neqv)))) // 2.15.1 Bind clause -TYPE_PARSER(sourced(construct<AccBindClause>(name)) || - sourced(construct<AccBindClause>(scalarDefaultCharExpr))) +TYPE_PARSER(sourced(construct<AccBindClause>(name) || + construct<AccBindClause>(scalarDefaultCharExpr))) // 2.5.16 Default clause -TYPE_PARSER(construct<AccDefaultClause>( +TYPE_PARSER(sourced(construct<AccDefaultClause>( first("NONE" >> pure(llvm::acc::DefaultValue::ACC_Default_none), - "PRESENT" >> pure(llvm::acc::DefaultValue::ACC_Default_present)))) + "PRESENT" >> pure(llvm::acc::DefaultValue::ACC_Default_present))))) // SELF clause is either a simple optional condition for compute construct // or a synonym of the HOST clause for the update directive 2.14.4 holding // an object list. -TYPE_PARSER( +TYPE_PARSER(sourced( construct<AccSelfClause>(Parser<AccObjectList>{}) / lookAhead(")"_tok) || - construct<AccSelfClause>(scalarLogicalExpr / lookAhead(")"_tok)) || + construct<AccSelfClause>(scalarLogicalExpr) / lookAhead(")"_tok) || construct<AccSelfClause>( recovery(fail<std::optional<ScalarLogicalExpr>>( "logical expression or object list expected"_err_en_US), - SkipTo<')'>{} >> pure<std::optional<ScalarLogicalExpr>>()))) + SkipTo<')'>{} >> pure<std::optional<ScalarLogicalExpr>>())))) // Modifier for copyin, copyout, cache and create -TYPE_PARSER(construct<AccDataModifier>( +TYPE_PARSER(sourced(construct<AccDataModifier>( first("ZERO:" >> pure(AccDataModifier::Modifier::Zero), - "READONLY:" >> pure(AccDataModifier::Modifier::ReadOnly)))) + "READONLY:" >> pure(AccDataModifier::Modifier::ReadOnly))))) // Combined directives TYPE_PARSER(sourced(construct<AccCombinedDirective>( @@ -166,14 +166,13 @@ TYPE_PARSER(sourced(construct<AccStandaloneDirective>( TYPE_PARSER(sourced(construct<AccLoopDirective>( first("LOOP" >> pure(llvm::acc::Directive::ACCD_loop))))) -TYPE_PARSER(construct<AccBeginLoopDirective>( - sourced(Parser<AccLoopDirective>{}), Parser<AccClauseList>{})) +TYPE_PARSER(sourced(construct<AccBeginLoopDirective>( + Parser<AccLoopDirective>{}, Parser<AccClauseList>{}))) TYPE_PARSER(construct<AccEndLoop>("END LOOP"_tok)) TYPE_PARSER(construct<OpenACCLoopConstruct>( - sourced(Parser<AccBeginLoopDirective>{} / endAccLine), - maybe(Parser<DoConstruct>{}), + Parser<AccBeginLoopDirective>{} / endAccLine, maybe(Parser<DoConstruct>{}), maybe(startAccLine >> Parser<AccEndLoop>{} / endAccLine))) // 2.15.1 Routine directive @@ -186,8 +185,8 @@ TYPE_PARSER(sourced( parenthesized(Parser<AccObjectListWithModifier>{})))) // 2.11 Combined constructs -TYPE_PARSER(construct<AccBeginCombinedDirective>( - sourced(Parser<AccCombinedDirective>{}), Parser<AccClauseList>{})) +TYPE_PARSER(sourced(construct<AccBeginCombinedDirective>( + Parser<AccCombinedDirective>{}, Parser<AccClauseList>{}))) // 2.12 Atomic constructs TYPE_PARSER(construct<AccEndAtomic>(startAccLine >> "END ATOMIC"_tok)) @@ -213,10 +212,10 @@ TYPE_PARSER("ATOMIC" >> statement(assignmentStmt), Parser<AccEndAtomic>{} / endAccLine)) TYPE_PARSER( - sourced(construct<OpenACCAtomicConstruct>(Parser<AccAtomicRead>{})) || - sourced(construct<OpenACCAtomicConstruct>(Parser<AccAtomicCapture>{})) || - sourced(construct<OpenACCAtomicConstruct>(Parser<AccAtomicWrite>{})) || - sourced(construct<OpenACCAtomicConstruct>(Parser<AccAtomicUpdate>{}))) + sourced(construct<OpenACCAtomicConstruct>(Parser<AccAtomicRead>{}) || + construct<OpenACCAtomicConstruct>(Parser<AccAtomicCapture>{}) || + construct<OpenACCAtomicConstruct>(Parser<AccAtomicWrite>{}) || + construct<OpenACCAtomicConstruct>(Parser<AccAtomicUpdate>{}))) // 2.13 Declare constructs TYPE_PARSER(sourced(construct<AccDeclarativeDirective>( @@ -250,18 +249,18 @@ TYPE_PARSER(construct<OpenACCBlockConstruct>( pure(llvm::acc::Directive::ACCD_data)))))) // Standalone constructs -TYPE_PARSER(construct<OpenACCStandaloneConstruct>( - sourced(Parser<AccStandaloneDirective>{}), Parser<AccClauseList>{})) +TYPE_PARSER(sourced(construct<OpenACCStandaloneConstruct>( + Parser<AccStandaloneDirective>{}, Parser<AccClauseList>{}))) // Standalone declarative constructs -TYPE_PARSER(construct<OpenACCStandaloneDeclarativeConstruct>( - sourced(Parser<AccDeclarativeDirective>{}), Parser<AccClauseList>{})) +TYPE_PARSER(sourced(construct<OpenACCStandaloneDeclarativeConstruct>( + Parser<AccDeclarativeDirective>{}, Parser<AccClauseList>{}))) TYPE_PARSER(startAccLine >> withMessage("expected OpenACC directive"_err_en_US, - first(sourced(construct<OpenACCDeclarativeConstruct>( - Parser<OpenACCStandaloneDeclarativeConstruct>{})), - sourced(construct<OpenACCDeclarativeConstruct>( + sourced(first(construct<OpenACCDeclarativeConstruct>( + Parser<OpenACCStandaloneDeclarativeConstruct>{}), + construct<OpenACCDeclarativeConstruct>( Parser<OpenACCRoutineConstruct>{}))))) TYPE_PARSER(sourced(construct<OpenACCEndConstruct>( @@ -293,9 +292,9 @@ TYPE_PARSER(startAccLine >> "SERIAL"_tok >> maybe("LOOP"_tok) >> pure(llvm::acc::Directive::ACCD_serial_loop)))))) -TYPE_PARSER(construct<OpenACCCombinedConstruct>( - sourced(Parser<AccBeginCombinedDirective>{} / endAccLine), +TYPE_PARSER(sourced(construct<OpenACCCombinedConstruct>( + Parser<AccBeginCombinedDirective>{} / endAccLine, maybe(Parser<DoConstruct>{}), - maybe(Parser<AccEndCombinedDirective>{} / endAccLine))) + maybe(Parser<AccEndCombinedDirective>{} / endAccLine)))) } // namespace Fortran::parser diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index 56fcac3e741a..c0472ad3c069 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -1835,8 +1835,8 @@ TYPE_PARSER(sourced(construct<OpenMPDeclareMapperConstruct>( TYPE_PARSER(construct<OmpReductionCombiner>(Parser<AssignmentStmt>{}) || construct<OmpReductionCombiner>(Parser<FunctionReference>{})) -TYPE_PARSER(construct<OpenMPCriticalConstruct>( - OmpBlockConstructParser{llvm::omp::Directive::OMPD_critical})) +TYPE_PARSER(sourced(construct<OpenMPCriticalConstruct>( + OmpBlockConstructParser{llvm::omp::Directive::OMPD_critical}))) // 2.11.3 Executable Allocate directive TYPE_PARSER( @@ -1911,12 +1911,12 @@ TYPE_PARSER( Parser<OmpMetadirectiveDirective>{})) / endOmpLine)) -TYPE_PARSER(construct<OpenMPAssumeConstruct>( - sourced(OmpBlockConstructParser{llvm::omp::Directive::OMPD_assume}))) +TYPE_PARSER(sourced(construct<OpenMPAssumeConstruct>( + OmpBlockConstructParser{llvm::omp::Directive::OMPD_assume}))) // Block Construct #define MakeBlockConstruct(dir) \ - construct<OmpBlockConstruct>(OmpBlockConstructParser{dir}) + sourced(construct<OmpBlockConstruct>(OmpBlockConstructParser{dir})) TYPE_PARSER( // MakeBlockConstruct(llvm::omp::Directive::OMPD_masked) || MakeBlockConstruct(llvm::omp::Directive::OMPD_master) || diff --git a/flang/lib/Parser/prescan.cpp b/flang/lib/Parser/prescan.cpp index 66e5b2cbd5c7..df0372bbe554 100644 --- a/flang/lib/Parser/prescan.cpp +++ b/flang/lib/Parser/prescan.cpp @@ -140,17 +140,9 @@ void Prescanner::Statement() { CHECK(*at_ == '!'); } std::optional<int> condOffset; - if (InOpenMPConditionalLine()) { + if (InOpenMPConditionalLine()) { // !$ condOffset = 2; - } else if (directiveSentinel_[0] == '@' && directiveSentinel_[1] == 'c' && - directiveSentinel_[2] == 'u' && directiveSentinel_[3] == 'f' && - directiveSentinel_[4] == '\0') { - // CUDA conditional compilation line. - condOffset = 5; - } else if (directiveSentinel_[0] == '@' && directiveSentinel_[1] == 'a' && - directiveSentinel_[2] == 'c' && directiveSentinel_[3] == 'c' && - directiveSentinel_[4] == '\0') { - // OpenACC conditional compilation line. + } else if (InOpenACCOrCUDAConditionalLine()) { // !@acc or !@cuf condOffset = 5; } if (condOffset && !preprocessingOnly_) { @@ -166,7 +158,8 @@ void Prescanner::Statement() { } else { // Compiler directive. Emit normalized sentinel, squash following spaces. // Conditional compilation lines (!$) take this path in -E mode too - // so that -fopenmp only has to appear on the later compilation. + // so that -fopenmp only has to appear on the later compilation + // (ditto for !@cuf and !@acc). EmitChar(tokens, '!'); ++at_, ++column_; for (const char *sp{directiveSentinel_}; *sp != '\0'; @@ -202,7 +195,7 @@ void Prescanner::Statement() { } tokens.CloseToken(); SkipSpaces(); - if (InOpenMPConditionalLine() && inFixedForm_ && !tabInCurrentLine_ && + if (InConditionalLine() && inFixedForm_ && !tabInCurrentLine_ && column_ == 6 && *at_ != '\n') { // !$ 0 - turn '0' into a space // !$ 1 - turn '1' into '&' @@ -347,7 +340,7 @@ void Prescanner::Statement() { while (CompilerDirectiveContinuation(tokens, line.sentinel)) { newlineProvenance = GetCurrentProvenance(); } - if (preprocessingOnly_ && inFixedForm_ && InOpenMPConditionalLine() && + if (preprocessingOnly_ && inFixedForm_ && InConditionalLine() && nextLine_ < limit_) { // In -E mode, when the line after !$ conditional compilation is a // regular fixed form continuation line, append a '&' to the line. @@ -1360,11 +1353,10 @@ const char *Prescanner::FixedFormContinuationLine(bool atNewline) { features_.IsEnabled(LanguageFeature::OldDebugLines))) && nextLine_[1] == ' ' && nextLine_[2] == ' ' && nextLine_[3] == ' ' && nextLine_[4] == ' '}; - if (InCompilerDirective() && - !(InOpenMPConditionalLine() && !preprocessingOnly_)) { + if (InCompilerDirective() && !(InConditionalLine() && !preprocessingOnly_)) { // !$ under -E is not continued, but deferred to later compilation if (IsFixedFormCommentChar(col1) && - !(InOpenMPConditionalLine() && preprocessingOnly_)) { + !(InConditionalLine() && preprocessingOnly_)) { int j{1}; for (; j < 5; ++j) { char ch{directiveSentinel_[j - 1]}; @@ -1443,7 +1435,7 @@ const char *Prescanner::FreeFormContinuationLine(bool ampersand) { } p = SkipWhiteSpaceIncludingEmptyMacros(p); if (InCompilerDirective()) { - if (InOpenMPConditionalLine()) { + if (InConditionalLine()) { if (preprocessingOnly_) { // in -E mode, don't treat !$ as a continuation return nullptr; diff --git a/flang/lib/Parser/prescan.h b/flang/lib/Parser/prescan.h index fc38adb92653..5e7481781d94 100644 --- a/flang/lib/Parser/prescan.h +++ b/flang/lib/Parser/prescan.h @@ -171,7 +171,17 @@ private: bool InOpenMPConditionalLine() const { return directiveSentinel_ && directiveSentinel_[0] == '$' && !directiveSentinel_[1]; - ; + } + bool InOpenACCOrCUDAConditionalLine() const { + return directiveSentinel_ && directiveSentinel_[0] == '@' && + ((directiveSentinel_[1] == 'a' && directiveSentinel_[2] == 'c' && + directiveSentinel_[3] == 'c') || + (directiveSentinel_[1] == 'c' && directiveSentinel_[2] == 'u' && + directiveSentinel_[3] == 'f')) && + directiveSentinel_[4] == '\0'; + } + bool InConditionalLine() const { + return InOpenMPConditionalLine() || InOpenACCOrCUDAConditionalLine(); } bool InFixedFormSource() const { return inFixedForm_ && !inPreprocessorDirective_ && !InCompilerDirective(); diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index 515121af04d5..2707921ca1df 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -286,7 +286,7 @@ static std::optional<AnalyzedCondStmt> AnalyzeConditionalStmt( // Extract the evaluate::Expr from ScalarLogicalExpr. auto getFromLogical{[](const parser::ScalarLogicalExpr &logical) { // ScalarLogicalExpr is Scalar<Logical<common::Indirection<Expr>>> - const parser::Expr &expr{logical.thing.thing.value()}; + auto &expr{parser::UnwrapRef<parser::Expr>(logical)}; return GetEvaluateExpr(expr); }}; diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index e2f8796aeb5e..41416304c1ea 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -61,6 +61,124 @@ namespace Fortran::semantics { using namespace Fortran::semantics::omp; using namespace Fortran::parser::omp; +OmpStructureChecker::OmpStructureChecker(SemanticsContext &context) + : DirectiveStructureChecker(context, +#define GEN_FLANG_DIRECTIVE_CLAUSE_MAP +#include "llvm/Frontend/OpenMP/OMP.inc" + ) { + scopeStack_.push_back(&context.globalScope()); +} + +bool OmpStructureChecker::Enter(const parser::MainProgram &x) { + using StatementProgramStmt = parser::Statement<parser::ProgramStmt>; + if (auto &stmt{std::get<std::optional<StatementProgramStmt>>(x.t)}) { + scopeStack_.push_back(stmt->statement.v.symbol->scope()); + } else { + for (const Scope &scope : context_.globalScope().children()) { + // There can only be one main program. + if (scope.kind() == Scope::Kind::MainProgram) { + scopeStack_.push_back(&scope); + break; + } + } + } + return true; +} + +void OmpStructureChecker::Leave(const parser::MainProgram &x) { + scopeStack_.pop_back(); +} + +bool OmpStructureChecker::Enter(const parser::BlockData &x) { + // The BLOCK DATA name is optional, so we need to look for the + // corresponding scope in the global scope. + auto &stmt{std::get<parser::Statement<parser::BlockDataStmt>>(x.t)}; + if (auto &name{stmt.statement.v}) { + scopeStack_.push_back(name->symbol->scope()); + } else { + for (const Scope &scope : context_.globalScope().children()) { + if (scope.kind() == Scope::Kind::BlockData) { + if (scope.symbol()->name().empty()) { + scopeStack_.push_back(&scope); + break; + } + } + } + } + return true; +} + +void OmpStructureChecker::Leave(const parser::BlockData &x) { + scopeStack_.pop_back(); +} + +bool OmpStructureChecker::Enter(const parser::Module &x) { + auto &stmt{std::get<parser::Statement<parser::ModuleStmt>>(x.t)}; + const Symbol *sym{stmt.statement.v.symbol}; + scopeStack_.push_back(sym->scope()); + return true; +} + +void OmpStructureChecker::Leave(const parser::Module &x) { + scopeStack_.pop_back(); +} + +bool OmpStructureChecker::Enter(const parser::Submodule &x) { + auto &stmt{std::get<parser::Statement<parser::SubmoduleStmt>>(x.t)}; + const Symbol *sym{std::get<parser::Name>(stmt.statement.t).symbol}; + scopeStack_.push_back(sym->scope()); + return true; +} + +void OmpStructureChecker::Leave(const parser::Submodule &x) { + scopeStack_.pop_back(); +} + +// Function/subroutine subprogram nodes don't appear in INTERFACEs, but +// the subprogram/end statements do. +bool OmpStructureChecker::Enter(const parser::SubroutineStmt &x) { + const Symbol *sym{std::get<parser::Name>(x.t).symbol}; + scopeStack_.push_back(sym->scope()); + return true; +} + +bool OmpStructureChecker::Enter(const parser::EndSubroutineStmt &x) { + scopeStack_.pop_back(); + return true; +} + +bool OmpStructureChecker::Enter(const parser::FunctionStmt &x) { + const Symbol *sym{std::get<parser::Name>(x.t).symbol}; + scopeStack_.push_back(sym->scope()); + return true; +} + +bool OmpStructureChecker::Enter(const parser::EndFunctionStmt &x) { + scopeStack_.pop_back(); + return true; +} + +bool OmpStructureChecker::Enter(const parser::BlockConstruct &x) { + auto &specPart{std::get<parser::BlockSpecificationPart>(x.t)}; + auto &execPart{std::get<parser::Block>(x.t)}; + if (auto &&source{parser::GetSource(specPart)}) { + scopeStack_.push_back(&context_.FindScope(*source)); + } else if (auto &&source{parser::GetSource(execPart)}) { + scopeStack_.push_back(&context_.FindScope(*source)); + } + return true; +} + +void OmpStructureChecker::Leave(const parser::BlockConstruct &x) { + auto &specPart{std::get<parser::BlockSpecificationPart>(x.t)}; + auto &execPart{std::get<parser::Block>(x.t)}; + if (auto &&source{parser::GetSource(specPart)}) { + scopeStack_.push_back(&context_.FindScope(*source)); + } else if (auto &&source{parser::GetSource(execPart)}) { + scopeStack_.push_back(&context_.FindScope(*source)); + } +} + // Use when clause falls under 'struct OmpClause' in 'parse-tree.h'. #define CHECK_SIMPLE_CLAUSE(X, Y) \ void OmpStructureChecker::Enter(const parser::OmpClause::X &) { \ @@ -362,6 +480,36 @@ bool OmpStructureChecker::IsNestedInDirective(llvm::omp::Directive directive) { return false; } +bool OmpStructureChecker::InTargetRegion() { + if (IsNestedInDirective(llvm::omp::Directive::OMPD_target)) { + // Return true even for device_type(host). + return true; + } + for (const Scope *scope : llvm::reverse(scopeStack_)) { + if (const auto *symbol{scope->symbol()}) { + if (symbol->test(Symbol::Flag::OmpDeclareTarget)) { + return true; + } + } + } + return false; +} + +bool OmpStructureChecker::HasRequires(llvm::omp::Clause req) { + const Scope &unit{GetProgramUnit(*scopeStack_.back())}; + return common::visit( + [&](const auto &details) { + if constexpr (std::is_convertible_v<decltype(details), + const WithOmpDeclarative &>) { + if (auto *reqs{details.ompRequires()}) { + return reqs->test(req); + } + } + return false; + }, + DEREF(unit.symbol()).details()); +} + void OmpStructureChecker::CheckVariableListItem( const SymbolSourceMap &symbols) { for (auto &[symbol, source] : symbols) { @@ -1562,52 +1710,95 @@ void OmpStructureChecker::Leave(const parser::OpenMPRequiresConstruct &) { dirContext_.pop_back(); } -void OmpStructureChecker::CheckAlignValue(const parser::OmpClause &clause) { - if (auto *align{std::get_if<parser::OmpClause::Align>(&clause.u)}) { - if (const auto &v{GetIntValue(align->v)}; v && *v <= 0) { - context_.Say(clause.source, "The alignment should be positive"_err_en_US); +void OmpStructureChecker::CheckAllocateDirective(parser::CharBlock source, + const parser::OmpObjectList &objects, + const parser::OmpClauseList &clauses) { + const Scope &thisScope{context_.FindScope(source)}; + SymbolSourceMap symbols; + GetSymbolsInObjectList(objects, symbols); + + auto maybeHasPredefinedAllocator{[&](const parser::OmpClause *calloc) { + // Return "true" if the ALLOCATOR clause was provided with an argument + // that is either a prefdefined allocator, or a run-time value. + // Otherwise return "false". + if (!calloc) { + return false; } - } -} + auto *allocator{std::get_if<parser::OmpClause::Allocator>(&calloc->u)}; + if (auto val{ToInt64(GetEvaluateExpr(DEREF(allocator).v))}) { + // Predefined allocators (defined in OpenMP 6.0 20.8.1): + // omp_null_allocator = 0, + // omp_default_mem_alloc = 1, + // omp_large_cap_mem_alloc = 2, + // omp_const_mem_alloc = 3, + // omp_high_bw_mem_alloc = 4, + // omp_low_lat_mem_alloc = 5, + // omp_cgroup_mem_alloc = 6, + // omp_pteam_mem_alloc = 7, + // omp_thread_mem_alloc = 8 + return *val >= 0 && *val <= 8; + } + return true; + }}; -void OmpStructureChecker::Enter(const parser::OpenMPDeclarativeAllocate &x) { - isPredefinedAllocator = true; - const auto &dir{std::get<parser::Verbatim>(x.t)}; - const auto &objectList{std::get<parser::OmpObjectList>(x.t)}; - PushContextAndClauseSets(dir.source, llvm::omp::Directive::OMPD_allocate); - const auto &clauseList{std::get<parser::OmpClauseList>(x.t)}; - SymbolSourceMap currSymbols; - GetSymbolsInObjectList(objectList, currSymbols); - for (auto &[symbol, source] : currSymbols) { - if (IsPointer(*symbol)) { + const auto *allocator{FindClause(llvm::omp::Clause::OMPC_allocator)}; + if (InTargetRegion()) { + bool hasDynAllocators{ + HasRequires(llvm::omp::Clause::OMPC_dynamic_allocators)}; + if (!allocator && !hasDynAllocators) { context_.Say(source, - "List item '%s' in ALLOCATE directive must not have POINTER " - "attribute"_err_en_US, - source.ToString()); + "An ALLOCATE directive in a TARGET region must specify an ALLOCATOR clause or REQUIRES(DYNAMIC_ALLOCATORS) must be specified"_err_en_US); } - if (IsDummy(*symbol)) { - context_.Say(source, - "List item '%s' in ALLOCATE directive must not be a dummy " - "argument"_err_en_US, - source.ToString()); + } + + auto maybePredefined{maybeHasPredefinedAllocator(allocator)}; + + for (auto &[symbol, source] : symbols) { + if (!inExecutableAllocate_) { + if (symbol->owner() != thisScope) { + context_.Say(source, + "A list item on a declarative ALLOCATE must be declared in the same scope in which the directive appears"_err_en_US); + } + if (IsPointer(*symbol) || IsAllocatable(*symbol)) { + context_.Say(source, + "A list item in a declarative ALLOCATE cannot have the ALLOCATABLE or POINTER attribute"_err_en_US); + } } if (symbol->GetUltimate().has<AssocEntityDetails>()) { context_.Say(source, - "List item '%s' in ALLOCATE directive must not be an associate " - "name"_err_en_US, - source.ToString()); + "A list item in a declarative ALLOCATE cannot be an associate name"_err_en_US); + } + if (symbol->attrs().test(Attr::SAVE) || IsCommonBlock(*symbol)) { + if (!allocator) { + context_.Say(source, + "If a list item is a named common block or has SAVE attribute, an ALLOCATOR clause must be present with a predefined allocator"_err_en_US); + } else if (!maybePredefined) { + context_.Say(source, + "If a list item is a named common block or has SAVE attribute, only a predefined allocator may be used on the ALLOCATOR clause"_err_en_US); + } + } + if (FindCommonBlockContaining(*symbol)) { + context_.Say(source, + "A variable that is part of a common block may not be specified as a list item in an ALLOCATE directive, except implicitly via the named common block"_err_en_US); } } - for (const auto &clause : clauseList.v) { - CheckAlignValue(clause); - } - CheckVarIsNotPartOfAnotherVar(dir.source, objectList); + CheckVarIsNotPartOfAnotherVar(source, objects); } -void OmpStructureChecker::Leave(const parser::OpenMPDeclarativeAllocate &x) { +void OmpStructureChecker::Enter(const parser::OpenMPDeclarativeAllocate &x) { const auto &dir{std::get<parser::Verbatim>(x.t)}; - const auto &objectList{std::get<parser::OmpObjectList>(x.t)}; - CheckPredefinedAllocatorRestriction(dir.source, objectList); + PushContextAndClauseSets(dir.source, llvm::omp::Directive::OMPD_allocate); +} + +void OmpStructureChecker::Leave(const parser::OpenMPDeclarativeAllocate &x) { + if (!inExecutableAllocate_) { + const auto &dir{std::get<parser::Verbatim>(x.t)}; + const auto &clauseList{std::get<parser::OmpClauseList>(x.t)}; + const auto &objectList{std::get<parser::OmpObjectList>(x.t)}; + + isPredefinedAllocator = true; + CheckAllocateDirective(dir.source, objectList, clauseList); + } dirContext_.pop_back(); } @@ -1963,6 +2154,7 @@ void OmpStructureChecker::CheckNameInAllocateStmt( } void OmpStructureChecker::Enter(const parser::OpenMPExecutableAllocate &x) { + inExecutableAllocate_ = true; const auto &dir{std::get<parser::Verbatim>(x.t)}; PushContextAndClauseSets(dir.source, llvm::omp::Directive::OMPD_allocate); @@ -1972,24 +2164,6 @@ void OmpStructureChecker::Enter(const parser::OpenMPExecutableAllocate &x) { "The executable form of the OpenMP ALLOCATE directive has been deprecated, please use ALLOCATORS instead"_warn_en_US); } - bool hasAllocator = false; - // TODO: Investigate whether searching the clause list can be done with - // parser::Unwrap instead of the following loop - const auto &clauseList{std::get<parser::OmpClauseList>(x.t)}; - for (const auto &clause : clauseList.v) { - if (std::get_if<parser::OmpClause::Allocator>(&clause.u)) { - hasAllocator = true; - } - } - - if (IsNestedInDirective(llvm::omp::Directive::OMPD_target) && !hasAllocator) { - // TODO: expand this check to exclude the case when a requires - // directive with the dynamic_allocators clause is present - // in the same compilation unit (OMP5.0 2.11.3). - context_.Say(x.source, - "ALLOCATE directives that appear in a TARGET region must specify an allocator clause"_err_en_US); - } - const auto &allocateStmt = std::get<parser::Statement<parser::AllocateStmt>>(x.t).statement; if (const auto &list{std::get<std::optional<parser::OmpObjectList>>(x.t)}) { @@ -2006,21 +2180,34 @@ void OmpStructureChecker::Enter(const parser::OpenMPExecutableAllocate &x) { } isPredefinedAllocator = true; - const auto &objectList{std::get<std::optional<parser::OmpObjectList>>(x.t)}; - for (const auto &clause : clauseList.v) { - CheckAlignValue(clause); - } - if (objectList) { - CheckVarIsNotPartOfAnotherVar(dir.source, *objectList); - } } void OmpStructureChecker::Leave(const parser::OpenMPExecutableAllocate &x) { - const auto &dir{std::get<parser::Verbatim>(x.t)}; - const auto &objectList{std::get<std::optional<parser::OmpObjectList>>(x.t)}; - if (objectList) - CheckPredefinedAllocatorRestriction(dir.source, *objectList); + parser::OmpObjectList empty{std::list<parser::OmpObject>{}}; + auto &objects{[&]() -> const parser::OmpObjectList & { + if (auto &objects{std::get<std::optional<parser::OmpObjectList>>(x.t)}) { + return *objects; + } else { + return empty; + } + }()}; + auto &clauses{std::get<parser::OmpClauseList>(x.t)}; + CheckAllocateDirective( + std::get<parser::Verbatim>(x.t).source, objects, clauses); + + if (const auto &subDirs{ + std::get<std::optional<std::list<parser::OpenMPDeclarativeAllocate>>>( + x.t)}) { + for (const auto &dalloc : *subDirs) { + const auto &dir{std::get<parser::Verbatim>(x.t)}; + const auto &clauses{std::get<parser::OmpClauseList>(dalloc.t)}; + const auto &objects{std::get<parser::OmpObjectList>(dalloc.t)}; + CheckAllocateDirective(dir.source, objects, clauses); + } + } + dirContext_.pop_back(); + inExecutableAllocate_ = false; } void OmpStructureChecker::Enter(const parser::OpenMPAllocatorsConstruct &x) { @@ -3234,7 +3421,6 @@ CHECK_SIMPLE_CLAUSE(AdjustArgs, OMPC_adjust_args) CHECK_SIMPLE_CLAUSE(AppendArgs, OMPC_append_args) CHECK_SIMPLE_CLAUSE(MemoryOrder, OMPC_memory_order) CHECK_SIMPLE_CLAUSE(Bind, OMPC_bind) -CHECK_SIMPLE_CLAUSE(Align, OMPC_align) CHECK_SIMPLE_CLAUSE(Compare, OMPC_compare) CHECK_SIMPLE_CLAUSE(OmpxAttribute, OMPC_ompx_attribute) CHECK_SIMPLE_CLAUSE(Weak, OMPC_weak) @@ -3898,6 +4084,19 @@ void OmpStructureChecker::CheckIsLoopIvPartOfClause( } } +void OmpStructureChecker::Enter(const parser::OmpClause::Align &x) { + CheckAllowedClause(llvm::omp::Clause::OMPC_align); + if (const auto &v{GetIntValue(x.v.v)}) { + if (*v <= 0) { + context_.Say(GetContext().clauseSource, + "The alignment should be positive"_err_en_US); + } else if (!llvm::isPowerOf2_64(*v)) { + context_.Say(GetContext().clauseSource, + "The alignment should be a power of 2"_err_en_US); + } + } +} + // Restrictions specific to each clause are implemented apart from the // generalized restrictions. void OmpStructureChecker::Enter(const parser::OmpClause::Aligned &x) { diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h index 543642ff322a..7426559e77ff 100644 --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -19,7 +19,6 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/semantics.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" using OmpClauseSet = Fortran::common::EnumSet<llvm::omp::Clause, llvm::omp::Clause_enumSize>; @@ -57,21 +56,32 @@ using SymbolSourceMap = std::multimap<const Symbol *, parser::CharBlock>; using DirectivesClauseTriple = std::multimap<llvm::omp::Directive, std::pair<llvm::omp::Directive, const OmpClauseSet>>; -class OmpStructureChecker - : public DirectiveStructureChecker<llvm::omp::Directive, llvm::omp::Clause, - parser::OmpClause, llvm::omp::Clause_enumSize> { +using OmpStructureCheckerBase = DirectiveStructureChecker<llvm::omp::Directive, + llvm::omp::Clause, parser::OmpClause, llvm::omp::Clause_enumSize>; + +class OmpStructureChecker : public OmpStructureCheckerBase { public: - using Base = DirectiveStructureChecker<llvm::omp::Directive, - llvm::omp::Clause, parser::OmpClause, llvm::omp::Clause_enumSize>; + using Base = OmpStructureCheckerBase; + + OmpStructureChecker(SemanticsContext &context); - OmpStructureChecker(SemanticsContext &context) - : DirectiveStructureChecker(context, -#define GEN_FLANG_DIRECTIVE_CLAUSE_MAP -#include "llvm/Frontend/OpenMP/OMP.inc" - ) { - } using llvmOmpClause = const llvm::omp::Clause; + bool Enter(const parser::MainProgram &); + void Leave(const parser::MainProgram &); + bool Enter(const parser::BlockData &); + void Leave(const parser::BlockData &); + bool Enter(const parser::Module &); + void Leave(const parser::Module &); + bool Enter(const parser::Submodule &); + void Leave(const parser::Submodule &); + bool Enter(const parser::SubroutineStmt &); + bool Enter(const parser::EndSubroutineStmt &); + bool Enter(const parser::FunctionStmt &); + bool Enter(const parser::EndFunctionStmt &); + bool Enter(const parser::BlockConstruct &); + void Leave(const parser::BlockConstruct &); + void Enter(const parser::OpenMPConstruct &); void Leave(const parser::OpenMPConstruct &); void Enter(const parser::OpenMPInteropConstruct &); @@ -178,10 +188,12 @@ private: const parser::CharBlock &, const OmpDirectiveSet &); bool IsCloselyNestedRegion(const OmpDirectiveSet &set); bool IsNestedInDirective(llvm::omp::Directive directive); + bool InTargetRegion(); void HasInvalidTeamsNesting( const llvm::omp::Directive &dir, const parser::CharBlock &source); void HasInvalidDistributeNesting(const parser::OpenMPLoopConstruct &x); void HasInvalidLoopBinding(const parser::OpenMPLoopConstruct &x); + bool HasRequires(llvm::omp::Clause req); // specific clause related void CheckAllowedMapTypes( parser::OmpMapType::Value, llvm::ArrayRef<parser::OmpMapType::Value>); @@ -251,6 +263,9 @@ private: bool CheckTargetBlockOnlyTeams(const parser::Block &); void CheckWorkshareBlockStmts(const parser::Block &, parser::CharBlock); void CheckWorkdistributeBlockStmts(const parser::Block &, parser::CharBlock); + void CheckAllocateDirective(parser::CharBlock source, + const parser::OmpObjectList &objects, + const parser::OmpClauseList &clauses); void CheckIteratorRange(const parser::OmpIteratorSpecifier &x); void CheckIteratorModifier(const parser::OmpIterator &x); @@ -347,8 +362,6 @@ private: void CheckAllowedRequiresClause(llvmOmpClause clause); bool deviceConstructFound_{false}; - void CheckAlignValue(const parser::OmpClause &); - void AddEndDirectiveClauses(const parser::OmpClauseList &clauses); void EnterDirectiveNest(const int index) { directiveNest_[index]++; } @@ -370,12 +383,15 @@ private: }; int directiveNest_[LastType + 1] = {0}; + bool inExecutableAllocate_{false}; parser::CharBlock visitedAtomicSource_; SymbolSourceMap deferredNonVariables_; using LoopConstruct = std::variant<const parser::DoConstruct *, const parser::OpenMPLoopConstruct *>; std::vector<LoopConstruct> loopStack_; + // Scopes for scoping units. + std::vector<const Scope *> scopeStack_; }; /// Find a duplicate entry in the range, and return an iterator to it. diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp index 292e73b4899c..cc55bb4954cc 100644 --- a/flang/lib/Semantics/openmp-utils.cpp +++ b/flang/lib/Semantics/openmp-utils.cpp @@ -218,7 +218,7 @@ bool IsMapExitingType(parser::OmpMapType::Value type) { } } -std::optional<SomeExpr> GetEvaluateExpr(const parser::Expr &parserExpr) { +MaybeExpr GetEvaluateExpr(const parser::Expr &parserExpr) { const parser::TypedExpr &typedExpr{parserExpr.typedExpr}; // ForwardOwningPointer typedExpr // `- GenericExprWrapper ^.get() diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 3bb586c51c58..196755e2912a 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -31,15 +31,17 @@ #include <list> #include <map> +namespace Fortran::semantics { + template <typename T> -static Fortran::semantics::Scope *GetScope( - Fortran::semantics::SemanticsContext &context, const T &x) { - std::optional<Fortran::parser::CharBlock> source{GetLastSource(x)}; - return source ? &context.FindScope(*source) : nullptr; +static Scope *GetScope(SemanticsContext &context, const T &x) { + if (auto source{GetLastSource(x)}) { + return &context.FindScope(*source); + } else { + return nullptr; + } } -namespace Fortran::semantics { - template <typename T> class DirectiveAttributeVisitor { public: explicit DirectiveAttributeVisitor(SemanticsContext &context) @@ -361,7 +363,7 @@ private: void ResolveAccObject(const parser::AccObject &, Symbol::Flag); Symbol *ResolveAcc(const parser::Name &, Symbol::Flag, Scope &); Symbol *ResolveAcc(Symbol &, Symbol::Flag, Scope &); - Symbol *ResolveName(const parser::Name &, bool parentScope = false); + Symbol *ResolveName(const parser::Name &); Symbol *ResolveFctName(const parser::Name &); Symbol *ResolveAccCommonBlockName(const parser::Name *); Symbol *DeclareOrMarkOtherAccessEntity(const parser::Name &, Symbol::Flag); @@ -560,7 +562,7 @@ public: auto getArgument{[&](auto &&maybeClause) { if (maybeClause) { // Scalar<Logical<Constant<common::Indirection<Expr>>>> - auto &parserExpr{maybeClause->v.thing.thing.thing.value()}; + auto &parserExpr{parser::UnwrapRef<parser::Expr>(*maybeClause)}; evaluate::ExpressionAnalyzer ea{context_}; if (auto &&maybeExpr{ea.Analyze(parserExpr)}) { if (auto v{omp::GetLogicalValue(*maybeExpr)}) { @@ -1257,31 +1259,22 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCStandaloneConstruct &x) { return true; } -Symbol *AccAttributeVisitor::ResolveName( - const parser::Name &name, bool parentScope) { - Symbol *prev{currScope().FindSymbol(name.source)}; - // Check in parent scope if asked for. - if (!prev && parentScope) { - prev = currScope().parent().FindSymbol(name.source); - } - if (prev != name.symbol) { - name.symbol = prev; - } - return prev; +Symbol *AccAttributeVisitor::ResolveName(const parser::Name &name) { + return name.symbol; } Symbol *AccAttributeVisitor::ResolveFctName(const parser::Name &name) { Symbol *prev{currScope().FindSymbol(name.source)}; - if (!prev || (prev && prev->IsFuncResult())) { + if (prev && prev->IsFuncResult()) { prev = currScope().parent().FindSymbol(name.source); - if (!prev) { - prev = &context_.globalScope().MakeSymbol( - name.source, Attrs{}, ProcEntityDetails{}); - } } - if (prev != name.symbol) { - name.symbol = prev; + if (!prev) { + prev = &*context_.globalScope() + .try_emplace(name.source, ProcEntityDetails{}) + .first->second; } + CHECK(!name.symbol || name.symbol == prev); + name.symbol = prev; return prev; } @@ -1388,9 +1381,8 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) { } else { PushContext(verbatim.source, llvm::acc::Directive::ACCD_routine); } - const auto &optName{std::get<std::optional<parser::Name>>(x.t)}; - if (optName) { - if (Symbol *sym = ResolveFctName(*optName)) { + if (const auto &optName{std::get<std::optional<parser::Name>>(x.t)}) { + if (Symbol * sym{ResolveFctName(*optName)}) { Symbol &ultimate{sym->GetUltimate()}; AddRoutineInfoToSymbol(ultimate, x); } else { @@ -1425,7 +1417,7 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCCombinedConstruct &x) { case llvm::acc::Directive::ACCD_kernels_loop: case llvm::acc::Directive::ACCD_parallel_loop: case llvm::acc::Directive::ACCD_serial_loop: - PushContext(combinedDir.source, combinedDir.v); + PushContext(x.source, combinedDir.v); break; default: break; @@ -1706,41 +1698,37 @@ void AccAttributeVisitor::Post(const parser::AccDefaultClause &x) { } } -// For OpenACC constructs, check all the data-refs within the constructs -// and adjust the symbol for each Name if necessary void AccAttributeVisitor::Post(const parser::Name &name) { - auto *symbol{name.symbol}; - if (symbol && WithinConstruct()) { - symbol = &symbol->GetUltimate(); - if (!symbol->owner().IsDerivedType() && !symbol->has<ProcEntityDetails>() && - !symbol->has<SubprogramDetails>() && !IsObjectWithVisibleDSA(*symbol)) { + if (name.symbol && WithinConstruct()) { + const Symbol &symbol{name.symbol->GetUltimate()}; + if (!symbol.owner().IsDerivedType() && !symbol.has<ProcEntityDetails>() && + !symbol.has<SubprogramDetails>() && !IsObjectWithVisibleDSA(symbol)) { if (Symbol * found{currScope().FindSymbol(name.source)}) { - if (symbol != found) { - name.symbol = found; // adjust the symbol within region + if (&symbol != found) { + // adjust the symbol within the region + // TODO: why didn't name resolution set the right name originally? + name.symbol = found; } else if (GetContext().defaultDSA == Symbol::Flag::AccNone) { // 2.5.14. context_.Say(name.source, "The DEFAULT(NONE) clause requires that '%s' must be listed in a data-mapping clause"_err_en_US, - symbol->name()); + symbol.name()); } + } else { + // TODO: assertion here? or clear name.symbol? } } - } // within OpenACC construct + } } Symbol *AccAttributeVisitor::ResolveAccCommonBlockName( const parser::Name *name) { - if (auto *prev{name - ? GetContext().scope.parent().FindCommonBlock(name->source) - : nullptr}) { - name->symbol = prev; - return prev; - } - // Check if the Common Block is declared in the current scope - if (auto *commonBlockSymbol{ - name ? GetContext().scope.FindCommonBlock(name->source) : nullptr}) { - name->symbol = commonBlockSymbol; - return commonBlockSymbol; + if (name) { + if (Symbol * + cb{GetContext().scope.FindCommonBlockInVisibleScopes(name->source)}) { + name->symbol = cb; + return cb; + } } return nullptr; } @@ -1790,8 +1778,8 @@ void AccAttributeVisitor::ResolveAccObject( } } else { context_.Say(name.source, - "COMMON block must be declared in the same scoping unit " - "in which the OpenACC directive or clause appears"_err_en_US); + "Could not find COMMON block '%s' used in OpenACC directive"_err_en_US, + name.ToString()); } }, }, @@ -1810,13 +1798,11 @@ Symbol *AccAttributeVisitor::ResolveAcc( Symbol *AccAttributeVisitor::DeclareOrMarkOtherAccessEntity( const parser::Name &name, Symbol::Flag accFlag) { - Symbol *prev{currScope().FindSymbol(name.source)}; - if (!name.symbol || !prev) { + if (name.symbol) { + return DeclareOrMarkOtherAccessEntity(*name.symbol, accFlag); + } else { return nullptr; - } else if (prev != name.symbol) { - name.symbol = prev; } - return DeclareOrMarkOtherAccessEntity(*prev, accFlag); } Symbol *AccAttributeVisitor::DeclareOrMarkOtherAccessEntity( @@ -2990,6 +2976,7 @@ void OmpAttributeVisitor::Post(const parser::Name &name) { } Symbol *OmpAttributeVisitor::ResolveName(const parser::Name *name) { + // TODO: why is the symbol not properly resolved by name resolution? if (auto *resolvedSymbol{ name ? GetContext().scope.FindSymbol(name->source) : nullptr}) { name->symbol = resolvedSymbol; @@ -3107,26 +3094,6 @@ void OmpAttributeVisitor::ResolveOmpDesignator( AddAllocateName(name); } } - if (ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective && - IsAllocatable(*symbol) && - !IsNestedInDirective(llvm::omp::Directive::OMPD_allocate)) { - context_.Say(designator.source, - "List items specified in the ALLOCATE directive must not have the ALLOCATABLE attribute unless the directive is associated with an ALLOCATE statement"_err_en_US); - } - bool checkScope{ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective}; - // In 5.1 the scope check only applies to declarative allocate. - if (version == 50 && !checkScope) { - checkScope = ompFlag == Symbol::Flag::OmpExecutableAllocateDirective; - } - if (checkScope) { - if (omp::GetScopingUnit(GetContext().scope) != - omp::GetScopingUnit(symbol->GetUltimate().owner())) { - context_.Say(designator.source, // 2.15.3 - "List items must be declared in the same scoping unit in which the %s directive appears"_err_en_US, - parser::ToUpperCaseLetters( - llvm::omp::getOpenMPDirectiveName(directive, version))); - } - } if (ompFlag == Symbol::Flag::OmpReduction) { // Using variables inside of a namelist in OpenMP reductions // is allowed by the standard, but is not allowed for diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index 0af1c94502bb..db75437708a6 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1441,6 +1441,30 @@ public: void Post(const parser::AccBeginLoopDirective &x) { messageHandler().set_currStmtSource(std::nullopt); } + bool Pre(const parser::OpenACCStandaloneConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } + bool Pre(const parser::OpenACCCacheConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } + bool Pre(const parser::OpenACCWaitConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } + bool Pre(const parser::OpenACCAtomicConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } + bool Pre(const parser::OpenACCEndConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } + bool Pre(const parser::OpenACCDeclarativeConstruct &x) { + currScope().AddSourceRange(x.source); + return true; + } void CopySymbolWithDevice(const parser::Name *name); @@ -1480,7 +1504,8 @@ void AccVisitor::CopySymbolWithDevice(const parser::Name *name) { // symbols are created for the one appearing in the use_device // clause. These new symbols have the CUDA Fortran device // attribute. - if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) { + if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA) && + name->symbol) { name->symbol = currScope().CopySymbol(*name->symbol); if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) { object->set_cudaDataAttr(common::CUDADataAttr::Device); @@ -1490,15 +1515,12 @@ void AccVisitor::CopySymbolWithDevice(const parser::Name *name) { bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) { for (const auto &accObject : x.v.v) { + Walk(accObject); common::visit( common::visitors{ [&](const parser::Designator &designator) { if (const auto *name{ parser::GetDesignatorNameIfDataRef(designator)}) { - Symbol *prev{currScope().FindSymbol(name->source)}; - if (prev != name->symbol) { - name->symbol = prev; - } CopySymbolWithDevice(name); } else { if (const auto *dataRef{ @@ -1507,13 +1529,8 @@ bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) { common::Indirection<parser::ArrayElement>; if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) { const parser::ArrayElement &arrayElement{ind->value()}; - Walk(arrayElement.subscripts); const parser::DataRef &base{arrayElement.base}; if (auto *name{std::get_if<parser::Name>(&base.u)}) { - Symbol *prev{currScope().FindSymbol(name->source)}; - if (prev != name->symbol) { - name->symbol = prev; - } CopySymbolWithDevice(name); } } @@ -1537,6 +1554,7 @@ void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) { bool AccVisitor::Pre(const parser::OpenACCCombinedConstruct &x) { PushScope(Scope::Kind::OpenACCConstruct, nullptr); + currScope().AddSourceRange(x.source); return true; } @@ -3627,6 +3645,20 @@ void ModuleVisitor::Post(const parser::UseStmt &x) { } } } + // Go through the list of COMMON block symbols in the module scope and add + // their USE association to the current scope's USE-associated COMMON blocks. + for (const auto &[name, symbol] : useModuleScope_->commonBlocks()) { + if (!currScope().FindCommonBlockInVisibleScopes(name)) { + currScope().AddCommonBlockUse( + name, symbol->attrs(), symbol->GetUltimate()); + } + } + // Go through the list of USE-associated COMMON block symbols in the module + // scope and add USE associations to their ultimate symbols to the current + // scope's USE-associated COMMON blocks. + for (const auto &[name, symbol] : useModuleScope_->commonBlockUses()) { + currScope().AddCommonBlockUse(name, symbol->attrs(), symbol->GetUltimate()); + } useModuleScope_ = nullptr; } @@ -5433,7 +5465,8 @@ void SubprogramVisitor::PushBlockDataScope(const parser::Name &name) { } } -// If name is a generic, return specific subprogram with the same name. +// If name is a generic in the same scope, return its specific subprogram with +// the same name, if any. Symbol *SubprogramVisitor::GetSpecificFromGeneric(const parser::Name &name) { // Search for the name but don't resolve it if (auto *symbol{currScope().FindSymbol(name.source)}) { @@ -5443,6 +5476,9 @@ Symbol *SubprogramVisitor::GetSpecificFromGeneric(const parser::Name &name) { // symbol doesn't inherit it and ruin the ability to check it. symbol->attrs().reset(Attr::MODULE); } + } else if (&symbol->owner() != &currScope() && inInterfaceBlock() && + !isGeneric()) { + // non-generic interface shadows outer definition } else if (auto *details{symbol->detailsIf<GenericDetails>()}) { // found generic, want specific procedure auto *specific{details->specific()}; diff --git a/flang/lib/Semantics/scope.cpp b/flang/lib/Semantics/scope.cpp index 4af371f3611f..ab75d4c60838 100644 --- a/flang/lib/Semantics/scope.cpp +++ b/flang/lib/Semantics/scope.cpp @@ -144,9 +144,8 @@ void Scope::add_crayPointer(const SourceName &name, Symbol &pointer) { } Symbol &Scope::MakeCommonBlock(SourceName name, SourceName location) { - const auto it{commonBlocks_.find(name)}; - if (it != commonBlocks_.end()) { - return *it->second; + if (auto *cb{FindCommonBlock(name)}) { + return *cb; } else { Symbol &symbol{MakeSymbol( name, Attrs{}, CommonBlockDetails{name.empty() ? location : name})}; @@ -154,9 +153,25 @@ Symbol &Scope::MakeCommonBlock(SourceName name, SourceName location) { return symbol; } } -Symbol *Scope::FindCommonBlock(const SourceName &name) const { - const auto it{commonBlocks_.find(name)}; - return it != commonBlocks_.end() ? &*it->second : nullptr; + +Symbol *Scope::FindCommonBlockInVisibleScopes(const SourceName &name) const { + if (Symbol * cb{FindCommonBlock(name)}) { + return cb; + } else if (Symbol * cb{FindCommonBlockUse(name)}) { + return &cb->GetUltimate(); + } else if (IsSubmodule()) { + if (const Scope *parent{ + symbol_ ? symbol_->get<ModuleDetails>().parent() : nullptr}) { + if (auto *cb{parent->FindCommonBlockInVisibleScopes(name)}) { + return cb; + } + } + } else if (!IsTopLevel() && parent_) { + if (auto *cb{parent_->FindCommonBlockInVisibleScopes(name)}) { + return cb; + } + } + return nullptr; } Scope *Scope::FindSubmodule(const SourceName &name) const { @@ -167,6 +182,31 @@ Scope *Scope::FindSubmodule(const SourceName &name) const { return &*it->second; } } + +bool Scope::AddCommonBlockUse( + const SourceName &name, Attrs attrs, Symbol &cbUltimate) { + CHECK(cbUltimate.has<CommonBlockDetails>()); + // Make a symbol, but don't add it to the Scope, since it needs to + // be added to the USE-associated COMMON blocks + Symbol &useCB{MakeSymbol(name, attrs, UseDetails{name, cbUltimate})}; + return commonBlockUses_.emplace(name, useCB).second; +} + +Symbol *Scope::FindCommonBlock(const SourceName &name) const { + if (const auto it{commonBlocks_.find(name)}; it != commonBlocks_.end()) { + return &*it->second; + } + return nullptr; +} + +Symbol *Scope::FindCommonBlockUse(const SourceName &name) const { + if (const auto it{commonBlockUses_.find(name)}; + it != commonBlockUses_.end()) { + return &*it->second; + } + return nullptr; +} + bool Scope::AddSubmodule(const SourceName &name, Scope &submodule) { return submodules_.emplace(name, submodule).second; } diff --git a/flang/lib/Semantics/semantics.cpp b/flang/lib/Semantics/semantics.cpp index bdb5377265c1..2606d997b1cd 100644 --- a/flang/lib/Semantics/semantics.cpp +++ b/flang/lib/Semantics/semantics.cpp @@ -452,6 +452,15 @@ void SemanticsContext::UpdateScopeIndex( } } +void SemanticsContext::DumpScopeIndex(llvm::raw_ostream &out) const { + out << "scopeIndex_:\n"; + for (const auto &[source, scope] : scopeIndex_) { + out << "source '" << source.ToString() << "' -> scope " << scope + << "... whose source range is '" << scope.sourceRange().ToString() + << "'\n"; + } +} + bool SemanticsContext::IsInModuleFile(parser::CharBlock source) const { for (const Scope *scope{&FindScope(source)}; !scope->IsGlobal(); scope = &scope->parent()) { diff --git a/flang/lib/Utils/OpenMP.cpp b/flang/lib/Utils/OpenMP.cpp index 2261912fec22..c2036c4a383f 100644 --- a/flang/lib/Utils/OpenMP.cpp +++ b/flang/lib/Utils/OpenMP.cpp @@ -22,8 +22,9 @@ mlir::omp::MapInfoOp createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value baseAddr, mlir::Value varPtrPtr, llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds, llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex, - uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType, - mlir::Type retTy, bool partialMap, mlir::FlatSymbolRefAttr mapperId) { + mlir::omp::ClauseMapFlags mapType, + mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, + bool partialMap, mlir::FlatSymbolRefAttr mapperId) { if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) { baseAddr = fir::BoxAddrOp::create(builder, loc, baseAddr); @@ -42,7 +43,7 @@ mlir::omp::MapInfoOp createMapInfoOp(mlir::OpBuilder &builder, mlir::omp::MapInfoOp op = mlir::omp::MapInfoOp::create(builder, loc, retTy, baseAddr, varType, - builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(mapType), builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType), varPtrPtr, members, membersIndex, bounds, mapperId, builder.getStringAttr(name), builder.getBoolAttr(partialMap)); @@ -75,8 +76,7 @@ mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder, firOpBuilder.setInsertionPoint(targetOp); - llvm::omp::OpenMPOffloadMappingFlags mapFlag = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::implicit; mlir::omp::VariableCaptureKind captureKind = mlir::omp::VariableCaptureKind::ByRef; @@ -88,16 +88,14 @@ mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder, if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) { captureKind = mlir::omp::VariableCaptureKind::ByCopy; } else if (!fir::isa_builtin_cptr_type(eleType)) { - mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + mapFlag |= mlir::omp::ClauseMapFlags::to; } mlir::Value mapOp = createMapInfoOp(firOpBuilder, copyVal.getLoc(), copyVal, /*varPtrPtr=*/mlir::Value{}, name.str(), bounds, /*members=*/llvm::SmallVector<mlir::Value>{}, - /*membersIndex=*/mlir::ArrayAttr{}, - static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( - mapFlag), - captureKind, copyVal.getType()); + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, + copyVal.getType()); auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp); mlir::Region ®ion = targetOp.getRegion(); @@ -114,7 +112,7 @@ mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder, mlir::Block *entryBlock = ®ion.getBlocks().front(); firOpBuilder.setInsertionPointToStart(entryBlock); auto loadOp = - firOpBuilder.create<fir::LoadOp>(clonedValArg.getLoc(), clonedValArg); + fir::LoadOp::create(firOpBuilder, clonedValArg.getLoc(), clonedValArg); return loadOp.getResult(); } |
