summaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen
diff options
context:
space:
mode:
authorHassnaa Hamdi <hassnaa.hamdi@arm.com>2025-11-18 13:15:47 +0000
committerGitHub <noreply@github.com>2025-11-18 13:15:47 +0000
commit3d5d32c6058807008e579dd5ea2faced33a7943b (patch)
tree724113316b3c3d34ca00c54f242329524548193d /llvm/lib/CodeGen
parent52f4c360e382e6926dccb315d4402af6211e25f0 (diff)
[CGP]: Optimize mul.overflow. (#148343)
- Detect cases where LHS & RHS values will not cause overflow (when the Hi halfs are zero).
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r--llvm/lib/CodeGen/CodeGenPrepare.cpp182
1 files changed, 182 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index b6dd174f9be8..587c1372b19c 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -431,6 +431,8 @@ private:
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
unsigned AddrSpace);
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
+ bool optimizeMulWithOverflow(Instruction *I, bool IsSigned,
+ ModifyDT &ModifiedDT);
bool optimizeInlineAsmInst(CallInst *CS);
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
bool optimizeExt(Instruction *&I);
@@ -2797,6 +2799,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
}
}
return false;
+ case Intrinsic::umul_with_overflow:
+ return optimizeMulWithOverflow(II, /*IsSigned=*/false, ModifiedDT);
+ case Intrinsic::smul_with_overflow:
+ return optimizeMulWithOverflow(II, /*IsSigned=*/true, ModifiedDT);
}
SmallVector<Value *, 2> PtrOps;
@@ -6391,6 +6397,182 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}
+// This is a helper for CodeGenPrepare::optimizeMulWithOverflow.
+// Check the pattern we are interested in where there are maximum 2 uses
+// of the intrinsic which are the extract instructions.
+static bool matchOverflowPattern(Instruction *&I, ExtractValueInst *&MulExtract,
+ ExtractValueInst *&OverflowExtract) {
+ // Bail out if it's more than 2 users:
+ if (I->hasNUsesOrMore(3))
+ return false;
+
+ for (User *U : I->users()) {
+ auto *Extract = dyn_cast<ExtractValueInst>(U);
+ if (!Extract || Extract->getNumIndices() != 1)
+ return false;
+
+ unsigned Index = Extract->getIndices()[0];
+ if (Index == 0)
+ MulExtract = Extract;
+ else if (Index == 1)
+ OverflowExtract = Extract;
+ else
+ return false;
+ }
+ return true;
+}
+
+// Rewrite the mul_with_overflow intrinsic by checking if both of the
+// operands' value ranges are within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+// The IR after the optimization will look like:
+// entry:
+// if signed:
+// ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
+// overflow_no
+// else:
+// (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
+// overflow_no:
+// overflow:
+// overflow.res:
+// \returns true if optimization was applied
+// TODO: This optimization can be further improved to optimize branching on
+// overflow where the 'overflow_no' BB can branch directly to the false
+// successor of overflow, but that would add additional complexity so we leave
+// it for future work.
+bool CodeGenPrepare::optimizeMulWithOverflow(Instruction *I, bool IsSigned,
+ ModifyDT &ModifiedDT) {
+ // Check if target supports this optimization.
+ if (!TLI->shouldOptimizeMulOverflowWithZeroHighBits(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))))
+ return false;
+
+ ExtractValueInst *MulExtract = nullptr, *OverflowExtract = nullptr;
+ if (!matchOverflowPattern(I, MulExtract, OverflowExtract))
+ return false;
+
+ // Keep track of the instruction to stop reoptimizing it again.
+ InsertedInsts.insert(I);
+
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ Type *Ty = LHS->getType();
+ unsigned VTHalfBitWidth = Ty->getScalarSizeInBits() / 2;
+ Type *LegalTy = Ty->getWithNewBitWidth(VTHalfBitWidth);
+
+ // New BBs:
+ BasicBlock *OverflowEntryBB =
+ I->getParent()->splitBasicBlock(I, "", /*Before*/ true);
+ OverflowEntryBB->takeName(I->getParent());
+ // Keep the 'br' instruction that is generated as a result of the split to be
+ // erased/replaced later.
+ Instruction *OldTerminator = OverflowEntryBB->getTerminator();
+ BasicBlock *NoOverflowBB =
+ BasicBlock::Create(I->getContext(), "overflow.no", I->getFunction());
+ NoOverflowBB->moveAfter(OverflowEntryBB);
+ BasicBlock *OverflowBB =
+ BasicBlock::Create(I->getContext(), "overflow", I->getFunction());
+ OverflowBB->moveAfter(NoOverflowBB);
+
+ // BB overflow.entry:
+ IRBuilder<> Builder(OverflowEntryBB);
+ // Extract low and high halves of LHS:
+ Value *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs");
+ Value *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
+
+ // Extract low and high halves of RHS:
+ Value *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs");
+ Value *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ HiRHS = Builder.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
+
+ Value *IsAnyBitTrue;
+ if (IsSigned) {
+ Value *SignLoLHS =
+ Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
+ Value *SignLoRHS =
+ Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
+ Value *XorLHS = Builder.CreateXor(HiLHS, SignLoLHS);
+ Value *XorRHS = Builder.CreateXor(HiRHS, SignLoRHS);
+ Value *Or = Builder.CreateOr(XorLHS, XorRHS, "or.lhs.rhs");
+ IsAnyBitTrue = Builder.CreateCmp(ICmpInst::ICMP_NE, Or,
+ ConstantInt::getNullValue(Or->getType()));
+ } else {
+ Value *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+ ConstantInt::getNullValue(LegalTy));
+ Value *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ IsAnyBitTrue = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
+ }
+ Builder.CreateCondBr(IsAnyBitTrue, OverflowBB, NoOverflowBB);
+
+ // BB overflow.no:
+ Builder.SetInsertPoint(NoOverflowBB);
+ Value *ExtLoLHS, *ExtLoRHS;
+ if (IsSigned) {
+ ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext");
+ ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext");
+ } else {
+ ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
+ ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
+ }
+
+ Value *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.overflow.no");
+
+ // Create the 'overflow.res' BB to merge the results of
+ // the two paths:
+ BasicBlock *OverflowResBB = I->getParent();
+ OverflowResBB->setName("overflow.res");
+
+ // BB overflow.no: jump to overflow.res BB
+ Builder.CreateBr(OverflowResBB);
+ // No we don't need the old terminator in overflow.entry BB, erase it:
+ OldTerminator->eraseFromParent();
+
+ // BB overflow.res:
+ Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt());
+ // Create PHI nodes to merge results from no.overflow BB and overflow BB to
+ // replace the extract instructions.
+ PHINode *OverflowResPHI = Builder.CreatePHI(Ty, 2),
+ *OverflowFlagPHI =
+ Builder.CreatePHI(IntegerType::getInt1Ty(I->getContext()), 2);
+
+ // Add the incoming values from no.overflow BB and later from overflow BB.
+ OverflowResPHI->addIncoming(Mul, NoOverflowBB);
+ OverflowFlagPHI->addIncoming(ConstantInt::getFalse(I->getContext()),
+ NoOverflowBB);
+
+ // Replace all users of MulExtract and OverflowExtract to use the PHI nodes.
+ if (MulExtract) {
+ MulExtract->replaceAllUsesWith(OverflowResPHI);
+ MulExtract->eraseFromParent();
+ }
+ if (OverflowExtract) {
+ OverflowExtract->replaceAllUsesWith(OverflowFlagPHI);
+ OverflowExtract->eraseFromParent();
+ }
+
+ // Remove the intrinsic from parent (overflow.res BB) as it will be part of
+ // overflow BB
+ I->removeFromParent();
+ // BB overflow:
+ I->insertInto(OverflowBB, OverflowBB->end());
+ Builder.SetInsertPoint(OverflowBB, OverflowBB->end());
+ Value *MulOverflow = Builder.CreateExtractValue(I, {0}, "mul.overflow");
+ Value *OverflowFlag = Builder.CreateExtractValue(I, {1}, "overflow.flag");
+ Builder.CreateBr(OverflowResBB);
+
+ // Add The Extracted values to the PHINodes in the overflow.res BB.
+ OverflowResPHI->addIncoming(MulOverflow, OverflowBB);
+ OverflowFlagPHI->addIncoming(OverflowFlag, OverflowBB);
+
+ ModifiedDT = ModifyDT::ModifyBBDT;
+ return true;
+}
+
/// If there are any memory operands, use OptimizeMemoryInst to sink their
/// address computing into the block when possible / profitable.
bool CodeGenPrepare::optimizeInlineAsmInst(CallInst *CS) {