diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp')
| -rw-r--r-- | llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp | 167 |
1 files changed, 134 insertions, 33 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp index 53557049ea33..29526cf5a527 100644 --- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp @@ -178,8 +178,20 @@ static unsigned getIntegerExtensionOperandEEW(unsigned Factor, return Log2EEW; } -static std::optional<unsigned> -getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { +#define VSEG_CASES(Prefix, EEW) \ + RISCV::Prefix##SEG2E##EEW##_V: \ + case RISCV::Prefix##SEG3E##EEW##_V: \ + case RISCV::Prefix##SEG4E##EEW##_V: \ + case RISCV::Prefix##SEG5E##EEW##_V: \ + case RISCV::Prefix##SEG6E##EEW##_V: \ + case RISCV::Prefix##SEG7E##EEW##_V: \ + case RISCV::Prefix##SEG8E##EEW##_V +#define VSSEG_CASES(EEW) VSEG_CASES(VS, EEW) +#define VSSSEG_CASES(EEW) VSEG_CASES(VSS, EEW) +#define VSUXSEG_CASES(EEW) VSEG_CASES(VSUX, I##EEW) +#define VSOXSEG_CASES(EEW) VSEG_CASES(VSOX, I##EEW) + +static std::optional<unsigned> getOperandLog2EEW(const MachineOperand &MO) { const MachineInstr &MI = *MO.getParent(); const MCInstrDesc &Desc = MI.getDesc(); const RISCVVPseudosTable::PseudoInfo *RVV = @@ -225,21 +237,29 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { case RISCV::VSE8_V: case RISCV::VLSE8_V: case RISCV::VSSE8_V: + case VSSEG_CASES(8): + case VSSSEG_CASES(8): return 3; case RISCV::VLE16_V: case RISCV::VSE16_V: case RISCV::VLSE16_V: case RISCV::VSSE16_V: + case VSSEG_CASES(16): + case VSSSEG_CASES(16): return 4; case RISCV::VLE32_V: case RISCV::VSE32_V: case RISCV::VLSE32_V: case RISCV::VSSE32_V: + case VSSEG_CASES(32): + case VSSSEG_CASES(32): return 5; case RISCV::VLE64_V: case RISCV::VSE64_V: case RISCV::VLSE64_V: case RISCV::VSSE64_V: + case VSSEG_CASES(64): + case VSSSEG_CASES(64): return 6; // Vector Indexed Instructions @@ -248,7 +268,9 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { case RISCV::VLUXEI8_V: case RISCV::VLOXEI8_V: case RISCV::VSUXEI8_V: - case RISCV::VSOXEI8_V: { + case RISCV::VSOXEI8_V: + case VSUXSEG_CASES(8): + case VSOXSEG_CASES(8): { if (MO.getOperandNo() == 0) return MILog2SEW; return 3; @@ -256,7 +278,9 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { case RISCV::VLUXEI16_V: case RISCV::VLOXEI16_V: case RISCV::VSUXEI16_V: - case RISCV::VSOXEI16_V: { + case RISCV::VSOXEI16_V: + case VSUXSEG_CASES(16): + case VSOXSEG_CASES(16): { if (MO.getOperandNo() == 0) return MILog2SEW; return 4; @@ -264,7 +288,9 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { case RISCV::VLUXEI32_V: case RISCV::VLOXEI32_V: case RISCV::VSUXEI32_V: - case RISCV::VSOXEI32_V: { + case RISCV::VSOXEI32_V: + case VSUXSEG_CASES(32): + case VSOXSEG_CASES(32): { if (MO.getOperandNo() == 0) return MILog2SEW; return 5; @@ -272,7 +298,9 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { case RISCV::VLUXEI64_V: case RISCV::VLOXEI64_V: case RISCV::VSUXEI64_V: - case RISCV::VSOXEI64_V: { + case RISCV::VSOXEI64_V: + case VSUXSEG_CASES(64): + case VSOXSEG_CASES(64): { if (MO.getOperandNo() == 0) return MILog2SEW; return 6; @@ -422,9 +450,6 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { case RISCV::VRGATHER_VI: case RISCV::VRGATHER_VV: case RISCV::VRGATHER_VX: - // Vector Compress Instruction - // EEW=SEW. - case RISCV::VCOMPRESS_VM: // Vector Element Index Instruction case RISCV::VID_V: // Vector Single-Width Floating-Point Add/Subtract Instructions @@ -674,6 +699,12 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { return MILog2SEW; } + // Vector Compress Instruction + // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled + // before this switch. + case RISCV::VCOMPRESS_VM: + return MO.getOperandNo() == 3 ? 0 : MILog2SEW; + // Vector Iota Instruction // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled // before this switch. @@ -778,14 +809,13 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) { } } -static std::optional<OperandInfo> -getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) { +static std::optional<OperandInfo> getOperandInfo(const MachineOperand &MO) { const MachineInstr &MI = *MO.getParent(); const RISCVVPseudosTable::PseudoInfo *RVV = RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); assert(RVV && "Could not find MI in PseudoTable"); - std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI); + std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO); if (!Log2EEW) return std::nullopt; @@ -900,13 +930,6 @@ static bool isSupportedInstr(const MachineInstr &MI) { case RISCV::VSEXT_VF4: case RISCV::VZEXT_VF8: case RISCV::VSEXT_VF8: - // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions - // FIXME: Add support - case RISCV::VMADC_VV: - case RISCV::VMADC_VI: - case RISCV::VMADC_VX: - case RISCV::VMSBC_VV: - case RISCV::VMSBC_VX: // Vector Narrowing Integer Right Shift Instructions case RISCV::VNSRL_WX: case RISCV::VNSRL_WI: @@ -993,6 +1016,11 @@ static bool isSupportedInstr(const MachineInstr &MI) { case RISCV::VSBC_VXM: case RISCV::VMSBC_VVM: case RISCV::VMSBC_VXM: + case RISCV::VMADC_VV: + case RISCV::VMADC_VI: + case RISCV::VMADC_VX: + case RISCV::VMSBC_VV: + case RISCV::VMSBC_VX: // Vector Widening Integer Multiply-Add Instructions case RISCV::VWMACCU_VV: case RISCV::VWMACCU_VX: @@ -1001,10 +1029,7 @@ static bool isSupportedInstr(const MachineInstr &MI) { case RISCV::VWMACCSU_VV: case RISCV::VWMACCSU_VX: case RISCV::VWMACCUS_VX: - // Vector Integer Merge Instructions - // FIXME: Add support // Vector Integer Move Instructions - // FIXME: Add support case RISCV::VMV_V_I: case RISCV::VMV_V_X: case RISCV::VMV_V_V: @@ -1306,7 +1331,8 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { // TODO: Use a better approach than a white-list, such as adding // properties to instructions using something like TSFlags. if (!isSupportedInstr(MI)) { - LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n"); + LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction: " + << MI); return false; } @@ -1328,14 +1354,14 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const { const MCInstrDesc &Desc = UserMI.getDesc(); if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { - LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" + LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" " use VLMAX\n"); return std::nullopt; } if (RISCVII::readsPastVL( TII->get(RISCV::getRVVMCOpcode(UserMI.getOpcode())).TSFlags)) { - LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); + LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); return std::nullopt; } @@ -1352,7 +1378,7 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const { RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc())); auto DemandedVL = DemandedVLs.lookup(&UserMI); if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) { - LLVM_DEBUG(dbgs() << " Abort because user is passthru in " + LLVM_DEBUG(dbgs() << " Abort because user is passthru in " "instruction with demanded tail\n"); return std::nullopt; } @@ -1376,6 +1402,54 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const { return VLOp; } +/// Return true if MI is an instruction used for assembling registers +/// for segmented store instructions, namely, RISCVISD::TUPLE_INSERT. +/// Currently it's lowered to INSERT_SUBREG. +static bool isTupleInsertInstr(const MachineInstr &MI) { + if (!MI.isInsertSubreg()) + return false; + + const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + const TargetRegisterClass *DstRC = MRI.getRegClass(MI.getOperand(0).getReg()); + const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); + if (!RISCVRI::isVRegClass(DstRC->TSFlags)) + return false; + unsigned NF = RISCVRI::getNF(DstRC->TSFlags); + if (NF < 2) + return false; + + // Check whether INSERT_SUBREG has the correct subreg index for tuple inserts. + auto VLMul = RISCVRI::getLMul(DstRC->TSFlags); + unsigned SubRegIdx = MI.getOperand(3).getImm(); + [[maybe_unused]] auto [LMul, IsFractional] = RISCVVType::decodeVLMUL(VLMul); + assert(!IsFractional && "unexpected LMUL for tuple register classes"); + return TRI->getSubRegIdxSize(SubRegIdx) == RISCV::RVVBitsPerBlock * LMul; +} + +static bool isSegmentedStoreInstr(const MachineInstr &MI) { + switch (RISCV::getRVVMCOpcode(MI.getOpcode())) { + case VSSEG_CASES(8): + case VSSSEG_CASES(8): + case VSUXSEG_CASES(8): + case VSOXSEG_CASES(8): + case VSSEG_CASES(16): + case VSSSEG_CASES(16): + case VSUXSEG_CASES(16): + case VSOXSEG_CASES(16): + case VSSEG_CASES(32): + case VSSSEG_CASES(32): + case VSUXSEG_CASES(32): + case VSOXSEG_CASES(32): + case VSSEG_CASES(64): + case VSSSEG_CASES(64): + case VSUXSEG_CASES(64): + case VSOXSEG_CASES(64): + return true; + default: + return false; + } +} + std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { std::optional<MachineOperand> CommonVL; @@ -1396,6 +1470,23 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { continue; } + if (isTupleInsertInstr(UserMI)) { + LLVM_DEBUG(dbgs().indent(4) << "Peeking through uses of INSERT_SUBREG\n"); + for (MachineOperand &UseOp : + MRI->use_operands(UserMI.getOperand(0).getReg())) { + const MachineInstr &CandidateMI = *UseOp.getParent(); + // We should not propagate the VL if the user is not a segmented store + // or another INSERT_SUBREG, since VL just works differently + // between segmented operations (per-field) v.s. other RVV ops (on the + // whole register group). + if (!isTupleInsertInstr(CandidateMI) && + !isSegmentedStoreInstr(CandidateMI)) + return std::nullopt; + Worklist.insert(&UseOp); + } + continue; + } + if (UserMI.isPHI()) { // Don't follow PHI cycles if (!PHISeen.insert(&UserMI).second) @@ -1425,9 +1516,8 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { return std::nullopt; } - std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI); - std::optional<OperandInfo> ProducerInfo = - getOperandInfo(MI.getOperand(0), MRI); + std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp); + std::optional<OperandInfo> ProducerInfo = getOperandInfo(MI.getOperand(0)); if (!ConsumerInfo || !ProducerInfo) { LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n"); LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); @@ -1449,7 +1539,7 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { } bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const { - LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); + LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI); unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc()); MachineOperand &VLOp = MI.getOperand(VLOpNum); @@ -1468,14 +1558,23 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const { assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && "Expected VL to be an Imm or virtual Reg"); + // If the VL is defined by a vleff that doesn't dominate MI, try using the + // vleff's AVL. It will be greater than or equal to the output VL. + if (CommonVL->isReg()) { + const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg()); + if (RISCVInstrInfo::isFaultOnlyFirstLoad(*VLMI) && + !MDT->dominates(VLMI, &MI)) + CommonVL = VLMI->getOperand(RISCVII::getVLOpNum(VLMI->getDesc())); + } + if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) { - LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n"); + LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n"); return false; } if (CommonVL->isIdenticalTo(VLOp)) { LLVM_DEBUG( - dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n"); + dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n"); return false; } @@ -1486,8 +1585,10 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const { return true; } const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg()); - if (!MDT->dominates(VLMI, &MI)) + if (!MDT->dominates(VLMI, &MI)) { + LLVM_DEBUG(dbgs() << " Abort due to VL not dominating.\n"); return false; + } LLVM_DEBUG( dbgs() << " Reduce VL from " << VLOp << " to " << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo()) |
