diff options
Diffstat (limited to 'llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp | 506 |
1 files changed, 477 insertions, 29 deletions
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp index 008c18837a52..b02465d99a60 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -2916,6 +2916,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) { case TargetOpcode::G_SREM: case TargetOpcode::G_SMIN: case TargetOpcode::G_SMAX: + case TargetOpcode::G_ABDS: Observer.changingInstr(MI); widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_SEXT); widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT); @@ -2953,6 +2954,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) { return Legalized; case TargetOpcode::G_UDIV: case TargetOpcode::G_UREM: + case TargetOpcode::G_ABDU: Observer.changingInstr(MI); widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT); widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT); @@ -4742,6 +4744,16 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) { return lowerShlSat(MI); case G_ABS: return lowerAbsToAddXor(MI); + case G_ABDS: + case G_ABDU: { + bool IsSigned = MI.getOpcode() == G_ABDS; + LLT Ty = MRI.getType(MI.getOperand(0).getReg()); + if ((IsSigned && LI.isLegal({G_SMIN, Ty}) && LI.isLegal({G_SMAX, Ty})) || + (!IsSigned && LI.isLegal({G_UMIN, Ty}) && LI.isLegal({G_UMAX, Ty}))) { + return lowerAbsDiffToMinMax(MI); + } + return lowerAbsDiffToSelect(MI); + } case G_FABS: return lowerFAbs(MI); case G_SELECT: @@ -4773,6 +4785,16 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) { return lowerVectorReduction(MI); case G_VAARG: return lowerVAArg(MI); + case G_ATOMICRMW_SUB: { + auto [Ret, Mem, Val] = MI.getFirst3Regs(); + const LLT ValTy = MRI.getType(Val); + MachineMemOperand *MMO = *MI.memoperands_begin(); + + auto VNeg = MIRBuilder.buildNeg(ValTy, Val); + MIRBuilder.buildAtomicRMW(G_ATOMICRMW_ADD, Ret, Mem, VNeg, *MMO); + MI.eraseFromParent(); + return Legalized; + } } } @@ -5222,19 +5244,13 @@ LegalizerHelper::fewerElementsVectorExtractInsertVectorElt(MachineInstr &MI, InsertVal = MI.getOperand(2).getReg(); Register Idx = MI.getOperand(MI.getNumOperands() - 1).getReg(); - - // TODO: Handle total scalarization case. - if (!NarrowVecTy.isVector()) - return UnableToLegalize; - LLT VecTy = MRI.getType(SrcVec); // If the index is a constant, we can really break this down as you would // expect, and index into the target size pieces. - int64_t IdxVal; auto MaybeCst = getIConstantVRegValWithLookThrough(Idx, MRI); if (MaybeCst) { - IdxVal = MaybeCst->Value.getSExtValue(); + uint64_t IdxVal = MaybeCst->Value.getZExtValue(); // Avoid out of bounds indexing the pieces. if (IdxVal >= VecTy.getNumElements()) { MIRBuilder.buildUndef(DstReg); @@ -5242,33 +5258,45 @@ LegalizerHelper::fewerElementsVectorExtractInsertVectorElt(MachineInstr &MI, return Legalized; } - SmallVector<Register, 8> VecParts; - LLT GCDTy = extractGCDType(VecParts, VecTy, NarrowVecTy, SrcVec); + if (!NarrowVecTy.isVector()) { + SmallVector<Register, 8> SplitPieces; + extractParts(MI.getOperand(1).getReg(), NarrowVecTy, + VecTy.getNumElements(), SplitPieces, MIRBuilder, MRI); + if (IsInsert) { + SplitPieces[IdxVal] = InsertVal; + MIRBuilder.buildMergeLikeInstr(MI.getOperand(0).getReg(), SplitPieces); + } else { + MIRBuilder.buildCopy(MI.getOperand(0).getReg(), SplitPieces[IdxVal]); + } + } else { + SmallVector<Register, 8> VecParts; + LLT GCDTy = extractGCDType(VecParts, VecTy, NarrowVecTy, SrcVec); - // Build a sequence of NarrowTy pieces in VecParts for this operand. - LLT LCMTy = buildLCMMergePieces(VecTy, NarrowVecTy, GCDTy, VecParts, - TargetOpcode::G_ANYEXT); + // Build a sequence of NarrowTy pieces in VecParts for this operand. + LLT LCMTy = buildLCMMergePieces(VecTy, NarrowVecTy, GCDTy, VecParts, + TargetOpcode::G_ANYEXT); - unsigned NewNumElts = NarrowVecTy.getNumElements(); + unsigned NewNumElts = NarrowVecTy.getNumElements(); - LLT IdxTy = MRI.getType(Idx); - int64_t PartIdx = IdxVal / NewNumElts; - auto NewIdx = - MIRBuilder.buildConstant(IdxTy, IdxVal - NewNumElts * PartIdx); + LLT IdxTy = MRI.getType(Idx); + int64_t PartIdx = IdxVal / NewNumElts; + auto NewIdx = + MIRBuilder.buildConstant(IdxTy, IdxVal - NewNumElts * PartIdx); - if (IsInsert) { - LLT PartTy = MRI.getType(VecParts[PartIdx]); + if (IsInsert) { + LLT PartTy = MRI.getType(VecParts[PartIdx]); - // Use the adjusted index to insert into one of the subvectors. - auto InsertPart = MIRBuilder.buildInsertVectorElement( - PartTy, VecParts[PartIdx], InsertVal, NewIdx); - VecParts[PartIdx] = InsertPart.getReg(0); + // Use the adjusted index to insert into one of the subvectors. + auto InsertPart = MIRBuilder.buildInsertVectorElement( + PartTy, VecParts[PartIdx], InsertVal, NewIdx); + VecParts[PartIdx] = InsertPart.getReg(0); - // Recombine the inserted subvector with the others to reform the result - // vector. - buildWidenedRemergeToDst(DstReg, LCMTy, VecParts); - } else { - MIRBuilder.buildExtractVectorElement(DstReg, VecParts[PartIdx], NewIdx); + // Recombine the inserted subvector with the others to reform the result + // vector. + buildWidenedRemergeToDst(DstReg, LCMTy, VecParts); + } else { + MIRBuilder.buildExtractVectorElement(DstReg, VecParts[PartIdx], NewIdx); + } } MI.eraseFromParent(); @@ -5970,7 +5998,6 @@ LegalizerHelper::narrowScalarShiftByConstant(MachineInstr &MI, const APInt &Amt, return Legalized; } -// TODO: Optimize if constant shift amount. LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalarShift(MachineInstr &MI, unsigned TypeIdx, LLT RequestedTy) { @@ -5992,6 +6019,27 @@ LegalizerHelper::narrowScalarShift(MachineInstr &MI, unsigned TypeIdx, if (DstEltSize % 2 != 0) return UnableToLegalize; + // Check if we should use multi-way splitting instead of recursive binary + // splitting. + // + // Multi-way splitting directly decomposes wide shifts (e.g., 128-bit -> + // 4×32-bit) in a single legalization step, avoiding the recursive overhead + // and dependency chains created by usual binary splitting approach + // (128->64->32). + // + // The >= 8 parts threshold ensures we only use this optimization when binary + // splitting would require multiple recursive passes, avoiding overhead for + // simple 2-way splits where binary approach is sufficient. + if (RequestedTy.isValid() && RequestedTy.isScalar() && + DstEltSize % RequestedTy.getSizeInBits() == 0) { + const unsigned NumParts = DstEltSize / RequestedTy.getSizeInBits(); + // Use multiway if we have 8 or more parts (i.e., would need 3+ recursive + // steps). + if (NumParts >= 8) + return narrowScalarShiftMultiway(MI, RequestedTy); + } + + // Fall back to binary splitting: // Ignore the input type. We can only go to exactly half the size of the // input. If that isn't small enough, the resulting pieces will be further // legalized. @@ -6080,6 +6128,358 @@ LegalizerHelper::narrowScalarShift(MachineInstr &MI, unsigned TypeIdx, return Legalized; } +Register LegalizerHelper::buildConstantShiftPart(unsigned Opcode, + unsigned PartIdx, + unsigned NumParts, + ArrayRef<Register> SrcParts, + const ShiftParams &Params, + LLT TargetTy, LLT ShiftAmtTy) { + auto WordShiftConst = getIConstantVRegVal(Params.WordShift, MRI); + auto BitShiftConst = getIConstantVRegVal(Params.BitShift, MRI); + assert(WordShiftConst && BitShiftConst && "Expected constants"); + + const unsigned ShiftWords = WordShiftConst->getZExtValue(); + const unsigned ShiftBits = BitShiftConst->getZExtValue(); + const bool NeedsInterWordShift = ShiftBits != 0; + + switch (Opcode) { + case TargetOpcode::G_SHL: { + // Data moves from lower indices to higher indices + // If this part would come from a source beyond our range, it's zero + if (PartIdx < ShiftWords) + return Params.Zero; + + unsigned SrcIdx = PartIdx - ShiftWords; + if (!NeedsInterWordShift) + return SrcParts[SrcIdx]; + + // Combine shifted main part with carry from previous part + auto Hi = MIRBuilder.buildShl(TargetTy, SrcParts[SrcIdx], Params.BitShift); + if (SrcIdx > 0) { + auto Lo = MIRBuilder.buildLShr(TargetTy, SrcParts[SrcIdx - 1], + Params.InvBitShift); + return MIRBuilder.buildOr(TargetTy, Hi, Lo).getReg(0); + } + return Hi.getReg(0); + } + + case TargetOpcode::G_LSHR: { + unsigned SrcIdx = PartIdx + ShiftWords; + if (SrcIdx >= NumParts) + return Params.Zero; + if (!NeedsInterWordShift) + return SrcParts[SrcIdx]; + + // Combine shifted main part with carry from next part + auto Lo = MIRBuilder.buildLShr(TargetTy, SrcParts[SrcIdx], Params.BitShift); + if (SrcIdx + 1 < NumParts) { + auto Hi = MIRBuilder.buildShl(TargetTy, SrcParts[SrcIdx + 1], + Params.InvBitShift); + return MIRBuilder.buildOr(TargetTy, Lo, Hi).getReg(0); + } + return Lo.getReg(0); + } + + case TargetOpcode::G_ASHR: { + // Like LSHR but preserves sign bit + unsigned SrcIdx = PartIdx + ShiftWords; + if (SrcIdx >= NumParts) + return Params.SignBit; + if (!NeedsInterWordShift) + return SrcParts[SrcIdx]; + + // Only the original MSB part uses arithmetic shift to preserve sign. All + // other parts use logical shift since they're just moving data bits. + auto Lo = + (SrcIdx == NumParts - 1) + ? MIRBuilder.buildAShr(TargetTy, SrcParts[SrcIdx], Params.BitShift) + : MIRBuilder.buildLShr(TargetTy, SrcParts[SrcIdx], Params.BitShift); + Register HiSrc = + (SrcIdx + 1 < NumParts) ? SrcParts[SrcIdx + 1] : Params.SignBit; + auto Hi = MIRBuilder.buildShl(TargetTy, HiSrc, Params.InvBitShift); + return MIRBuilder.buildOr(TargetTy, Lo, Hi).getReg(0); + } + + default: + llvm_unreachable("not a shift"); + } +} + +Register LegalizerHelper::buildVariableShiftPart(unsigned Opcode, + Register MainOperand, + Register ShiftAmt, + LLT TargetTy, + Register CarryOperand) { + // This helper generates a single output part for variable shifts by combining + // the main operand (shifted by BitShift) with carry bits from an adjacent + // part. + + // For G_ASHR, individual parts don't have their own sign bit, only the + // complete value does. So we use LSHR for the main operand shift in ASHR + // context. + unsigned MainOpcode = + (Opcode == TargetOpcode::G_ASHR) ? TargetOpcode::G_LSHR : Opcode; + + // Perform the primary shift on the main operand + Register MainShifted = + MIRBuilder.buildInstr(MainOpcode, {TargetTy}, {MainOperand, ShiftAmt}) + .getReg(0); + + // No carry operand available + if (!CarryOperand.isValid()) + return MainShifted; + + // If BitShift is 0 (word-aligned shift), no inter-word bit movement occurs, + // so carry bits aren't needed. + LLT ShiftAmtTy = MRI.getType(ShiftAmt); + auto ZeroConst = MIRBuilder.buildConstant(ShiftAmtTy, 0); + LLT BoolTy = LLT::scalar(1); + auto IsZeroBitShift = + MIRBuilder.buildICmp(ICmpInst::ICMP_EQ, BoolTy, ShiftAmt, ZeroConst); + + // Extract bits from the adjacent part that will "carry over" into this part. + // The carry direction is opposite to the main shift direction, so we can + // align the two shifted values before combining them with OR. + + // Determine the carry shift opcode (opposite direction) + unsigned CarryOpcode = (Opcode == TargetOpcode::G_SHL) ? TargetOpcode::G_LSHR + : TargetOpcode::G_SHL; + + // Calculate inverse shift amount: BitWidth - ShiftAmt + auto TargetBitsConst = + MIRBuilder.buildConstant(ShiftAmtTy, TargetTy.getScalarSizeInBits()); + auto InvShiftAmt = MIRBuilder.buildSub(ShiftAmtTy, TargetBitsConst, ShiftAmt); + + // Shift the carry operand + Register CarryBits = + MIRBuilder + .buildInstr(CarryOpcode, {TargetTy}, {CarryOperand, InvShiftAmt}) + .getReg(0); + + // If BitShift is 0, don't include carry bits (InvShiftAmt would equal + // TargetBits which would be poison for the individual carry shift operation). + auto ZeroReg = MIRBuilder.buildConstant(TargetTy, 0); + Register SafeCarryBits = + MIRBuilder.buildSelect(TargetTy, IsZeroBitShift, ZeroReg, CarryBits) + .getReg(0); + + // Combine the main shifted part with the carry bits + return MIRBuilder.buildOr(TargetTy, MainShifted, SafeCarryBits).getReg(0); +} + +LegalizerHelper::LegalizeResult +LegalizerHelper::narrowScalarShiftByConstantMultiway(MachineInstr &MI, + const APInt &Amt, + LLT TargetTy, + LLT ShiftAmtTy) { + // Any wide shift can be decomposed into WordShift + BitShift components. + // When shift amount is known constant, directly compute the decomposition + // values and generate constant registers. + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(1).getReg(); + LLT DstTy = MRI.getType(DstReg); + + const unsigned DstBits = DstTy.getScalarSizeInBits(); + const unsigned TargetBits = TargetTy.getScalarSizeInBits(); + const unsigned NumParts = DstBits / TargetBits; + + assert(DstBits % TargetBits == 0 && "Target type must evenly divide source"); + + // When the shift amount is known at compile time, we just calculate which + // source parts contribute to each output part. + + SmallVector<Register, 8> SrcParts; + extractParts(SrcReg, TargetTy, NumParts, SrcParts, MIRBuilder, MRI); + + if (Amt.isZero()) { + // No shift needed, just copy + MIRBuilder.buildMergeLikeInstr(DstReg, SrcParts); + MI.eraseFromParent(); + return Legalized; + } + + ShiftParams Params; + const unsigned ShiftWords = Amt.getZExtValue() / TargetBits; + const unsigned ShiftBits = Amt.getZExtValue() % TargetBits; + + // Generate constants and values needed by all shift types + Params.WordShift = MIRBuilder.buildConstant(ShiftAmtTy, ShiftWords).getReg(0); + Params.BitShift = MIRBuilder.buildConstant(ShiftAmtTy, ShiftBits).getReg(0); + Params.InvBitShift = + MIRBuilder.buildConstant(ShiftAmtTy, TargetBits - ShiftBits).getReg(0); + Params.Zero = MIRBuilder.buildConstant(TargetTy, 0).getReg(0); + + // For ASHR, we need the sign-extended value to fill shifted-out positions + if (MI.getOpcode() == TargetOpcode::G_ASHR) + Params.SignBit = + MIRBuilder + .buildAShr(TargetTy, SrcParts[SrcParts.size() - 1], + MIRBuilder.buildConstant(ShiftAmtTy, TargetBits - 1)) + .getReg(0); + + SmallVector<Register, 8> DstParts(NumParts); + for (unsigned I = 0; I < NumParts; ++I) + DstParts[I] = buildConstantShiftPart(MI.getOpcode(), I, NumParts, SrcParts, + Params, TargetTy, ShiftAmtTy); + + MIRBuilder.buildMergeLikeInstr(DstReg, DstParts); + MI.eraseFromParent(); + return Legalized; +} + +LegalizerHelper::LegalizeResult +LegalizerHelper::narrowScalarShiftMultiway(MachineInstr &MI, LLT TargetTy) { + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(1).getReg(); + Register AmtReg = MI.getOperand(2).getReg(); + LLT DstTy = MRI.getType(DstReg); + LLT ShiftAmtTy = MRI.getType(AmtReg); + + const unsigned DstBits = DstTy.getScalarSizeInBits(); + const unsigned TargetBits = TargetTy.getScalarSizeInBits(); + const unsigned NumParts = DstBits / TargetBits; + + assert(DstBits % TargetBits == 0 && "Target type must evenly divide source"); + assert(isPowerOf2_32(TargetBits) && "Target bit width must be power of 2"); + + // If the shift amount is known at compile time, we can use direct indexing + // instead of generating select chains in the general case. + if (auto VRegAndVal = getIConstantVRegValWithLookThrough(AmtReg, MRI)) + return narrowScalarShiftByConstantMultiway(MI, VRegAndVal->Value, TargetTy, + ShiftAmtTy); + + // For runtime-variable shift amounts, we must generate a more complex + // sequence that handles all possible shift values using select chains. + + // Split the input into target-sized pieces + SmallVector<Register, 8> SrcParts; + extractParts(SrcReg, TargetTy, NumParts, SrcParts, MIRBuilder, MRI); + + // Shifting by zero should be a no-op. + auto ZeroAmtConst = MIRBuilder.buildConstant(ShiftAmtTy, 0); + LLT BoolTy = LLT::scalar(1); + auto IsZeroShift = + MIRBuilder.buildICmp(ICmpInst::ICMP_EQ, BoolTy, AmtReg, ZeroAmtConst); + + // Any wide shift can be decomposed into two components: + // 1. WordShift: number of complete target-sized words to shift + // 2. BitShift: number of bits to shift within each word + // + // Example: 128-bit >> 50 with 32-bit target: + // WordShift = 50 / 32 = 1 (shift right by 1 complete word) + // BitShift = 50 % 32 = 18 (shift each word right by 18 bits) + unsigned TargetBitsLog2 = Log2_32(TargetBits); + auto TargetBitsLog2Const = + MIRBuilder.buildConstant(ShiftAmtTy, TargetBitsLog2); + auto TargetBitsMask = MIRBuilder.buildConstant(ShiftAmtTy, TargetBits - 1); + + Register WordShift = + MIRBuilder.buildLShr(ShiftAmtTy, AmtReg, TargetBitsLog2Const).getReg(0); + Register BitShift = + MIRBuilder.buildAnd(ShiftAmtTy, AmtReg, TargetBitsMask).getReg(0); + + // Fill values: + // - SHL/LSHR: fill with zeros + // - ASHR: fill with sign-extended MSB + Register ZeroReg = MIRBuilder.buildConstant(TargetTy, 0).getReg(0); + + Register FillValue; + if (MI.getOpcode() == TargetOpcode::G_ASHR) { + auto TargetBitsMinusOneConst = + MIRBuilder.buildConstant(ShiftAmtTy, TargetBits - 1); + FillValue = MIRBuilder + .buildAShr(TargetTy, SrcParts[NumParts - 1], + TargetBitsMinusOneConst) + .getReg(0); + } else { + FillValue = ZeroReg; + } + + SmallVector<Register, 8> DstParts(NumParts); + + // For each output part, generate a select chain that chooses the correct + // result based on the runtime WordShift value. This handles all possible + // word shift amounts by pre-calculating what each would produce. + for (unsigned I = 0; I < NumParts; ++I) { + // Initialize with appropriate default value for this shift type + Register InBoundsResult = FillValue; + + // clang-format off + // Build a branchless select chain by pre-computing results for all possible + // WordShift values (0 to NumParts-1). Each iteration nests a new select: + // + // K=0: select(WordShift==0, result0, FillValue) + // K=1: select(WordShift==1, result1, select(WordShift==0, result0, FillValue)) + // K=2: select(WordShift==2, result2, select(WordShift==1, result1, select(...))) + // clang-format on + for (unsigned K = 0; K < NumParts; ++K) { + auto WordShiftKConst = MIRBuilder.buildConstant(ShiftAmtTy, K); + auto IsWordShiftK = MIRBuilder.buildICmp(ICmpInst::ICMP_EQ, BoolTy, + WordShift, WordShiftKConst); + + // Calculate source indices for this word shift + // + // For 4-part 128-bit value with K=1 word shift: + // SHL: [3][2][1][0] << K => [2][1][0][Z] + // -> (MainIdx = I-K, CarryIdx = I-K-1) + // LSHR: [3][2][1][0] >> K => [Z][3][2][1] + // -> (MainIdx = I+K, CarryIdx = I+K+1) + int MainSrcIdx; + int CarrySrcIdx; // Index for the word that provides the carried-in bits. + + switch (MI.getOpcode()) { + case TargetOpcode::G_SHL: + MainSrcIdx = (int)I - (int)K; + CarrySrcIdx = MainSrcIdx - 1; + break; + case TargetOpcode::G_LSHR: + case TargetOpcode::G_ASHR: + MainSrcIdx = (int)I + (int)K; + CarrySrcIdx = MainSrcIdx + 1; + break; + default: + llvm_unreachable("Not a shift"); + } + + // Check bounds and build the result for this word shift + Register ResultForK; + if (MainSrcIdx >= 0 && MainSrcIdx < (int)NumParts) { + Register MainOp = SrcParts[MainSrcIdx]; + Register CarryOp; + + // Determine carry operand with bounds checking + if (CarrySrcIdx >= 0 && CarrySrcIdx < (int)NumParts) + CarryOp = SrcParts[CarrySrcIdx]; + else if (MI.getOpcode() == TargetOpcode::G_ASHR && + CarrySrcIdx >= (int)NumParts) + CarryOp = FillValue; // Use sign extension + + ResultForK = buildVariableShiftPart(MI.getOpcode(), MainOp, BitShift, + TargetTy, CarryOp); + } else { + // Out of bounds - use fill value for this k + ResultForK = FillValue; + } + + // Select this result if WordShift equals k + InBoundsResult = + MIRBuilder + .buildSelect(TargetTy, IsWordShiftK, ResultForK, InBoundsResult) + .getReg(0); + } + + // Handle zero-shift special case: if shift is 0, use original input + DstParts[I] = + MIRBuilder + .buildSelect(TargetTy, IsZeroShift, SrcParts[I], InBoundsResult) + .getReg(0); + } + + MIRBuilder.buildMergeLikeInstr(DstReg, DstParts); + MI.eraseFromParent(); + return Legalized; +} + LegalizerHelper::LegalizeResult LegalizerHelper::moreElementsVectorPhi(MachineInstr &MI, unsigned TypeIdx, LLT MoreTy) { @@ -9537,6 +9937,54 @@ LegalizerHelper::lowerAbsToCNeg(MachineInstr &MI) { return Legalized; } +LegalizerHelper::LegalizeResult +LegalizerHelper::lowerAbsDiffToSelect(MachineInstr &MI) { + assert((MI.getOpcode() == TargetOpcode::G_ABDS || + MI.getOpcode() == TargetOpcode::G_ABDU) && + "Expected G_ABDS or G_ABDU instruction"); + + auto [DstReg, LHS, RHS] = MI.getFirst3Regs(); + LLT Ty = MRI.getType(LHS); + + // abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs)) + // abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs)) + Register LHSSub = MIRBuilder.buildSub(Ty, LHS, RHS).getReg(0); + Register RHSSub = MIRBuilder.buildSub(Ty, RHS, LHS).getReg(0); + CmpInst::Predicate Pred = (MI.getOpcode() == TargetOpcode::G_ABDS) + ? CmpInst::ICMP_SGT + : CmpInst::ICMP_UGT; + auto ICmp = MIRBuilder.buildICmp(Pred, LLT::scalar(1), LHS, RHS); + MIRBuilder.buildSelect(DstReg, ICmp, LHSSub, RHSSub); + + MI.eraseFromParent(); + return Legalized; +} + +LegalizerHelper::LegalizeResult +LegalizerHelper::lowerAbsDiffToMinMax(MachineInstr &MI) { + assert((MI.getOpcode() == TargetOpcode::G_ABDS || + MI.getOpcode() == TargetOpcode::G_ABDU) && + "Expected G_ABDS or G_ABDU instruction"); + + auto [DstReg, LHS, RHS] = MI.getFirst3Regs(); + LLT Ty = MRI.getType(LHS); + + // abds(lhs, rhs) -→ sub(smax(lhs, rhs), smin(lhs, rhs)) + // abdu(lhs, rhs) -→ sub(umax(lhs, rhs), umin(lhs, rhs)) + Register MaxReg, MinReg; + if (MI.getOpcode() == TargetOpcode::G_ABDS) { + MaxReg = MIRBuilder.buildSMax(Ty, LHS, RHS).getReg(0); + MinReg = MIRBuilder.buildSMin(Ty, LHS, RHS).getReg(0); + } else { + MaxReg = MIRBuilder.buildUMax(Ty, LHS, RHS).getReg(0); + MinReg = MIRBuilder.buildUMin(Ty, LHS, RHS).getReg(0); + } + MIRBuilder.buildSub(DstReg, MaxReg, MinReg); + + MI.eraseFromParent(); + return Legalized; +} + LegalizerHelper::LegalizeResult LegalizerHelper::lowerFAbs(MachineInstr &MI) { Register SrcReg = MI.getOperand(1).getReg(); Register DstReg = MI.getOperand(0).getReg(); |
