diff options
Diffstat (limited to 'flang/lib/Optimizer/CodeGen/Target.cpp')
| -rw-r--r-- | flang/lib/Optimizer/CodeGen/Target.cpp | 146 |
1 files changed, 120 insertions, 26 deletions
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp index f7bffbf53c19..c332493eb807 100644 --- a/flang/lib/Optimizer/CodeGen/Target.cpp +++ b/flang/lib/Optimizer/CodeGen/Target.cpp @@ -788,6 +788,8 @@ struct TargetX86_64Win : public GenericTarget<TargetX86_64Win> { //===----------------------------------------------------------------------===// namespace { +// AArch64 procedure call standard: +// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing struct TargetAArch64 : public GenericTarget<TargetAArch64> { using GenericTarget::GenericTarget; @@ -826,7 +828,7 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> { return marshal; } - // Flatten a RecordType::TypeList containing more record types or array types + // Flatten a RecordType::TypeList containing more record types or array type static std::optional<std::vector<mlir::Type>> flattenTypeList(const RecordType::TypeList &types) { std::vector<mlir::Type> flatTypes; @@ -870,52 +872,144 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> { // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An // HFA is a record type with up to 4 floating-point members of the same type. - static bool isHFA(fir::RecordType ty) { + static std::optional<int> usedRegsForHFA(fir::RecordType ty) { RecordType::TypeList types = ty.getTypeList(); if (types.empty() || types.size() > 4) - return false; + return std::nullopt; std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types); if (!flatTypes || flatTypes->size() > 4) { - return false; + return std::nullopt; } if (!isa_real(flatTypes->front())) { - return false; + return std::nullopt; } - return llvm::all_equal(*flatTypes); + return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()} + : std::nullopt; } - // AArch64 procedure call ABI: - // https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing - CodeGenSpecifics::Marshalling - structReturnType(mlir::Location loc, fir::RecordType ty) const override { - CodeGenSpecifics::Marshalling marshal; + struct NRegs { + int n{0}; + bool isSimd{false}; + }; - if (isHFA(ty)) { - // Just return the existing record type - marshal.emplace_back(ty, AT{}); - return marshal; + NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const { + if (std::optional<int> size = usedRegsForHFA(type)) + return {*size, true}; + + auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash( + loc, type, getDataLayout(), kindMap); + + if (size <= 16) + return {static_cast<int>((size + 7) / 8), false}; + + // Pass on the stack, i.e. no registers used + return {}; + } + + NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const { + return llvm::TypeSwitch<mlir::Type, NRegs>(type) + .Case<mlir::IntegerType>([&](auto intTy) { + return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false}; + }) + .Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; }) + .Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; }) + .Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; }) + .Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; }) + .Case<fir::SequenceType>([&](auto ty) { + assert(ty.getShape().size() == 1 && + "invalid array dimensions in BIND(C)"); + NRegs nregs = usedRegsForType(loc, ty.getEleTy()); + nregs.n *= ty.getShape()[0]; + return nregs; + }) + .Case<fir::RecordType>( + [&](auto ty) { return usedRegsForRecordType(loc, ty); }) + .Case<fir::VectorType>([&](auto) { + TODO(loc, "passing vector argument to C by value is not supported"); + return NRegs{}; + }); + } + + bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type, + const Marshalling &previousArguments) const { + int availIntRegisters = 8; + int availSIMDRegisters = 8; + + // Check previous arguments to see how many registers are used already + for (auto [type, attr] : previousArguments) { + if (availIntRegisters <= 0 || availSIMDRegisters <= 0) + break; + + if (attr.isByVal()) + continue; // Previous argument passed on the stack + + NRegs nregs = usedRegsForType(loc, type); + if (nregs.isSimd) + availSIMDRegisters -= nregs.n; + else + availIntRegisters -= nregs.n; } - auto [size, align] = + NRegs nregs = usedRegsForRecordType(loc, type); + + if (nregs.isSimd) + return nregs.n <= availSIMDRegisters; + + return nregs.n <= availIntRegisters; + } + + CodeGenSpecifics::Marshalling + passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const { + CodeGenSpecifics::Marshalling marshal; + auto sizeAndAlign = fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap); + // The stack is always 8 byte aligned + unsigned short align = + std::max(sizeAndAlign.second, static_cast<unsigned short>(8)); + marshal.emplace_back(fir::ReferenceType::get(ty), + AT{align, /*byval=*/!isResult, /*sret=*/isResult}); + return marshal; + } - // return in registers if size <= 16 bytes - if (size <= 16) { - std::size_t dwordSize = (size + 7) / 8; - auto newTy = fir::SequenceType::get( - dwordSize, mlir::IntegerType::get(ty.getContext(), 64)); - marshal.emplace_back(newTy, AT{}); - return marshal; + CodeGenSpecifics::Marshalling + structType(mlir::Location loc, fir::RecordType type, bool isResult) const { + NRegs nregs = usedRegsForRecordType(loc, type); + + // If the type needs no registers it must need to be passed on the stack + if (nregs.n == 0) + return passOnTheStack(loc, type, isResult); + + CodeGenSpecifics::Marshalling marshal; + + mlir::Type pcsType; + if (nregs.isSimd) { + pcsType = type; + } else { + pcsType = fir::SequenceType::get( + nregs.n, mlir::IntegerType::get(type.getContext(), 64)); } - unsigned short stackAlign = std::max<unsigned short>(align, 8u); - marshal.emplace_back(fir::ReferenceType::get(ty), - AT{stackAlign, false, true}); + marshal.emplace_back(pcsType, AT{}); return marshal; } + + CodeGenSpecifics::Marshalling + structArgumentType(mlir::Location loc, fir::RecordType ty, + const Marshalling &previousArguments) const override { + if (!hasEnoughRegisters(loc, ty, previousArguments)) { + return passOnTheStack(loc, ty, /*isResult=*/false); + } + + return structType(loc, ty, /*isResult=*/false); + } + + CodeGenSpecifics::Marshalling + structReturnType(mlir::Location loc, fir::RecordType ty) const override { + return structType(loc, ty, /*isResult=*/true); + } }; } // namespace |
