diff options
| author | Mingming Liu <mingmingl@google.com> | 2025-09-10 15:25:31 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-10 15:25:31 -0700 |
| commit | 1417dafa1db9cb1b2b09438aa9f53ea5ab6e36e2 (patch) | |
| tree | 57f4b1f313c8cf74eed8819870f39c36ea263c68 /flang/lib/Optimizer | |
| parent | 898b813bc8a6d0276bf0f4769f5f2f64b34e632d (diff) | |
| parent | b8cefcb601ddaa18482555c4ff363c01a270c2fe (diff) | |
Merge branch 'main' into users/mingmingl-llvm/samplefdo-profile-formatusers/mingmingl-llvm/samplefdo-profile-format
Diffstat (limited to 'flang/lib/Optimizer')
37 files changed, 1914 insertions, 287 deletions
diff --git a/flang/lib/Optimizer/Analysis/TBAAForest.cpp b/flang/lib/Optimizer/Analysis/TBAAForest.cpp index cce50e0de1bc..44a0348da3a6 100644 --- a/flang/lib/Optimizer/Analysis/TBAAForest.cpp +++ b/flang/lib/Optimizer/Analysis/TBAAForest.cpp @@ -11,12 +11,23 @@ mlir::LLVM::TBAATagAttr fir::TBAATree::SubtreeState::getTag(llvm::StringRef uniqueName) const { - std::string id = (parentId + "/" + uniqueName).str(); + std::string id = (parentId + '/' + uniqueName).str(); mlir::LLVM::TBAATypeDescriptorAttr type = mlir::LLVM::TBAATypeDescriptorAttr::get( context, id, mlir::LLVM::TBAAMemberAttr::get(parent, 0)); return mlir::LLVM::TBAATagAttr::get(type, type, 0); - // return tag; +} + +fir::TBAATree::SubtreeState & +fir::TBAATree::SubtreeState::getOrCreateNamedSubtree(mlir::StringAttr name) { + auto it = namedSubtrees.find(name); + if (it != namedSubtrees.end()) + return it->second; + + return namedSubtrees + .insert( + {name, SubtreeState(context, parentId + '/' + name.str(), parent)}) + .first->second; } mlir::LLVM::TBAATagAttr fir::TBAATree::SubtreeState::getTag() const { diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt index 8fb36a750d43..404afd185fd3 100644 --- a/flang/lib/Optimizer/Builder/CMakeLists.txt +++ b/flang/lib/Optimizer/Builder/CMakeLists.txt @@ -50,6 +50,7 @@ add_flang_library(FIRBuilder FIRDialectSupport FIRSupport FortranEvaluate + FortranSupport HLFIRDialect MLIR_DEPS diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index 99533690018e..b6501fd53099 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -423,10 +423,11 @@ mlir::Value fir::FirOpBuilder::genTempDeclareOp( llvm::ArrayRef<mlir::Value> typeParams, fir::FortranVariableFlagsAttr fortranAttrs) { auto nameAttr = mlir::StringAttr::get(builder.getContext(), name); - return fir::DeclareOp::create(builder, loc, memref.getType(), memref, shape, - typeParams, - /*dummy_scope=*/nullptr, nameAttr, fortranAttrs, - cuf::DataAttributeAttr{}); + return fir::DeclareOp::create( + builder, loc, memref.getType(), memref, shape, typeParams, + /*dummy_scope=*/nullptr, + /*storage=*/nullptr, + /*storage_offset=*/0, nameAttr, fortranAttrs, cuf::DataAttributeAttr{}); } mlir::Value fir::FirOpBuilder::genStackSave(mlir::Location loc) { diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index 086dd6671160..f93eaf7ba90b 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -249,6 +249,7 @@ fir::FortranVariableOpInterface hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder, const fir::ExtendedValue &exv, llvm::StringRef name, fir::FortranVariableFlagsAttr flags, mlir::Value dummyScope, + mlir::Value storage, std::uint64_t storageOffset, cuf::DataAttributeAttr dataAttr) { mlir::Value base = fir::getBase(exv); @@ -278,9 +279,9 @@ hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder, box.nonDeferredLenParams().end()); }, [](const auto &) {}); - auto declareOp = - hlfir::DeclareOp::create(builder, loc, base, name, shapeOrShift, - lenParams, dummyScope, flags, dataAttr); + auto declareOp = hlfir::DeclareOp::create( + builder, loc, base, name, shapeOrShift, lenParams, dummyScope, storage, + storageOffset, flags, dataAttr); return mlir::cast<fir::FortranVariableOpInterface>(declareOp.getOperation()); } @@ -1372,7 +1373,8 @@ hlfir::createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder, fir::FortranVariableFlagsAttr attrs) -> mlir::Value { auto declareOp = hlfir::DeclareOp::create(builder, loc, memref, name, shape, typeParams, - /*dummy_scope=*/nullptr, attrs); + /*dummy_scope=*/nullptr, /*storage=*/nullptr, + /*storage_offset=*/0, attrs); return declareOp.getBase(); }; @@ -1409,7 +1411,8 @@ hlfir::Entity hlfir::createStackTempFromMold(mlir::Location loc, } auto declareOp = hlfir::DeclareOp::create(builder, loc, alloc, tmpName, shape, lenParams, - /*dummy_scope=*/nullptr, declAttrs); + /*dummy_scope=*/nullptr, /*storage=*/nullptr, + /*storage_offset=*/0, declAttrs); return hlfir::Entity{declareOp.getBase()}; } @@ -1426,8 +1429,7 @@ hlfir::convertCharacterKind(mlir::Location loc, fir::FirOpBuilder &builder, return hlfir::EntityWithAttributes{hlfir::DeclareOp::create( builder, loc, res.getAddr(), ".temp.kindconvert", /*shape=*/nullptr, - /*typeparams=*/mlir::ValueRange{res.getLen()}, - /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{})}; + /*typeparams=*/mlir::ValueRange{res.getLen()})}; } std::pair<hlfir::Entity, std::optional<hlfir::CleanupFunction>> @@ -1499,8 +1501,7 @@ hlfir::genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder, fir::ShapeShiftOp::create(builder, loc, shapeShiftType, lbAndExtents); auto declareOp = hlfir::DeclareOp::create( builder, loc, associate.getFirBase(), *associate.getUniqName(), - shapeShift, associate.getTypeparams(), /*dummy_scope=*/nullptr, - /*flags=*/fir::FortranVariableFlagsAttr{}); + shapeShift, associate.getTypeparams()); hlfir::Entity castWithLbounds = mlir::cast<fir::FortranVariableOpInterface>(declareOp.getOperation()); fir::FirOpBuilder *bldr = &builder; @@ -1538,9 +1539,8 @@ std::pair<hlfir::Entity, bool> hlfir::computeEvaluateOpInNewTemp( extents, typeParams); mlir::Value innerMemory = evalInMem.getMemory(); temp = builder.createConvert(loc, innerMemory.getType(), temp); - auto declareOp = hlfir::DeclareOp::create( - builder, loc, temp, tmpName, shape, typeParams, - /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{}); + auto declareOp = + hlfir::DeclareOp::create(builder, loc, temp, tmpName, shape, typeParams); computeEvaluateOpIn(loc, builder, evalInMem, declareOp.getOriginalBase()); return {hlfir::Entity{declareOp.getBase()}, /*heapAllocated=*/heapAllocated}; } diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 7e987de94c5d..6ae48c1d5d88 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -138,7 +138,7 @@ static const char __ldlu_r8x2[] = "__ldlu_r8x2_"; /// Table that drives the fir generation depending on the intrinsic or intrinsic /// module procedure one to one mapping with Fortran arguments. If no mapping is /// defined here for a generic intrinsic, genRuntimeCall will be called -/// to look for a match in the runtime a emit a call. Note that the argument +/// to look for a match in the runtime and emit a call. Note that the argument /// lowering rules for an intrinsic need to be provided only if at least one /// argument must not be lowered by value. In which case, the lowering rules /// should be provided for all the intrinsic arguments for completeness. @@ -397,6 +397,34 @@ static constexpr IntrinsicHandler handlers[]{ {"cmplx", &I::genCmplx, {{{"x", asValue}, {"y", asValue, handleDynamicOptional}}}}, + {"co_broadcast", + &I::genCoBroadcast, + {{{"a", asBox}, + {"source_image", asAddr}, + {"stat", asAddr, handleDynamicOptional}, + {"errmsg", asBox, handleDynamicOptional}}}, + /*isElemental*/ false}, + {"co_max", + &I::genCoMax, + {{{"a", asBox}, + {"result_image", asAddr, handleDynamicOptional}, + {"stat", asAddr, handleDynamicOptional}, + {"errmsg", asBox, handleDynamicOptional}}}, + /*isElemental*/ false}, + {"co_min", + &I::genCoMin, + {{{"a", asBox}, + {"result_image", asAddr, handleDynamicOptional}, + {"stat", asAddr, handleDynamicOptional}, + {"errmsg", asBox, handleDynamicOptional}}}, + /*isElemental*/ false}, + {"co_sum", + &I::genCoSum, + {{{"a", asBox}, + {"result_image", asAddr, handleDynamicOptional}, + {"stat", asAddr, handleDynamicOptional}, + {"errmsg", asBox, handleDynamicOptional}}}, + /*isElemental*/ false}, {"command_argument_count", &I::genCommandArgumentCount}, {"conjg", &I::genConjg}, {"cosd", &I::genCosd}, @@ -869,6 +897,10 @@ static constexpr IntrinsicHandler handlers[]{ {"back", asValue, handleDynamicOptional}, {"kind", asValue}}}, /*isElemental=*/true}, + {"secnds", + &I::genSecnds, + {{{"refTime", asAddr}}}, + /*isElemental=*/false}, {"second", &I::genSecond, {{{"time", asAddr}}}, @@ -1058,7 +1090,7 @@ prettyPrintIntrinsicName(fir::FirOpBuilder &builder, mlir::Location loc, llvm::StringRef suffix, mlir::FunctionType funcType) { std::string output = prefix.str(); llvm::raw_string_ostream sstream(output); - if (name == "pow") { + if (name == "pow" || name == "pow-unsigned") { assert(funcType.getNumInputs() == 2 && "power operator has two arguments"); std::string displayName{" ** "}; sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc, @@ -1671,6 +1703,14 @@ static constexpr MathOperation mathOperations[] = { genComplexPow}, {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8, genLibF128Call}, + {"pow-unsigned", RTNAME_STRING(UPow1), + genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall}, + {"pow-unsigned", RTNAME_STRING(UPow2), + genFuncType<Ty::Integer<2>, Ty::Integer<2>, Ty::Integer<2>>, genLibCall}, + {"pow-unsigned", RTNAME_STRING(UPow4), + genFuncType<Ty::Integer<4>, Ty::Integer<4>, Ty::Integer<4>>, genLibCall}, + {"pow-unsigned", RTNAME_STRING(UPow8), + genFuncType<Ty::Integer<8>, Ty::Integer<8>, Ty::Integer<8>>, genLibCall}, {"remainder", "remainderf", genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Real<4>>, genLibCall}, {"remainder", "remainder", @@ -3674,6 +3714,85 @@ mlir::Value IntrinsicLibrary::genCmplx(mlir::Type resultType, imag); } +// CO_BROADCAST +void IntrinsicLibrary::genCoBroadcast(llvm::ArrayRef<fir::ExtendedValue> args) { + checkCoarrayEnabled(); + assert(args.size() == 4); + mlir::Value sourceImage = fir::getBase(args[1]); + mlir::Value status = + isStaticallyAbsent(args[2]) + ? fir::AbsentOp::create(builder, loc, + builder.getRefType(builder.getI32Type())) + .getResult() + : fir::getBase(args[2]); + mlir::Value errmsg = + isStaticallyAbsent(args[3]) + ? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult() + : fir::getBase(args[3]); + fir::runtime::genCoBroadcast(builder, loc, fir::getBase(args[0]), sourceImage, + status, errmsg); +} + +// CO_MAX +void IntrinsicLibrary::genCoMax(llvm::ArrayRef<fir::ExtendedValue> args) { + checkCoarrayEnabled(); + assert(args.size() == 4); + mlir::Value refNone = + fir::AbsentOp::create(builder, loc, + builder.getRefType(builder.getI32Type())) + .getResult(); + mlir::Value resultImage = + isStaticallyAbsent(args[1]) ? refNone : fir::getBase(args[1]); + mlir::Value status = + isStaticallyAbsent(args[2]) ? refNone : fir::getBase(args[2]); + mlir::Value errmsg = + isStaticallyAbsent(args[3]) + ? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult() + : fir::getBase(args[3]); + fir::runtime::genCoMax(builder, loc, fir::getBase(args[0]), resultImage, + status, errmsg); +} + +// CO_MIN +void IntrinsicLibrary::genCoMin(llvm::ArrayRef<fir::ExtendedValue> args) { + checkCoarrayEnabled(); + assert(args.size() == 4); + mlir::Value refNone = + fir::AbsentOp::create(builder, loc, + builder.getRefType(builder.getI32Type())) + .getResult(); + mlir::Value resultImage = + isStaticallyAbsent(args[1]) ? refNone : fir::getBase(args[1]); + mlir::Value status = + isStaticallyAbsent(args[2]) ? refNone : fir::getBase(args[2]); + mlir::Value errmsg = + isStaticallyAbsent(args[3]) + ? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult() + : fir::getBase(args[3]); + fir::runtime::genCoMin(builder, loc, fir::getBase(args[0]), resultImage, + status, errmsg); +} + +// CO_SUM +void IntrinsicLibrary::genCoSum(llvm::ArrayRef<fir::ExtendedValue> args) { + checkCoarrayEnabled(); + assert(args.size() == 4); + mlir::Value absentInt = + fir::AbsentOp::create(builder, loc, + builder.getRefType(builder.getI32Type())) + .getResult(); + mlir::Value resultImage = + isStaticallyAbsent(args[1]) ? absentInt : fir::getBase(args[1]); + mlir::Value status = + isStaticallyAbsent(args[2]) ? absentInt : fir::getBase(args[2]); + mlir::Value errmsg = + isStaticallyAbsent(args[3]) + ? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult() + : fir::getBase(args[3]); + fir::runtime::genCoSum(builder, loc, fir::getBase(args[0]), resultImage, + status, errmsg); +} + // COMMAND_ARGUMENT_COUNT fir::ExtendedValue IntrinsicLibrary::genCommandArgumentCount( mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args) { @@ -3707,10 +3826,11 @@ mlir::Value IntrinsicLibrary::genCosd(mlir::Type resultType, mlir::MLIRContext *context = builder.getContext(); mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = builder.createRealConstant( - loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0)); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + const llvm::fltSemantics &fltSem = + llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(); + llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant( + loc, resultType, pi / llvm::APFloat(fltSem, "180.0")); mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor); return getRuntimeCallGenerator("cos", ftype)(builder, loc, {arg}); } @@ -7863,6 +7983,22 @@ IntrinsicLibrary::genScan(mlir::Type resultType, return readAndAddCleanUp(resultMutableBox, resultType, "SCAN"); } +// SECNDS +fir::ExtendedValue +IntrinsicLibrary::genSecnds(mlir::Type resultType, + llvm::ArrayRef<fir::ExtendedValue> args) { + assert(args.size() == 1 && "SECNDS expects one argument"); + + mlir::Value refTime = fir::getBase(args[0]); + + if (!refTime) + fir::emitFatalError(loc, "expected REFERENCE TIME parameter"); + + mlir::Value result = fir::runtime::genSecnds(builder, loc, refTime); + + return builder.createConvert(loc, resultType, result); +} + // SECOND fir::ExtendedValue IntrinsicLibrary::genSecond(std::optional<mlir::Type> resultType, @@ -8171,10 +8307,11 @@ mlir::Value IntrinsicLibrary::genSind(mlir::Type resultType, mlir::MLIRContext *context = builder.getContext(); mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = builder.createRealConstant( - loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0)); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + const llvm::fltSemantics &fltSem = + llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(); + llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant( + loc, resultType, pi / llvm::APFloat(fltSem, "180.0")); mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor); return getRuntimeCallGenerator("sin", ftype)(builder, loc, {arg}); } @@ -8268,10 +8405,11 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType, mlir::MLIRContext *context = builder.getContext(); mlir::FunctionType ftype = mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); - llvm::APFloat pi = llvm::APFloat(llvm::numbers::pi); - mlir::Value dfactor = builder.createRealConstant( - loc, mlir::Float64Type::get(context), pi / llvm::APFloat(180.0)); - mlir::Value factor = builder.createConvert(loc, args[0].getType(), dfactor); + const llvm::fltSemantics &fltSem = + llvm::cast<mlir::FloatType>(resultType).getFloatSemantics(); + llvm::APFloat pi = llvm::APFloat(fltSem, llvm::numbers::pis); + mlir::Value factor = builder.createRealConstant( + loc, resultType, pi / llvm::APFloat(fltSem, "180.0")); mlir::Value arg = mlir::arith::MulFOp::create(builder, loc, args[0], factor); return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg}); } @@ -9418,6 +9556,14 @@ mlir::Value genPow(fir::FirOpBuilder &builder, mlir::Location loc, // implementation and mark it 'strictfp'. // Another option is to implement it in Fortran runtime library // (just like matmul). + if (type.isUnsignedInteger()) { + assert(x.getType().isUnsignedInteger() && y.getType().isUnsignedInteger() && + "unsigned pow requires unsigned arguments"); + return IntrinsicLibrary{builder, loc}.genRuntimeCall("pow-unsigned", type, + {x, y}); + } + assert(!x.getType().isUnsignedInteger() && !y.getType().isUnsignedInteger() && + "non-unsigned pow requires non-unsigned arguments"); return IntrinsicLibrary{builder, loc}.genRuntimeCall("pow", type, {x, y}); } diff --git a/flang/lib/Optimizer/Builder/MutableBox.cpp b/flang/lib/Optimizer/Builder/MutableBox.cpp index 50c945df5b46..d4cdfecd0b08 100644 --- a/flang/lib/Optimizer/Builder/MutableBox.cpp +++ b/flang/lib/Optimizer/Builder/MutableBox.cpp @@ -603,21 +603,23 @@ void fir::factory::associateMutableBoxWithRemap( mlir::ValueRange lbounds, mlir::ValueRange ubounds) { // Compute new extents llvm::SmallVector<mlir::Value> extents; - auto idxTy = builder.getIndexType(); + mlir::Type idxTy = builder.getIndexType(); + mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); if (!lbounds.empty()) { auto one = builder.createIntegerConstant(loc, idxTy, 1); for (auto [lb, ub] : llvm::zip(lbounds, ubounds)) { - auto lbi = builder.createConvert(loc, idxTy, lb); - auto ubi = builder.createConvert(loc, idxTy, ub); - auto diff = mlir::arith::SubIOp::create(builder, loc, idxTy, ubi, lbi); + + mlir::Value lbi = builder.createConvert(loc, idxTy, lb); + mlir::Value ubi = builder.createConvert(loc, idxTy, ub); extents.emplace_back( - mlir::arith::AddIOp::create(builder, loc, idxTy, diff, one)); + fir::factory::computeExtent(builder, loc, lbi, ubi, zero, one)); } } else { // lbounds are default. Upper bounds and extents are the same. - for (auto ub : ubounds) { - auto cast = builder.createConvert(loc, idxTy, ub); - extents.emplace_back(cast); + for (mlir::Value ub : ubounds) { + mlir::Value cast = builder.createConvert(loc, idxTy, ub); + extents.emplace_back( + fir::factory::genMaxWithZero(builder, loc, cast, zero)); } } const auto newRank = extents.size(); diff --git a/flang/lib/Optimizer/Builder/Runtime/CUDA/Descriptor.cpp b/flang/lib/Optimizer/Builder/Runtime/CUDA/Descriptor.cpp index a6ee98685f3c..37e4c5a706df 100644 --- a/flang/lib/Optimizer/Builder/Runtime/CUDA/Descriptor.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/CUDA/Descriptor.cpp @@ -47,18 +47,3 @@ void fir::runtime::cuda::genDescriptorCheckSection(fir::FirOpBuilder &builder, builder, loc, fTy, desc, sourceFile, sourceLine)}; fir::CallOp::create(builder, loc, func, args); } - -void fir::runtime::cuda::genSetAllocatorIndex(fir::FirOpBuilder &builder, - mlir::Location loc, - mlir::Value desc, - mlir::Value index) { - mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(CUFSetAllocatorIndex)>(loc, builder); - auto fTy = func.getFunctionType(); - mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); - mlir::Value sourceLine = - fir::factory::locationToLineNo(builder, loc, fTy.getInput(3)); - llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( - builder, loc, fTy, desc, index, sourceFile, sourceLine)}; - fir::CallOp::create(builder, loc, func, args); -} diff --git a/flang/lib/Optimizer/Builder/Runtime/Character.cpp b/flang/lib/Optimizer/Builder/Runtime/Character.cpp index 57fb0cccf686..540ecba299dc 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Character.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Character.cpp @@ -119,23 +119,23 @@ fir::runtime::genCharCompare(fir::FirOpBuilder &builder, mlir::Location loc, return mlir::arith::CmpIOp::create(builder, loc, cmp, tri, zero); } +static mlir::Value allocateIfNotInMemory(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value base) { + if (fir::isa_ref_type(base.getType())) + return base; + auto mem = + fir::AllocaOp::create(builder, loc, base.getType(), /*pinned=*/false); + fir::StoreOp::create(builder, loc, base, mem); + return mem; +} + mlir::Value fir::runtime::genCharCompare(fir::FirOpBuilder &builder, mlir::Location loc, mlir::arith::CmpIPredicate cmp, const fir::ExtendedValue &lhs, const fir::ExtendedValue &rhs) { - if (lhs.getBoxOf<fir::BoxValue>() || rhs.getBoxOf<fir::BoxValue>()) - TODO(loc, "character compare from descriptors"); - auto allocateIfNotInMemory = [&](mlir::Value base) -> mlir::Value { - if (fir::isa_ref_type(base.getType())) - return base; - auto mem = - fir::AllocaOp::create(builder, loc, base.getType(), /*pinned=*/false); - fir::StoreOp::create(builder, loc, base, mem); - return mem; - }; - auto lhsBuffer = allocateIfNotInMemory(fir::getBase(lhs)); - auto rhsBuffer = allocateIfNotInMemory(fir::getBase(rhs)); + auto lhsBuffer = allocateIfNotInMemory(builder, loc, fir::getBase(lhs)); + auto rhsBuffer = allocateIfNotInMemory(builder, loc, fir::getBase(rhs)); return genCharCompare(builder, loc, cmp, lhsBuffer, fir::getLen(lhs), rhsBuffer, fir::getLen(rhs)); } @@ -168,6 +168,20 @@ mlir::Value fir::runtime::genIndex(fir::FirOpBuilder &builder, return fir::CallOp::create(builder, loc, indexFunc, args).getResult(0); } +mlir::Value fir::runtime::genIndex(fir::FirOpBuilder &builder, + mlir::Location loc, + const fir::ExtendedValue &str, + const fir::ExtendedValue &substr, + mlir::Value back) { + assert(!substr.getBoxOf<fir::BoxValue>() && !str.getBoxOf<fir::BoxValue>() && + "shall use genIndexDescriptor version"); + auto strBuffer = allocateIfNotInMemory(builder, loc, fir::getBase(str)); + auto substrBuffer = allocateIfNotInMemory(builder, loc, fir::getBase(substr)); + int kind = discoverKind(strBuffer.getType()); + return genIndex(builder, loc, kind, strBuffer, fir::getLen(str), substrBuffer, + fir::getLen(substr), back); +} + void fir::runtime::genIndexDescriptor(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value resultBox, mlir::Value stringBox, diff --git a/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp b/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp index fb72fc2089e2..9a893d61122a 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp @@ -14,6 +14,24 @@ using namespace Fortran::runtime; using namespace Fortran::semantics; +// Most PRIF functions take `errmsg` and `errmsg_alloc` as two optional +// arguments of intent (out). One is allocatable, the other is not. +// It is the responsibility of the compiler to ensure that the appropriate +// optional argument is passed, and at most one must be provided in a given +// call. +// Depending on the type of `errmsg`, this function will return the pair +// corresponding to (`errmsg`, `errmsg_alloc`). +static std::pair<mlir::Value, mlir::Value> +genErrmsgPRIF(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value errmsg) { + bool isAllocatableErrmsg = fir::isAllocatableType(errmsg.getType()); + + mlir::Value absent = fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE); + mlir::Value errMsg = isAllocatableErrmsg ? absent : errmsg; + mlir::Value errMsgAlloc = isAllocatableErrmsg ? errmsg : absent; + return {errMsg, errMsgAlloc}; +} + /// Generate Call to runtime prif_init mlir::Value fir::runtime::genInitCoarray(fir::FirOpBuilder &builder, mlir::Location loc) { @@ -24,8 +42,8 @@ mlir::Value fir::runtime::genInitCoarray(fir::FirOpBuilder &builder, builder.createFunction(loc, PRIFNAME_SUB("init"), ftype); llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(builder, loc, ftype, result); - builder.create<fir::CallOp>(loc, funcOp, args); - return builder.create<fir::LoadOp>(loc, result); + fir::CallOp::create(builder, loc, funcOp, args); + return fir::LoadOp::create(builder, loc, result); } /// Generate Call to runtime prif_num_images @@ -38,8 +56,8 @@ mlir::Value fir::runtime::getNumImages(fir::FirOpBuilder &builder, builder.createFunction(loc, PRIFNAME_SUB("num_images"), ftype); llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(builder, loc, ftype, result); - builder.create<fir::CallOp>(loc, funcOp, args); - return builder.create<fir::LoadOp>(loc, result); + fir::CallOp::create(builder, loc, funcOp, args); + return fir::LoadOp::create(builder, loc, result); } /// Generate Call to runtime prif_num_images_with_{team|team_number} @@ -63,8 +81,8 @@ mlir::Value fir::runtime::getNumImagesWithTeam(fir::FirOpBuilder &builder, team = builder.createBox(loc, team); llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(builder, loc, ftype, team, result); - builder.create<fir::CallOp>(loc, funcOp, args); - return builder.create<fir::LoadOp>(loc, result); + fir::CallOp::create(builder, loc, funcOp, args); + return fir::LoadOp::create(builder, loc, result); } /// Generate Call to runtime prif_this_image_no_coarray @@ -78,9 +96,72 @@ mlir::Value fir::runtime::getThisImage(fir::FirOpBuilder &builder, mlir::Value result = builder.createTemporary(loc, builder.getI32Type()); mlir::Value teamArg = - !team ? builder.create<fir::AbsentOp>(loc, boxTy) : team; + !team ? fir::AbsentOp::create(builder, loc, boxTy) : team; llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(builder, loc, ftype, teamArg, result); - builder.create<fir::CallOp>(loc, funcOp, args); - return builder.create<fir::LoadOp>(loc, result); + fir::CallOp::create(builder, loc, funcOp, args); + return fir::LoadOp::create(builder, loc, result); +} + +/// Generate call to collective subroutines except co_reduce +/// A must be lowered as a box +void genCollectiveSubroutine(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value A, mlir::Value rootImage, + mlir::Value stat, mlir::Value errmsg, + std::string coName) { + mlir::Type boxTy = fir::BoxType::get(builder.getNoneType()); + mlir::FunctionType ftype = + PRIF_FUNCTYPE(boxTy, builder.getRefType(builder.getI32Type()), + PRIF_STAT_TYPE, PRIF_ERRMSG_TYPE, PRIF_ERRMSG_TYPE); + mlir::func::FuncOp funcOp = builder.createFunction(loc, coName, ftype); + + auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, errmsg); + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, ftype, A, rootImage, stat, errmsgArg, errmsgAllocArg); + fir::CallOp::create(builder, loc, funcOp, args); +} + +/// Generate call to runtime subroutine prif_co_broadcast +void fir::runtime::genCoBroadcast(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value A, + mlir::Value sourceImage, mlir::Value stat, + mlir::Value errmsg) { + genCollectiveSubroutine(builder, loc, A, sourceImage, stat, errmsg, + PRIFNAME_SUB("co_broadcast")); +} + +/// Generate call to runtime subroutine prif_co_max or prif_co_max_character +void fir::runtime::genCoMax(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value A, mlir::Value resultImage, + mlir::Value stat, mlir::Value errmsg) { + mlir::Type argTy = + fir::unwrapSequenceType(fir::unwrapPassByRefType(A.getType())); + if (mlir::isa<fir::CharacterType>(argTy)) + genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg, + PRIFNAME_SUB("co_max_character")); + else + genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg, + PRIFNAME_SUB("co_max")); +} + +/// Generate call to runtime subroutine prif_co_min or prif_co_min_character +void fir::runtime::genCoMin(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value A, mlir::Value resultImage, + mlir::Value stat, mlir::Value errmsg) { + mlir::Type argTy = + fir::unwrapSequenceType(fir::unwrapPassByRefType(A.getType())); + if (mlir::isa<fir::CharacterType>(argTy)) + genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg, + PRIFNAME_SUB("co_min_character")); + else + genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg, + PRIFNAME_SUB("co_min")); +} + +/// Generate call to runtime subroutine prif_co_sum +void fir::runtime::genCoSum(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value A, mlir::Value resultImage, + mlir::Value stat, mlir::Value errmsg) { + genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg, + PRIFNAME_SUB("co_sum")); } diff --git a/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp b/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp index ee151576ace9..dc61903ddd36 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp @@ -276,6 +276,23 @@ void fir::runtime::genRename(fir::FirOpBuilder &builder, mlir::Location loc, fir::CallOp::create(builder, loc, runtimeFunc, args); } +mlir::Value fir::runtime::genSecnds(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value refTime) { + auto runtimeFunc = + fir::runtime::getRuntimeFunc<mkRTKey(Secnds)>(loc, builder); + + mlir::FunctionType runtimeFuncTy = runtimeFunc.getFunctionType(); + + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, runtimeFuncTy.getInput(2)); + + llvm::SmallVector<mlir::Value> args = {refTime, sourceFile, sourceLine}; + args = fir::runtime::createArguments(builder, loc, runtimeFuncTy, args); + + return fir::CallOp::create(builder, loc, runtimeFunc, args).getResult(0); +} + /// generate runtime call to time intrinsic mlir::Value fir::runtime::genTime(fir::FirOpBuilder &builder, mlir::Location loc) { diff --git a/flang/lib/Optimizer/Builder/TemporaryStorage.cpp b/flang/lib/Optimizer/Builder/TemporaryStorage.cpp index c0d6606b8d29..7e329e357d7b 100644 --- a/flang/lib/Optimizer/Builder/TemporaryStorage.cpp +++ b/flang/lib/Optimizer/Builder/TemporaryStorage.cpp @@ -82,8 +82,7 @@ fir::factory::HomogeneousScalarStack::HomogeneousScalarStack( mlir::Value shape = builder.genShape(loc, extents); temp = hlfir::DeclareOp::create(builder, loc, tempStorage, tempName, shape, - lengths, /*dummy_scope=*/nullptr, - fir::FortranVariableFlagsAttr{}) + lengths) .getBase(); } diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 76f3cbd421cb..0800ed4db8c3 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -329,6 +329,31 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> { } // namespace namespace { + +static bool isInGlobalOp(mlir::ConversionPatternRewriter &rewriter) { + auto *thisBlock = rewriter.getInsertionBlock(); + return thisBlock && mlir::isa<mlir::LLVM::GlobalOp>(thisBlock->getParentOp()); +} + +// Inside a fir.global, the input box was produced as an llvm.struct<> +// because objects cannot be handled in memory inside a fir.global body that +// must be constant foldable. However, the type translation are not +// contextual, so the fir.box<T> type of the operation that produced the +// fir.box was translated to an llvm.ptr<llvm.struct<>> and the MLIR pass +// manager inserted a builtin.unrealized_conversion_cast that was inserted +// and needs to be removed here. +// This should be called by any pattern operating on operations that are +// accepting fir.box inputs and are used in fir.global. +static mlir::Value +fixBoxInputInsideGlobalOp(mlir::ConversionPatternRewriter &rewriter, + mlir::Value box) { + if (isInGlobalOp(rewriter)) + if (auto unrealizedCast = + box.getDefiningOp<mlir::UnrealizedConversionCastOp>()) + return unrealizedCast.getInputs()[0]; + return box; +} + /// Lower `fir.box_addr` to the sequence of operations to extract the first /// element of the box. struct BoxAddrOpConversion : public fir::FIROpConversion<fir::BoxAddrOp> { @@ -341,6 +366,7 @@ struct BoxAddrOpConversion : public fir::FIROpConversion<fir::BoxAddrOp> { auto loc = boxaddr.getLoc(); if (auto argty = mlir::dyn_cast<fir::BaseBoxType>(boxaddr.getVal().getType())) { + a = fixBoxInputInsideGlobalOp(rewriter, a); TypePair boxTyPair = getBoxTypePair(argty); rewriter.replaceOp(boxaddr, getBaseAddrFromBox(loc, boxTyPair, a, rewriter)); @@ -1737,12 +1763,6 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> { xbox.getSubcomponent().size()); } - static bool isInGlobalOp(mlir::ConversionPatternRewriter &rewriter) { - auto *thisBlock = rewriter.getInsertionBlock(); - return thisBlock && - mlir::isa<mlir::LLVM::GlobalOp>(thisBlock->getParentOp()); - } - /// If the embox is not in a globalOp body, allocate storage for the box; /// store the value inside and return the generated alloca. Return the input /// value otherwise. @@ -2076,21 +2096,10 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> { mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = rebox.getLoc(); mlir::Type idxTy = lowerTy().indexType(); - mlir::Value loweredBox = adaptor.getOperands()[0]; + mlir::Value loweredBox = + fixBoxInputInsideGlobalOp(rewriter, adaptor.getBox()); mlir::ValueRange operands = adaptor.getOperands(); - // Inside a fir.global, the input box was produced as an llvm.struct<> - // because objects cannot be handled in memory inside a fir.global body that - // must be constant foldable. However, the type translation are not - // contextual, so the fir.box<T> type of the operation that produced the - // fir.box was translated to an llvm.ptr<llvm.struct<>> and the MLIR pass - // manager inserted a builtin.unrealized_conversion_cast that was inserted - // and needs to be removed here. - if (isInGlobalOp(rewriter)) - if (auto unrealizedCast = - loweredBox.getDefiningOp<mlir::UnrealizedConversionCastOp>()) - loweredBox = unrealizedCast.getInputs()[0]; - TypePair inputBoxTyPair = getBoxTypePair(rebox.getBox().getType()); // Create new descriptor and fill its non-shape related data. diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 97912bda79b0..381b2a29c517 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -60,6 +60,21 @@ struct MapInfoOpConversion : public OpenMPFIROpConversion<mlir::omp::MapInfoOp> { using OpenMPFIROpConversion::OpenMPFIROpConversion; + mlir::omp::MapBoundsOp + createBoundsForCharString(mlir::ConversionPatternRewriter &rewriter, + unsigned int len, mlir::Location loc) const { + mlir::Type i64Ty = rewriter.getIntegerType(64); + auto lBound = mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, 0); + auto uBoundAndExt = + mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, len - 1); + auto stride = mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, 1); + auto baseLb = mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, 1); + auto mapBoundType = rewriter.getType<mlir::omp::MapBoundsType>(); + return mlir::omp::MapBoundsOp::create(rewriter, loc, mapBoundType, lBound, + uBoundAndExt, uBoundAndExt, stride, + /*strideInBytes*/ false, baseLb); + } + llvm::LogicalResult matchAndRewrite(mlir::omp::MapInfoOp curOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { @@ -69,13 +84,79 @@ struct MapInfoOpConversion return mlir::failure(); llvm::SmallVector<mlir::NamedAttribute> newAttrs; - mlir::omp::MapInfoOp newOp; + mlir::omp::MapBoundsOp mapBoundsOp; for (mlir::NamedAttribute attr : curOp->getAttrs()) { if (auto typeAttr = mlir::dyn_cast<mlir::TypeAttr>(attr.getValue())) { mlir::Type newAttr; if (fir::isTypeWithDescriptor(typeAttr.getValue())) { newAttr = lowerTy().convertBoxTypeAsStruct( mlir::cast<fir::BaseBoxType>(typeAttr.getValue())); + } else if (fir::isa_char_string(fir::unwrapSequenceType( + fir::unwrapPassByRefType(typeAttr.getValue()))) && + !characterWithDynamicLen( + fir::unwrapPassByRefType(typeAttr.getValue()))) { + // Characters with a LEN param are represented as strings + // (array of characters), the lowering to LLVM dialect + // doesn't generate bounds for these (and this is not + // done at the initial lowering either) and there is + // minor inconsistencies in the variable types we + // create for the map without this step when converting + // to the LLVM dialect. + // + // For example, given the types: + // + // 1) CHARACTER(LEN=16), dimension(:,:), allocatable :: char_arr + // 2) CHARACTER(LEN=16), dimension(10,10) :: char_arr + // + // We get the FIR types (note for 1: we already peeled off the + // dynamic extents from the type at this stage, but the conversion + // to llvm dialect does that in any case, so the final result + // is the same): + // + // 1) !fir.char<1,16> + // 2) !fir.array<10x10x!fir.char<1,16>> + // + // Which are converted to the LLVM dialect types: + // + // 1) !llvm.array<16 x i8> + // 2) llvm.array<10 x array<10 x array<16 x i8>> + // + // And in both cases, we are missing the innermost bounds for + // the !fir.char<1,16> which is expanded into a 16 x i8 array + // in the conversion to LLVM dialect. + // + // The problem with this is that we would like to treat these + // cases identically and not have to create specialised + // lowerings for either of these in the lowering to LLVM-IR + // and treat them like any other array that passes through. + // + // To do so below, we generate an extra bound for the + // innermost array (the char type/string) using the LEN + // parameter of the character type. And we "canonicalize" + // the type, stripping it down to the base element type, + // which in this case is an i8. This effectively allows + // the lowering to treat this as a 1-D array with multiple + // bounds which it is capable of handling without any special + // casing. + // TODO: Handle dynamic LEN characters. + if (auto ct = mlir::dyn_cast_or_null<fir::CharacterType>( + fir::unwrapSequenceType(typeAttr.getValue()))) { + newAttr = converter->convertType( + fir::unwrapSequenceType(typeAttr.getValue())); + if (auto type = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(newAttr)) + newAttr = type.getElementType(); + // We do not generate MapBoundsOps for the device pass, as + // MapBoundsOps are not generated for the device pass, as + // they're unused in the device lowering. + auto offloadMod = + llvm::dyn_cast_or_null<mlir::omp::OffloadModuleInterface>( + *curOp->getParentOfType<mlir::ModuleOp>()); + if (!offloadMod.getIsTargetDevice()) + mapBoundsOp = createBoundsForCharString(rewriter, ct.getLen(), + curOp.getLoc()); + } else { + newAttr = converter->convertType(typeAttr.getValue()); + } } else { newAttr = converter->convertType(typeAttr.getValue()); } @@ -85,8 +166,13 @@ struct MapInfoOpConversion } } - rewriter.replaceOpWithNewOp<mlir::omp::MapInfoOp>( + auto newOp = rewriter.replaceOpWithNewOp<mlir::omp::MapInfoOp>( curOp, resTypes, adaptor.getOperands(), newAttrs); + if (mapBoundsOp) { + rewriter.startOpModification(newOp); + newOp.getBoundsMutable().append(mlir::ValueRange{mapBoundsOp}); + rewriter.finalizeOpModification(newOp); + } return mlir::success(); } diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index fa935542d40f..ac285b5d403d 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -1336,7 +1336,15 @@ public: private: // Replace `op` and remove it. void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { - op->replaceAllUsesWith(newValues); + llvm::SmallVector<mlir::Value> casts; + for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) { + if (oldValue.getType() == newValue.getType()) + casts.push_back(newValue); + else + casts.push_back(fir::ConvertOp::create(*rewriter, op->getLoc(), + oldValue.getType(), newValue)); + } + op->replaceAllUsesWith(casts); op->dropAllReferences(); op->erase(); } diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp index ade80716f256..687007d95722 100644 --- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp @@ -345,17 +345,6 @@ llvm::LogicalResult cuf::StreamCastOp::verify() { return checkStreamType(*this); } -//===----------------------------------------------------------------------===// -// SetAllocatorOp -//===----------------------------------------------------------------------===// - -llvm::LogicalResult cuf::SetAllocatorIndexOp::verify() { - if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType()))) - return emitOpError( - "expect box to be a reference to class or box type value"); - return mlir::success(); -} - // Tablegen operators #define GET_OP_CLASSES diff --git a/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp b/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp index 034f8c74ec79..f16072a90dfa 100644 --- a/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp +++ b/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp @@ -68,3 +68,31 @@ fir::FortranVariableOpInterface::verifyDeclareLikeOpImpl(mlir::Value memref) { } return mlir::success(); } + +mlir::LogicalResult +fir::detail::verifyFortranVariableStorageOpInterface(mlir::Operation *op) { + auto storageIface = mlir::cast<fir::FortranVariableStorageOpInterface>(op); + mlir::Value storage = storageIface.getStorage(); + std::uint64_t storageOffset = storageIface.getStorageOffset(); + if (!storage) { + if (storageOffset != 0) + return op->emitOpError( + "storage offset specified without the storage reference"); + return mlir::success(); + } + + auto storageType = + mlir::dyn_cast<fir::SequenceType>(fir::unwrapRefType(storage.getType())); + if (!storageType || storageType.getDimension() != 1) + return op->emitOpError("storage must be a vector"); + if (storageType.hasDynamicExtents()) + return op->emitOpError("storage must have known extent"); + if (storageType.getEleTy() != mlir::IntegerType::get(op->getContext(), 8)) + return op->emitOpError("storage must be an array of i8 elements"); + if (storageOffset > storageType.getConstantArraySize()) + return op->emitOpError("storage offset exceeds the storage size"); + // TODO: we should probably verify that the (offset + sizeof(var)) + // is within the storage object, but this requires mlir::DataLayout. + // Can we make it available during the verification? + return mlir::success(); +} diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index 3c5095da0145..1a63b1bea317 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -265,7 +265,8 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value memref, llvm::StringRef uniq_name, mlir::Value shape, mlir::ValueRange typeparams, - mlir::Value dummy_scope, + mlir::Value dummy_scope, mlir::Value storage, + std::uint64_t storage_offset, fir::FortranVariableFlagsAttr fortran_attrs, cuf::DataAttributeAttr data_attr) { auto nameAttr = builder.getStringAttr(uniq_name); @@ -279,7 +280,8 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder, auto [hlfirVariableType, firVarType] = getDeclareOutputTypes(inputType, hasExplicitLbs); build(builder, result, {hlfirVariableType, firVarType}, memref, shape, - typeparams, dummy_scope, nameAttr, fortran_attrs, data_attr); + typeparams, dummy_scope, storage, storage_offset, nameAttr, + fortran_attrs, data_attr); } llvm::LogicalResult hlfir::DeclareOp::verify() { @@ -821,6 +823,84 @@ void hlfir::ConcatOp::getEffects( } //===----------------------------------------------------------------------===// +// CmpCharOp +//===----------------------------------------------------------------------===// + +llvm::LogicalResult hlfir::CmpCharOp::verify() { + mlir::Value lchr = getLchr(); + mlir::Value rchr = getRchr(); + + unsigned kind = getCharacterKind(lchr.getType()); + if (kind != getCharacterKind(rchr.getType())) + return emitOpError("character arguments must have the same KIND"); + + switch (getPredicate()) { + case mlir::arith::CmpIPredicate::slt: + case mlir::arith::CmpIPredicate::sle: + case mlir::arith::CmpIPredicate::eq: + case mlir::arith::CmpIPredicate::ne: + case mlir::arith::CmpIPredicate::sgt: + case mlir::arith::CmpIPredicate::sge: + break; + default: + return emitOpError("expected signed predicate"); + } + + return mlir::success(); +} + +void hlfir::CmpCharOp::getEffects( + llvm::SmallVectorImpl< + mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> + &effects) { + getIntrinsicEffects(getOperation(), effects); +} + +//===----------------------------------------------------------------------===// +// CharTrimOp +//===----------------------------------------------------------------------===// + +void hlfir::CharTrimOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Value chr) { + unsigned kind = getCharacterKind(chr.getType()); + auto resultType = hlfir::ExprType::get( + builder.getContext(), hlfir::ExprType::Shape{}, + fir::CharacterType::get(builder.getContext(), kind, + fir::CharacterType::unknownLen()), + /*polymorphic=*/false); + build(builder, result, resultType, chr); +} + +void hlfir::CharTrimOp::getEffects( + llvm::SmallVectorImpl< + mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> + &effects) { + getIntrinsicEffects(getOperation(), effects); +} + +//===----------------------------------------------------------------------===// +// IndexOp +//===----------------------------------------------------------------------===// + +llvm::LogicalResult hlfir::IndexOp::verify() { + mlir::Value substr = getSubstr(); + mlir::Value str = getStr(); + + unsigned charKind = getCharacterKind(substr.getType()); + if (charKind != getCharacterKind(str.getType())) + return emitOpError("character arguments must have the same KIND"); + + return mlir::success(); +} + +void hlfir::IndexOp::getEffects( + llvm::SmallVectorImpl< + mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> + &effects) { + getIntrinsicEffects(getOperation(), effects); +} + +//===----------------------------------------------------------------------===// // NumericalReductionOp //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index 886a8a59e744..1c77636d301e 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -118,7 +118,8 @@ createArrayTemp(mlir::Location loc, fir::FirOpBuilder &builder, fir::FortranVariableFlagsAttr attrs) -> mlir::Value { auto declareOp = hlfir::DeclareOp::create(builder, loc, memref, name, shape, typeParams, - /*dummy_scope=*/nullptr, attrs); + /*dummy_scope=*/nullptr, /*storage=*/nullptr, + /*storage_offset=*/0, attrs); return declareOp.getBase(); }; @@ -298,8 +299,7 @@ struct SetLengthOpConversion auto alloca = builder.createTemporary(loc, charType, tmpName, /*shape=*/{}, lenParams); auto declareOp = hlfir::DeclareOp::create( - builder, loc, alloca, tmpName, /*shape=*/mlir::Value{}, lenParams, - /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{}); + builder, loc, alloca, tmpName, /*shape=*/mlir::Value{}, lenParams); hlfir::Entity temp{declareOp.getBase()}; // Assign string value to the created temp. hlfir::AssignOp::create(builder, loc, string, temp, diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt index cc74273d9c5d..3775a13e31e9 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt @@ -27,6 +27,8 @@ add_flang_library(HLFIRTransforms FIRSupport FIRTransforms FlangOpenMPTransforms + FortranEvaluate + FortranSupport HLFIRDialect LINK_COMPONENTS diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp index 4e7de4732357..8104e53920c2 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -305,6 +305,8 @@ public: auto firDeclareOp = fir::DeclareOp::create( rewriter, loc, memref.getType(), memref, declareOp.getShape(), declareOp.getTypeparams(), declareOp.getDummyScope(), + /*storage=*/declareOp.getStorage(), + /*storage_offset=*/declareOp.getStorageOffset(), declareOp.getUniqName(), fortranAttrs, dataAttr); // Propagate other attributes from hlfir.declare to fir.declare. diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRCopyIn.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRCopyIn.cpp index e1df01e0e2ee..b4e89b0966e9 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRCopyIn.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRCopyIn.cpp @@ -109,7 +109,9 @@ InlineCopyInConversion::matchAndRewrite(hlfir::CopyInOp copyIn, auto declareOp = hlfir::DeclareOp::create(builder, loc, alloc, tmpName, shape, lenParams, - /*dummy_scope=*/nullptr); + /*dummy_scope=*/nullptr, + /*storage=*/nullptr, + /*storage_offset=*/0); hlfir::Entity temp{declareOp.getBase()}; hlfir::LoopNest loopNest = hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp index e0167cc12b8a..4239e579ae70 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -551,6 +551,107 @@ class ReshapeOpConversion : public HlfirIntrinsicConversion<hlfir::ReshapeOp> { } }; +class CmpCharOpConversion : public HlfirIntrinsicConversion<hlfir::CmpCharOp> { + using HlfirIntrinsicConversion<hlfir::CmpCharOp>::HlfirIntrinsicConversion; + + llvm::LogicalResult + matchAndRewrite(hlfir::CmpCharOp cmp, + mlir::PatternRewriter &rewriter) const override { + fir::FirOpBuilder builder{rewriter, cmp.getOperation()}; + const mlir::Location &loc = cmp->getLoc(); + hlfir::Entity lhs{cmp.getLchr()}; + hlfir::Entity rhs{cmp.getRchr()}; + + auto [lhsExv, lhsCleanUp] = + hlfir::translateToExtendedValue(loc, builder, lhs); + auto [rhsExv, rhsCleanUp] = + hlfir::translateToExtendedValue(loc, builder, rhs); + + auto resultVal = fir::runtime::genCharCompare( + builder, loc, cmp.getPredicate(), lhsExv, rhsExv); + if (lhsCleanUp || rhsCleanUp) { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(cmp); + if (lhsCleanUp) + (*lhsCleanUp)(); + if (rhsCleanUp) + (*rhsCleanUp)(); + } + auto resultEntity = hlfir::EntityWithAttributes{resultVal}; + + processReturnValue(cmp, resultEntity, /*mustBeFreed=*/false, builder, + rewriter); + return mlir::success(); + } +}; + +class CharTrimOpConversion + : public HlfirIntrinsicConversion<hlfir::CharTrimOp> { + using HlfirIntrinsicConversion<hlfir::CharTrimOp>::HlfirIntrinsicConversion; + + llvm::LogicalResult + matchAndRewrite(hlfir::CharTrimOp trim, + mlir::PatternRewriter &rewriter) const override { + fir::FirOpBuilder builder{rewriter, trim.getOperation()}; + const mlir::Location &loc = trim->getLoc(); + + llvm::SmallVector<IntrinsicArgument, 1> inArgs; + mlir::Value chr = trim.getChr(); + inArgs.push_back({chr, chr.getType()}); + + auto *argLowering = fir::getIntrinsicArgumentLowering("trim"); + llvm::SmallVector<fir::ExtendedValue, 1> args = + lowerArguments(trim, inArgs, rewriter, argLowering); + + mlir::Type resultType = hlfir::getFortranElementType(trim.getType()); + + auto [resultExv, mustBeFreed] = + fir::genIntrinsicCall(builder, loc, "trim", resultType, args); + + processReturnValue(trim, resultExv, mustBeFreed, builder, rewriter); + return mlir::success(); + } +}; + +class IndexOpConversion : public HlfirIntrinsicConversion<hlfir::IndexOp> { + using HlfirIntrinsicConversion<hlfir::IndexOp>::HlfirIntrinsicConversion; + + llvm::LogicalResult + matchAndRewrite(hlfir::IndexOp op, + mlir::PatternRewriter &rewriter) const override { + fir::FirOpBuilder builder{rewriter, op.getOperation()}; + const mlir::Location &loc = op->getLoc(); + hlfir::Entity substr{op.getSubstr()}; + hlfir::Entity str{op.getStr()}; + + auto [substrExv, substrCleanUp] = + hlfir::translateToExtendedValue(loc, builder, substr); + auto [strExv, strCleanUp] = + hlfir::translateToExtendedValue(loc, builder, str); + + mlir::Value back = op.getBack(); + if (!back) + back = builder.createBool(loc, false); + + mlir::Value result = + fir::runtime::genIndex(builder, loc, strExv, substrExv, back); + result = builder.createConvert(loc, op.getType(), result); + if (strCleanUp || substrCleanUp) { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(op); + if (strCleanUp) + (*strCleanUp)(); + if (substrCleanUp) + (*substrCleanUp)(); + } + auto resultEntity = hlfir::EntityWithAttributes{result}; + + processReturnValue(op, resultEntity, /*mustBeFreed=*/false, builder, + rewriter); + return mlir::success(); + } +}; + class LowerHLFIRIntrinsics : public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> { public: @@ -564,7 +665,8 @@ public: TransposeOpConversion, CountOpConversion, DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion, MaxlocOpConversion, ArrayShiftOpConversion<hlfir::CShiftOp>, - ArrayShiftOpConversion<hlfir::EOShiftOp>, ReshapeOpConversion>(context); + ArrayShiftOpConversion<hlfir::EOShiftOp>, ReshapeOpConversion, + CmpCharOpConversion, CharTrimOpConversion, IndexOpConversion>(context); // While conceptually this pass is performing dialect conversion, we use // pattern rewrites here instead of dialect conversion because this pass diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index fe12f49c655b..d8e36ea294cd 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -2078,6 +2078,212 @@ private: } }; +class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> { +public: + using mlir::OpRewritePattern<hlfir::CmpCharOp>::OpRewritePattern; + + llvm::LogicalResult + matchAndRewrite(hlfir::CmpCharOp cmp, + mlir::PatternRewriter &rewriter) const override { + + fir::FirOpBuilder builder{rewriter, cmp.getOperation()}; + const mlir::Location &loc = cmp->getLoc(); + + auto toVariable = + [&builder, + &loc](mlir::Value val) -> std::pair<mlir::Value, hlfir::AssociateOp> { + mlir::Value opnd; + hlfir::AssociateOp associate; + if (mlir::isa<hlfir::ExprType>(val.getType())) { + hlfir::Entity entity{val}; + mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder); + associate = hlfir::genAssociateExpr(loc, builder, entity, + entity.getType(), "", byRefAttr); + opnd = associate.getBase(); + } else { + opnd = val; + } + return {opnd, associate}; + }; + + auto [lhsOpnd, lhsAssociate] = toVariable(cmp.getLchr()); + auto [rhsOpnd, rhsAssociate] = toVariable(cmp.getRchr()); + + hlfir::Entity lhs{lhsOpnd}; + hlfir::Entity rhs{rhsOpnd}; + + auto charTy = mlir::cast<fir::CharacterType>(lhs.getFortranElementType()); + unsigned kind = charTy.getFKind(); + + auto bits = builder.getKindMap().getCharacterBitsize(kind); + auto intTy = builder.getIntegerType(bits); + + auto idxTy = builder.getIndexType(); + auto charLen1Ty = + fir::CharacterType::getSingleton(builder.getContext(), kind); + mlir::Type designatorType = + fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy)); + auto idxAttr = builder.getIntegerAttr(idxTy, 0); + + auto genExtractAndConvertToInt = + [&idxAttr, &intTy, &designatorType]( + mlir::Location loc, fir::FirOpBuilder &builder, + hlfir::Entity &charStr, mlir::Value index, mlir::Value length) { + auto singleChr = hlfir::DesignateOp::create( + builder, loc, designatorType, charStr, /*component=*/{}, + /*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{}, + /*substring=*/mlir::ValueRange{index, index}, + /*complexPart=*/std::nullopt, + /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{length}, + fir::FortranVariableFlagsAttr{}); + auto chrVal = fir::LoadOp::create(builder, loc, singleChr); + mlir::Value intVal = fir::ExtractValueOp::create( + builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr)); + return intVal; + }; + + mlir::arith::CmpIPredicate predicate = cmp.getPredicate(); + mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1); + + mlir::Value lhsLen = builder.createConvert( + loc, idxTy, hlfir::genCharLength(loc, builder, lhs)); + mlir::Value rhsLen = builder.createConvert( + loc, idxTy, hlfir::genCharLength(loc, builder, rhs)); + + enum class GenCmp { LeftToRight, LeftToBlank, BlankToRight }; + + mlir::Value zeroInt = builder.createIntegerConstant(loc, intTy, 0); + mlir::Value oneInt = builder.createIntegerConstant(loc, intTy, 1); + mlir::Value negOneInt = builder.createIntegerConstant(loc, intTy, -1); + mlir::Value blankInt = builder.createIntegerConstant(loc, intTy, ' '); + + auto step = GenCmp::LeftToRight; + auto genCmp = [&](mlir::Location loc, fir::FirOpBuilder &builder, + mlir::ValueRange index, mlir::ValueRange reductionArgs) + -> llvm::SmallVector<mlir::Value, 1> { + assert(index.size() == 1 && "expected single loop"); + assert(reductionArgs.size() == 1 && "expected single reduction value"); + mlir::Value inRes = reductionArgs[0]; + auto accEQzero = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroInt); + + mlir::Value res = + builder + .genIfOp(loc, {intTy}, accEQzero, + /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value offset = + builder.createConvert(loc, idxTy, index[0]); + mlir::Value lhsInt; + mlir::Value rhsInt; + if (step == GenCmp::LeftToRight) { + lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset, + oneIdx); + rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset, + oneIdx); + } else if (step == GenCmp::LeftToBlank) { + // lhsLen > rhsLen + offset = + mlir::arith::AddIOp::create(builder, loc, rhsLen, offset); + + lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset, + oneIdx); + rhsInt = blankInt; + } else if (step == GenCmp::BlankToRight) { + // rhsLen > lhsLen + offset = + mlir::arith::AddIOp::create(builder, loc, lhsLen, offset); + + lhsInt = blankInt; + rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset, + oneIdx); + } else { + llvm_unreachable( + "unknown compare step for CmpCharOp lowering"); + } + + mlir::Value newVal = mlir::arith::SelectOp::create( + builder, loc, + mlir::arith::CmpIOp::create(builder, loc, + mlir::arith::CmpIPredicate::ult, + lhsInt, rhsInt), + negOneInt, inRes); + newVal = mlir::arith::SelectOp::create( + builder, loc, + mlir::arith::CmpIOp::create(builder, loc, + mlir::arith::CmpIPredicate::ugt, + lhsInt, rhsInt), + oneInt, newVal); + fir::ResultOp::create(builder, loc, newVal); + }) + .genElse([&]() { fir::ResultOp::create(builder, loc, inRes); }) + .getResults()[0]; + + return {res}; + }; + + // First generate comparison of two strings for the legth of the shorter + // one. + mlir::Value minLen = mlir::arith::SelectOp::create( + builder, loc, + mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::slt, lhsLen, rhsLen), + lhsLen, rhsLen); + + llvm::SmallVector<mlir::Value, 1> loopOut = + hlfir::genLoopNestWithReductions(loc, builder, {minLen}, + /*reductionInits=*/{zeroInt}, genCmp, + /*isUnordered=*/false); + mlir::Value partRes = loopOut[0]; + + auto lhsLonger = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::sgt, lhsLen, rhsLen); + mlir::Value tempRes = + builder + .genIfOp(loc, {intTy}, lhsLonger, + /*withElseRegion=*/true) + .genThen([&]() { + // If left is the longer string generate compare left to blank. + step = GenCmp::LeftToBlank; + auto lenDiff = + mlir::arith::SubIOp::create(builder, loc, lhsLen, rhsLen); + + llvm::SmallVector<mlir::Value, 1> output = + hlfir::genLoopNestWithReductions(loc, builder, {lenDiff}, + /*reductionInits=*/{partRes}, + genCmp, + /*isUnordered=*/false); + mlir::Value res = output[0]; + fir::ResultOp::create(builder, loc, res); + }) + .genElse([&]() { + // If right is the longer string generate compare blank to + // right. + step = GenCmp::BlankToRight; + auto lenDiff = + mlir::arith::SubIOp::create(builder, loc, rhsLen, lhsLen); + llvm::SmallVector<mlir::Value, 1> output = + hlfir::genLoopNestWithReductions(loc, builder, {lenDiff}, + /*reductionInits=*/{partRes}, + genCmp, + /*isUnordered=*/false); + + mlir::Value res = output[0]; + fir::ResultOp::create(builder, loc, res); + }) + .getResults()[0]; + if (lhsAssociate) + hlfir::EndAssociateOp::create(builder, loc, lhsAssociate); + if (rhsAssociate) + hlfir::EndAssociateOp::create(builder, loc, rhsAssociate); + + auto finalCmpResult = + mlir::arith::CmpIOp::create(builder, loc, predicate, tempRes, zeroInt); + rewriter.replaceOp(cmp, finalCmpResult); + return mlir::success(); + } +}; + template <typename Op> class MatmulConversion : public mlir::OpRewritePattern<Op> { public: @@ -2748,8 +2954,8 @@ public: patterns.insert<ReductionConversion<hlfir::SumOp>>(context); patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context); patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context); + patterns.insert<CmpCharOpConversion>(context); patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context); - patterns.insert<ReductionConversion<hlfir::CountOp>>(context); patterns.insert<ReductionConversion<hlfir::AnyOp>>(context); patterns.insert<ReductionConversion<hlfir::AllOp>>(context); diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp index 5b6d904fb0d5..f4b173575d87 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp @@ -271,8 +271,6 @@ generateSeqTyAccBounds(fir::SequenceType seqType, mlir::Value var, mlir::Value extent = val; mlir::Value upperbound = mlir::arith::SubIOp::create(builder, loc, extent, one); - upperbound = mlir::arith::AddIOp::create(builder, loc, lowerbound, - upperbound); mlir::Value stride = one; if (strideIncludeLowerExtent) { stride = cummulativeExtent; @@ -552,10 +550,7 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit( auto getDeclareOpForType = [&](mlir::Type ty) -> hlfir::DeclareOp { auto alloca = fir::AllocaOp::create(firBuilder, loc, ty); - return hlfir::DeclareOp::create( - firBuilder, loc, alloca, varName, /*shape=*/nullptr, - llvm::ArrayRef<mlir::Value>{}, - /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{}); + return hlfir::DeclareOp::create(firBuilder, loc, alloca, varName); }; if (fir::isa_trivial(unwrappedTy)) { @@ -576,10 +571,8 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit( } auto alloca = fir::AllocaOp::create( firBuilder, loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents); - auto declareOp = hlfir::DeclareOp::create( - firBuilder, loc, alloca, varName, shape, - llvm::ArrayRef<mlir::Value>{}, - /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{}); + auto declareOp = + hlfir::DeclareOp::create(firBuilder, loc, alloca, varName, shape); if (initVal) { mlir::Type idxTy = firBuilder.getIndexType(); diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt index e0aebd0714c8..b85ee7e861a4 100644 --- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt @@ -26,6 +26,7 @@ add_flang_library(FlangOpenMPTransforms FIRSupport FortranSupport HLFIRDialect + FortranUtils MLIR_DEPS ${dialect_libs} diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 2b3ac169e8b5..6c7192400084 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -6,17 +6,23 @@ // //===----------------------------------------------------------------------===// +#include "flang/Optimizer/Builder/DirectivesCommon.h" #include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/OpenMP/Passes.h" #include "flang/Optimizer/OpenMP/Utils.h" #include "flang/Support/OpenMP-utils.h" +#include "flang/Utils/OpenMP.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/IRMapping.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" namespace flangomp { #define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS @@ -107,6 +113,33 @@ private: using InductionVariableInfos = llvm::SmallVector<InductionVariableInfo>; +/// Collect the list of values used inside the loop but defined outside of it. +void collectLoopLiveIns(fir::DoConcurrentLoopOp loop, + llvm::SmallVectorImpl<mlir::Value> &liveIns) { + llvm::SmallDenseSet<mlir::Value> seenValues; + llvm::SmallPtrSet<mlir::Operation *, 8> seenOps; + + for (auto [lb, ub, st] : llvm::zip_equal( + loop.getLowerBound(), loop.getUpperBound(), loop.getStep())) { + liveIns.push_back(lb); + liveIns.push_back(ub); + liveIns.push_back(st); + } + + mlir::visitUsedValuesDefinedAbove( + loop.getRegion(), [&](mlir::OpOperand *operand) { + if (!seenValues.insert(operand->get()).second) + return; + + mlir::Operation *definingOp = operand->get().getDefiningOp(); + // We want to collect ops corresponding to live-ins only once. + if (definingOp && !seenOps.insert(definingOp).second) + return; + + liveIns.push_back(operand->get()); + }); +} + /// Collects values that are local to a loop: "loop-local values". A loop-local /// value is one that is used exclusively inside the loop but allocated outside /// of it. This usually corresponds to temporary values that are used inside the @@ -168,22 +201,66 @@ static void localizeLoopLocalValue(mlir::Value local, mlir::Region &allocRegion, class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoConcurrentOp> { +private: + struct TargetDeclareShapeCreationInfo { + // Note: We use `std::vector` (rather than `llvm::SmallVector` as usual) to + // interface more easily `ShapeShiftOp::getOrigins()` which returns + // `std::vector`. + std::vector<mlir::Value> startIndices; + std::vector<mlir::Value> extents; + + TargetDeclareShapeCreationInfo(mlir::Value liveIn) { + mlir::Value shape = nullptr; + mlir::Operation *liveInDefiningOp = liveIn.getDefiningOp(); + auto declareOp = + mlir::dyn_cast_if_present<hlfir::DeclareOp>(liveInDefiningOp); + + if (declareOp != nullptr) + shape = declareOp.getShape(); + + if (!shape) + return; + + auto shapeOp = + mlir::dyn_cast_if_present<fir::ShapeOp>(shape.getDefiningOp()); + auto shapeShiftOp = + mlir::dyn_cast_if_present<fir::ShapeShiftOp>(shape.getDefiningOp()); + + if (!shapeOp && !shapeShiftOp) + TODO(liveIn.getLoc(), + "Shapes not defined by `fir.shape` or `fir.shape_shift` op's are" + "not supported yet."); + + if (shapeShiftOp != nullptr) + startIndices = shapeShiftOp.getOrigins(); + + extents = shapeOp != nullptr + ? std::vector<mlir::Value>(shapeOp.getExtents().begin(), + shapeOp.getExtents().end()) + : shapeShiftOp.getExtents(); + } + + bool isShapedValue() const { return !extents.empty(); } + bool isShapeShiftedValue() const { return !startIndices.empty(); } + }; + + using LiveInShapeInfoMap = + llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>; + public: using mlir::OpConversionPattern<fir::DoConcurrentOp>::OpConversionPattern; DoConcurrentConversion( mlir::MLIRContext *context, bool mapToDevice, - llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip) + llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip, + mlir::SymbolTable &moduleSymbolTable) : OpConversionPattern(context), mapToDevice(mapToDevice), - concurrentLoopsToSkip(concurrentLoopsToSkip) {} + concurrentLoopsToSkip(concurrentLoopsToSkip), + moduleSymbolTable(moduleSymbolTable) {} mlir::LogicalResult matchAndRewrite(fir::DoConcurrentOp doLoop, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - if (mapToDevice) - return doLoop.emitError( - "not yet implemented: Mapping `do concurrent` loops to device"); - looputils::InductionVariableInfos ivInfos; auto loop = mlir::cast<fir::DoConcurrentLoopOp>( doLoop.getRegion().back().getTerminator()); @@ -194,20 +271,72 @@ public: for (mlir::Value indVar : *indVars) ivInfos.emplace_back(loop, indVar); + llvm::SmallVector<mlir::Value> loopNestLiveIns; + looputils::collectLoopLiveIns(loop, loopNestLiveIns); + assert(!loopNestLiveIns.empty()); + llvm::SetVector<mlir::Value> locals; looputils::collectLoopLocalValues(loop, locals); + // We do not want to map "loop-local" values to the device through + // `omp.map.info` ops. Therefore, we remove them from the list of live-ins. + loopNestLiveIns.erase(llvm::remove_if(loopNestLiveIns, + [&](mlir::Value liveIn) { + return locals.contains(liveIn); + }), + loopNestLiveIns.end()); + + mlir::omp::TargetOp targetOp; + mlir::omp::LoopNestOperands loopNestClauseOps; + mlir::IRMapping mapper; + + if (mapToDevice) { + mlir::ModuleOp module = doLoop->getParentOfType<mlir::ModuleOp>(); + bool isTargetDevice = + llvm::cast<mlir::omp::OffloadModuleInterface>(*module) + .getIsTargetDevice(); + + mlir::omp::TargetOperands targetClauseOps; + genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper, + loopNestClauseOps, + isTargetDevice ? nullptr : &targetClauseOps); + + LiveInShapeInfoMap liveInShapeInfoMap; + fir::FirOpBuilder builder( + rewriter, + fir::getKindMapping(doLoop->getParentOfType<mlir::ModuleOp>())); + + for (mlir::Value liveIn : loopNestLiveIns) { + targetClauseOps.mapVars.push_back( + genMapInfoOpForLiveIn(builder, liveIn)); + liveInShapeInfoMap.insert( + {liveIn, TargetDeclareShapeCreationInfo(liveIn)}); + } + + targetOp = + genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns, + targetClauseOps, loopNestClauseOps, liveInShapeInfoMap); + genTeamsOp(doLoop.getLoc(), rewriter); + } + mlir::omp::ParallelOp parallelOp = genParallelOp(doLoop.getLoc(), rewriter, ivInfos, mapper); - mlir::omp::LoopNestOperands loopNestClauseOps; - genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper, - loopNestClauseOps); + + // Only set as composite when part of `distribute parallel do`. + parallelOp.setComposite(mapToDevice); + + if (!mapToDevice) + genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper, + loopNestClauseOps); for (mlir::Value local : locals) looputils::localizeLoopLocalValue(local, parallelOp.getRegion(), rewriter); + if (mapToDevice) + genDistributeOp(doLoop.getLoc(), rewriter).setComposite(/*val=*/true); + mlir::omp::LoopNestOp ompLoopNest = genWsLoopOp(rewriter, loop, mapper, loopNestClauseOps, /*isComposite=*/mapToDevice); @@ -282,11 +411,11 @@ private: return result; } - void - genLoopNestClauseOps(mlir::Location loc, - mlir::ConversionPatternRewriter &rewriter, - fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper, - mlir::omp::LoopNestOperands &loopNestClauseOps) const { + void genLoopNestClauseOps( + mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, + fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper, + mlir::omp::LoopNestOperands &loopNestClauseOps, + mlir::omp::TargetOperands *targetClauseOps = nullptr) const { assert(loopNestClauseOps.loopLowerBounds.empty() && "Loop nest bounds were already emitted!"); @@ -295,11 +424,21 @@ private: bounds.push_back(var.getDefiningOp()->getResult(0)); }; + auto hostEvalCapture = [&](mlir::Value var, + llvm::SmallVectorImpl<mlir::Value> &bounds) { + populateBounds(var, bounds); + + // Ensure that loop-nest bounds are evaluated in the host and forwarded to + // the nested omp constructs when we map to the device. + if (targetClauseOps) + targetClauseOps->hostEvalVars.push_back(var); + }; + for (auto [lb, ub, st] : llvm::zip_equal( loop.getLowerBound(), loop.getUpperBound(), loop.getStep())) { - populateBounds(lb, loopNestClauseOps.loopLowerBounds); - populateBounds(ub, loopNestClauseOps.loopUpperBounds); - populateBounds(st, loopNestClauseOps.loopSteps); + hostEvalCapture(lb, loopNestClauseOps.loopLowerBounds); + hostEvalCapture(ub, loopNestClauseOps.loopUpperBounds); + hostEvalCapture(st, loopNestClauseOps.loopSteps); } loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); @@ -332,8 +471,8 @@ private: loop.getLocalVars(), loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(), loop.getRegionLocalArgs())) { - auto localizer = mlir::SymbolTable::lookupNearestSymbolFrom< - fir::LocalitySpecifierOp>(loop, sym); + auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>( + sym.getLeafReference()); if (localizer.getLocalitySpecifierType() == fir::LocalitySpecifierType::LocalInit) TODO(localizer.getLoc(), @@ -352,6 +491,8 @@ private: cloneFIRRegionToOMP(localizer.getDeallocRegion(), privatizer.getDeallocRegion()); + moduleSymbolTable.insert(privatizer); + wsloopClauseOps.privateVars.push_back(op); wsloopClauseOps.privateSyms.push_back( mlir::SymbolRefAttr::get(privatizer)); @@ -362,28 +503,34 @@ private: loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(), loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(), loop.getRegionReduceArgs())) { - auto firReducer = - mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>( - loop, sym); + auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>( + sym.getLeafReference()); mlir::OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(firReducer); - - auto ompReducer = mlir::omp::DeclareReductionOp::create( - rewriter, firReducer.getLoc(), - sym.getLeafReference().str() + ".omp", - firReducer.getTypeAttr().getValue()); - - cloneFIRRegionToOMP(firReducer.getAllocRegion(), - ompReducer.getAllocRegion()); - cloneFIRRegionToOMP(firReducer.getInitializerRegion(), - ompReducer.getInitializerRegion()); - cloneFIRRegionToOMP(firReducer.getReductionRegion(), - ompReducer.getReductionRegion()); - cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(), - ompReducer.getAtomicReductionRegion()); - cloneFIRRegionToOMP(firReducer.getCleanupRegion(), - ompReducer.getCleanupRegion()); + std::string ompReducerName = sym.getLeafReference().str() + ".omp"; + + auto ompReducer = + moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>( + rewriter.getStringAttr(ompReducerName)); + + if (!ompReducer) { + ompReducer = mlir::omp::DeclareReductionOp::create( + rewriter, firReducer.getLoc(), ompReducerName, + firReducer.getTypeAttr().getValue()); + + cloneFIRRegionToOMP(firReducer.getAllocRegion(), + ompReducer.getAllocRegion()); + cloneFIRRegionToOMP(firReducer.getInitializerRegion(), + ompReducer.getInitializerRegion()); + cloneFIRRegionToOMP(firReducer.getReductionRegion(), + ompReducer.getReductionRegion()); + cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(), + ompReducer.getAtomicReductionRegion()); + cloneFIRRegionToOMP(firReducer.getCleanupRegion(), + ompReducer.getCleanupRegion()); + moduleSymbolTable.insert(ompReducer); + } wsloopClauseOps.reductionVars.push_back(op); wsloopClauseOps.reductionByref.push_back(byRef); @@ -429,8 +576,262 @@ private: return loopNestOp; } + void genBoundsOps(fir::FirOpBuilder &builder, mlir::Value liveIn, + mlir::Value rawAddr, + llvm::SmallVectorImpl<mlir::Value> &boundsOps) const { + fir::ExtendedValue extVal = + hlfir::translateToExtendedValue(rawAddr.getLoc(), builder, + hlfir::Entity{liveIn}, + /*contiguousHint=*/ + true) + .first; + fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( + builder, rawAddr, /*isOptional=*/false, rawAddr.getLoc()); + boundsOps = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, + mlir::omp::MapBoundsType>( + builder, info, extVal, + /*dataExvIsAssumedSize=*/false, rawAddr.getLoc()); + } + + mlir::omp::MapInfoOp genMapInfoOpForLiveIn(fir::FirOpBuilder &builder, + mlir::Value liveIn) const { + mlir::Value rawAddr = liveIn; + llvm::StringRef name; + + mlir::Operation *liveInDefiningOp = liveIn.getDefiningOp(); + auto declareOp = + mlir::dyn_cast_if_present<hlfir::DeclareOp>(liveInDefiningOp); + + if (declareOp != nullptr) { + // Use the raw address to avoid unboxing `fir.box` values whenever + // possible. Put differently, if we have access to the direct value memory + // reference/address, we use it. + rawAddr = declareOp.getOriginalBase(); + name = declareOp.getUniqName(); + } + + if (!llvm::isa<mlir::omp::PointerLikeType>(rawAddr.getType())) { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(liveInDefiningOp); + auto copyVal = builder.createTemporary(liveIn.getLoc(), liveIn.getType()); + builder.createStoreWithConvert(copyVal.getLoc(), liveIn, copyVal); + rawAddr = copyVal; + } + + mlir::Type liveInType = liveIn.getType(); + mlir::Type eleType = liveInType; + if (auto refType = mlir::dyn_cast<fir::ReferenceType>(liveInType)) + eleType = refType.getElementType(); + + llvm::omp::OpenMPOffloadMappingFlags mapFlag = + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mlir::omp::VariableCaptureKind captureKind = + mlir::omp::VariableCaptureKind::ByRef; + + if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) { + captureKind = mlir::omp::VariableCaptureKind::ByCopy; + } else if (!fir::isa_builtin_cptr_type(eleType)) { + mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + } + + llvm::SmallVector<mlir::Value> boundsOps; + genBoundsOps(builder, liveIn, rawAddr, boundsOps); + + return Fortran::utils::openmp::createMapInfoOp( + builder, liveIn.getLoc(), rawAddr, + /*varPtrPtr=*/{}, name.str(), boundsOps, + /*members=*/{}, + /*membersIndex=*/mlir::ArrayAttr{}, + static_cast< + std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( + mapFlag), + captureKind, rawAddr.getType()); + } + + mlir::omp::TargetOp + genTargetOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, + mlir::IRMapping &mapper, llvm::ArrayRef<mlir::Value> mappedVars, + mlir::omp::TargetOperands &clauseOps, + mlir::omp::LoopNestOperands &loopNestClauseOps, + const LiveInShapeInfoMap &liveInShapeInfoMap) const { + auto targetOp = rewriter.create<mlir::omp::TargetOp>(loc, clauseOps); + auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp); + + mlir::Region ®ion = targetOp.getRegion(); + + llvm::SmallVector<mlir::Type> regionArgTypes; + llvm::SmallVector<mlir::Location> regionArgLocs; + + for (auto var : llvm::concat<const mlir::Value>(clauseOps.hostEvalVars, + clauseOps.mapVars)) { + regionArgTypes.push_back(var.getType()); + regionArgLocs.push_back(var.getLoc()); + } + + rewriter.createBlock(®ion, {}, regionArgTypes, regionArgLocs); + fir::FirOpBuilder builder( + rewriter, + fir::getKindMapping(targetOp->getParentOfType<mlir::ModuleOp>())); + + // Within the loop, it is possible that we discover other values that need + // to be mapped to the target region (the shape info values for arrays, for + // example). Therefore, the map block args might be extended and resized. + // Hence, we invoke `argIface.getMapBlockArgs()` every iteration to make + // sure we access the proper vector of data. + int idx = 0; + for (auto [mapInfoOp, mappedVar] : + llvm::zip_equal(clauseOps.mapVars, mappedVars)) { + auto miOp = mlir::cast<mlir::omp::MapInfoOp>(mapInfoOp.getDefiningOp()); + hlfir::DeclareOp liveInDeclare = + genLiveInDeclare(builder, targetOp, argIface.getMapBlockArgs()[idx], + miOp, liveInShapeInfoMap.at(mappedVar)); + ++idx; + + // If `mappedVar.getDefiningOp()` is a `fir::BoxAddrOp`, we probably + // need to "unpack" the box by getting the defining op of it's value. + // However, we did not hit this case in reality yet so leaving it as a + // todo for now. + if (mlir::isa<fir::BoxAddrOp>(mappedVar.getDefiningOp())) + TODO(mappedVar.getLoc(), + "Mapped variabled defined by `BoxAddrOp` are not supported yet"); + + auto mapHostValueToDevice = [&](mlir::Value hostValue, + mlir::Value deviceValue) { + if (!llvm::isa<mlir::omp::PointerLikeType>(hostValue.getType())) + mapper.map(hostValue, + builder.loadIfRef(hostValue.getLoc(), deviceValue)); + else + mapper.map(hostValue, deviceValue); + }; + + mapHostValueToDevice(mappedVar, liveInDeclare.getOriginalBase()); + + if (auto origDeclareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>( + mappedVar.getDefiningOp())) + mapHostValueToDevice(origDeclareOp.getBase(), liveInDeclare.getBase()); + } + + for (auto [arg, hostEval] : llvm::zip_equal(argIface.getHostEvalBlockArgs(), + clauseOps.hostEvalVars)) + mapper.map(hostEval, arg); + + for (unsigned i = 0; i < loopNestClauseOps.loopLowerBounds.size(); ++i) { + loopNestClauseOps.loopLowerBounds[i] = + mapper.lookup(loopNestClauseOps.loopLowerBounds[i]); + loopNestClauseOps.loopUpperBounds[i] = + mapper.lookup(loopNestClauseOps.loopUpperBounds[i]); + loopNestClauseOps.loopSteps[i] = + mapper.lookup(loopNestClauseOps.loopSteps[i]); + } + + // Check if cloning the bounds introduced any dependency on the outer + // region. If so, then either clone them as well if they are + // MemoryEffectFree, or else copy them to a new temporary and add them to + // the map and block_argument lists and replace their uses with the new + // temporary. + Fortran::utils::openmp::cloneOrMapRegionOutsiders(builder, targetOp); + rewriter.setInsertionPoint( + rewriter.create<mlir::omp::TerminatorOp>(targetOp.getLoc())); + + return targetOp; + } + + hlfir::DeclareOp genLiveInDeclare( + fir::FirOpBuilder &builder, mlir::omp::TargetOp targetOp, + mlir::Value liveInArg, mlir::omp::MapInfoOp liveInMapInfoOp, + const TargetDeclareShapeCreationInfo &targetShapeCreationInfo) const { + mlir::Type liveInType = liveInArg.getType(); + std::string liveInName = liveInMapInfoOp.getName().has_value() + ? liveInMapInfoOp.getName().value().str() + : std::string(""); + if (fir::isa_ref_type(liveInType)) + liveInType = fir::unwrapRefType(liveInType); + + mlir::Value shape = [&]() -> mlir::Value { + if (!targetShapeCreationInfo.isShapedValue()) + return {}; + + llvm::SmallVector<mlir::Value> extentOperands; + llvm::SmallVector<mlir::Value> startIndexOperands; + + if (targetShapeCreationInfo.isShapeShiftedValue()) { + llvm::SmallVector<mlir::Value> shapeShiftOperands; + + size_t shapeIdx = 0; + for (auto [startIndex, extent] : + llvm::zip_equal(targetShapeCreationInfo.startIndices, + targetShapeCreationInfo.extents)) { + shapeShiftOperands.push_back( + Fortran::utils::openmp::mapTemporaryValue( + builder, targetOp, startIndex, + liveInName + ".start_idx.dim" + std::to_string(shapeIdx))); + shapeShiftOperands.push_back( + Fortran::utils::openmp::mapTemporaryValue( + builder, targetOp, extent, + liveInName + ".extent.dim" + std::to_string(shapeIdx))); + ++shapeIdx; + } + + auto shapeShiftType = fir::ShapeShiftType::get( + builder.getContext(), shapeShiftOperands.size() / 2); + return builder.create<fir::ShapeShiftOp>( + liveInArg.getLoc(), shapeShiftType, shapeShiftOperands); + } + + llvm::SmallVector<mlir::Value> shapeOperands; + size_t shapeIdx = 0; + for (auto extent : targetShapeCreationInfo.extents) { + shapeOperands.push_back(Fortran::utils::openmp::mapTemporaryValue( + builder, targetOp, extent, + liveInName + ".extent.dim" + std::to_string(shapeIdx))); + ++shapeIdx; + } + + return builder.create<fir::ShapeOp>(liveInArg.getLoc(), shapeOperands); + }(); + + return builder.create<hlfir::DeclareOp>(liveInArg.getLoc(), liveInArg, + liveInName, shape); + } + + mlir::omp::TeamsOp + genTeamsOp(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter) const { + auto teamsOp = rewriter.create<mlir::omp::TeamsOp>( + loc, /*clauses=*/mlir::omp::TeamsOperands{}); + + rewriter.createBlock(&teamsOp.getRegion()); + rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc)); + + return teamsOp; + } + + mlir::omp::DistributeOp + genDistributeOp(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter) const { + auto distOp = rewriter.create<mlir::omp::DistributeOp>( + loc, /*clauses=*/mlir::omp::DistributeOperands{}); + + rewriter.createBlock(&distOp.getRegion()); + return distOp; + } + bool mapToDevice; llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip; + mlir::SymbolTable &moduleSymbolTable; +}; + +/// A listener that forwards notifyOperationErased to the given callback. +struct CallbackListener : public mlir::RewriterBase::Listener { + CallbackListener(std::function<void(mlir::Operation *op)> onOperationErased) + : onOperationErased(onOperationErased) {} + + void notifyOperationErased(mlir::Operation *op) override { + onOperationErased(op); + } + + std::function<void(mlir::Operation *op)> onOperationErased; }; class DoConcurrentConversionPass @@ -444,12 +845,9 @@ public: : DoConcurrentConversionPassBase(options) {} void runOnOperation() override { - mlir::func::FuncOp func = getOperation(); - - if (func.isDeclaration()) - return; - + mlir::ModuleOp module = getOperation(); mlir::MLIRContext *context = &getContext(); + mlir::SymbolTable moduleSymbolTable(module); if (mapTo != flangomp::DoConcurrentMappingKind::DCMK_Host && mapTo != flangomp::DoConcurrentMappingKind::DCMK_Device) { @@ -460,10 +858,14 @@ public: } llvm::DenseSet<fir::DoConcurrentOp> concurrentLoopsToSkip; + CallbackListener callbackListener([&](mlir::Operation *op) { + if (auto loop = mlir::dyn_cast<fir::DoConcurrentOp>(op)) + concurrentLoopsToSkip.erase(loop); + }); mlir::RewritePatternSet patterns(context); patterns.insert<DoConcurrentConversion>( context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device, - concurrentLoopsToSkip); + concurrentLoopsToSkip, moduleSymbolTable); mlir::ConversionTarget target(*context); target.addDynamicallyLegalOp<fir::DoConcurrentOp>( [&](fir::DoConcurrentOp op) { @@ -472,8 +874,11 @@ public: target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); - if (mlir::failed(mlir::applyFullConversion(getOperation(), target, - std::move(patterns)))) { + mlir::ConversionConfig config; + config.allowPatternRollback = false; + config.listener = &callbackListener; + if (mlir::failed(mlir::applyFullConversion(module, target, + std::move(patterns), config))) { signalPassFailure(); } } diff --git a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp index 66593ec8104f..0ff68eb01dab 100644 --- a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp @@ -518,8 +518,10 @@ public: loopOp)); }); + mlir::ConversionConfig config; + config.allowPatternRollback = false; if (mlir::failed(mlir::applyFullConversion(getOperation(), target, - std::move(patterns)))) { + std::move(patterns), config))) { mlir::emitError(func.getLoc(), "error in converting `omp.loop` op"); signalPassFailure(); } diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp index 970f7d7ab063..30328573b74f 100644 --- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp +++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp @@ -53,6 +53,7 @@ class MapsForPrivatizedSymbolsPass : public flangomp::impl::MapsForPrivatizedSymbolsPassBase< MapsForPrivatizedSymbolsPass> { + // TODO Use `createMapInfoOp` from `flang/Utils/OpenMP.h`. omp::MapInfoOp createMapInfo(Location loc, Value var, fir::FirOpBuilder &builder) { // Check if a value of type `type` can be passed to the kernel by value. diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 98f947a1f635..7c2777baebef 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -14,7 +14,7 @@ /// Force setting the no-alias attribute on fuction arguments when possible. static llvm::cl::opt<bool> forceNoAlias("force-no-alias", llvm::cl::Hidden, - llvm::cl::init(false)); + llvm::cl::init(true)); namespace fir { @@ -217,9 +217,6 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm, pm.addPass(fir::createSimplifyFIROperations( {/*preferInlineImplementation=*/pc.OptLevel.isOptimizingForSpeed()})); - if (pc.AliasAnalysis && !disableFirAliasTags && !useOldAliasTags) - pm.addPass(fir::createAddAliasTags()); - addNestedPassToAllTopLevelOperations<PassConstructor>( pm, fir::createStackReclaim); // convert control flow to CFG form @@ -345,6 +342,9 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config, llvm::StringRef inputFilename) { fir::addBoxedProcedurePass(pm); + if (config.OptLevel.isOptimizingForSpeed() && config.AliasAnalysis && + !disableFirAliasTags && !useOldAliasTags) + pm.addPass(fir::createAddAliasTags()); addNestedPassToAllTopLevelOperations<PassConstructor>( pm, fir::createAbstractResultOpt); addPassToGPUModuleOperations<PassConstructor>(pm, diff --git a/flang/lib/Optimizer/Transforms/AddAliasTags.cpp b/flang/lib/Optimizer/Transforms/AddAliasTags.cpp index 85403ad25765..0221c7a8184d 100644 --- a/flang/lib/Optimizer/Transforms/AddAliasTags.cpp +++ b/flang/lib/Optimizer/Transforms/AddAliasTags.cpp @@ -14,12 +14,17 @@ #include "flang/Optimizer/Analysis/AliasAnalysis.h" #include "flang/Optimizer/Analysis/TBAAForest.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FirAliasTagOpInterface.h" +#include "flang/Optimizer/Support/DataLayout.h" +#include "flang/Optimizer/Support/Utils.h" #include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/IR/Dominance.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/CommandLine.h" @@ -57,12 +62,132 @@ static llvm::cl::opt<unsigned> localAllocsThreshold( namespace { +// Return the size and alignment (in bytes) for the given type. +// TODO: this must be combined with DebugTypeGenerator::getFieldSizeAndAlign(). +// We'd better move fir::LLVMTypeConverter out of the FIRCodeGen component. +static std::pair<std::uint64_t, unsigned short> +getTypeSizeAndAlignment(mlir::Type type, + fir::LLVMTypeConverter &llvmTypeConverter) { + mlir::Type llvmTy; + if (auto boxTy = mlir::dyn_cast_if_present<fir::BaseBoxType>(type)) + llvmTy = llvmTypeConverter.convertBoxTypeAsStruct(boxTy, getBoxRank(boxTy)); + else + llvmTy = llvmTypeConverter.convertType(type); + + const mlir::DataLayout &dataLayout = llvmTypeConverter.getDataLayout(); + uint64_t byteSize = dataLayout.getTypeSize(llvmTy); + unsigned short byteAlign = dataLayout.getTypeABIAlignment(llvmTy); + return std::pair{byteSize, byteAlign}; +} + +// IntervalTy class describes a range of bytes addressed by a variable +// within some storage. Zero-sized intervals are not allowed. +class IntervalTy { +public: + IntervalTy() = delete; + IntervalTy(std::uint64_t start, std::size_t size) + : start(start), end(start + (size - 1)) { + assert(size != 0 && "empty intervals should not be created"); + } + constexpr bool operator<(const IntervalTy &rhs) const { + if (start < rhs.start) + return true; + if (rhs.start < start) + return false; + return end < rhs.end; + } + bool overlaps(const IntervalTy &other) const { + return end >= other.start && other.end >= start; + } + bool contains(const IntervalTy &other) const { + return start <= other.start && end >= other.end; + } + void merge(const IntervalTy &other) { + start = std::min(start, other.start); + end = std::max(end, other.end); + assert(start <= end); + } + void print(llvm::raw_ostream &os) const { + os << "[" << start << "," << end << "]"; + } + std::uint64_t getStart() const { return start; } + std::uint64_t getEnd() const { return end; } + +private: + std::uint64_t start; + std::uint64_t end; +}; + +// IntervalSetTy is an ordered set of IntervalTy entities. +class IntervalSetTy : public std::set<IntervalTy> { +public: + // Find an interval from the set that contain the given interval. + // The complexity is O(log(N)), where N is the size of the set. + std::optional<IntervalTy> getContainingInterval(const IntervalTy &interval) { + if (empty()) + return std::nullopt; + + auto it = lower_bound(interval); + // The iterator points to the first interval that is not less than + // the given interval. The given interval may belong to the one + // pointed out by the iterator or to the previous one. + // + // In the following cases there might be no interval that is not less + // than the given interval, e.g.: + // Case 1: + // interval: [5,5] + // set: {[4,6]} + // Case 2: + // interval: [5,5] + // set: {[4,5]} + // We have to look starting from the last interval in the set. + if (it == end()) + --it; + + // The loop must finish in two iterator max. + do { + if (it->contains(interval)) + return *it; + // If the current interval from the set is less than the given + // interval and there is no overlap, we should not look further. + if ((!it->overlaps(interval) && *it < interval) || it == begin()) + break; + + --it; + } while (true); + + return std::nullopt; + } +}; + +// Stream operators for IntervalTy and IntervalSetTy. +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const IntervalTy &interval) { + interval.print(os); + return os; +} + +[[maybe_unused]] inline llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, const IntervalSetTy &set) { + if (set.empty()) { + os << " <empty>"; + return os; + } + for (const auto &interval : set) + os << ' ' << interval; + return os; +} + /// Shared state per-module class PassState { public: - PassState(mlir::DominanceInfo &domInfo, + PassState(mlir::ModuleOp module, const mlir::DataLayout &dl, + mlir::DominanceInfo &domInfo, std::optional<unsigned> localAllocsThreshold) - : domInfo(domInfo), localAllocsThreshold(localAllocsThreshold) {} + : domInfo(domInfo), localAllocsThreshold(localAllocsThreshold), + symTab(module.getOperation()), + llvmTypeConverter(module, /*applyTBAA=*/false, + /*forceUnifiedTBAATree=*/false, dl) {} /// memoised call to fir::AliasAnalysis::getSource inline const fir::AliasAnalysis::Source &getSource(mlir::Value value) { if (!analysisCache.contains(value)) @@ -72,13 +197,14 @@ public: } /// get the per-function TBAATree for this function - inline const fir::TBAATree &getFuncTree(mlir::func::FuncOp func) { - return forrest[func]; + inline fir::TBAATree &getMutableFuncTreeWithScope(mlir::func::FuncOp func, + fir::DummyScopeOp scope) { + auto &scopeMap = scopeNames.at(func); + return forrest.getMutableFuncTreeWithScope(func, scopeMap.lookup(scope)); } inline const fir::TBAATree &getFuncTreeWithScope(mlir::func::FuncOp func, fir::DummyScopeOp scope) { - auto &scopeMap = scopeNames.at(func); - return forrest.getFuncTreeWithScope(func, scopeMap.lookup(scope)); + return getMutableFuncTreeWithScope(func, scope); } void processFunctionScopes(mlir::func::FuncOp func); @@ -98,8 +224,82 @@ public: // attachment. bool attachLocalAllocTag(); + // Return fir.global for the given name. + fir::GlobalOp getGlobalDefiningOp(mlir::StringAttr name) const { + return symTab.lookup<fir::GlobalOp>(name); + } + + // Process fir::FortranVariableStorageOpInterface operations within + // the given op, and fill in declToStorageMap with the information + // about their physical storages and layouts. + void collectPhysicalStorageAliasSets(mlir::Operation *op); + + // Return the byte size of the given declaration. + std::size_t getDeclarationSize(fir::FortranVariableStorageOpInterface decl) { + mlir::Type memType = fir::unwrapRefType(decl.getBase().getType()); + auto [size, alignment] = + getTypeSizeAndAlignment(memType, llvmTypeConverter); + return llvm::alignTo(size, alignment); + } + + // A StorageDesc specifies an operation that defines a physical storage + // and the <offset, size> pair within that physical storage where + // a variable resides. + struct StorageDesc { + StorageDesc() = delete; + StorageDesc(mlir::Operation *storageDef, std::uint64_t start, + std::size_t size) + : storageDef(storageDef), interval(start, size) {} + + // Return a string representing the byte range of the variable within + // its storage, e.g. bytes_0_to_0 for a 1-byte variable starting + // at offset 0. + std::string getByteRangeStr() const { + return ("bytes_" + llvm::Twine(interval.getStart()) + "_to_" + + llvm::Twine(interval.getEnd())) + .str(); + } + + mlir::Operation *storageDef; + IntervalTy interval; + }; + + // Fills in declToStorageMap on the first invocation. + // Returns a storage descriptor for the given op (if registered + // in declToStorageMap). + const StorageDesc *computeStorageDesc(mlir::Operation *op) { + if (!op) + return nullptr; + + // TODO: it should be safe to run collectPhysicalStorageAliasSets() + // on the parent func.func instead of the module, since the TBAA + // tags use different roots per function. This may provide better + // results for storages that have members with descriptors + // in one function but not the others. + if (!declToStorageMapComputed) + collectPhysicalStorageAliasSets(op->getParentOfType<mlir::ModuleOp>()); + return getStorageDesc(op); + } + +private: + const StorageDesc *getStorageDesc(mlir::Operation *op) const { + auto it = declToStorageMap.find(op); + return it == declToStorageMap.end() ? nullptr : &it->second; + } + + StorageDesc &getMutableStorageDesc(mlir::Operation *op) { + auto it = declToStorageMap.find(op); + assert(it != declToStorageMap.end()); + return it->second; + } + private: mlir::DominanceInfo &domInfo; + std::optional<unsigned> localAllocsThreshold; + // Symbol table cache for the module. + mlir::SymbolTable symTab; + // Type converter to compute the size of declarations. + fir::LLVMTypeConverter llvmTypeConverter; fir::AliasAnalysis analysis; llvm::DenseMap<mlir::Value, fir::AliasAnalysis::Source> analysisCache; fir::TBAAForrest forrest; @@ -118,7 +318,12 @@ private: // member(s), to avoid the cost of isRecordWithDescriptorMember(). llvm::DenseSet<mlir::Type> typesContainingDescriptors; - std::optional<unsigned> localAllocsThreshold; + // A map between fir::FortranVariableStorageOpInterface operations + // and their storage descriptors. + llvm::DenseMap<mlir::Operation *, StorageDesc> declToStorageMap; + // declToStorageMapComputed is set to true after declToStorageMap + // is initialized by collectPhysicalStorageAliasSets(). + bool declToStorageMapComputed = false; }; // Process fir.dummy_scope operations in the given func: @@ -198,6 +403,202 @@ bool PassState::attachLocalAllocTag() { return true; } +static mlir::Value getStorageDefinition(mlir::Value storageRef) { + while (auto convert = + mlir::dyn_cast_or_null<fir::ConvertOp>(storageRef.getDefiningOp())) + storageRef = convert.getValue(); + return storageRef; +} + +void PassState::collectPhysicalStorageAliasSets(mlir::Operation *op) { + // A map between fir::FortranVariableStorageOpInterface operations + // and the intervals describing their layout within their physical + // storages. + llvm::DenseMap<mlir::Operation *, IntervalSetTy> memberIntervals; + // A map between operations defining physical storages (e.g. fir.global) + // and sets of fir::FortranVariableStorageOpInterface operations + // declaring their member variables. + llvm::DenseMap<mlir::Operation *, llvm::SmallVector<mlir::Operation *, 10>> + storageDecls; + + bool seenUnknownStorage = false; + bool seenDeclWithDescriptor = false; + op->walk([&](fir::FortranVariableStorageOpInterface decl) { + mlir::Value storageRef = decl.getStorage(); + if (!storageRef) + return mlir::WalkResult::advance(); + + // If we have seen a declaration of a variable containing + // a descriptor, and we have not been able to identify + // a storage of any variable, then any variable may + // potentially overlap with the variable containing + // a descriptor. In this case, it is hard to make any + // assumptions about any variable with physical + // storage. Exit early. + if (seenUnknownStorage && seenDeclWithDescriptor) + return mlir::WalkResult::interrupt(); + + if (typeReferencesDescriptor(decl.getBase().getType())) + seenDeclWithDescriptor = true; + + mlir::Operation *storageDef = + getStorageDefinition(storageRef).getDefiningOp(); + // All physical storages that are defined by non-global + // objects (e.g. via fir.alloca) indicate an EQUIVALENCE. + // Inside an EQUIVALENCE each variable overlaps + // with at least one another variable. So all EQUIVALENCE + // variables belong to the same alias set, and there is + // no reason to investigate them further. + // Note that, in general, the storage may be defined by a block + // argument. + auto addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(storageDef); + if (!storageDef || + (!addrOfOp && !mlir::dyn_cast<fir::AllocaOp>(storageDef))) { + seenUnknownStorage = true; + return mlir::WalkResult::advance(); + } + if (!addrOfOp) + return mlir::WalkResult::advance(); + fir::GlobalOp globalDef = + getGlobalDefiningOp(addrOfOp.getSymbol().getRootReference()); + std::uint64_t storageOffset = decl.getStorageOffset(); + std::size_t declSize = getDeclarationSize(decl); + LLVM_DEBUG(llvm::dbgs() + << "Found variable with storage:\n" + << "Declaration: " << decl << "\n" + << "Storage: " << (globalDef ? globalDef : nullptr) << "\n" + << "Offset: " << storageOffset << "\n" + << "Size: " << declSize << "\n"); + if (!globalDef) { + seenUnknownStorage = true; + return mlir::WalkResult::advance(); + } + // Zero-sized variables do not need any TBAA tags, because + // they cannot be accessed. + if (declSize == 0) + return mlir::WalkResult::advance(); + + declToStorageMap.try_emplace(decl.getOperation(), globalDef.getOperation(), + storageOffset, declSize); + storageDecls.try_emplace(globalDef.getOperation()) + .first->second.push_back(decl.getOperation()); + + auto &set = + memberIntervals.try_emplace(globalDef.getOperation()).first->second; + set.insert(IntervalTy(storageOffset, declSize)); + return mlir::WalkResult::advance(); + }); + + // Mark the map as computed before any early exits below. + declToStorageMapComputed = true; + + if (seenUnknownStorage && seenDeclWithDescriptor) { + declToStorageMap.clear(); + return; + } + + // Process each physical storage. + for (auto &map : memberIntervals) { + mlir::Operation *storageDef = map.first; + const IntervalSetTy &originalSet = map.second; + LLVM_DEBUG( + llvm::dbgs() << "Merging " << originalSet.size() + << " member intervals for: "; + storageDef->print(llvm::dbgs(), mlir::OpPrintingFlags{}.skipRegions()); + llvm::dbgs() << "\nIntervals: " << originalSet << "\n"); + // Ordered set of merged overlapping intervals. + // Since the intervals in originalSet are sorted, the merged + // intervals are always added at the end of the mergedIntervals set. + IntervalSetTy mergedIntervals; + if (originalSet.size() > 1) { + auto intervalIt = originalSet.begin(); + IntervalTy mergedInterval = *intervalIt; + while (++intervalIt != originalSet.end()) { + if (mergedInterval.overlaps(*intervalIt)) { + mergedInterval.merge(*intervalIt); + } else { + mergedIntervals.insert(mergedIntervals.end(), mergedInterval); + mergedInterval = *intervalIt; + } + } + mergedIntervals.insert(mergedIntervals.end(), mergedInterval); + } else { + // 0 or 1 total interval requires no merging. + mergedIntervals = originalSet; + } + LLVM_DEBUG(llvm::dbgs() << "Merged intervals:" << mergedIntervals << "\n"); + + bool wasMerged = originalSet.size() != mergedIntervals.size(); + + // Go through all the declarations within the storage, and assign + // them to their final intervals (if some merging happened), + // and collect information about "poisoned" intervals (see below). + // invalidIntervals set will contain the "poisoned" intervals. + IntervalSetTy invalidIntervals; + for (auto *decl : storageDecls.at(storageDef)) { + StorageDesc &declStorageDesc = getMutableStorageDesc(decl); + + if (wasMerged) { + // Some intervals were merged, so we have to modify the intervals + // for some declarations. + + auto containingInterval = + mergedIntervals.getContainingInterval(declStorageDesc.interval); + assert(containingInterval && "did not find the containing interval"); + LLVM_DEBUG(llvm::dbgs() << "Placing: " << *decl << " into interval " + << *containingInterval); + declStorageDesc.interval = *containingInterval; + } + if (typeReferencesDescriptor( + mlir::cast<fir::FortranVariableStorageOpInterface>(decl) + .getBase() + .getType())) { + // If a variable contains a descriptor within it. + // We cannot attach any data tag to it, because it will + // conflict with the late TBBA tags attachment for + // the descriptor data. This also applies to all + // variables overlapping with this one, thus we should + // remove any storage descriptors for their declarations. + LLVM_DEBUG(llvm::dbgs() << " (poisoned)"); + invalidIntervals.insert(declStorageDesc.interval); + } + LLVM_DEBUG(llvm::dbgs() << "\n"); + } + + if (invalidIntervals.empty()) + continue; + + // Now that all the declarations are assigned to their intervals, + // go through the "poisoned" intervals and remove all declarations + // belonging to them from declToStorageMap, so that they do not + // have any tags attached. + LLVM_DEBUG(llvm::dbgs() + << "Invalid intervals:" << invalidIntervals << "\n"); + if (invalidIntervals.size() == mergedIntervals.size()) { + // All variables are "poisoned". Save the O(log(N)) lookups + // in invalidIntervals set, and poison them all. + for (auto *decl : storageDecls.at(storageDef)) { + LLVM_DEBUG(llvm::dbgs() + << "Removing storage descriptor for: " << *decl << "\n"); + declToStorageMap.erase(decl); + } + continue; + } + + // Some variables are "poisoned". + for (auto *decl : storageDecls.at(storageDef)) { + const StorageDesc *declStorageDesc = getStorageDesc(decl); + assert(declStorageDesc && "declaration must have a storage descriptor"); + if (auto containingInterval = invalidIntervals.getContainingInterval( + declStorageDesc->interval)) { + LLVM_DEBUG(llvm::dbgs() + << "Removing storage descriptor for: " << *decl << "\n"); + declToStorageMap.erase(decl); + } + } + } +} + class AddAliasTagsPass : public fir::impl::AddAliasTagsBase<AddAliasTagsPass> { public: void runOnOperation() override; @@ -310,14 +711,62 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, source.kind == fir::AliasAnalysis::SourceKind::Global && !source.isBoxData()) { mlir::SymbolRefAttr glbl = llvm::cast<mlir::SymbolRefAttr>(source.origin.u); - const char *name = glbl.getRootReference().data(); - LLVM_DEBUG(llvm::dbgs().indent(2) << "Found reference to global " << name - << " at " << *op << "\n"); - if (source.isPointer()) + mlir::StringAttr globalName = glbl.getRootReference(); + LLVM_DEBUG(llvm::dbgs().indent(2) + << "Found reference to global " << globalName.str() << " at " + << *op << "\n"); + if (source.isPointer()) { tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag(); - else - tag = - state.getFuncTreeWithScope(func, scopeOp).globalDataTree.getTag(name); + } else { + // In general, place the tags under the "global data" root. + fir::TBAATree::SubtreeState *subTree = + &state.getMutableFuncTreeWithScope(func, scopeOp).globalDataTree; + + mlir::Operation *instantiationPoint = source.origin.instantiationPoint; + auto storageIface = + mlir::dyn_cast_or_null<fir::FortranVariableStorageOpInterface>( + instantiationPoint); + const PassState::StorageDesc *storageDesc = + state.computeStorageDesc(instantiationPoint); + + if (storageDesc) { + // This is a variable that is part of a known physical storage + // that may contain multiple and maybe overlapping variables. + // We have may assign it with a tag that relates + // to the byte range within the physical storage. + assert(instantiationPoint && "cannot be null"); + assert(storageDesc->storageDef && "cannot be null"); + assert(storageDesc->storageDef == + state.getGlobalDefiningOp(globalName) && + "alias analysis reached a different storage"); + std::string aliasSetName = storageDesc->getByteRangeStr(); + subTree = &subTree->getOrCreateNamedSubtree(globalName); + tag = subTree->getTag(aliasSetName); + LLVM_DEBUG(llvm::dbgs() + << "Variable instantiated by " << *instantiationPoint + << " tagged with '" << aliasSetName << "' under '" + << globalName << "' root\n"); + } else if (storageIface && storageIface.getStorage()) { + // This is a variable that is: + // * aliasing a descriptor, or + // * part of an unknown physical storage, or + // * zero-sized. + // If it aliases a descriptor or the storage is unknown + // (i.e. it *may* alias a descriptor), then we cannot assign any tag to + // it, because we cannot use any tag from the "any data accesses" tree. + // If it is a zero-sized variable, we do not care about + // attaching a tag, because the access is invalid. + LLVM_DEBUG(llvm::dbgs() << "WARNING: poisoned or unknown storage or " + "zero-sized variable access\n"); + } else { + // This is a variable defined by the global symbol, + // and it is the only variable that belong to that global storage. + // Tag it using the global's name. + tag = subTree->getTag(globalName); + LLVM_DEBUG(llvm::dbgs() + << "Tagged under '" << globalName << "' root\n"); + } + } // TBAA for global variables with descriptors } else if (enableDirect && @@ -401,12 +850,15 @@ void AddAliasTagsPass::runOnOperation() { // thinks the pass operates on), then the real work of the pass is done in // runOnAliasInterface auto &domInfo = getAnalysis<mlir::DominanceInfo>(); - PassState state(domInfo, localAllocsThreshold.getPosition() - ? std::optional<unsigned>(localAllocsThreshold) - : std::nullopt); - - mlir::ModuleOp mod = getOperation(); - mod.walk( + mlir::ModuleOp module = getOperation(); + mlir::DataLayout dl = *fir::support::getOrSetMLIRDataLayout( + module, /*allowDefaultLayout=*/false); + PassState state(module, dl, domInfo, + localAllocsThreshold.getPosition() + ? std::optional<unsigned>(localAllocsThreshold) + : std::nullopt); + + module.walk( [&](fir::FirAliasTagOpInterface op) { runOnAliasInterface(op, state); }); LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n"); diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp index b032767eef6f..061a7d201edd 100644 --- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp @@ -25,7 +25,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Visitors.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Debug.h" #include <optional> @@ -451,10 +451,10 @@ static void rewriteStore(fir::StoreOp storeOp, } static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) { - for (auto &bodyOp : block->getOperations()) { + for (auto &bodyOp : llvm::make_early_inc_range(block->getOperations())) { if (isa<fir::LoadOp>(bodyOp)) rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter); - if (isa<fir::StoreOp>(bodyOp)) + else if (isa<fir::StoreOp>(bodyOp)) rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter); } } @@ -476,6 +476,8 @@ public: loop.dump();); LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = functionAnalysis.getChildLoopAnalysis(loop); + if (!loopAnalysis.canPromoteToAffine()) + return rewriter.notifyMatchFailure(loop, "cannot promote to affine"); auto &loopOps = loop.getBody()->getOperations(); auto resultOp = cast<fir::ResultOp>(loop.getBody()->getTerminator()); auto results = resultOp.getOperands(); @@ -576,12 +578,14 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> { public: using OpRewritePattern::OpRewritePattern; AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) - : OpRewritePattern(context) {} + : OpRewritePattern(context), functionAnalysis(afa) {} llvm::LogicalResult matchAndRewrite(fir::IfOp op, mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n"; op.dump();); + if (!functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()) + return rewriter.notifyMatchFailure(op, "cannot promote to affine"); auto &ifOps = op.getThenRegion().front().getOperations(); auto affineCondition = AffineIfCondition(op.getCondition()); if (!affineCondition.hasIntegerSet()) { @@ -611,6 +615,8 @@ public: rewriter.replaceOp(op, affineIf.getOperation()->getResults()); return success(); } + + AffineFunctionAnalysis &functionAnalysis; }; /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases @@ -627,28 +633,11 @@ public: mlir::RewritePatternSet patterns(context); patterns.insert<AffineIfConversion>(context, functionAnalysis); patterns.insert<AffineLoopConversion>(context, functionAnalysis); - mlir::ConversionTarget target = *context; - target.addLegalDialect<mlir::affine::AffineDialect, FIROpsDialect, - mlir::scf::SCFDialect, mlir::arith::ArithDialect, - mlir::func::FuncDialect>(); - target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) { - return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()); - }); - target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis]( - fir::DoLoopOp op) { - return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine()); - }); - LLVM_DEBUG(llvm::dbgs() << "AffineDialectPromotion: running promotion on: \n"; function.print(llvm::dbgs());); // apply the patterns - if (mlir::failed(mlir::applyPartialConversion(function, target, - std::move(patterns)))) { - mlir::emitError(mlir::UnknownLoc::get(context), - "error in converting to affine dialect\n"); - signalPassFailure(); - } + walkAndApplyPatterns(function, std::move(patterns)); } }; } // namespace diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 9834b0499b93..609a1fc9fb02 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -557,8 +557,8 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter, mlir::Value src = op.getSrc(); if (srcTy.isInteger(1)) { // i1 is not a supported type in the descriptor and it is actually coming - // from a LOGICAL constant. Store it as a fir.logical. - srcTy = fir::LogicalType::get(rewriter.getContext(), 4); + // from a LOGICAL constant. Use the destination type to avoid mismatch. + srcTy = dstEleTy; src = createConvertOp(rewriter, loc, srcTy, src); addr = builder.createTemporary(loc, srcTy); fir::StoreOp::create(builder, loc, src, addr); @@ -650,7 +650,7 @@ struct CUFDataTransferOpConversion if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) { // Initialization of an array from a scalar value should be implemented - // via a kernel launch. Use the flan runtime via the Assign function + // via a kernel launch. Use the flang runtime via the Assign function // until we have more infrastructure. mlir::Value src = emboxSrc(rewriter, op, symtab); mlir::Value dst = emboxDst(rewriter, op, symtab); @@ -928,34 +928,6 @@ struct CUFSyncDescriptorOpConversion } }; -struct CUFSetAllocatorIndexOpConversion - : public mlir::OpRewritePattern<cuf::SetAllocatorIndexOp> { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(cuf::SetAllocatorIndexOp op, - mlir::PatternRewriter &rewriter) const override { - auto mod = op->getParentOfType<mlir::ModuleOp>(); - fir::FirOpBuilder builder(rewriter, mod); - mlir::Location loc = op.getLoc(); - int idx = kDefaultAllocator; - if (op.getDataAttr() == cuf::DataAttribute::Device) { - idx = kDeviceAllocatorPos; - } else if (op.getDataAttr() == cuf::DataAttribute::Managed) { - idx = kManagedAllocatorPos; - } else if (op.getDataAttr() == cuf::DataAttribute::Unified) { - idx = kUnifiedAllocatorPos; - } else if (op.getDataAttr() == cuf::DataAttribute::Pinned) { - idx = kPinnedAllocatorPos; - } - mlir::Value index = - builder.createIntegerConstant(loc, builder.getI32Type(), idx); - fir::runtime::cuda::genSetAllocatorIndex(builder, loc, op.getBox(), index); - op.erase(); - return mlir::success(); - } -}; - class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> { public: void runOnOperation() override { @@ -1017,8 +989,8 @@ void cuf::populateCUFToFIRConversionPatterns( const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) { patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter); patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion, - CUFFreeOpConversion, CUFSyncDescriptorOpConversion, - CUFSetAllocatorIndexOpConversion>(patterns.getContext()); + CUFFreeOpConversion, CUFSyncDescriptorOpConversion>( + patterns.getContext()); patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab, &dl, &converter); patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>( diff --git a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp index 5dcb54eaf9b9..d038c467b166 100644 --- a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp +++ b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp @@ -178,8 +178,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertBoxedSequenceType( context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr, /*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy, mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, - elements, dataLocation, rank, /*allocated=*/nullptr, - /*associated=*/nullptr); + dataLocation, rank, /*allocated=*/nullptr, + /*associated=*/nullptr, elements); } addOp(llvm::dwarf::DW_OP_push_object_address, {}); @@ -255,8 +255,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertBoxedSequenceType( return mlir::LLVM::DICompositeTypeAttr::get( context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr, /*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy, - mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, elements, - dataLocation, /*rank=*/nullptr, allocated, associated); + mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, + dataLocation, /*rank=*/nullptr, allocated, associated, elements); } std::pair<std::uint64_t, unsigned short> @@ -389,8 +389,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType( context, recId, /*isRecSelf=*/true, llvm::dwarf::DW_TAG_structure_type, mlir::StringAttr::get(context, ""), fileAttr, /*line=*/0, scope, /*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, - /*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr, - /*allocated=*/nullptr, /*associated=*/nullptr); + /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, + /*allocated=*/nullptr, /*associated=*/nullptr, elements); DerivedTypeCache::ActiveLevels nestedRecursions = derivedTypeCache.startTranslating(Ty, placeHolder); @@ -429,8 +429,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType( /*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, convertType(seqTy.getEleTy(), fileAttr, scope, declOp), mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, - arrayElements, /*dataLocation=*/nullptr, /*rank=*/nullptr, - /*allocated=*/nullptr, /*associated=*/nullptr); + /*dataLocation=*/nullptr, /*rank=*/nullptr, + /*allocated=*/nullptr, /*associated=*/nullptr, arrayElements); } else elemTy = convertType(fieldTy, fileAttr, scope, /*declOp=*/nullptr); offset = llvm::alignTo(offset, byteAlign); @@ -448,8 +448,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType( context, recId, /*isRecSelf=*/false, llvm::dwarf::DW_TAG_structure_type, mlir::StringAttr::get(context, sourceName.name), fileAttr, line, scope, /*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, offset * 8, - /*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr, - /*allocated=*/nullptr, /*associated=*/nullptr); + /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, + /*allocated=*/nullptr, /*associated=*/nullptr, elements); derivedTypeCache.finalize(Ty, finalAttr, std::move(nestedRecursions)); @@ -490,8 +490,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType( context, llvm::dwarf::DW_TAG_structure_type, mlir::StringAttr::get(context, ""), fileAttr, /*line=*/0, scope, /*baseType=*/nullptr, mlir::LLVM::DIFlags::Zero, offset * 8, - /*alignInBits=*/0, elements, /*dataLocation=*/nullptr, /*rank=*/nullptr, - /*allocated=*/nullptr, /*associated=*/nullptr); + /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, + /*allocated=*/nullptr, /*associated=*/nullptr, elements); derivedTypeCache.finalize(Ty, typeAttr, std::move(nestedRecursions)); return typeAttr; } @@ -554,9 +554,9 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertSequenceType( return mlir::LLVM::DICompositeTypeAttr::get( context, llvm::dwarf::DW_TAG_array_type, /*name=*/nullptr, /*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy, - mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, elements, + mlir::LLVM::DIFlags::Zero, /*sizeInBits=*/0, /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, /*allocated=*/nullptr, - /*associated=*/nullptr); + /*associated=*/nullptr, elements); } mlir::LLVM::DITypeAttr DebugTypeGenerator::convertVectorType( @@ -587,9 +587,9 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertVectorType( context, llvm::dwarf::DW_TAG_array_type, mlir::StringAttr::get(context, name), /*file=*/nullptr, /*line=*/0, /*scope=*/nullptr, elemTy, - mlir::LLVM::DIFlags::Vector, sizeInBits, /*alignInBits=*/0, elements, + mlir::LLVM::DIFlags::Vector, sizeInBits, /*alignInBits=*/0, /*dataLocation=*/nullptr, /*rank=*/nullptr, /*allocated=*/nullptr, - /*associated=*/nullptr); + /*associated=*/nullptr, elements); } mlir::LLVM::DITypeAttr DebugTypeGenerator::convertCharacterType( diff --git a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp index 2fcff87fdc39..031a5aeb28d7 100644 --- a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp +++ b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp @@ -76,12 +76,49 @@ void ExternalNameConversionPass::runOnOperation() { auto *context = &getContext(); llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings; + mlir::SymbolTable symbolTable(op); auto processFctOrGlobal = [&](mlir::Operation &funcOrGlobal) { auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>( mlir::SymbolTable::getSymbolAttrName()); auto deconstructedName = fir::NameUniquer::deconstruct(symName); if (fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) { + // Check if this is a private function that would conflict with a common + // block and get its mangled name. + if (auto funcOp = llvm::dyn_cast<mlir::func::FuncOp>(funcOrGlobal)) { + if (funcOp.isPrivate()) { + std::string mangledName = + mangleExternalName(deconstructedName, appendUnderscoreOpt); + auto mod = funcOp->getParentOfType<mlir::ModuleOp>(); + bool hasConflictingCommonBlock = false; + + // Check if any existing global has the same mangled name. + if (symbolTable.lookup<fir::GlobalOp>(mangledName)) + hasConflictingCommonBlock = true; + + // Skip externalization if the function has a conflicting common block + // and is not directly called (i.e. procedure pointers or type + // specifications) + if (hasConflictingCommonBlock) { + bool isDirectlyCalled = false; + std::optional<SymbolTable::UseRange> uses = + funcOp.getSymbolUses(mod); + if (uses.has_value()) { + for (auto use : *uses) { + mlir::Operation *user = use.getUser(); + if (mlir::isa<fir::CallOp>(user) || + mlir::isa<mlir::func::CallOp>(user)) { + isDirectlyCalled = true; + break; + } + } + } + if (!isDirectlyCalled) + return; + } + } + } + auto newName = mangleExternalName(deconstructedName, appendUnderscoreOpt); auto newAttr = mlir::StringAttr::get(context, newName); mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr); diff --git a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp index 5ac4ed8a93b6..9dfe26cbf589 100644 --- a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp +++ b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp @@ -95,10 +95,6 @@ void FunctionAttrPass::runOnOperation() { func->setAttr( mlir::LLVM::LLVMFuncOp::getNoNansFpMathAttrName(llvmFuncOpName), mlir::BoolAttr::get(context, true)); - if (approxFuncFPMath) - func->setAttr( - mlir::LLVM::LLVMFuncOp::getApproxFuncFpMathAttrName(llvmFuncOpName), - mlir::BoolAttr::get(context, true)); if (noSignedZerosFPMath) func->setAttr( mlir::LLVM::LLVMFuncOp::getNoSignedZerosFpMathAttrName(llvmFuncOpName), diff --git a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp index c6aec96ceb5a..03f97ebdc635 100644 --- a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp @@ -210,19 +210,33 @@ public: mapper.map(region.getArguments(), regionArgs); for (mlir::Operation &op : region.front().without_terminator()) (void)rewriter.clone(op, mapper); + + auto yield = mlir::cast<fir::YieldOp>(region.front().getTerminator()); + assert(yield.getResults().size() < 2); + + return yield.getResults().empty() + ? mlir::Value{} + : mapper.lookup(yield.getResults()[0]); }; - if (!localizer.getInitRegion().empty()) - cloneLocalizerRegion(localizer.getInitRegion(), {localVar, localArg}, - rewriter.getInsertionPoint()); + if (!localizer.getInitRegion().empty()) { + // Prefer the value yielded from the init region to the allocated + // private variable in case the region is operating on arguments + // by-value (e.g. Fortran character boxes). + localAlloc = cloneLocalizerRegion(localizer.getInitRegion(), + {localVar, localAlloc}, + rewriter.getInsertionPoint()); + assert(localAlloc); + } if (localizer.getLocalitySpecifierType() == fir::LocalitySpecifierType::LocalInit) - cloneLocalizerRegion(localizer.getCopyRegion(), {localVar, localArg}, + cloneLocalizerRegion(localizer.getCopyRegion(), + {localVar, localAlloc}, rewriter.getInsertionPoint()); if (!localizer.getDeallocRegion().empty()) - cloneLocalizerRegion(localizer.getDeallocRegion(), {localArg}, + cloneLocalizerRegion(localizer.getDeallocRegion(), {localAlloc}, rewriter.getInsertionBlock()->end()); rewriter.replaceAllUsesWith(localArg, localAlloc); diff --git a/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp b/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp index 7d1f86f8cee9..0cd2858ab5e7 100644 --- a/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyRegionLite.cpp @@ -26,22 +26,16 @@ class SimplifyRegionLitePass public: void runOnOperation() override; }; - -class DummyRewriter : public mlir::PatternRewriter { -public: - DummyRewriter(mlir::MLIRContext *ctx) : mlir::PatternRewriter(ctx) {} -}; - } // namespace void SimplifyRegionLitePass::runOnOperation() { auto op = getOperation(); auto regions = op->getRegions(); mlir::RewritePatternSet patterns(op.getContext()); - DummyRewriter rewriter(op.getContext()); if (regions.empty()) return; + mlir::PatternRewriter rewriter(op.getContext()); (void)mlir::eraseUnreachableBlocks(rewriter, regions); (void)mlir::runRegionDCE(rewriter, regions); } |
