summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp')
-rw-r--r--llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp167
1 files changed, 134 insertions, 33 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
index 53557049ea33..29526cf5a527 100644
--- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
@@ -178,8 +178,20 @@ static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
return Log2EEW;
}
-static std::optional<unsigned>
-getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
+#define VSEG_CASES(Prefix, EEW) \
+ RISCV::Prefix##SEG2E##EEW##_V: \
+ case RISCV::Prefix##SEG3E##EEW##_V: \
+ case RISCV::Prefix##SEG4E##EEW##_V: \
+ case RISCV::Prefix##SEG5E##EEW##_V: \
+ case RISCV::Prefix##SEG6E##EEW##_V: \
+ case RISCV::Prefix##SEG7E##EEW##_V: \
+ case RISCV::Prefix##SEG8E##EEW##_V
+#define VSSEG_CASES(EEW) VSEG_CASES(VS, EEW)
+#define VSSSEG_CASES(EEW) VSEG_CASES(VSS, EEW)
+#define VSUXSEG_CASES(EEW) VSEG_CASES(VSUX, I##EEW)
+#define VSOXSEG_CASES(EEW) VSEG_CASES(VSOX, I##EEW)
+
+static std::optional<unsigned> getOperandLog2EEW(const MachineOperand &MO) {
const MachineInstr &MI = *MO.getParent();
const MCInstrDesc &Desc = MI.getDesc();
const RISCVVPseudosTable::PseudoInfo *RVV =
@@ -225,21 +237,29 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
case RISCV::VSE8_V:
case RISCV::VLSE8_V:
case RISCV::VSSE8_V:
+ case VSSEG_CASES(8):
+ case VSSSEG_CASES(8):
return 3;
case RISCV::VLE16_V:
case RISCV::VSE16_V:
case RISCV::VLSE16_V:
case RISCV::VSSE16_V:
+ case VSSEG_CASES(16):
+ case VSSSEG_CASES(16):
return 4;
case RISCV::VLE32_V:
case RISCV::VSE32_V:
case RISCV::VLSE32_V:
case RISCV::VSSE32_V:
+ case VSSEG_CASES(32):
+ case VSSSEG_CASES(32):
return 5;
case RISCV::VLE64_V:
case RISCV::VSE64_V:
case RISCV::VLSE64_V:
case RISCV::VSSE64_V:
+ case VSSEG_CASES(64):
+ case VSSSEG_CASES(64):
return 6;
// Vector Indexed Instructions
@@ -248,7 +268,9 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
case RISCV::VLUXEI8_V:
case RISCV::VLOXEI8_V:
case RISCV::VSUXEI8_V:
- case RISCV::VSOXEI8_V: {
+ case RISCV::VSOXEI8_V:
+ case VSUXSEG_CASES(8):
+ case VSOXSEG_CASES(8): {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 3;
@@ -256,7 +278,9 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
case RISCV::VLUXEI16_V:
case RISCV::VLOXEI16_V:
case RISCV::VSUXEI16_V:
- case RISCV::VSOXEI16_V: {
+ case RISCV::VSOXEI16_V:
+ case VSUXSEG_CASES(16):
+ case VSOXSEG_CASES(16): {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 4;
@@ -264,7 +288,9 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
case RISCV::VLUXEI32_V:
case RISCV::VLOXEI32_V:
case RISCV::VSUXEI32_V:
- case RISCV::VSOXEI32_V: {
+ case RISCV::VSOXEI32_V:
+ case VSUXSEG_CASES(32):
+ case VSOXSEG_CASES(32): {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 5;
@@ -272,7 +298,9 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
case RISCV::VLUXEI64_V:
case RISCV::VLOXEI64_V:
case RISCV::VSUXEI64_V:
- case RISCV::VSOXEI64_V: {
+ case RISCV::VSOXEI64_V:
+ case VSUXSEG_CASES(64):
+ case VSOXSEG_CASES(64): {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 6;
@@ -422,9 +450,6 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
case RISCV::VRGATHER_VI:
case RISCV::VRGATHER_VV:
case RISCV::VRGATHER_VX:
- // Vector Compress Instruction
- // EEW=SEW.
- case RISCV::VCOMPRESS_VM:
// Vector Element Index Instruction
case RISCV::VID_V:
// Vector Single-Width Floating-Point Add/Subtract Instructions
@@ -674,6 +699,12 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
return MILog2SEW;
}
+ // Vector Compress Instruction
+ // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
+ // before this switch.
+ case RISCV::VCOMPRESS_VM:
+ return MO.getOperandNo() == 3 ? 0 : MILog2SEW;
+
// Vector Iota Instruction
// EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
// before this switch.
@@ -778,14 +809,13 @@ getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
}
}
-static std::optional<OperandInfo>
-getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
+static std::optional<OperandInfo> getOperandInfo(const MachineOperand &MO) {
const MachineInstr &MI = *MO.getParent();
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
assert(RVV && "Could not find MI in PseudoTable");
- std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI);
+ std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO);
if (!Log2EEW)
return std::nullopt;
@@ -900,13 +930,6 @@ static bool isSupportedInstr(const MachineInstr &MI) {
case RISCV::VSEXT_VF4:
case RISCV::VZEXT_VF8:
case RISCV::VSEXT_VF8:
- // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
- // FIXME: Add support
- case RISCV::VMADC_VV:
- case RISCV::VMADC_VI:
- case RISCV::VMADC_VX:
- case RISCV::VMSBC_VV:
- case RISCV::VMSBC_VX:
// Vector Narrowing Integer Right Shift Instructions
case RISCV::VNSRL_WX:
case RISCV::VNSRL_WI:
@@ -993,6 +1016,11 @@ static bool isSupportedInstr(const MachineInstr &MI) {
case RISCV::VSBC_VXM:
case RISCV::VMSBC_VVM:
case RISCV::VMSBC_VXM:
+ case RISCV::VMADC_VV:
+ case RISCV::VMADC_VI:
+ case RISCV::VMADC_VX:
+ case RISCV::VMSBC_VV:
+ case RISCV::VMSBC_VX:
// Vector Widening Integer Multiply-Add Instructions
case RISCV::VWMACCU_VV:
case RISCV::VWMACCU_VX:
@@ -1001,10 +1029,7 @@ static bool isSupportedInstr(const MachineInstr &MI) {
case RISCV::VWMACCSU_VV:
case RISCV::VWMACCSU_VX:
case RISCV::VWMACCUS_VX:
- // Vector Integer Merge Instructions
- // FIXME: Add support
// Vector Integer Move Instructions
- // FIXME: Add support
case RISCV::VMV_V_I:
case RISCV::VMV_V_X:
case RISCV::VMV_V_V:
@@ -1306,7 +1331,8 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
// TODO: Use a better approach than a white-list, such as adding
// properties to instructions using something like TSFlags.
if (!isSupportedInstr(MI)) {
- LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n");
+ LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction: "
+ << MI);
return false;
}
@@ -1328,14 +1354,14 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
const MCInstrDesc &Desc = UserMI.getDesc();
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
- LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
+ LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
" use VLMAX\n");
return std::nullopt;
}
if (RISCVII::readsPastVL(
TII->get(RISCV::getRVVMCOpcode(UserMI.getOpcode())).TSFlags)) {
- LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
+ LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
return std::nullopt;
}
@@ -1352,7 +1378,7 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc()));
auto DemandedVL = DemandedVLs.lookup(&UserMI);
if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) {
- LLVM_DEBUG(dbgs() << " Abort because user is passthru in "
+ LLVM_DEBUG(dbgs() << " Abort because user is passthru in "
"instruction with demanded tail\n");
return std::nullopt;
}
@@ -1376,6 +1402,54 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
return VLOp;
}
+/// Return true if MI is an instruction used for assembling registers
+/// for segmented store instructions, namely, RISCVISD::TUPLE_INSERT.
+/// Currently it's lowered to INSERT_SUBREG.
+static bool isTupleInsertInstr(const MachineInstr &MI) {
+ if (!MI.isInsertSubreg())
+ return false;
+
+ const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ const TargetRegisterClass *DstRC = MRI.getRegClass(MI.getOperand(0).getReg());
+ const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
+ if (!RISCVRI::isVRegClass(DstRC->TSFlags))
+ return false;
+ unsigned NF = RISCVRI::getNF(DstRC->TSFlags);
+ if (NF < 2)
+ return false;
+
+ // Check whether INSERT_SUBREG has the correct subreg index for tuple inserts.
+ auto VLMul = RISCVRI::getLMul(DstRC->TSFlags);
+ unsigned SubRegIdx = MI.getOperand(3).getImm();
+ [[maybe_unused]] auto [LMul, IsFractional] = RISCVVType::decodeVLMUL(VLMul);
+ assert(!IsFractional && "unexpected LMUL for tuple register classes");
+ return TRI->getSubRegIdxSize(SubRegIdx) == RISCV::RVVBitsPerBlock * LMul;
+}
+
+static bool isSegmentedStoreInstr(const MachineInstr &MI) {
+ switch (RISCV::getRVVMCOpcode(MI.getOpcode())) {
+ case VSSEG_CASES(8):
+ case VSSSEG_CASES(8):
+ case VSUXSEG_CASES(8):
+ case VSOXSEG_CASES(8):
+ case VSSEG_CASES(16):
+ case VSSSEG_CASES(16):
+ case VSUXSEG_CASES(16):
+ case VSOXSEG_CASES(16):
+ case VSSEG_CASES(32):
+ case VSSSEG_CASES(32):
+ case VSUXSEG_CASES(32):
+ case VSOXSEG_CASES(32):
+ case VSSEG_CASES(64):
+ case VSSSEG_CASES(64):
+ case VSUXSEG_CASES(64):
+ case VSOXSEG_CASES(64):
+ return true;
+ default:
+ return false;
+ }
+}
+
std::optional<MachineOperand>
RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
std::optional<MachineOperand> CommonVL;
@@ -1396,6 +1470,23 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
continue;
}
+ if (isTupleInsertInstr(UserMI)) {
+ LLVM_DEBUG(dbgs().indent(4) << "Peeking through uses of INSERT_SUBREG\n");
+ for (MachineOperand &UseOp :
+ MRI->use_operands(UserMI.getOperand(0).getReg())) {
+ const MachineInstr &CandidateMI = *UseOp.getParent();
+ // We should not propagate the VL if the user is not a segmented store
+ // or another INSERT_SUBREG, since VL just works differently
+ // between segmented operations (per-field) v.s. other RVV ops (on the
+ // whole register group).
+ if (!isTupleInsertInstr(CandidateMI) &&
+ !isSegmentedStoreInstr(CandidateMI))
+ return std::nullopt;
+ Worklist.insert(&UseOp);
+ }
+ continue;
+ }
+
if (UserMI.isPHI()) {
// Don't follow PHI cycles
if (!PHISeen.insert(&UserMI).second)
@@ -1425,9 +1516,8 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
return std::nullopt;
}
- std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
- std::optional<OperandInfo> ProducerInfo =
- getOperandInfo(MI.getOperand(0), MRI);
+ std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp);
+ std::optional<OperandInfo> ProducerInfo = getOperandInfo(MI.getOperand(0));
if (!ConsumerInfo || !ProducerInfo) {
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
@@ -1449,7 +1539,7 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
}
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const {
- LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
+ LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI);
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
MachineOperand &VLOp = MI.getOperand(VLOpNum);
@@ -1468,14 +1558,23 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const {
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
"Expected VL to be an Imm or virtual Reg");
+ // If the VL is defined by a vleff that doesn't dominate MI, try using the
+ // vleff's AVL. It will be greater than or equal to the output VL.
+ if (CommonVL->isReg()) {
+ const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
+ if (RISCVInstrInfo::isFaultOnlyFirstLoad(*VLMI) &&
+ !MDT->dominates(VLMI, &MI))
+ CommonVL = VLMI->getOperand(RISCVII::getVLOpNum(VLMI->getDesc()));
+ }
+
if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
- LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
+ LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
return false;
}
if (CommonVL->isIdenticalTo(VLOp)) {
LLVM_DEBUG(
- dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
+ dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
return false;
}
@@ -1486,8 +1585,10 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const {
return true;
}
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
- if (!MDT->dominates(VLMI, &MI))
+ if (!MDT->dominates(VLMI, &MI)) {
+ LLVM_DEBUG(dbgs() << " Abort due to VL not dominating.\n");
return false;
+ }
LLVM_DEBUG(
dbgs() << " Reduce VL from " << VLOp << " to "
<< printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())