summaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer/CodeGen/TargetRewrite.cpp')
-rw-r--r--flang/lib/Optimizer/CodeGen/TargetRewrite.cpp149
1 files changed, 92 insertions, 57 deletions
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 1b86d5241704..b0b9499557e2 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -62,14 +62,21 @@ struct FixupTy {
FixupTy(Codes code, std::size_t index,
std::function<void(mlir::func::FuncOp)> &&finalizer)
: code{code}, index{index}, finalizer{finalizer} {}
+ FixupTy(Codes code, std::size_t index,
+ std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
+ : code{code}, index{index}, gpuFinalizer{finalizer} {}
FixupTy(Codes code, std::size_t index, std::size_t second,
std::function<void(mlir::func::FuncOp)> &&finalizer)
: code{code}, index{index}, second{second}, finalizer{finalizer} {}
+ FixupTy(Codes code, std::size_t index, std::size_t second,
+ std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
+ : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
Codes code;
std::size_t index;
std::size_t second{};
std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
+ std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
}; // namespace
/// Target-specific rewriting of the FIR. This is a prerequisite pass to code
@@ -127,10 +134,18 @@ public:
mod.walk([&](mlir::Operation *op) {
if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
if (!hasPortableSignature(call.getFunctionType(), op))
- convertCallOp(call);
+ convertCallOp(call, call.getFunctionType());
} else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
if (!hasPortableSignature(dispatch.getFunctionType(), op))
- convertCallOp(dispatch);
+ convertCallOp(dispatch, dispatch.getFunctionType());
+ } else if (auto gpuLaunchFunc =
+ mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) {
+ llvm::SmallVector<mlir::Type> operandsTypes;
+ for (auto arg : gpuLaunchFunc.getKernelOperands())
+ operandsTypes.push_back(arg.getType());
+ auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {});
+ if (!hasPortableSignature(fctTy, op))
+ convertCallOp(gpuLaunchFunc, fctTy);
} else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
if (mlir::isa<mlir::FunctionType>(addr.getType()) &&
!hasPortableSignature(addr.getType(), op))
@@ -350,8 +365,7 @@ public:
// Convert fir.call and fir.dispatch Ops.
template <typename A>
- void convertCallOp(A callOp) {
- auto fnTy = callOp.getFunctionType();
+ void convertCallOp(A callOp, mlir::FunctionType fnTy) {
auto loc = callOp.getLoc();
rewriter->setInsertionPoint(callOp);
llvm::SmallVector<mlir::Type> newResTys;
@@ -369,7 +383,7 @@ public:
newOpers.push_back(callOp.getOperand(0));
dropFront = 1;
}
- } else {
+ } else if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
dropFront = 1; // First operand is the polymorphic object.
}
@@ -395,10 +409,14 @@ public:
llvm::SmallVector<mlir::Type> trailingInTys;
llvm::SmallVector<mlir::Value> trailingOpers;
+ llvm::SmallVector<mlir::Value> operands;
unsigned passArgShift = 0;
+ if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>)
+ operands = callOp.getKernelOperands();
+ else
+ operands = callOp.getOperands().drop_front(dropFront);
for (auto e : llvm::enumerate(
- llvm::zip(fnTy.getInputs().drop_front(dropFront),
- callOp.getOperands().drop_front(dropFront)))) {
+ llvm::zip(fnTy.getInputs().drop_front(dropFront), operands))) {
mlir::Type ty = std::get<0>(e.value());
mlir::Value oper = std::get<1>(e.value());
unsigned index = e.index();
@@ -500,7 +518,19 @@ public:
newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
llvm::SmallVector<mlir::Value, 1> newCallResults;
- if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
+ if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
+ auto newCall = rewriter->create<A>(
+ loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
+ callOp.getBlockSizeOperandValues(),
+ callOp.getDynamicSharedMemorySize(), newOpers);
+ if (callOp.getClusterSizeX())
+ newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX());
+ if (callOp.getClusterSizeY())
+ newCall.getClusterSizeYMutable().assign(callOp.getClusterSizeY());
+ if (callOp.getClusterSizeZ())
+ newCall.getClusterSizeZMutable().assign(callOp.getClusterSizeZ());
+ newCallResults.append(newCall.result_begin(), newCall.result_end());
+ } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
fir::CallOp newCall;
if (callOp.getCallee()) {
newCall =
@@ -719,12 +749,15 @@ public:
if (targetFeaturesAttr)
fn->setAttr("target_features", targetFeaturesAttr);
- convertSignature(fn);
+ convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn);
}
- for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
+ for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
- convertSignature(fn);
+ convertSignature<mlir::func::ReturnOp, mlir::func::FuncOp>(fn);
+ for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
+ convertSignature<mlir::gpu::ReturnOp, mlir::gpu::GPUFuncOp>(fn);
+ }
return mlir::success();
}
@@ -770,17 +803,20 @@ public:
/// Determine if the signature has host associations. The host association
/// argument may need special target specific rewriting.
- static bool hasHostAssociations(mlir::func::FuncOp func) {
+ template <typename OpTy>
+ static bool hasHostAssociations(OpTy func) {
std::size_t end = func.getFunctionType().getInputs().size();
for (std::size_t i = 0; i < end; ++i)
- if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
+ if (func.template getArgAttrOfType<mlir::UnitAttr>(
+ i, fir::getHostAssocAttrName()))
return true;
return false;
}
/// Rewrite the signatures and body of the `FuncOp`s in the module for
/// the immediately subsequent target code gen.
- void convertSignature(mlir::func::FuncOp func) {
+ template <typename ReturnOpTy, typename FuncOpTy>
+ void convertSignature(FuncOpTy func) {
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
return;
@@ -805,13 +841,13 @@ public:
// Convert return value(s)
for (auto ty : funcTy.getResults())
llvm::TypeSwitch<mlir::Type>(ty)
- .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
if (noComplexConversion)
newResTys.push_back(cmplx);
else
doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
})
- .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+ .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
assert(m.size() == 1);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -825,7 +861,7 @@ public:
rewriter->getUnitAttr()));
newResTys.push_back(retTy);
})
- .Case<fir::RecordType>([&](fir::RecordType recTy) {
+ .template Case<fir::RecordType>([&](fir::RecordType recTy) {
doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
})
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
@@ -840,7 +876,7 @@ public:
auto ty = e.value();
unsigned index = e.index();
llvm::TypeSwitch<mlir::Type>(ty)
- .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
+ .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
if (noCharacterConversion) {
newInTyAndAttrs.push_back(
fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
@@ -863,10 +899,10 @@ public:
}
}
})
- .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
})
- .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+ .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
if (fir::isCharacterProcedureTuple(tuple)) {
fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
newInTyAndAttrs.size(), trailingTys.size());
@@ -878,7 +914,7 @@ public:
fir::CodeGenSpecifics::getTypeAndAttr(ty));
}
})
- .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+ .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
assert(m.size() == 1);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -887,7 +923,7 @@ public:
if (!extensionAttrName.empty() &&
isFuncWithCCallingConvention(func))
fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
- [=](mlir::func::FuncOp func) {
+ [=](FuncOpTy func) {
func.setArgAttr(
argNo, extensionAttrName,
mlir::UnitAttr::get(func.getContext()));
@@ -903,8 +939,8 @@ public:
fir::CodeGenSpecifics::getTypeAndAttr(ty));
});
- if (func.getArgAttrOfType<mlir::UnitAttr>(index,
- fir::getHostAssocAttrName())) {
+ if (func.template getArgAttrOfType<mlir::UnitAttr>(
+ index, fir::getHostAssocAttrName())) {
extraAttrs.push_back(
{newInTyAndAttrs.size() - 1,
rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
@@ -979,27 +1015,27 @@ public:
auto newArg =
func.front().insertArgument(fixup.index, fixupType, loc);
offset++;
- func.walk([&](mlir::func::ReturnOp ret) {
+ func.walk([&](ReturnOpTy ret) {
rewriter->setInsertionPoint(ret);
auto oldOper = ret.getOperand(0);
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
auto cast =
rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
- rewriter->create<mlir::func::ReturnOp>(loc);
+ rewriter->create<ReturnOpTy>(loc);
ret.erase();
});
} break;
case FixupTy::Codes::ReturnType: {
// The function is still returning a value, but its type has likely
// changed to suit the target ABI convention.
- func.walk([&](mlir::func::ReturnOp ret) {
+ func.walk([&](ReturnOpTy ret) {
rewriter->setInsertionPoint(ret);
auto oldOper = ret.getOperand(0);
mlir::Value bitcast =
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
/*inputMayBeBigger=*/false);
- rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
+ rewriter->create<ReturnOpTy>(loc, bitcast);
ret.erase();
});
} break;
@@ -1101,13 +1137,18 @@ public:
}
}
- for (auto &fixup : fixups)
- if (fixup.finalizer)
- (*fixup.finalizer)(func);
+ for (auto &fixup : fixups) {
+ if constexpr (std::is_same_v<FuncOpTy, mlir::func::FuncOp>)
+ if (fixup.finalizer)
+ (*fixup.finalizer)(func);
+ if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>)
+ if (fixup.gpuFinalizer)
+ (*fixup.gpuFinalizer)(func);
+ }
}
- template <typename Ty, typename FIXUPS>
- void doReturn(mlir::func::FuncOp func, Ty &newResTys,
+ template <typename OpTy, typename Ty, typename FIXUPS>
+ void doReturn(OpTy func, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
assert(m.size() == 1 &&
@@ -1119,7 +1160,7 @@ public:
unsigned argNo = newInTyAndAttrs.size();
if (auto align = attr.getAlignment())
fixups.emplace_back(
- FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
+ FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.sret",
@@ -1130,7 +1171,7 @@ public:
});
else
fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
- [=](mlir::func::FuncOp func) {
+ [=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.sret",
@@ -1141,8 +1182,7 @@ public:
}
if (auto align = attr.getAlignment())
fixups.emplace_back(
- FixupTy::Codes::ReturnType, newResTys.size(),
- [=](mlir::func::FuncOp func) {
+ FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) {
func.setArgAttr(
newResTys.size(), "llvm.align",
rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
@@ -1155,9 +1195,8 @@ public:
/// Convert a complex return value. This can involve converting the return
/// value to a "hidden" first argument or packing the complex into a wide
/// GPR.
- template <typename Ty, typename FIXUPS>
- void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
- Ty &newResTys,
+ template <typename OpTy, typename Ty, typename FIXUPS>
+ void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noComplexConversion) {
@@ -1169,9 +1208,8 @@ public:
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
}
- template <typename Ty, typename FIXUPS>
- void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
- Ty &newResTys,
+ template <typename OpTy, typename Ty, typename FIXUPS>
+ void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noStructConversion) {
@@ -1182,12 +1220,10 @@ public:
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
}
- template <typename FIXUPS>
- void
- createFuncOpArgFixups(mlir::func::FuncOp func,
- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
- fir::CodeGenSpecifics::Marshalling &argsInTys,
- FIXUPS &fixups) {
+ template <typename OpTy, typename FIXUPS>
+ void createFuncOpArgFixups(
+ OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
: FixupTy::Codes::ArgumentType;
for (auto e : llvm::enumerate(argsInTys)) {
@@ -1198,7 +1234,7 @@ public:
if (attr.isByVal()) {
if (auto align = attr.getAlignment())
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
- [=](mlir::func::FuncOp func) {
+ [=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.byval",
@@ -1210,8 +1246,7 @@ public:
});
else
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
- newInTyAndAttrs.size(),
- [=](mlir::func::FuncOp func) {
+ newInTyAndAttrs.size(), [=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.byval",
@@ -1220,7 +1255,7 @@ public:
} else {
if (auto align = attr.getAlignment())
fixups.emplace_back(
- fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
+ fixupCode, argNo, index, [=](OpTy func) {
func.setArgAttr(argNo, "llvm.align",
rewriter->getIntegerAttr(
rewriter->getIntegerType(32), align));
@@ -1235,8 +1270,8 @@ public:
/// Convert a complex argument value. This can involve storing the value to
/// a temporary memory location or factoring the value into two distinct
/// arguments.
- template <typename FIXUPS>
- void doComplexArg(mlir::func::FuncOp func, mlir::ComplexType cmplx,
+ template <typename OpTy, typename FIXUPS>
+ void doComplexArg(OpTy func, mlir::ComplexType cmplx,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noComplexConversion) {
@@ -1248,8 +1283,8 @@ public:
createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
}
- template <typename FIXUPS>
- void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy,
+ template <typename OpTy, typename FIXUPS>
+ void doStructArg(OpTy func, fir::RecordType recTy,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noStructConversion) {