diff options
| author | Jonathan Cohen <joncoh@apple.com> | 2025-03-25 15:58:20 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-03-25 15:58:20 +0200 |
| commit | 6785951410c35aa9429152d3e041b44b79db53f2 (patch) | |
| tree | 9f9f0a46766a16e2d7d31641ef03d77f61c8d6d2 /llvm/lib/CodeGen/TargetInstrInfo.cpp | |
| parent | 9768077de65e31daa619eae231f027e052d601c2 (diff) | |
[Machine-Combiner] Add a pass to reassociate chains of accumulation instructions into a tree (#132728)
This pass is designed to increase ILP by performing accumulation into
multiple registers. It currently supports only the S/UABAL accumulation
instruction, but can be extended to support additional instructions.
Reland of #126060 which was reverted due to a conflict with #131272.
Diffstat (limited to 'llvm/lib/CodeGen/TargetInstrInfo.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/TargetInstrInfo.cpp | 270 |
1 files changed, 259 insertions, 11 deletions
diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp index e517ae1a7c44..d077dc189f5a 100644 --- a/llvm/lib/CodeGen/TargetInstrInfo.cpp +++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/BinaryFormat/Dwarf.h" #include "llvm/CodeGen/MachineCombinerPattern.h" @@ -42,6 +43,19 @@ static cl::opt<bool> DisableHazardRecognizer( "disable-sched-hazard", cl::Hidden, cl::init(false), cl::desc("Disable hazard detection during preRA scheduling")); +static cl::opt<bool> EnableAccReassociation( + "acc-reassoc", cl::Hidden, cl::init(true), + cl::desc("Enable reassociation of accumulation chains")); + +static cl::opt<unsigned int> + MinAccumulatorDepth("acc-min-depth", cl::Hidden, cl::init(8), + cl::desc("Minimum length of accumulator chains " + "required for the optimization to kick in")); + +static cl::opt<unsigned int> MaxAccumulatorWidth( + "acc-max-width", cl::Hidden, cl::init(3), + cl::desc("Maximum number of branches in the accumulator tree")); + TargetInstrInfo::~TargetInstrInfo() = default; const TargetRegisterClass* @@ -899,6 +913,154 @@ bool TargetInstrInfo::isReassociationCandidate(const MachineInstr &Inst, hasReassociableSibling(Inst, Commuted); } +// Utility routine that checks if \param MO is defined by an +// \param CombineOpc instruction in the basic block \param MBB. +// If \param CombineOpc is not provided, the OpCode check will +// be skipped. +static bool canCombine(MachineBasicBlock &MBB, MachineOperand &MO, + unsigned CombineOpc = 0) { + MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); + MachineInstr *MI = nullptr; + + if (MO.isReg() && MO.getReg().isVirtual()) + MI = MRI.getUniqueVRegDef(MO.getReg()); + // And it needs to be in the trace (otherwise, it won't have a depth). + if (!MI || MI->getParent() != &MBB || + ((unsigned)MI->getOpcode() != CombineOpc && CombineOpc != 0)) + return false; + // Must only used by the user we combine with. + if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg())) + return false; + + return true; +} + +// A chain of accumulation instructions will be selected IFF: +// 1. All the accumulation instructions in the chain have the same opcode, +// besides the first that has a slightly different opcode because it does +// not accumulate into a register. +// 2. All the instructions in the chain are combinable (have a single use +// which itself is part of the chain). +// 3. Meets the required minimum length. +void TargetInstrInfo::getAccumulatorChain( + MachineInstr *CurrentInstr, SmallVectorImpl<Register> &Chain) const { + // Walk up the chain of accumulation instructions and collect them in the + // vector. + MachineBasicBlock &MBB = *CurrentInstr->getParent(); + const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); + unsigned AccumulatorOpcode = CurrentInstr->getOpcode(); + std::optional<unsigned> ChainStartOpCode = + getAccumulationStartOpcode(AccumulatorOpcode); + + if (!ChainStartOpCode.has_value()) + return; + + // Push the first accumulator result to the start of the chain. + Chain.push_back(CurrentInstr->getOperand(0).getReg()); + + // Collect the accumulator input register from all instructions in the chain. + while (CurrentInstr && + canCombine(MBB, CurrentInstr->getOperand(1), AccumulatorOpcode)) { + Chain.push_back(CurrentInstr->getOperand(1).getReg()); + CurrentInstr = MRI.getUniqueVRegDef(CurrentInstr->getOperand(1).getReg()); + } + + // Add the instruction at the top of the chain. + if (CurrentInstr->getOpcode() == AccumulatorOpcode && + canCombine(MBB, CurrentInstr->getOperand(1))) + Chain.push_back(CurrentInstr->getOperand(1).getReg()); +} + +/// Find chains of accumulations that can be rewritten as a tree for increased +/// ILP. +bool TargetInstrInfo::getAccumulatorReassociationPatterns( + MachineInstr &Root, SmallVectorImpl<unsigned> &Patterns) const { + if (!EnableAccReassociation) + return false; + + unsigned Opc = Root.getOpcode(); + if (!isAccumulationOpcode(Opc)) + return false; + + // Verify that this is the end of the chain. + MachineBasicBlock &MBB = *Root.getParent(); + MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); + if (!MRI.hasOneNonDBGUser(Root.getOperand(0).getReg())) + return false; + + auto User = MRI.use_instr_begin(Root.getOperand(0).getReg()); + if (User->getOpcode() == Opc) + return false; + + // Walk up the use chain and collect the reduction chain. + SmallVector<Register, 32> Chain; + getAccumulatorChain(&Root, Chain); + + // Reject chains which are too short to be worth modifying. + if (Chain.size() < MinAccumulatorDepth) + return false; + + // Check if the MBB this instruction is a part of contains any other chains. + // If so, don't apply it. + SmallSet<Register, 32> ReductionChain(Chain.begin(), Chain.end()); + for (const auto &I : MBB) { + if (I.getOpcode() == Opc && + !ReductionChain.contains(I.getOperand(0).getReg())) + return false; + } + + Patterns.push_back(MachineCombinerPattern::ACC_CHAIN); + return true; +} + +// Reduce branches of the accumulator tree by adding them together. +void TargetInstrInfo::reduceAccumulatorTree( + SmallVectorImpl<Register> &RegistersToReduce, + SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF, + MachineInstr &Root, MachineRegisterInfo &MRI, + DenseMap<Register, unsigned> &InstrIdxForVirtReg, + Register ResultReg) const { + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + SmallVector<Register, 8> NewRegs; + + // Get the opcode for the reduction instruction we will need to build. + // If for some reason it is not defined, early exit and don't apply this. + unsigned ReduceOpCode = getReduceOpcodeForAccumulator(Root.getOpcode()); + + for (unsigned int i = 1; i <= (RegistersToReduce.size() / 2); i += 2) { + auto RHS = RegistersToReduce[i - 1]; + auto LHS = RegistersToReduce[i]; + Register Dest; + // If we are reducing 2 registers, reuse the original result register. + if (RegistersToReduce.size() == 2) + Dest = ResultReg; + // Otherwise, create a new virtual register to hold the partial sum. + else { + auto NewVR = MRI.createVirtualRegister( + MRI.getRegClass(Root.getOperand(0).getReg())); + Dest = NewVR; + NewRegs.push_back(Dest); + InstrIdxForVirtReg.insert(std::make_pair(Dest, InsInstrs.size())); + } + + // Create the new reduction instruction. + MachineInstrBuilder MIB = + BuildMI(MF, MIMetadata(Root), TII->get(ReduceOpCode), Dest) + .addReg(RHS, getKillRegState(true)) + .addReg(LHS, getKillRegState(true)); + // Copy any flags needed from the original instruction. + MIB->setFlags(Root.getFlags()); + InsInstrs.push_back(MIB); + } + + // If the number of registers to reduce is odd, add the remaining register to + // the vector of registers to reduce. + if (RegistersToReduce.size() % 2 != 0) + NewRegs.push_back(RegistersToReduce[RegistersToReduce.size() - 1]); + + RegistersToReduce = NewRegs; +} + // The concept of the reassociation pass is that these operations can benefit // from this kind of transformation: // @@ -938,6 +1100,8 @@ bool TargetInstrInfo::getMachineCombinerPatterns( } return true; } + if (getAccumulatorReassociationPatterns(Root, Patterns)) + return true; return false; } @@ -949,7 +1113,12 @@ bool TargetInstrInfo::isThroughputPattern(unsigned Pattern) const { CombinerObjective TargetInstrInfo::getCombinerObjective(unsigned Pattern) const { - return CombinerObjective::Default; + switch (Pattern) { + case MachineCombinerPattern::ACC_CHAIN: + return CombinerObjective::MustReduceDepth; + default: + return CombinerObjective::Default; + } } std::pair<unsigned, unsigned> @@ -1252,19 +1421,98 @@ void TargetInstrInfo::genAlternativeCodeSequence( SmallVectorImpl<MachineInstr *> &DelInstrs, DenseMap<Register, unsigned> &InstIdxForVirtReg) const { MachineRegisterInfo &MRI = Root.getMF()->getRegInfo(); + MachineBasicBlock &MBB = *Root.getParent(); + MachineFunction &MF = *MBB.getParent(); + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); - // Select the previous instruction in the sequence based on the input pattern. - std::array<unsigned, 5> OperandIndices; - getReassociateOperandIndices(Root, Pattern, OperandIndices); - MachineInstr *Prev = - MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg()); + switch (Pattern) { + case MachineCombinerPattern::REASSOC_AX_BY: + case MachineCombinerPattern::REASSOC_AX_YB: + case MachineCombinerPattern::REASSOC_XA_BY: + case MachineCombinerPattern::REASSOC_XA_YB: { + // Select the previous instruction in the sequence based on the input + // pattern. + std::array<unsigned, 5> OperandIndices; + getReassociateOperandIndices(Root, Pattern, OperandIndices); + MachineInstr *Prev = + MRI.getUniqueVRegDef(Root.getOperand(OperandIndices[0]).getReg()); + + // Don't reassociate if Prev and Root are in different blocks. + if (Prev->getParent() != Root.getParent()) + return; - // Don't reassociate if Prev and Root are in different blocks. - if (Prev->getParent() != Root.getParent()) - return; + reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices, + InstIdxForVirtReg); + break; + } + case MachineCombinerPattern::ACC_CHAIN: { + SmallVector<Register, 32> ChainRegs; + getAccumulatorChain(&Root, ChainRegs); + unsigned int Depth = ChainRegs.size(); + assert(MaxAccumulatorWidth > 1 && + "Max accumulator width set to illegal value"); + unsigned int MaxWidth = Log2_32(Depth) < MaxAccumulatorWidth + ? Log2_32(Depth) + : MaxAccumulatorWidth; + + // Walk down the chain and rewrite it as a tree. + for (auto IndexedReg : llvm::enumerate(llvm::reverse(ChainRegs))) { + // No need to rewrite the first node, it is already perfect as it is. + if (IndexedReg.index() == 0) + continue; + + MachineInstr *Instr = MRI.getUniqueVRegDef(IndexedReg.value()); + MachineInstrBuilder MIB; + Register AccReg; + if (IndexedReg.index() < MaxWidth) { + // Now we need to create new instructions for the first row. + AccReg = Instr->getOperand(0).getReg(); + unsigned OpCode = getAccumulationStartOpcode(Root.getOpcode()); + + MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(OpCode), AccReg) + .addReg(Instr->getOperand(2).getReg(), + getKillRegState(Instr->getOperand(2).isKill())) + .addReg(Instr->getOperand(3).getReg(), + getKillRegState(Instr->getOperand(3).isKill())); + } else { + // For the remaining cases, we need to use an output register of one of + // the newly inserted instuctions as operand 1 + AccReg = Instr->getOperand(0).getReg() == Root.getOperand(0).getReg() + ? MRI.createVirtualRegister( + MRI.getRegClass(Root.getOperand(0).getReg())) + : Instr->getOperand(0).getReg(); + assert(IndexedReg.index() >= MaxWidth); + auto AccumulatorInput = + ChainRegs[Depth - (IndexedReg.index() - MaxWidth) - 1]; + MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(Instr->getOpcode()), + AccReg) + .addReg(AccumulatorInput, getKillRegState(true)) + .addReg(Instr->getOperand(2).getReg(), + getKillRegState(Instr->getOperand(2).isKill())) + .addReg(Instr->getOperand(3).getReg(), + getKillRegState(Instr->getOperand(3).isKill())); + } - reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, OperandIndices, - InstIdxForVirtReg); + MIB->setFlags(Instr->getFlags()); + InstIdxForVirtReg.insert(std::make_pair(AccReg, InsInstrs.size())); + InsInstrs.push_back(MIB); + DelInstrs.push_back(Instr); + } + + SmallVector<Register, 8> RegistersToReduce; + for (unsigned i = (InsInstrs.size() - MaxWidth); i < InsInstrs.size(); + ++i) { + auto Reg = InsInstrs[i]->getOperand(0).getReg(); + RegistersToReduce.push_back(Reg); + } + + while (RegistersToReduce.size() > 1) + reduceAccumulatorTree(RegistersToReduce, InsInstrs, MF, Root, MRI, + InstIdxForVirtReg, Root.getOperand(0).getReg()); + + break; + } + } } MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const { |
