summaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp93
1 files changed, 34 insertions, 59 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 402a012e8e55..fd6d20e146bb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -832,7 +832,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
case ISD::SHL: {
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (std::optional<uint64_t> MaxSA =
+ if (std::optional<unsigned> MaxSA =
DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
SDValue Op0 = Op.getOperand(0);
unsigned ShAmt = *MaxSA;
@@ -847,7 +847,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
case ISD::SRL: {
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (std::optional<uint64_t> MaxSA =
+ if (std::optional<unsigned> MaxSA =
DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
SDValue Op0 = Op.getOperand(0);
unsigned ShAmt = *MaxSA;
@@ -1780,7 +1780,7 @@ bool TargetLowering::SimplifyDemandedBits(
SDValue Op1 = Op.getOperand(1);
EVT ShiftVT = Op1.getValueType();
- if (std::optional<uint64_t> KnownSA =
+ if (std::optional<unsigned> KnownSA =
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
unsigned ShAmt = *KnownSA;
if (ShAmt == 0)
@@ -1792,7 +1792,7 @@ bool TargetLowering::SimplifyDemandedBits(
// TODO - support non-uniform vector amounts.
if (Op0.getOpcode() == ISD::SRL) {
if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
- if (std::optional<uint64_t> InnerSA =
+ if (std::optional<unsigned> InnerSA =
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
unsigned C1 = *InnerSA;
unsigned Opc = ISD::SHL;
@@ -1832,7 +1832,7 @@ bool TargetLowering::SimplifyDemandedBits(
// TODO - support non-uniform vector amounts.
if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() &&
InnerOp.hasOneUse()) {
- if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount(
+ if (std::optional<unsigned> SA2 = TLO.DAG.getValidShiftAmount(
InnerOp, DemandedElts, Depth + 2)) {
unsigned InnerShAmt = *SA2;
if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
@@ -1858,8 +1858,7 @@ bool TargetLowering::SimplifyDemandedBits(
Op->dropFlags(SDNodeFlags::NoWrap);
return true;
}
- Known.Zero <<= ShAmt;
- Known.One <<= ShAmt;
+ Known <<= ShAmt;
// low bits known zero.
Known.Zero.setLowBits(ShAmt);
@@ -1950,7 +1949,7 @@ bool TargetLowering::SimplifyDemandedBits(
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (std::optional<uint64_t> MaxSA =
+ if (std::optional<unsigned> MaxSA =
TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
unsigned ShAmt = *MaxSA;
unsigned NumSignBits =
@@ -1966,7 +1965,7 @@ bool TargetLowering::SimplifyDemandedBits(
SDValue Op1 = Op.getOperand(1);
EVT ShiftVT = Op1.getValueType();
- if (std::optional<uint64_t> KnownSA =
+ if (std::optional<unsigned> KnownSA =
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
unsigned ShAmt = *KnownSA;
if (ShAmt == 0)
@@ -1978,7 +1977,7 @@ bool TargetLowering::SimplifyDemandedBits(
// TODO - support non-uniform vector amounts.
if (Op0.getOpcode() == ISD::SHL) {
if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
- if (std::optional<uint64_t> InnerSA =
+ if (std::optional<unsigned> InnerSA =
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
unsigned C1 = *InnerSA;
unsigned Opc = ISD::SRL;
@@ -1998,7 +1997,7 @@ bool TargetLowering::SimplifyDemandedBits(
// single sra. We can do this if the top bits are never demanded.
if (Op0.getOpcode() == ISD::SRA && Op0.hasOneUse()) {
if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) {
- if (std::optional<uint64_t> InnerSA =
+ if (std::optional<unsigned> InnerSA =
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
unsigned C1 = *InnerSA;
// Clamp the combined shift amount if it exceeds the bit width.
@@ -2042,8 +2041,7 @@ bool TargetLowering::SimplifyDemandedBits(
if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
Depth + 1))
return true;
- Known.Zero.lshrInPlace(ShAmt);
- Known.One.lshrInPlace(ShAmt);
+ Known >>= ShAmt;
// High bits known zero.
Known.Zero.setHighBits(ShAmt);
@@ -2064,7 +2062,7 @@ bool TargetLowering::SimplifyDemandedBits(
// If we are only demanding sign bits then we can use the shift source
// directly.
- if (std::optional<uint64_t> MaxSA =
+ if (std::optional<unsigned> MaxSA =
TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
unsigned ShAmt = *MaxSA;
// Must already be signbits in DemandedBits bounds, and can't demand any
@@ -2103,7 +2101,7 @@ bool TargetLowering::SimplifyDemandedBits(
if (DemandedBits.isOne())
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
- if (std::optional<uint64_t> KnownSA =
+ if (std::optional<unsigned> KnownSA =
TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) {
unsigned ShAmt = *KnownSA;
if (ShAmt == 0)
@@ -2112,7 +2110,7 @@ bool TargetLowering::SimplifyDemandedBits(
// fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target
// supports sext_inreg.
if (Op0.getOpcode() == ISD::SHL) {
- if (std::optional<uint64_t> InnerSA =
+ if (std::optional<unsigned> InnerSA =
TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) {
unsigned LowBits = BitWidth - ShAmt;
EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits);
@@ -2153,8 +2151,7 @@ bool TargetLowering::SimplifyDemandedBits(
if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
Depth + 1))
return true;
- Known.Zero.lshrInPlace(ShAmt);
- Known.One.lshrInPlace(ShAmt);
+ Known >>= ShAmt;
// If the input sign bit is known to be zero, or if none of the top bits
// are demanded, turn this into an unsigned shift right.
@@ -2225,10 +2222,8 @@ bool TargetLowering::SimplifyDemandedBits(
Depth + 1))
return true;
- Known2.One <<= (IsFSHL ? Amt : (BitWidth - Amt));
- Known2.Zero <<= (IsFSHL ? Amt : (BitWidth - Amt));
- Known.One.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
- Known.Zero.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt);
+ Known2 <<= (IsFSHL ? Amt : (BitWidth - Amt));
+ Known >>= (IsFSHL ? (BitWidth - Amt) : Amt);
Known = Known.unionWith(Known2);
// Attempt to avoid multi-use ops if we don't need anything from them.
@@ -2363,8 +2358,7 @@ bool TargetLowering::SimplifyDemandedBits(
if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
Depth + 1))
return true;
- Known.One = Known2.One.reverseBits();
- Known.Zero = Known2.Zero.reverseBits();
+ Known = Known2.reverseBits();
break;
}
case ISD::BSWAP: {
@@ -2397,8 +2391,7 @@ bool TargetLowering::SimplifyDemandedBits(
if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO,
Depth + 1))
return true;
- Known.One = Known2.One.byteSwap();
- Known.Zero = Known2.Zero.byteSwap();
+ Known = Known2.byteSwap();
break;
}
case ISD::CTPOP: {
@@ -2664,11 +2657,11 @@ bool TargetLowering::SimplifyDemandedBits(
break;
}
- std::optional<uint64_t> ShAmtC =
+ std::optional<unsigned> ShAmtC =
TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
if (!ShAmtC || *ShAmtC >= BitWidth)
break;
- uint64_t ShVal = *ShAmtC;
+ unsigned ShVal = *ShAmtC;
APInt HighBits =
APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
@@ -3234,27 +3227,6 @@ bool TargetLowering::SimplifyDemandedVectorElts(
KnownUndef.setAllBits();
return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT));
}
- SDValue ScalarSrc = Op.getOperand(0);
- if (ScalarSrc.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
- SDValue Src = ScalarSrc.getOperand(0);
- SDValue Idx = ScalarSrc.getOperand(1);
- EVT SrcVT = Src.getValueType();
-
- ElementCount SrcEltCnt = SrcVT.getVectorElementCount();
-
- if (SrcEltCnt.isScalable())
- return false;
-
- unsigned NumSrcElts = SrcEltCnt.getFixedValue();
- if (isNullConstant(Idx)) {
- APInt SrcDemandedElts = APInt::getOneBitSet(NumSrcElts, 0);
- APInt SrcUndef = KnownUndef.zextOrTrunc(NumSrcElts);
- APInt SrcZero = KnownZero.zextOrTrunc(NumSrcElts);
- if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
- TLO, Depth + 1))
- return true;
- }
- }
KnownUndef.setHighBits(NumElts - 1);
break;
}
@@ -9740,8 +9712,8 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
SDLoc dl(N);
EVT VT = N->getValueType(0);
- SDValue LHS = DAG.getFreeze(N->getOperand(0));
- SDValue RHS = DAG.getFreeze(N->getOperand(1));
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
bool IsSigned = N->getOpcode() == ISD::ABDS;
// abds(lhs, rhs) -> sub(smax(lhs,rhs), smin(lhs,rhs))
@@ -9749,34 +9721,37 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
unsigned MaxOpc = IsSigned ? ISD::SMAX : ISD::UMAX;
unsigned MinOpc = IsSigned ? ISD::SMIN : ISD::UMIN;
if (isOperationLegal(MaxOpc, VT) && isOperationLegal(MinOpc, VT)) {
+ LHS = DAG.getFreeze(LHS);
+ RHS = DAG.getFreeze(RHS);
SDValue Max = DAG.getNode(MaxOpc, dl, VT, LHS, RHS);
SDValue Min = DAG.getNode(MinOpc, dl, VT, LHS, RHS);
return DAG.getNode(ISD::SUB, dl, VT, Max, Min);
}
// abdu(lhs, rhs) -> or(usubsat(lhs,rhs), usubsat(rhs,lhs))
- if (!IsSigned && isOperationLegal(ISD::USUBSAT, VT))
+ if (!IsSigned && isOperationLegal(ISD::USUBSAT, VT)) {
+ LHS = DAG.getFreeze(LHS);
+ RHS = DAG.getFreeze(RHS);
return DAG.getNode(ISD::OR, dl, VT,
DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
+ }
// If the subtract doesn't overflow then just use abs(sub())
- // NOTE: don't use frozen operands for value tracking.
- bool IsNonNegative = DAG.SignBitIsZero(N->getOperand(1)) &&
- DAG.SignBitIsZero(N->getOperand(0));
+ bool IsNonNegative = DAG.SignBitIsZero(LHS) && DAG.SignBitIsZero(RHS);
- if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(0),
- N->getOperand(1)))
+ if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, LHS, RHS))
return DAG.getNode(ISD::ABS, dl, VT,
DAG.getNode(ISD::SUB, dl, VT, LHS, RHS));
- if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(1),
- N->getOperand(0)))
+ if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, RHS, LHS))
return DAG.getNode(ISD::ABS, dl, VT,
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
+ LHS = DAG.getFreeze(LHS);
+ RHS = DAG.getFreeze(RHS);
SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
// Branchless expansion iff cmp result is allbits: