diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 167 |
1 files changed, 94 insertions, 73 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 231184587d68..fed5e7238433 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -609,6 +609,8 @@ namespace { SDValue foldABSToABD(SDNode *N, const SDLoc &DL); SDValue foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True, SDValue False, ISD::CondCode CC, const SDLoc &DL); + SDValue foldSelectToUMin(SDValue LHS, SDValue RHS, SDValue True, + SDValue False, ISD::CondCode CC, const SDLoc &DL); SDValue unfoldMaskedMerge(SDNode *N); SDValue unfoldExtremeBitClearingToShifts(SDNode *N); SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond, @@ -859,7 +861,7 @@ namespace { auto LK = TLI.getTypeConversion(*DAG.getContext(), VT); return (LK.first == TargetLoweringBase::TypeLegal || LK.first == TargetLoweringBase::TypePromoteInteger) && - TLI.isOperationLegal(ISD::UMIN, LK.second); + TLI.isOperationLegalOrCustom(ISD::UMIN, LK.second); } public: @@ -2606,9 +2608,7 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { return SDValue(); } - SDValue SelectOp = DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF); - SelectOp->setFlags(BO->getFlags()); - return SelectOp; + return DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF, BO->getFlags()); } static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL, @@ -4095,6 +4095,26 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { return N0; } + // (sub x, ([v]select (ult x, y), 0, y)) -> (umin x, (sub x, y)) + // (sub x, ([v]select (uge x, y), y, 0)) -> (umin x, (sub x, y)) + if (N1.hasOneUse() && hasUMin(VT)) { + SDValue Y; + if (sd_match(N1, m_Select(m_SetCC(m_Specific(N0), m_Value(Y), + m_SpecificCondCode(ISD::SETULT)), + m_Zero(), m_Deferred(Y))) || + sd_match(N1, m_Select(m_SetCC(m_Specific(N0), m_Value(Y), + m_SpecificCondCode(ISD::SETUGE)), + m_Deferred(Y), m_Zero())) || + sd_match(N1, m_VSelect(m_SetCC(m_Specific(N0), m_Value(Y), + m_SpecificCondCode(ISD::SETULT)), + m_Zero(), m_Deferred(Y))) || + sd_match(N1, m_VSelect(m_SetCC(m_Specific(N0), m_Value(Y), + m_SpecificCondCode(ISD::SETUGE)), + m_Deferred(Y), m_Zero()))) + return DAG.getNode(ISD::UMIN, DL, VT, N0, + DAG.getNode(ISD::SUB, DL, VT, N0, Y)); + } + if (SDValue NewSel = foldBinOpIntoSelect(N)) return NewSel; @@ -4444,20 +4464,6 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { sd_match(N1, m_UMaxLike(m_Specific(A), m_Specific(B)))) return DAG.getNegative(DAG.getNode(ISD::ABDU, DL, VT, A, B), DL, VT); - // (sub x, (select (ult x, y), 0, y)) -> (umin x, (sub x, y)) - // (sub x, (select (uge x, y), y, 0)) -> (umin x, (sub x, y)) - if (hasUMin(VT)) { - SDValue Y; - if (sd_match(N1, m_OneUse(m_Select(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETULT)), - m_Zero(), m_Deferred(Y)))) || - sd_match(N1, m_OneUse(m_Select(m_SetCC(m_Specific(N0), m_Value(Y), - m_SpecificCondCode(ISD::SETUGE)), - m_Deferred(Y), m_Zero())))) - return DAG.getNode(ISD::UMIN, DL, VT, N0, - DAG.getNode(ISD::SUB, DL, VT, N0, Y)); - } - return SDValue(); } @@ -7635,7 +7641,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (SDValue(GN0, 0).hasOneUse() && isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) && - TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) { + TLI.isVectorLoadExtDesirable(SDValue(N, 0))) { SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(), GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()}; @@ -9149,7 +9155,7 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value()) return std::nullopt; - unsigned BitWidth = Op.getValueSizeInBits(); + unsigned BitWidth = Op.getScalarValueSizeInBits(); if (BitWidth % 8 != 0) return std::nullopt; unsigned ByteWidth = BitWidth / 8; @@ -9248,7 +9254,7 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, if (!L->isSimple() || L->isIndexed()) return std::nullopt; - unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits(); + unsigned NarrowBitWidth = L->getMemoryVT().getScalarSizeInBits(); if (NarrowBitWidth % 8 != 0) return std::nullopt; uint64_t NarrowByteWidth = NarrowBitWidth / 8; @@ -12175,6 +12181,30 @@ SDValue DAGCombiner::foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True, return SDValue(); } +// ([v]select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x) +// ([v]select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C)) +SDValue DAGCombiner::foldSelectToUMin(SDValue LHS, SDValue RHS, SDValue True, + SDValue False, ISD::CondCode CC, + const SDLoc &DL) { + APInt C; + EVT VT = True.getValueType(); + if (sd_match(RHS, m_ConstInt(C)) && hasUMin(VT)) { + if (CC == ISD::SETUGT && LHS == False && + sd_match(True, m_Add(m_Specific(False), m_SpecificInt(~C)))) { + SDValue AddC = DAG.getConstant(~C, DL, VT); + SDValue Add = DAG.getNode(ISD::ADD, DL, VT, False, AddC); + return DAG.getNode(ISD::UMIN, DL, VT, Add, False); + } + if (CC == ISD::SETULT && LHS == True && + sd_match(False, m_Add(m_Specific(True), m_SpecificInt(-C)))) { + SDValue AddC = DAG.getConstant(-C, DL, VT); + SDValue Add = DAG.getNode(ISD::ADD, DL, VT, True, AddC); + return DAG.getNode(ISD::UMIN, DL, VT, True, Add); + } + } + return SDValue(); +} + SDValue DAGCombiner::visitSELECT(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -12191,11 +12221,8 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { return V; // select (not Cond), N1, N2 -> select Cond, N2, N1 - if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) { - SDValue SelectOp = DAG.getSelect(DL, VT, F, N2, N1); - SelectOp->setFlags(Flags); - return SelectOp; - } + if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) + return DAG.getSelect(DL, VT, F, N2, N1, Flags); if (SDValue V = foldSelectOfConstants(N)) return V; @@ -12363,24 +12390,8 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { // (select (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x) // (select (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C)) - APInt C; - if (sd_match(Cond1, m_ConstInt(C)) && hasUMin(VT)) { - if (CC == ISD::SETUGT && Cond0 == N2 && - sd_match(N1, m_Add(m_Specific(N2), m_SpecificInt(~C)))) { - // The resulting code relies on an unsigned wrap in ADD. - // Recreating ADD to drop possible nuw/nsw flags. - SDValue AddC = DAG.getConstant(~C, DL, VT); - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N2, AddC); - return DAG.getNode(ISD::UMIN, DL, VT, Add, N2); - } - if (CC == ISD::SETULT && Cond0 == N1 && - sd_match(N2, m_Add(m_Specific(N1), m_SpecificInt(-C)))) { - // Ditto. - SDValue AddC = DAG.getConstant(-C, DL, VT); - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N1, AddC); - return DAG.getNode(ISD::UMIN, DL, VT, N1, Add); - } - } + if (SDValue UMin = foldSelectToUMin(Cond0, Cond1, N1, N2, CC, DL)) + return UMin; } if (!VT.isVector()) @@ -13417,6 +13428,11 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { } } } + + // (vselect (ugt x, C), (add x, ~C), x) -> (umin (add x, ~C), x) + // (vselect (ult x, C), x, (add x, -C)) -> (umin x, (add x, -C)) + if (SDValue UMin = foldSelectToUMin(LHS, RHS, N1, N2, CC, DL)) + return UMin; } if (SimplifySelectOps(N, N1, N2)) @@ -13490,11 +13506,9 @@ SDValue DAGCombiner::visitSELECT_CC(SDNode *N) { // Fold to a simpler select_cc if (SCC.getOpcode() == ISD::SETCC) { - SDValue SelectOp = - DAG.getNode(ISD::SELECT_CC, DL, N2.getValueType(), SCC.getOperand(0), - SCC.getOperand(1), N2, N3, SCC.getOperand(2)); - SelectOp->setFlags(SCC->getFlags()); - return SelectOp; + return DAG.getNode(ISD::SELECT_CC, DL, N2.getValueType(), + SCC.getOperand(0), SCC.getOperand(1), N2, N3, + SCC.getOperand(2), SCC->getFlags()); } } @@ -15731,7 +15745,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x) if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(N0)) { if (SDValue(GN0, 0).hasOneUse() && ExtVT == GN0->getMemoryVT() && - TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) { + TLI.isVectorLoadExtDesirable(SDValue(N, 0))) { SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(), GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()}; @@ -16758,12 +16772,8 @@ SDValue DAGCombiner::visitFREEZE(SDNode *N) { if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false, /*Depth*/ 1)) continue; - bool HadMaybePoisonOperands = !MaybePoisonOperands.empty(); - bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(Op).second; - if (IsNewMaybePoisonOperand) + if (MaybePoisonOperands.insert(Op).second) MaybePoisonOperandNumbers.push_back(OpNo); - if (!HadMaybePoisonOperands) - continue; } // NOTE: the whole op may be not guaranteed to not be undef or poison because // it could create undef or poison due to it's poison-generating flags. @@ -18713,6 +18723,12 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) { if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI)) return FTrunc; + // fold (sint_to_fp (trunc nsw x)) -> (sint_to_fp x) + if (N0.getOpcode() == ISD::TRUNCATE && N0->getFlags().hasNoSignedWrap() && + TLI.isTypeDesirableForOp(ISD::SINT_TO_FP, + N0.getOperand(0).getValueType())) + return DAG.getNode(ISD::SINT_TO_FP, DL, VT, N0.getOperand(0)); + return SDValue(); } @@ -18750,6 +18766,12 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) { if (SDValue FTrunc = foldFPToIntToFP(N, DL, DAG, TLI)) return FTrunc; + // fold (uint_to_fp (trunc nuw x)) -> (uint_to_fp x) + if (N0.getOpcode() == ISD::TRUNCATE && N0->getFlags().hasNoUnsignedWrap() && + TLI.isTypeDesirableForOp(ISD::UINT_TO_FP, + N0.getOperand(0).getValueType())) + return DAG.getNode(ISD::UINT_TO_FP, DL, VT, N0.getOperand(0)); + return SDValue(); } @@ -28194,14 +28216,16 @@ SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) { TLI.preferScalarizeSplat(N)) { EVT SrcVT = N0.getValueType(); EVT SrcEltVT = SrcVT.getVectorElementType(); - SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL); - SDValue Elt = - DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcEltVT, Src0, IndexC); - SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, Elt, N->getFlags()); - if (VT.isScalableVector()) - return DAG.getSplatVector(VT, DL, ScalarBO); - SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO); - return DAG.getBuildVector(VT, DL, Ops); + if (!LegalTypes || TLI.isTypeLegal(SrcEltVT)) { + SDValue IndexC = DAG.getVectorIdxConstant(Index0, DL); + SDValue Elt = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, SrcEltVT, Src0, IndexC); + SDValue ScalarBO = DAG.getNode(Opcode, DL, EltVT, Elt, N->getFlags()); + if (VT.isScalableVector()) + return DAG.getSplatVector(VT, DL, ScalarBO); + SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO); + return DAG.getBuildVector(VT, DL, Ops); + } } return SDValue(); @@ -28343,10 +28367,8 @@ SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SCC.getOperand(0), SCC.getOperand(1), SCC.getOperand(4), Flags); AddToWorklist(SETCC.getNode()); - SDValue SelectNode = DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC, - SCC.getOperand(2), SCC.getOperand(3)); - SelectNode->setFlags(Flags); - return SelectNode; + return DAG.getSelect(SDLoc(SCC), SCC.getValueType(), SETCC, + SCC.getOperand(2), SCC.getOperand(3), Flags); } return SCC; @@ -28647,9 +28669,9 @@ SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) { SDValue N10 = N1.getOperand(0); SDValue N20 = N2.getOperand(0); SDValue NewSel = DAG.getSelect(DL, N10.getValueType(), N0, N10, N20); - SDValue NewBinOp = DAG.getNode(BinOpc, DL, OpVTs, NewSel, N1.getOperand(1)); - NewBinOp->setFlags(N1->getFlags()); - NewBinOp->intersectFlagsWith(N2->getFlags()); + SDNodeFlags Flags = N1->getFlags() & N2->getFlags(); + SDValue NewBinOp = + DAG.getNode(BinOpc, DL, OpVTs, {NewSel, N1.getOperand(1)}, Flags); return SDValue(NewBinOp.getNode(), N1.getResNo()); } @@ -28661,10 +28683,9 @@ SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) { // Second op VT might be different (e.g. shift amount type) if (N11.getValueType() == N21.getValueType()) { SDValue NewSel = DAG.getSelect(DL, N11.getValueType(), N0, N11, N21); + SDNodeFlags Flags = N1->getFlags() & N2->getFlags(); SDValue NewBinOp = - DAG.getNode(BinOpc, DL, OpVTs, N1.getOperand(0), NewSel); - NewBinOp->setFlags(N1->getFlags()); - NewBinOp->intersectFlagsWith(N2->getFlags()); + DAG.getNode(BinOpc, DL, OpVTs, {N1.getOperand(0), NewSel}, Flags); return SDValue(NewBinOp.getNode(), N1.getResNo()); } } |
