diff options
| author | Pedro Lobo <pedro.lobo@tecnico.ulisboa.pt> | 2025-11-22 15:44:06 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-11-22 15:44:06 +0000 |
| commit | e8af134bb7f891caa49178c8a04a8ca944c611df (patch) | |
| tree | e67e6dd18345210bef83888631439db6a639426c /llvm/lib/Transforms | |
| parent | cc4dd015ad4a1b33d43fbac00d62f6b309a96ff4 (diff) | |
[InstCombine] Generalize trunc-shift-icmp fold from (1 << Y) to (Pow2 << Y) (#169163)
Extends the `icmp(trunc(shl))` fold to handle any power of 2 constant as
the shift base, not just 1. This generalizes the following patterns by
adjusting the comparison offsets by `log2(Pow2)`.
```llvm
(trunc (1 << Y) to iN) == 0 --> Y u>= N
(trunc (1 << Y) to iN) != 0 --> Y u< N
(trunc (1 << Y) to iN) == 2**C --> Y == C
(trunc (1 << Y) to iN) != 2**C --> Y != C
; to
(trunc (Pow2 << Y) to iN) == 0 --> Y u>= N - log2(Pow2)
(trunc (Pow2 << Y) to iN) != 0 --> Y u< N - log2(Pow2)
(trunc (Pow2 << Y) to iN) == 2**C --> Y == C - log2(Pow2)
(trunc (Pow2 << Y) to iN) != 2**C --> Y != C - log2(Pow2)
```
Proof: https://alive2.llvm.org/ce/z/2zwTkp
Diffstat (limited to 'llvm/lib/Transforms')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index f153db177cac..cf6e7315114d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1465,20 +1465,24 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, ConstantInt::get(V->getType(), 1)); } - // TODO: Handle any shifted constant by subtracting trailing zeros. // TODO: Handle non-equality predicates. Value *Y; - if (Cmp.isEquality() && match(X, m_Shl(m_One(), m_Value(Y)))) { - // (trunc (1 << Y) to iN) == 0 --> Y u>= N - // (trunc (1 << Y) to iN) != 0 --> Y u< N + const APInt *Pow2; + if (Cmp.isEquality() && match(X, m_Shl(m_Power2(Pow2), m_Value(Y))) && + DstBits > Pow2->logBase2()) { + // (trunc (Pow2 << Y) to iN) == 0 --> Y u>= N - log2(Pow2) + // (trunc (Pow2 << Y) to iN) != 0 --> Y u< N - log2(Pow2) + // iff N > log2(Pow2) if (C.isZero()) { auto NewPred = (Pred == Cmp.ICMP_EQ) ? Cmp.ICMP_UGE : Cmp.ICMP_ULT; - return new ICmpInst(NewPred, Y, ConstantInt::get(SrcTy, DstBits)); + return new ICmpInst(NewPred, Y, + ConstantInt::get(SrcTy, DstBits - Pow2->logBase2())); } - // (trunc (1 << Y) to iN) == 2**C --> Y == C - // (trunc (1 << Y) to iN) != 2**C --> Y != C + // (trunc (Pow2 << Y) to iN) == 2**C --> Y == C - log2(Pow2) + // (trunc (Pow2 << Y) to iN) != 2**C --> Y != C - log2(Pow2) if (C.isPowerOf2()) - return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2())); + return new ICmpInst( + Pred, Y, ConstantInt::get(SrcTy, C.logBase2() - Pow2->logBase2())); } if (Cmp.isEquality() && (Trunc->hasOneUse() || Trunc->hasNoUnsignedWrap())) { |
