diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 93 |
1 files changed, 34 insertions, 59 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 402a012e8e55..fd6d20e146bb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -832,7 +832,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits( case ISD::SHL: { // If we are only demanding sign bits then we can use the shift source // directly. - if (std::optional<uint64_t> MaxSA = + if (std::optional<unsigned> MaxSA = DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) { SDValue Op0 = Op.getOperand(0); unsigned ShAmt = *MaxSA; @@ -847,7 +847,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits( case ISD::SRL: { // If we are only demanding sign bits then we can use the shift source // directly. - if (std::optional<uint64_t> MaxSA = + if (std::optional<unsigned> MaxSA = DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) { SDValue Op0 = Op.getOperand(0); unsigned ShAmt = *MaxSA; @@ -1780,7 +1780,7 @@ bool TargetLowering::SimplifyDemandedBits( SDValue Op1 = Op.getOperand(1); EVT ShiftVT = Op1.getValueType(); - if (std::optional<uint64_t> KnownSA = + if (std::optional<unsigned> KnownSA = TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) { unsigned ShAmt = *KnownSA; if (ShAmt == 0) @@ -1792,7 +1792,7 @@ bool TargetLowering::SimplifyDemandedBits( // TODO - support non-uniform vector amounts. if (Op0.getOpcode() == ISD::SRL) { if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) { - if (std::optional<uint64_t> InnerSA = + if (std::optional<unsigned> InnerSA = TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) { unsigned C1 = *InnerSA; unsigned Opc = ISD::SHL; @@ -1832,7 +1832,7 @@ bool TargetLowering::SimplifyDemandedBits( // TODO - support non-uniform vector amounts. if (InnerOp.getOpcode() == ISD::SRL && Op0.hasOneUse() && InnerOp.hasOneUse()) { - if (std::optional<uint64_t> SA2 = TLO.DAG.getValidShiftAmount( + if (std::optional<unsigned> SA2 = TLO.DAG.getValidShiftAmount( InnerOp, DemandedElts, Depth + 2)) { unsigned InnerShAmt = *SA2; if (InnerShAmt < ShAmt && InnerShAmt < InnerBits && @@ -1858,8 +1858,7 @@ bool TargetLowering::SimplifyDemandedBits( Op->dropFlags(SDNodeFlags::NoWrap); return true; } - Known.Zero <<= ShAmt; - Known.One <<= ShAmt; + Known <<= ShAmt; // low bits known zero. Known.Zero.setLowBits(ShAmt); @@ -1950,7 +1949,7 @@ bool TargetLowering::SimplifyDemandedBits( // If we are only demanding sign bits then we can use the shift source // directly. - if (std::optional<uint64_t> MaxSA = + if (std::optional<unsigned> MaxSA = TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) { unsigned ShAmt = *MaxSA; unsigned NumSignBits = @@ -1966,7 +1965,7 @@ bool TargetLowering::SimplifyDemandedBits( SDValue Op1 = Op.getOperand(1); EVT ShiftVT = Op1.getValueType(); - if (std::optional<uint64_t> KnownSA = + if (std::optional<unsigned> KnownSA = TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) { unsigned ShAmt = *KnownSA; if (ShAmt == 0) @@ -1978,7 +1977,7 @@ bool TargetLowering::SimplifyDemandedBits( // TODO - support non-uniform vector amounts. if (Op0.getOpcode() == ISD::SHL) { if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) { - if (std::optional<uint64_t> InnerSA = + if (std::optional<unsigned> InnerSA = TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) { unsigned C1 = *InnerSA; unsigned Opc = ISD::SRL; @@ -1998,7 +1997,7 @@ bool TargetLowering::SimplifyDemandedBits( // single sra. We can do this if the top bits are never demanded. if (Op0.getOpcode() == ISD::SRA && Op0.hasOneUse()) { if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) { - if (std::optional<uint64_t> InnerSA = + if (std::optional<unsigned> InnerSA = TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) { unsigned C1 = *InnerSA; // Clamp the combined shift amount if it exceeds the bit width. @@ -2042,8 +2041,7 @@ bool TargetLowering::SimplifyDemandedBits( if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) return true; - Known.Zero.lshrInPlace(ShAmt); - Known.One.lshrInPlace(ShAmt); + Known >>= ShAmt; // High bits known zero. Known.Zero.setHighBits(ShAmt); @@ -2064,7 +2062,7 @@ bool TargetLowering::SimplifyDemandedBits( // If we are only demanding sign bits then we can use the shift source // directly. - if (std::optional<uint64_t> MaxSA = + if (std::optional<unsigned> MaxSA = TLO.DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) { unsigned ShAmt = *MaxSA; // Must already be signbits in DemandedBits bounds, and can't demand any @@ -2103,7 +2101,7 @@ bool TargetLowering::SimplifyDemandedBits( if (DemandedBits.isOne()) return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1)); - if (std::optional<uint64_t> KnownSA = + if (std::optional<unsigned> KnownSA = TLO.DAG.getValidShiftAmount(Op, DemandedElts, Depth + 1)) { unsigned ShAmt = *KnownSA; if (ShAmt == 0) @@ -2112,7 +2110,7 @@ bool TargetLowering::SimplifyDemandedBits( // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target // supports sext_inreg. if (Op0.getOpcode() == ISD::SHL) { - if (std::optional<uint64_t> InnerSA = + if (std::optional<unsigned> InnerSA = TLO.DAG.getValidShiftAmount(Op0, DemandedElts, Depth + 2)) { unsigned LowBits = BitWidth - ShAmt; EVT ExtVT = EVT::getIntegerVT(*TLO.DAG.getContext(), LowBits); @@ -2153,8 +2151,7 @@ bool TargetLowering::SimplifyDemandedBits( if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) return true; - Known.Zero.lshrInPlace(ShAmt); - Known.One.lshrInPlace(ShAmt); + Known >>= ShAmt; // If the input sign bit is known to be zero, or if none of the top bits // are demanded, turn this into an unsigned shift right. @@ -2225,10 +2222,8 @@ bool TargetLowering::SimplifyDemandedBits( Depth + 1)) return true; - Known2.One <<= (IsFSHL ? Amt : (BitWidth - Amt)); - Known2.Zero <<= (IsFSHL ? Amt : (BitWidth - Amt)); - Known.One.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt); - Known.Zero.lshrInPlace(IsFSHL ? (BitWidth - Amt) : Amt); + Known2 <<= (IsFSHL ? Amt : (BitWidth - Amt)); + Known >>= (IsFSHL ? (BitWidth - Amt) : Amt); Known = Known.unionWith(Known2); // Attempt to avoid multi-use ops if we don't need anything from them. @@ -2363,8 +2358,7 @@ bool TargetLowering::SimplifyDemandedBits( if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO, Depth + 1)) return true; - Known.One = Known2.One.reverseBits(); - Known.Zero = Known2.Zero.reverseBits(); + Known = Known2.reverseBits(); break; } case ISD::BSWAP: { @@ -2397,8 +2391,7 @@ bool TargetLowering::SimplifyDemandedBits( if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedElts, Known2, TLO, Depth + 1)) return true; - Known.One = Known2.One.byteSwap(); - Known.Zero = Known2.Zero.byteSwap(); + Known = Known2.byteSwap(); break; } case ISD::CTPOP: { @@ -2664,11 +2657,11 @@ bool TargetLowering::SimplifyDemandedBits( break; } - std::optional<uint64_t> ShAmtC = + std::optional<unsigned> ShAmtC = TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2); if (!ShAmtC || *ShAmtC >= BitWidth) break; - uint64_t ShVal = *ShAmtC; + unsigned ShVal = *ShAmtC; APInt HighBits = APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth); @@ -3234,27 +3227,6 @@ bool TargetLowering::SimplifyDemandedVectorElts( KnownUndef.setAllBits(); return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); } - SDValue ScalarSrc = Op.getOperand(0); - if (ScalarSrc.getOpcode() == ISD::EXTRACT_VECTOR_ELT) { - SDValue Src = ScalarSrc.getOperand(0); - SDValue Idx = ScalarSrc.getOperand(1); - EVT SrcVT = Src.getValueType(); - - ElementCount SrcEltCnt = SrcVT.getVectorElementCount(); - - if (SrcEltCnt.isScalable()) - return false; - - unsigned NumSrcElts = SrcEltCnt.getFixedValue(); - if (isNullConstant(Idx)) { - APInt SrcDemandedElts = APInt::getOneBitSet(NumSrcElts, 0); - APInt SrcUndef = KnownUndef.zextOrTrunc(NumSrcElts); - APInt SrcZero = KnownZero.zextOrTrunc(NumSrcElts); - if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero, - TLO, Depth + 1)) - return true; - } - } KnownUndef.setHighBits(NumElts - 1); break; } @@ -9740,8 +9712,8 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG, SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const { SDLoc dl(N); EVT VT = N->getValueType(0); - SDValue LHS = DAG.getFreeze(N->getOperand(0)); - SDValue RHS = DAG.getFreeze(N->getOperand(1)); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); bool IsSigned = N->getOpcode() == ISD::ABDS; // abds(lhs, rhs) -> sub(smax(lhs,rhs), smin(lhs,rhs)) @@ -9749,34 +9721,37 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const { unsigned MaxOpc = IsSigned ? ISD::SMAX : ISD::UMAX; unsigned MinOpc = IsSigned ? ISD::SMIN : ISD::UMIN; if (isOperationLegal(MaxOpc, VT) && isOperationLegal(MinOpc, VT)) { + LHS = DAG.getFreeze(LHS); + RHS = DAG.getFreeze(RHS); SDValue Max = DAG.getNode(MaxOpc, dl, VT, LHS, RHS); SDValue Min = DAG.getNode(MinOpc, dl, VT, LHS, RHS); return DAG.getNode(ISD::SUB, dl, VT, Max, Min); } // abdu(lhs, rhs) -> or(usubsat(lhs,rhs), usubsat(rhs,lhs)) - if (!IsSigned && isOperationLegal(ISD::USUBSAT, VT)) + if (!IsSigned && isOperationLegal(ISD::USUBSAT, VT)) { + LHS = DAG.getFreeze(LHS); + RHS = DAG.getFreeze(RHS); return DAG.getNode(ISD::OR, dl, VT, DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS), DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS)); + } // If the subtract doesn't overflow then just use abs(sub()) - // NOTE: don't use frozen operands for value tracking. - bool IsNonNegative = DAG.SignBitIsZero(N->getOperand(1)) && - DAG.SignBitIsZero(N->getOperand(0)); + bool IsNonNegative = DAG.SignBitIsZero(LHS) && DAG.SignBitIsZero(RHS); - if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(0), - N->getOperand(1))) + if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, LHS, RHS)) return DAG.getNode(ISD::ABS, dl, VT, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS)); - if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(1), - N->getOperand(0))) + if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, RHS, LHS)) return DAG.getNode(ISD::ABS, dl, VT, DAG.getNode(ISD::SUB, dl, VT, RHS, LHS)); EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT; + LHS = DAG.getFreeze(LHS); + RHS = DAG.getFreeze(RHS); SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC); // Branchless expansion iff cmp result is allbits: |
