diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 167 |
1 files changed, 109 insertions, 58 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 995ae75da1c3..1977d3372c5f 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16117,6 +16117,46 @@ static SDValue reverseZExtICmpCombine(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Res); } +// (and (i1) f, (setcc c, 0, ne)) -> (czero.nez f, c) +// (and (i1) f, (setcc c, 0, eq)) -> (czero.eqz f, c) +// (and (setcc c, 0, ne), (i1) g) -> (czero.nez g, c) +// (and (setcc c, 0, eq), (i1) g) -> (czero.eqz g, c) +static SDValue combineANDOfSETCCToCZERO(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (!Subtarget.hasCZEROLike()) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + auto IsEqualCompZero = [](SDValue &V) -> bool { + if (V.getOpcode() == ISD::SETCC && isNullConstant(V.getOperand(1))) { + ISD::CondCode CC = cast<CondCodeSDNode>(V.getOperand(2))->get(); + if (ISD::isIntEqualitySetCC(CC)) + return true; + } + return false; + }; + + if (!IsEqualCompZero(N0) || !N0.hasOneUse()) + std::swap(N0, N1); + if (!IsEqualCompZero(N0) || !N0.hasOneUse()) + return SDValue(); + + KnownBits Known = DAG.computeKnownBits(N1); + if (Known.getMaxValue().ugt(1)) + return SDValue(); + + unsigned CzeroOpcode = + (cast<CondCodeSDNode>(N0.getOperand(2))->get() == ISD::SETNE) + ? RISCVISD::CZERO_EQZ + : RISCVISD::CZERO_NEZ; + + EVT VT = N->getValueType(0); + SDLoc DL(N); + return DAG.getNode(CzeroOpcode, DL, VT, N1, N0.getOperand(0)); +} + static SDValue reduceANDOfAtomicLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SelectionDAG &DAG = DCI.DAG; @@ -16180,7 +16220,9 @@ static SDValue performANDCombine(SDNode *N, if (SDValue V = reverseZExtICmpCombine(N, DAG, Subtarget)) return V; - + if (DCI.isAfterLegalizeDAG()) + if (SDValue V = combineANDOfSETCCToCZERO(N, DAG, Subtarget)) + return V; if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget)) @@ -16496,30 +16538,50 @@ static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG, } static SDValue getShlAddShlAdd(SDNode *N, SelectionDAG &DAG, unsigned ShX, - unsigned ShY, bool AddX) { + unsigned ShY, bool AddX, unsigned Shift) { SDLoc DL(N); EVT VT = N->getValueType(0); SDValue X = N->getOperand(0); - SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + // Put the shift first if we can fold a zext into the shift forming a slli.uw. + using namespace SDPatternMatch; + if (Shift != 0 && + sd_match(X, m_And(m_Value(), m_SpecificInt(UINT64_C(0xffffffff))))) { + X = DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT)); + Shift = 0; + } + SDValue ShlAdd = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, DAG.getTargetConstant(ShY, DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getTargetConstant(ShX, DL, VT), AddX ? X : Mul359); + if (ShX != 0) + ShlAdd = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, ShlAdd, + DAG.getTargetConstant(ShX, DL, VT), AddX ? X : ShlAdd); + if (Shift == 0) + return ShlAdd; + // Otherwise, put the shl last so that it can fold with following instructions + // (e.g. sext or add). + return DAG.getNode(ISD::SHL, DL, VT, ShlAdd, DAG.getConstant(Shift, DL, VT)); } static SDValue expandMulToShlAddShlAdd(SDNode *N, SelectionDAG &DAG, - uint64_t MulAmt) { - // 3/5/9 * 3/5/9 -> (shXadd (shYadd X, X), (shYadd X, X)) + uint64_t MulAmt, unsigned Shift) { switch (MulAmt) { + // 3/5/9 -> (shYadd X, X) + case 3: + return getShlAddShlAdd(N, DAG, 0, 1, /*AddX=*/false, Shift); + case 5: + return getShlAddShlAdd(N, DAG, 0, 2, /*AddX=*/false, Shift); + case 9: + return getShlAddShlAdd(N, DAG, 0, 3, /*AddX=*/false, Shift); + // 3/5/9 * 3/5/9 -> (shXadd (shYadd X, X), (shYadd X, X)) case 5 * 3: - return getShlAddShlAdd(N, DAG, 2, 1, /*AddX=*/false); + return getShlAddShlAdd(N, DAG, 2, 1, /*AddX=*/false, Shift); case 9 * 3: - return getShlAddShlAdd(N, DAG, 3, 1, /*AddX=*/false); + return getShlAddShlAdd(N, DAG, 3, 1, /*AddX=*/false, Shift); case 5 * 5: - return getShlAddShlAdd(N, DAG, 2, 2, /*AddX=*/false); + return getShlAddShlAdd(N, DAG, 2, 2, /*AddX=*/false, Shift); case 9 * 5: - return getShlAddShlAdd(N, DAG, 3, 2, /*AddX=*/false); + return getShlAddShlAdd(N, DAG, 3, 2, /*AddX=*/false, Shift); case 9 * 9: - return getShlAddShlAdd(N, DAG, 3, 3, /*AddX=*/false); + return getShlAddShlAdd(N, DAG, 3, 3, /*AddX=*/false, Shift); default: break; } @@ -16529,7 +16591,7 @@ static SDValue expandMulToShlAddShlAdd(SDNode *N, SelectionDAG &DAG, if (int ShY = isShifted359(MulAmt - 1, ShX)) { assert(ShX != 0 && "MulAmt=4,6,10 handled before"); if (ShX <= 3) - return getShlAddShlAdd(N, DAG, ShX, ShY, /*AddX=*/true); + return getShlAddShlAdd(N, DAG, ShX, ShY, /*AddX=*/true, Shift); } return SDValue(); } @@ -16569,42 +16631,18 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, // real regressions, and no other target properly freezes X in these cases // either. if (Subtarget.hasShlAdd(3)) { - SDValue X = N->getOperand(0); - int Shift; - if (int ShXAmount = isShifted359(MulAmt, Shift)) { - // 3/5/9 * 2^N -> shl (shXadd X, X), N - SDLoc DL(N); - // Put the shift first if we can fold a zext into the shift forming - // a slli.uw. - if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) && - X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) { - SDValue Shl = - DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT)); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl, - DAG.getTargetConstant(ShXAmount, DL, VT), Shl); - } - // Otherwise, put the shl second so that it can fold with following - // instructions (e.g. sext or add). - SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getTargetConstant(ShXAmount, DL, VT), X); - return DAG.getNode(ISD::SHL, DL, VT, Mul359, - DAG.getConstant(Shift, DL, VT)); - } - + // 3/5/9 * 2^N -> (shl (shXadd X, X), N) // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples // of 25 which happen to be quite common. // (2/4/8 * 3/5/9 + 1) * 2^N - Shift = llvm::countr_zero(MulAmt); - if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift)) { - if (Shift == 0) - return V; - SDLoc DL(N); - return DAG.getNode(ISD::SHL, DL, VT, V, DAG.getConstant(Shift, DL, VT)); - } + unsigned Shift = llvm::countr_zero(MulAmt); + if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift, Shift)) + return V; // If this is a power 2 + 2/4/8, we can use a shift followed by a single // shXadd. First check if this a sum of two power of 2s because that's // easy. Then count how many zeros are up to the first bit. + SDValue X = N->getOperand(0); if (Shift >= 1 && Shift <= 3 && isPowerOf2_64(MulAmt & (MulAmt - 1))) { unsigned ShiftAmt = llvm::countr_zero((MulAmt & (MulAmt - 1))); SDLoc DL(N); @@ -17867,6 +17905,7 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, SmallVector<SDNode *> Worklist; SmallPtrSet<SDNode *, 8> Inserted; + SmallPtrSet<SDNode *, 8> ExtensionsToRemove; Worklist.push_back(N); Inserted.insert(N); SmallVector<CombineResult> CombinesToApply; @@ -17876,22 +17915,25 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, NodeExtensionHelper LHS(Root, 0, DAG, Subtarget); NodeExtensionHelper RHS(Root, 1, DAG, Subtarget); - auto AppendUsersIfNeeded = [&Worklist, &Subtarget, - &Inserted](const NodeExtensionHelper &Op) { - if (Op.needToPromoteOtherUsers()) { - for (SDUse &Use : Op.OrigOperand->uses()) { - SDNode *TheUser = Use.getUser(); - if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget)) - return false; - // We only support the first 2 operands of FMA. - if (Use.getOperandNo() >= 2) - return false; - if (Inserted.insert(TheUser).second) - Worklist.push_back(TheUser); - } - } - return true; - }; + auto AppendUsersIfNeeded = + [&Worklist, &Subtarget, &Inserted, + &ExtensionsToRemove](const NodeExtensionHelper &Op) { + if (Op.needToPromoteOtherUsers()) { + // Remember that we're supposed to remove this extension. + ExtensionsToRemove.insert(Op.OrigOperand.getNode()); + for (SDUse &Use : Op.OrigOperand->uses()) { + SDNode *TheUser = Use.getUser(); + if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget)) + return false; + // We only support the first 2 operands of FMA. + if (Use.getOperandNo() >= 2) + return false; + if (Inserted.insert(TheUser).second) + Worklist.push_back(TheUser); + } + } + return true; + }; // Control the compile time by limiting the number of node we look at in // total. @@ -17912,6 +17954,15 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, std::optional<CombineResult> Res = FoldingStrategy(Root, LHS, RHS, DAG, Subtarget); if (Res) { + // If this strategy wouldn't remove an extension we're supposed to + // remove, reject it. + if (!Res->LHSExt.has_value() && + ExtensionsToRemove.contains(LHS.OrigOperand.getNode())) + continue; + if (!Res->RHSExt.has_value() && + ExtensionsToRemove.contains(RHS.OrigOperand.getNode())) + continue; + Matched = true; CombinesToApply.push_back(*Res); // All the inputs that are extended need to be folded, otherwise |
