summaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/ExpandFp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/ExpandFp.cpp')
-rw-r--r--llvm/lib/CodeGen/ExpandFp.cpp496
1 files changed, 474 insertions, 22 deletions
diff --git a/llvm/lib/CodeGen/ExpandFp.cpp b/llvm/lib/CodeGen/ExpandFp.cpp
index 1c1047c1ce18..9cc6c6a706c5 100644
--- a/llvm/lib/CodeGen/ExpandFp.cpp
+++ b/llvm/lib/CodeGen/ExpandFp.cpp
@@ -16,18 +16,29 @@
#include "llvm/CodeGen/ExpandFp.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/SimplifyQuery.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/RuntimeLibcalls.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include <optional>
+
+#define DEBUG_TYPE "expand-fp"
using namespace llvm;
@@ -37,6 +48,359 @@ static cl::opt<unsigned>
cl::desc("fp convert instructions on integers with "
"more than <N> bits are expanded."));
+namespace {
+/// This class implements a precise expansion of the frem instruction.
+/// The generated code is based on the fmod implementation in the AMD device
+/// libs.
+class FRemExpander {
+ /// The IRBuilder to use for the expansion.
+ IRBuilder<> &B;
+
+ /// Floating point type of the return value and the arguments of the FRem
+ /// instructions that should be expanded.
+ Type *FremTy;
+
+ /// Floating point type to use for the computation. This may be
+ /// wider than the \p FremTy.
+ Type *ComputeFpTy;
+
+ /// Integer type used to hold the exponents returned by frexp.
+ Type *ExTy;
+
+ /// How many bits of the quotient to compute per iteration of the
+ /// algorithm, stored as a value of type \p ExTy.
+ Value *Bits;
+
+ /// Constant 1 of type \p ExTy.
+ Value *One;
+
+public:
+ static bool canExpandType(Type *Ty) {
+ // TODO The expansion should work for other floating point types
+ // as well, but this would require additional testing.
+ return Ty->isIEEELikeFPTy() && !Ty->isBFloatTy() && !Ty->isFP128Ty();
+ }
+
+ static FRemExpander create(IRBuilder<> &B, Type *Ty) {
+ assert(canExpandType(Ty));
+
+ // The type to use for the computation of the remainder. This may be
+ // wider than the input/result type which affects the ...
+ Type *ComputeTy = Ty;
+ // ... maximum number of iterations of the remainder computation loop
+ // to use. This value is for the case in which the computation
+ // uses the same input/result type.
+ unsigned MaxIter = 2;
+
+ if (Ty->isHalfTy()) {
+ // Use the wider type and less iterations.
+ ComputeTy = B.getFloatTy();
+ MaxIter = 1;
+ }
+
+ unsigned Precision =
+ llvm::APFloat::semanticsPrecision(Ty->getFltSemantics());
+ return FRemExpander{B, Ty, Precision / MaxIter, ComputeTy};
+ }
+
+ /// Build the FRem expansion for the numerator \p X and the
+ /// denumerator \p Y. The type of X and Y must match \p FremTy. The
+ /// code will be generated at the insertion point of \p B and the
+ /// insertion point will be reset at exit.
+ Value *buildFRem(Value *X, Value *Y, std::optional<SimplifyQuery> &SQ) const;
+
+ /// Build an approximate FRem expansion for the numerator \p X and
+ /// the denumerator \p Y at the insertion point of builder \p B.
+ /// The type of X and Y must match \p FremTy.
+ Value *buildApproxFRem(Value *X, Value *Y) const;
+
+private:
+ FRemExpander(IRBuilder<> &B, Type *FremTy, unsigned Bits, Type *ComputeFpTy)
+ : B(B), FremTy(FremTy), ComputeFpTy(ComputeFpTy), ExTy(B.getInt32Ty()),
+ Bits(ConstantInt::get(ExTy, Bits)), One(ConstantInt::get(ExTy, 1)) {};
+
+ Value *createRcp(Value *V, const Twine &Name) const {
+ // Leave it to later optimizations to turn this into an rcp
+ // instruction if available.
+ return B.CreateFDiv(ConstantFP::get(ComputeFpTy, 1.0), V, Name);
+ }
+
+ // Helper function to build the UPDATE_AX code which is common to the
+ // loop body and the "final iteration".
+ Value *buildUpdateAx(Value *Ax, Value *Ay, Value *Ayinv) const {
+ // Build:
+ // float q = rint(ax * ayinv);
+ // ax = fma(-q, ay, ax);
+ // int clt = ax < 0.0f;
+ // float axp = ax + ay;
+ // ax = clt ? axp : ax;
+ Value *Q = B.CreateUnaryIntrinsic(Intrinsic::rint, B.CreateFMul(Ax, Ayinv),
+ {}, "q");
+ Value *AxUpdate = B.CreateFMA(B.CreateFNeg(Q), Ay, Ax, {}, "ax");
+ Value *Clt = B.CreateFCmp(CmpInst::FCMP_OLT, AxUpdate,
+ ConstantFP::getZero(ComputeFpTy), "clt");
+ Value *Axp = B.CreateFAdd(AxUpdate, Ay, "axp");
+ return B.CreateSelect(Clt, Axp, AxUpdate, "ax");
+ }
+
+ /// Build code to extract the exponent and mantissa of \p Src.
+ /// Return the exponent minus one for use as a loop bound and
+ /// the mantissa taken to the given \p NewExp power.
+ std::pair<Value *, Value *> buildExpAndPower(Value *Src, Value *NewExp,
+ const Twine &ExName,
+ const Twine &PowName) const {
+ // Build:
+ // ExName = frexp_exp(Src) - 1;
+ // PowName = fldexp(frexp_mant(ExName), NewExp);
+ Type *Ty = Src->getType();
+ Type *ExTy = B.getInt32Ty();
+ Value *Frexp = B.CreateIntrinsic(Intrinsic::frexp, {Ty, ExTy}, Src);
+ Value *Mant = B.CreateExtractValue(Frexp, {0});
+ Value *Exp = B.CreateExtractValue(Frexp, {1});
+
+ Exp = B.CreateSub(Exp, One, ExName);
+ Value *Pow = B.CreateLdexp(Mant, NewExp, {}, PowName);
+
+ return {Pow, Exp};
+ }
+
+ /// Build the main computation of the remainder for the case in which
+ /// Ax > Ay, where Ax = |X|, Ay = |Y|, and X is the numerator and Y the
+ /// denumerator. Add the incoming edge from the computation result
+ /// to \p RetPhi.
+ void buildRemainderComputation(Value *AxInitial, Value *AyInitial, Value *X,
+ PHINode *RetPhi, FastMathFlags FMF) const {
+ IRBuilder<>::FastMathFlagGuard Guard(B);
+ B.setFastMathFlags(FMF);
+
+ // Build:
+ // ex = frexp_exp(ax) - 1;
+ // ax = fldexp(frexp_mant(ax), bits);
+ // ey = frexp_exp(ay) - 1;
+ // ay = fledxp(frexp_mant(ay), 1);
+ auto [Ax, Ex] = buildExpAndPower(AxInitial, Bits, "ex", "ax");
+ auto [Ay, Ey] = buildExpAndPower(AyInitial, One, "ey", "ay");
+
+ // Build:
+ // int nb = ex - ey;
+ // float ayinv = 1.0/ay;
+ Value *Nb = B.CreateSub(Ex, Ey, "nb");
+ Value *Ayinv = createRcp(Ay, "ayinv");
+
+ // Build: while (nb > bits)
+ BasicBlock *PreheaderBB = B.GetInsertBlock();
+ Function *Fun = PreheaderBB->getParent();
+ auto *LoopBB = BasicBlock::Create(B.getContext(), "frem.loop_body", Fun);
+ auto *ExitBB = BasicBlock::Create(B.getContext(), "frem.loop_exit", Fun);
+
+ B.CreateCondBr(B.CreateICmp(CmpInst::ICMP_SGT, Nb, Bits), LoopBB, ExitBB);
+
+ // Build loop body:
+ // UPDATE_AX
+ // ax = fldexp(ax, bits);
+ // nb -= bits;
+ // One iteration of the loop is factored out. The code shared by
+ // the loop and this "iteration" is denoted by UPDATE_AX.
+ B.SetInsertPoint(LoopBB);
+ PHINode *NbIv = B.CreatePHI(Nb->getType(), 2, "nb_iv");
+ NbIv->addIncoming(Nb, PreheaderBB);
+
+ auto *AxPhi = B.CreatePHI(ComputeFpTy, 2, "ax_loop_phi");
+ AxPhi->addIncoming(Ax, PreheaderBB);
+
+ Value *AxPhiUpdate = buildUpdateAx(AxPhi, Ay, Ayinv);
+ AxPhiUpdate = B.CreateLdexp(AxPhiUpdate, Bits, {}, "ax_update");
+ AxPhi->addIncoming(AxPhiUpdate, LoopBB);
+ NbIv->addIncoming(B.CreateSub(NbIv, Bits, "nb_update"), LoopBB);
+
+ B.CreateCondBr(B.CreateICmp(CmpInst::ICMP_SGT, NbIv, Bits), LoopBB, ExitBB);
+
+ // Build final iteration
+ // ax = fldexp(ax, nb - bits + 1);
+ // UPDATE_AX
+ B.SetInsertPoint(ExitBB);
+
+ auto *AxPhiExit = B.CreatePHI(ComputeFpTy, 2, "ax_exit_phi");
+ AxPhiExit->addIncoming(Ax, PreheaderBB);
+ AxPhiExit->addIncoming(AxPhi, LoopBB);
+ auto *NbExitPhi = B.CreatePHI(Nb->getType(), 2, "nb_exit_phi");
+ NbExitPhi->addIncoming(NbIv, LoopBB);
+ NbExitPhi->addIncoming(Nb, PreheaderBB);
+
+ Value *AxFinal = B.CreateLdexp(
+ AxPhiExit, B.CreateAdd(B.CreateSub(NbExitPhi, Bits), One), {}, "ax");
+ AxFinal = buildUpdateAx(AxFinal, Ay, Ayinv);
+
+ // Build:
+ // ax = fldexp(ax, ey);
+ // ret = copysign(ax,x);
+ AxFinal = B.CreateLdexp(AxFinal, Ey, {}, "ax");
+ if (ComputeFpTy != FremTy)
+ AxFinal = B.CreateFPTrunc(AxFinal, FremTy);
+ Value *Ret = B.CreateCopySign(AxFinal, X);
+
+ RetPhi->addIncoming(Ret, ExitBB);
+ }
+
+ /// Build the else-branch of the conditional in the FRem
+ /// expansion, i.e. the case in wich Ax <= Ay, where Ax = |X|, Ay
+ /// = |Y|, and X is the numerator and Y the denumerator. Add the
+ /// incoming edge from the result to \p RetPhi.
+ void buildElseBranch(Value *Ax, Value *Ay, Value *X, PHINode *RetPhi) const {
+ // Build:
+ // ret = ax == ay ? copysign(0.0f, x) : x;
+ Value *ZeroWithXSign = B.CreateCopySign(ConstantFP::getZero(FremTy), X);
+ Value *Ret = B.CreateSelect(B.CreateFCmpOEQ(Ax, Ay), ZeroWithXSign, X);
+
+ RetPhi->addIncoming(Ret, B.GetInsertBlock());
+ }
+
+ /// Return a value that is NaN if one of the corner cases concerning
+ /// the inputs \p X and \p Y is detected, and \p Ret otherwise.
+ Value *handleInputCornerCases(Value *Ret, Value *X, Value *Y,
+ std::optional<SimplifyQuery> &SQ,
+ bool NoInfs) const {
+ // Build:
+ // ret = (y == 0.0f || isnan(y)) ? QNAN : ret;
+ // ret = isfinite(x) ? ret : QNAN;
+ Value *Nan = ConstantFP::getQNaN(FremTy);
+ Ret = B.CreateSelect(B.CreateFCmpUEQ(Y, ConstantFP::getZero(FremTy)), Nan,
+ Ret);
+ Value *XFinite =
+ NoInfs || (SQ && isKnownNeverInfinity(X, *SQ))
+ ? B.getTrue()
+ : B.CreateFCmpULT(B.CreateUnaryIntrinsic(Intrinsic::fabs, X),
+ ConstantFP::getInfinity(FremTy));
+ Ret = B.CreateSelect(XFinite, Ret, Nan);
+
+ return Ret;
+ }
+};
+
+Value *FRemExpander::buildApproxFRem(Value *X, Value *Y) const {
+ IRBuilder<>::FastMathFlagGuard Guard(B);
+ // Propagating the approximate functions flag to the
+ // division leads to an unacceptable drop in precision
+ // on AMDGPU.
+ // TODO Find out if any flags might be worth propagating.
+ B.clearFastMathFlags();
+
+ Value *Quot = B.CreateFDiv(X, Y);
+ Value *Trunc = B.CreateUnaryIntrinsic(Intrinsic::trunc, Quot, {});
+ Value *Neg = B.CreateFNeg(Trunc);
+
+ return B.CreateFMA(Neg, Y, X);
+}
+
+Value *FRemExpander::buildFRem(Value *X, Value *Y,
+ std::optional<SimplifyQuery> &SQ) const {
+ assert(X->getType() == FremTy && Y->getType() == FremTy);
+
+ FastMathFlags FMF = B.getFastMathFlags();
+
+ // This function generates the following code structure:
+ // if (abs(x) > abs(y))
+ // { ret = compute remainder }
+ // else
+ // { ret = x or 0 with sign of x }
+ // Adjust ret to NaN/inf in input
+ // return ret
+ Value *Ax = B.CreateUnaryIntrinsic(Intrinsic::fabs, X, {}, "ax");
+ Value *Ay = B.CreateUnaryIntrinsic(Intrinsic::fabs, Y, {}, "ay");
+ if (ComputeFpTy != X->getType()) {
+ Ax = B.CreateFPExt(Ax, ComputeFpTy, "ax");
+ Ay = B.CreateFPExt(Ay, ComputeFpTy, "ay");
+ }
+ Value *AxAyCmp = B.CreateFCmpOGT(Ax, Ay);
+
+ PHINode *RetPhi = B.CreatePHI(FremTy, 2, "ret");
+ Value *Ret = RetPhi;
+
+ // We would return NaN in all corner cases handled here.
+ // Hence, if NaNs are excluded, keep the result as it is.
+ if (!FMF.noNaNs())
+ Ret = handleInputCornerCases(Ret, X, Y, SQ, FMF.noInfs());
+
+ Function *Fun = B.GetInsertBlock()->getParent();
+ auto *ThenBB = BasicBlock::Create(B.getContext(), "frem.compute", Fun);
+ auto *ElseBB = BasicBlock::Create(B.getContext(), "frem.else", Fun);
+ SplitBlockAndInsertIfThenElse(AxAyCmp, RetPhi, &ThenBB, &ElseBB);
+
+ auto SavedInsertPt = B.GetInsertPoint();
+
+ // Build remainder computation for "then" branch
+ //
+ // The ordered comparison ensures that ax and ay are not NaNs
+ // in the then-branch. Furthermore, y cannot be an infinity and the
+ // check at the end of the function ensures that the result will not
+ // be used if x is an infinity.
+ FastMathFlags ComputeFMF = FMF;
+ ComputeFMF.setNoInfs();
+ ComputeFMF.setNoNaNs();
+
+ B.SetInsertPoint(ThenBB);
+ buildRemainderComputation(Ax, Ay, X, RetPhi, FMF);
+ B.CreateBr(RetPhi->getParent());
+
+ // Build "else"-branch
+ B.SetInsertPoint(ElseBB);
+ buildElseBranch(Ax, Ay, X, RetPhi);
+ B.CreateBr(RetPhi->getParent());
+
+ B.SetInsertPoint(SavedInsertPt);
+
+ return Ret;
+}
+} // namespace
+
+static bool expandFRem(BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
+ LLVM_DEBUG(dbgs() << "Expanding instruction: " << I << '\n');
+
+ Type *ReturnTy = I.getType();
+ assert(FRemExpander::canExpandType(ReturnTy->getScalarType()));
+
+ FastMathFlags FMF = I.getFastMathFlags();
+ // TODO Make use of those flags for optimization?
+ FMF.setAllowReciprocal(false);
+ FMF.setAllowContract(false);
+
+ IRBuilder<> B(&I);
+ B.setFastMathFlags(FMF);
+ B.SetCurrentDebugLocation(I.getDebugLoc());
+
+ Type *ElemTy = ReturnTy->getScalarType();
+ const FRemExpander Expander = FRemExpander::create(B, ElemTy);
+
+ Value *Ret;
+ if (ReturnTy->isFloatingPointTy())
+ Ret = FMF.approxFunc()
+ ? Expander.buildApproxFRem(I.getOperand(0), I.getOperand(1))
+ : Expander.buildFRem(I.getOperand(0), I.getOperand(1), SQ);
+ else {
+ auto *VecTy = cast<FixedVectorType>(ReturnTy);
+
+ // This could use SplitBlockAndInsertForEachLane but the interface
+ // is a bit awkward for a constant number of elements and it will
+ // boil down to the same code.
+ // TODO Expand the FRem instruction only once and reuse the code.
+ Value *Nums = I.getOperand(0);
+ Value *Denums = I.getOperand(1);
+ Ret = PoisonValue::get(I.getType());
+ for (int I = 0, E = VecTy->getNumElements(); I != E; ++I) {
+ Value *Num = B.CreateExtractElement(Nums, I);
+ Value *Denum = B.CreateExtractElement(Denums, I);
+ Value *Rem = FMF.approxFunc() ? Expander.buildApproxFRem(Num, Denum)
+ : Expander.buildFRem(Num, Denum, SQ);
+ Ret = B.CreateInsertElement(Ret, Rem, I);
+ }
+ }
+
+ I.replaceAllUsesWith(Ret);
+ Ret->takeName(&I);
+ I.eraseFromParent();
+
+ return true;
+}
// clang-format off: preserve formatting of the following example
/// Generate code to convert a fp number to integer, replacing FPToS(U)I with
@@ -64,8 +428,8 @@ static cl::opt<unsigned>
/// br i1 %cmp6.not, label %if.end12, label %if.then8
///
/// if.then8: ; preds = %if.end
-/// %cond11 = select i1 %tobool.not, i64 9223372036854775807, i64 -9223372036854775808
-/// br label %cleanup
+/// %cond11 = select i1 %tobool.not, i64 9223372036854775807, i64
+/// -9223372036854775808 br label %cleanup
///
/// if.end12: ; preds = %if.end
/// %cmp13 = icmp ult i64 %shr, 150
@@ -83,9 +447,10 @@ static cl::opt<unsigned>
/// %mul19 = mul nsw i64 %shl, %conv
/// br label %cleanup
///
-/// cleanup: ; preds = %entry, %if.else, %if.then15, %if.then8
-/// %retval.0 = phi i64 [ %cond11, %if.then8 ], [ %mul, %if.then15 ], [ %mul19, %if.else ], [ 0, %entry ]
-/// ret i64 %retval.0
+/// cleanup: ; preds = %entry,
+/// %if.else, %if.then15, %if.then8
+/// %retval.0 = phi i64 [ %cond11, %if.then8 ], [ %mul, %if.then15 ], [
+/// %mul19, %if.else ], [ 0, %entry ] ret i64 %retval.0
/// }
///
/// Replace fp to integer with generated code.
@@ -272,13 +637,11 @@ static void expandFPToI(Instruction *FPToI) {
/// %or = or i64 %shr6, %conv11
/// br label %sw.epilog
///
-/// sw.epilog: ; preds = %sw.default, %if.then4, %sw.bb
-/// %a.addr.0 = phi i64 [ %or, %sw.default ], [ %sub, %if.then4 ], [ %shl, %sw.bb ]
-/// %1 = lshr i64 %a.addr.0, 2
-/// %2 = and i64 %1, 1
-/// %or16 = or i64 %2, %a.addr.0
-/// %inc = add nsw i64 %or16, 1
-/// %3 = and i64 %inc, 67108864
+/// sw.epilog: ; preds = %sw.default,
+/// %if.then4, %sw.bb
+/// %a.addr.0 = phi i64 [ %or, %sw.default ], [ %sub, %if.then4 ], [ %shl,
+/// %sw.bb ] %1 = lshr i64 %a.addr.0, 2 %2 = and i64 %1, 1 %or16 = or i64 %2,
+/// %a.addr.0 %inc = add nsw i64 %or16, 1 %3 = and i64 %inc, 67108864
/// %tobool.not = icmp eq i64 %3, 0
/// %spec.select.v = select i1 %tobool.not, i64 2, i64 3
/// %spec.select = ashr i64 %inc, %spec.select.v
@@ -291,7 +654,8 @@ static void expandFPToI(Instruction *FPToI) {
/// %shl25 = shl i64 %sub, %sh_prom24
/// br label %if.end26
///
-/// if.end26: ; preds = %sw.epilog, %if.else
+/// if.end26: ; preds = %sw.epilog,
+/// %if.else
/// %a.addr.1 = phi i64 [ %shl25, %if.else ], [ %spec.select, %sw.epilog ]
/// %e.0 = phi i32 [ %sub2, %if.else ], [ %spec.select56, %sw.epilog ]
/// %conv27 = trunc i64 %shr to i32
@@ -305,7 +669,8 @@ static void expandFPToI(Instruction *FPToI) {
/// %4 = bitcast i32 %or33 to float
/// br label %return
///
-/// return: ; preds = %entry, %if.end26
+/// return: ; preds = %entry,
+/// %if.end26
/// %retval.0 = phi float [ %4, %if.end26 ], [ 0.000000e+00, %entry ]
/// ret float %retval.0
/// }
@@ -594,7 +959,38 @@ static void scalarize(Instruction *I, SmallVectorImpl<Instruction *> &Replace) {
I->eraseFromParent();
}
-static bool runImpl(Function &F, const TargetLowering &TLI) {
+// This covers all floating point types; more than we need here.
+// TODO Move somewhere else for general use?
+/// Return the Libcall for a frem instruction of
+/// type \p Ty.
+static RTLIB::Libcall fremToLibcall(Type *Ty) {
+ assert(Ty->isFloatingPointTy());
+ if (Ty->isFloatTy() || Ty->is16bitFPTy())
+ return RTLIB::REM_F32;
+ if (Ty->isDoubleTy())
+ return RTLIB::REM_F64;
+ if (Ty->isFP128Ty())
+ return RTLIB::REM_F128;
+ if (Ty->isX86_FP80Ty())
+ return RTLIB::REM_F80;
+ if (Ty->isPPC_FP128Ty())
+ return RTLIB::REM_PPCF128;
+
+ llvm_unreachable("Unknown floating point type");
+}
+
+/* Return true if, according to \p LibInfo, the target either directly
+ supports the frem instruction for the \p Ty, has a custom lowering,
+ or uses a libcall. */
+static bool targetSupportsFrem(const TargetLowering &TLI, Type *Ty) {
+ if (!TLI.isOperationExpand(ISD::FREM, EVT::getEVT(Ty)))
+ return true;
+
+ return TLI.getLibcallName(fremToLibcall(Ty->getScalarType()));
+}
+
+static bool runImpl(Function &F, const TargetLowering &TLI,
+ AssumptionCache *AC) {
SmallVector<Instruction *, 4> Replace;
SmallVector<Instruction *, 4> ReplaceVector;
bool Modified = false;
@@ -609,6 +1005,21 @@ static bool runImpl(Function &F, const TargetLowering &TLI) {
for (auto &I : instructions(F)) {
switch (I.getOpcode()) {
+ case Instruction::FRem: {
+ Type *Ty = I.getType();
+ // TODO: This pass doesn't handle scalable vectors.
+ if (Ty->isScalableTy())
+ continue;
+
+ if (targetSupportsFrem(TLI, Ty) ||
+ !FRemExpander::canExpandType(Ty->getScalarType()))
+ continue;
+
+ Replace.push_back(&I);
+ Modified = true;
+
+ break;
+ }
case Instruction::FPToUI:
case Instruction::FPToSI: {
// TODO: This pass doesn't handle scalable vectors.
@@ -659,8 +1070,20 @@ static bool runImpl(Function &F, const TargetLowering &TLI) {
while (!Replace.empty()) {
Instruction *I = Replace.pop_back_val();
- if (I->getOpcode() == Instruction::FPToUI ||
- I->getOpcode() == Instruction::FPToSI) {
+ if (I->getOpcode() == Instruction::FRem) {
+ auto SQ = [&]() -> std::optional<SimplifyQuery> {
+ if (AC) {
+ auto Res = std::make_optional<SimplifyQuery>(
+ I->getModule()->getDataLayout(), I);
+ Res->AC = AC;
+ return Res;
+ }
+ return {};
+ }();
+
+ expandFRem(cast<BinaryOperator>(*I), SQ);
+ } else if (I->getOpcode() == Instruction::FPToUI ||
+ I->getOpcode() == Instruction::FPToSI) {
expandFPToI(I);
} else {
expandIToFP(I);
@@ -672,31 +1095,58 @@ static bool runImpl(Function &F, const TargetLowering &TLI) {
namespace {
class ExpandFpLegacyPass : public FunctionPass {
+ CodeGenOptLevel OptLevel;
+
public:
static char ID;
- ExpandFpLegacyPass() : FunctionPass(ID) {
+ ExpandFpLegacyPass(CodeGenOptLevel OptLevel)
+ : FunctionPass(ID), OptLevel(OptLevel) {
initializeExpandFpLegacyPassPass(*PassRegistry::getPassRegistry());
}
+ ExpandFpLegacyPass() : ExpandFpLegacyPass(CodeGenOptLevel::None) {};
+
bool runOnFunction(Function &F) override {
auto *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
auto *TLI = TM->getSubtargetImpl(F)->getTargetLowering();
- return runImpl(F, *TLI);
+ AssumptionCache *AC = nullptr;
+
+ if (OptLevel != CodeGenOptLevel::None && !F.hasOptNone())
+ AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ return runImpl(F, *TLI, AC);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<TargetPassConfig>();
+ if (OptLevel != CodeGenOptLevel::None)
+ AU.addRequired<AssumptionCacheTracker>();
AU.addPreserved<AAResultsWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
}
};
} // namespace
+ExpandFpPass::ExpandFpPass(const TargetMachine *TM, CodeGenOptLevel OptLevel)
+ : TM(TM), OptLevel(OptLevel) {}
+
+void ExpandFpPass::printPipeline(
+ raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
+ static_cast<PassInfoMixin<ExpandFpPass> *>(this)->printPipeline(
+ OS, MapClassName2PassName);
+ OS << '<';
+ OS << "O" << (int)OptLevel;
+ OS << '>';
+}
+
PreservedAnalyses ExpandFpPass::run(Function &F, FunctionAnalysisManager &FAM) {
const TargetSubtargetInfo *STI = TM->getSubtargetImpl(F);
- return runImpl(F, *STI->getTargetLowering()) ? PreservedAnalyses::none()
- : PreservedAnalyses::all();
+ auto &TLI = *STI->getTargetLowering();
+ AssumptionCache *AC = nullptr;
+ if (OptLevel != CodeGenOptLevel::None)
+ AC = &FAM.getResult<AssumptionAnalysis>(F);
+ return runImpl(F, TLI, AC) ? PreservedAnalyses::none()
+ : PreservedAnalyses::all();
}
char ExpandFpLegacyPass::ID = 0;
@@ -704,4 +1154,6 @@ INITIALIZE_PASS_BEGIN(ExpandFpLegacyPass, "expand-fp",
"Expand certain fp instructions", false, false)
INITIALIZE_PASS_END(ExpandFpLegacyPass, "expand-fp", "Expand fp", false, false)
-FunctionPass *llvm::createExpandFpPass() { return new ExpandFpLegacyPass(); }
+FunctionPass *llvm::createExpandFpPass(CodeGenOptLevel OptLevel) {
+ return new ExpandFpLegacyPass(OptLevel);
+}