diff options
Diffstat (limited to 'flang/lib/Optimizer/CodeGen/TargetRewrite.cpp')
| -rw-r--r-- | flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 149 |
1 files changed, 92 insertions, 57 deletions
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index 1b86d5241704..b0b9499557e2 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -62,14 +62,21 @@ struct FixupTy { FixupTy(Codes code, std::size_t index, std::function<void(mlir::func::FuncOp)> &&finalizer) : code{code}, index{index}, finalizer{finalizer} {} + FixupTy(Codes code, std::size_t index, + std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer) + : code{code}, index{index}, gpuFinalizer{finalizer} {} FixupTy(Codes code, std::size_t index, std::size_t second, std::function<void(mlir::func::FuncOp)> &&finalizer) : code{code}, index{index}, second{second}, finalizer{finalizer} {} + FixupTy(Codes code, std::size_t index, std::size_t second, + std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer) + : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {} Codes code; std::size_t index; std::size_t second{}; std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{}; + std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{}; }; // namespace /// Target-specific rewriting of the FIR. This is a prerequisite pass to code @@ -127,10 +134,18 @@ public: mod.walk([&](mlir::Operation *op) { if (auto call = mlir::dyn_cast<fir::CallOp>(op)) { if (!hasPortableSignature(call.getFunctionType(), op)) - convertCallOp(call); + convertCallOp(call, call.getFunctionType()); } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) { if (!hasPortableSignature(dispatch.getFunctionType(), op)) - convertCallOp(dispatch); + convertCallOp(dispatch, dispatch.getFunctionType()); + } else if (auto gpuLaunchFunc = + mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) { + llvm::SmallVector<mlir::Type> operandsTypes; + for (auto arg : gpuLaunchFunc.getKernelOperands()) + operandsTypes.push_back(arg.getType()); + auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {}); + if (!hasPortableSignature(fctTy, op)) + convertCallOp(gpuLaunchFunc, fctTy); } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) { if (mlir::isa<mlir::FunctionType>(addr.getType()) && !hasPortableSignature(addr.getType(), op)) @@ -350,8 +365,7 @@ public: // Convert fir.call and fir.dispatch Ops. template <typename A> - void convertCallOp(A callOp) { - auto fnTy = callOp.getFunctionType(); + void convertCallOp(A callOp, mlir::FunctionType fnTy) { auto loc = callOp.getLoc(); rewriter->setInsertionPoint(callOp); llvm::SmallVector<mlir::Type> newResTys; @@ -369,7 +383,7 @@ public: newOpers.push_back(callOp.getOperand(0)); dropFront = 1; } - } else { + } else if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) { dropFront = 1; // First operand is the polymorphic object. } @@ -395,10 +409,14 @@ public: llvm::SmallVector<mlir::Type> trailingInTys; llvm::SmallVector<mlir::Value> trailingOpers; + llvm::SmallVector<mlir::Value> operands; unsigned passArgShift = 0; + if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) + operands = callOp.getKernelOperands(); + else + operands = callOp.getOperands().drop_front(dropFront); for (auto e : llvm::enumerate( - llvm::zip(fnTy.getInputs().drop_front(dropFront), - callOp.getOperands().drop_front(dropFront)))) { + llvm::zip(fnTy.getInputs().drop_front(dropFront), operands))) { mlir::Type ty = std::get<0>(e.value()); mlir::Value oper = std::get<1>(e.value()); unsigned index = e.index(); @@ -500,7 +518,19 @@ public: newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); llvm::SmallVector<mlir::Value, 1> newCallResults; - if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { + if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) { + auto newCall = rewriter->create<A>( + loc, callOp.getKernel(), callOp.getGridSizeOperandValues(), + callOp.getBlockSizeOperandValues(), + callOp.getDynamicSharedMemorySize(), newOpers); + if (callOp.getClusterSizeX()) + newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX()); + if (callOp.getClusterSizeY()) + newCall.getClusterSizeYMutable().assign(callOp.getClusterSizeY()); + if (callOp.getClusterSizeZ()) + newCall.getClusterSizeZMutable().assign(callOp.getClusterSizeZ()); + newCallResults.append(newCall.result_begin(), newCall.result_end()); + } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { fir::CallOp newCall; if (callOp.getCallee()) { newCall = @@ -719,12 +749,15 @@ public: if (targetFeaturesAttr) fn->setAttr("target_features", targetFeaturesAttr); - convertSignature(fn); + convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn); } - for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) + for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) { for (auto fn : gpuMod.getOps<mlir::func::FuncOp>()) - convertSignature(fn); + convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn); + for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) + convertSignature<mlir::gpu::ReturnOp, mlir::gpu::GPUFuncOp>(fn); + } return mlir::success(); } @@ -770,17 +803,20 @@ public: /// Determine if the signature has host associations. The host association /// argument may need special target specific rewriting. - static bool hasHostAssociations(mlir::func::FuncOp func) { + template <typename OpTy> + static bool hasHostAssociations(OpTy func) { std::size_t end = func.getFunctionType().getInputs().size(); for (std::size_t i = 0; i < end; ++i) - if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName())) + if (func.template getArgAttrOfType<mlir::UnitAttr>( + i, fir::getHostAssocAttrName())) return true; return false; } /// Rewrite the signatures and body of the `FuncOp`s in the module for /// the immediately subsequent target code gen. - void convertSignature(mlir::func::FuncOp func) { + template <typename ReturnOpTy, typename FuncOpTy> + void convertSignature(FuncOpTy func) { auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType()); if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func)) return; @@ -805,13 +841,13 @@ public: // Convert return value(s) for (auto ty : funcTy.getResults()) llvm::TypeSwitch<mlir::Type>(ty) - .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { + .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { if (noComplexConversion) newResTys.push_back(cmplx); else doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups); }) - .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { + .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { auto m = specifics->integerArgumentType(func.getLoc(), intTy); assert(m.size() == 1); auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); @@ -825,7 +861,7 @@ public: rewriter->getUnitAttr())); newResTys.push_back(retTy); }) - .Case<fir::RecordType>([&](fir::RecordType recTy) { + .template Case<fir::RecordType>([&](fir::RecordType recTy) { doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups); }) .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); @@ -840,7 +876,7 @@ public: auto ty = e.value(); unsigned index = e.index(); llvm::TypeSwitch<mlir::Type>(ty) - .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { + .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) { if (noCharacterConversion) { newInTyAndAttrs.push_back( fir::CodeGenSpecifics::getTypeAndAttr(boxTy)); @@ -863,10 +899,10 @@ public: } } }) - .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { + .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { doComplexArg(func, cmplx, newInTyAndAttrs, fixups); }) - .Case<mlir::TupleType>([&](mlir::TupleType tuple) { + .template Case<mlir::TupleType>([&](mlir::TupleType tuple) { if (fir::isCharacterProcedureTuple(tuple)) { fixups.emplace_back(FixupTy::Codes::TrailingCharProc, newInTyAndAttrs.size(), trailingTys.size()); @@ -878,7 +914,7 @@ public: fir::CodeGenSpecifics::getTypeAndAttr(ty)); } }) - .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { + .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) { auto m = specifics->integerArgumentType(func.getLoc(), intTy); assert(m.size() == 1); auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]); @@ -887,7 +923,7 @@ public: if (!extensionAttrName.empty() && isFuncWithCCallingConvention(func)) fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo, - [=](mlir::func::FuncOp func) { + [=](FuncOpTy func) { func.setArgAttr( argNo, extensionAttrName, mlir::UnitAttr::get(func.getContext())); @@ -903,8 +939,8 @@ public: fir::CodeGenSpecifics::getTypeAndAttr(ty)); }); - if (func.getArgAttrOfType<mlir::UnitAttr>(index, - fir::getHostAssocAttrName())) { + if (func.template getArgAttrOfType<mlir::UnitAttr>( + index, fir::getHostAssocAttrName())) { extraAttrs.push_back( {newInTyAndAttrs.size() - 1, rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())}); @@ -979,27 +1015,27 @@ public: auto newArg = func.front().insertArgument(fixup.index, fixupType, loc); offset++; - func.walk([&](mlir::func::ReturnOp ret) { + func.walk([&](ReturnOpTy ret) { rewriter->setInsertionPoint(ret); auto oldOper = ret.getOperand(0); auto oldOperTy = fir::ReferenceType::get(oldOper.getType()); auto cast = rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg); rewriter->create<fir::StoreOp>(loc, oldOper, cast); - rewriter->create<mlir::func::ReturnOp>(loc); + rewriter->create<ReturnOpTy>(loc); ret.erase(); }); } break; case FixupTy::Codes::ReturnType: { // The function is still returning a value, but its type has likely // changed to suit the target ABI convention. - func.walk([&](mlir::func::ReturnOp ret) { + func.walk([&](ReturnOpTy ret) { rewriter->setInsertionPoint(ret); auto oldOper = ret.getOperand(0); mlir::Value bitcast = convertValueInMemory(loc, oldOper, newResTys[fixup.index], /*inputMayBeBigger=*/false); - rewriter->create<mlir::func::ReturnOp>(loc, bitcast); + rewriter->create<ReturnOpTy>(loc, bitcast); ret.erase(); }); } break; @@ -1101,13 +1137,18 @@ public: } } - for (auto &fixup : fixups) - if (fixup.finalizer) - (*fixup.finalizer)(func); + for (auto &fixup : fixups) { + if constexpr (std::is_same_v<FuncOpTy, mlir::func::FuncOp>) + if (fixup.finalizer) + (*fixup.finalizer)(func); + if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) + if (fixup.gpuFinalizer) + (*fixup.gpuFinalizer)(func); + } } - template <typename Ty, typename FIXUPS> - void doReturn(mlir::func::FuncOp func, Ty &newResTys, + template <typename OpTy, typename Ty, typename FIXUPS> + void doReturn(OpTy func, Ty &newResTys, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) { assert(m.size() == 1 && @@ -1119,7 +1160,7 @@ public: unsigned argNo = newInTyAndAttrs.size(); if (auto align = attr.getAlignment()) fixups.emplace_back( - FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) { + FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) { auto elemType = fir::dyn_cast_ptrOrBoxEleTy( func.getFunctionType().getInput(argNo)); func.setArgAttr(argNo, "llvm.sret", @@ -1130,7 +1171,7 @@ public: }); else fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo, - [=](mlir::func::FuncOp func) { + [=](OpTy func) { auto elemType = fir::dyn_cast_ptrOrBoxEleTy( func.getFunctionType().getInput(argNo)); func.setArgAttr(argNo, "llvm.sret", @@ -1141,8 +1182,7 @@ public: } if (auto align = attr.getAlignment()) fixups.emplace_back( - FixupTy::Codes::ReturnType, newResTys.size(), - [=](mlir::func::FuncOp func) { + FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) { func.setArgAttr( newResTys.size(), "llvm.align", rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); @@ -1155,9 +1195,8 @@ public: /// Convert a complex return value. This can involve converting the return /// value to a "hidden" first argument or packing the complex into a wide /// GPR. - template <typename Ty, typename FIXUPS> - void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx, - Ty &newResTys, + template <typename OpTy, typename Ty, typename FIXUPS> + void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, FIXUPS &fixups) { if (noComplexConversion) { @@ -1169,9 +1208,8 @@ public: doReturn(func, newResTys, newInTyAndAttrs, fixups, m); } - template <typename Ty, typename FIXUPS> - void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy, - Ty &newResTys, + template <typename OpTy, typename Ty, typename FIXUPS> + void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, FIXUPS &fixups) { if (noStructConversion) { @@ -1182,12 +1220,10 @@ public: doReturn(func, newResTys, newInTyAndAttrs, fixups, m); } - template <typename FIXUPS> - void - createFuncOpArgFixups(mlir::func::FuncOp func, - fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, - fir::CodeGenSpecifics::Marshalling &argsInTys, - FIXUPS &fixups) { + template <typename OpTy, typename FIXUPS> + void createFuncOpArgFixups( + OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, + fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) { const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; for (auto e : llvm::enumerate(argsInTys)) { @@ -1198,7 +1234,7 @@ public: if (attr.isByVal()) { if (auto align = attr.getAlignment()) fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo, - [=](mlir::func::FuncOp func) { + [=](OpTy func) { auto elemType = fir::dyn_cast_ptrOrBoxEleTy( func.getFunctionType().getInput(argNo)); func.setArgAttr(argNo, "llvm.byval", @@ -1210,8 +1246,7 @@ public: }); else fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, - newInTyAndAttrs.size(), - [=](mlir::func::FuncOp func) { + newInTyAndAttrs.size(), [=](OpTy func) { auto elemType = fir::dyn_cast_ptrOrBoxEleTy( func.getFunctionType().getInput(argNo)); func.setArgAttr(argNo, "llvm.byval", @@ -1220,7 +1255,7 @@ public: } else { if (auto align = attr.getAlignment()) fixups.emplace_back( - fixupCode, argNo, index, [=](mlir::func::FuncOp func) { + fixupCode, argNo, index, [=](OpTy func) { func.setArgAttr(argNo, "llvm.align", rewriter->getIntegerAttr( rewriter->getIntegerType(32), align)); @@ -1235,8 +1270,8 @@ public: /// Convert a complex argument value. This can involve storing the value to /// a temporary memory location or factoring the value into two distinct /// arguments. - template <typename FIXUPS> - void doComplexArg(mlir::func::FuncOp func, mlir::ComplexType cmplx, + template <typename OpTy, typename FIXUPS> + void doComplexArg(OpTy func, mlir::ComplexType cmplx, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, FIXUPS &fixups) { if (noComplexConversion) { @@ -1248,8 +1283,8 @@ public: createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups); } - template <typename FIXUPS> - void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy, + template <typename OpTy, typename FIXUPS> + void doStructArg(OpTy func, fir::RecordType recTy, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, FIXUPS &fixups) { if (noStructConversion) { |
