summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms
diff options
context:
space:
mode:
authorPedro Lobo <pedro.lobo@tecnico.ulisboa.pt>2025-11-22 15:44:06 +0000
committerGitHub <noreply@github.com>2025-11-22 15:44:06 +0000
commite8af134bb7f891caa49178c8a04a8ca944c611df (patch)
treee67e6dd18345210bef83888631439db6a639426c /llvm/lib/Transforms
parentcc4dd015ad4a1b33d43fbac00d62f6b309a96ff4 (diff)
[InstCombine] Generalize trunc-shift-icmp fold from (1 << Y) to (Pow2 << Y) (#169163)
Extends the `icmp(trunc(shl))` fold to handle any power of 2 constant as the shift base, not just 1. This generalizes the following patterns by adjusting the comparison offsets by `log2(Pow2)`. ```llvm (trunc (1 << Y) to iN) == 0 --> Y u>= N (trunc (1 << Y) to iN) != 0 --> Y u< N (trunc (1 << Y) to iN) == 2**C --> Y == C (trunc (1 << Y) to iN) != 2**C --> Y != C ; to (trunc (Pow2 << Y) to iN) == 0 --> Y u>= N - log2(Pow2) (trunc (Pow2 << Y) to iN) != 0 --> Y u< N - log2(Pow2) (trunc (Pow2 << Y) to iN) == 2**C --> Y == C - log2(Pow2) (trunc (Pow2 << Y) to iN) != 2**C --> Y != C - log2(Pow2) ``` Proof: https://alive2.llvm.org/ce/z/2zwTkp
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp20
1 files changed, 12 insertions, 8 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index f153db177cac..cf6e7315114d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1465,20 +1465,24 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
ConstantInt::get(V->getType(), 1));
}
- // TODO: Handle any shifted constant by subtracting trailing zeros.
// TODO: Handle non-equality predicates.
Value *Y;
- if (Cmp.isEquality() && match(X, m_Shl(m_One(), m_Value(Y)))) {
- // (trunc (1 << Y) to iN) == 0 --> Y u>= N
- // (trunc (1 << Y) to iN) != 0 --> Y u< N
+ const APInt *Pow2;
+ if (Cmp.isEquality() && match(X, m_Shl(m_Power2(Pow2), m_Value(Y))) &&
+ DstBits > Pow2->logBase2()) {
+ // (trunc (Pow2 << Y) to iN) == 0 --> Y u>= N - log2(Pow2)
+ // (trunc (Pow2 << Y) to iN) != 0 --> Y u< N - log2(Pow2)
+ // iff N > log2(Pow2)
if (C.isZero()) {
auto NewPred = (Pred == Cmp.ICMP_EQ) ? Cmp.ICMP_UGE : Cmp.ICMP_ULT;
- return new ICmpInst(NewPred, Y, ConstantInt::get(SrcTy, DstBits));
+ return new ICmpInst(NewPred, Y,
+ ConstantInt::get(SrcTy, DstBits - Pow2->logBase2()));
}
- // (trunc (1 << Y) to iN) == 2**C --> Y == C
- // (trunc (1 << Y) to iN) != 2**C --> Y != C
+ // (trunc (Pow2 << Y) to iN) == 2**C --> Y == C - log2(Pow2)
+ // (trunc (Pow2 << Y) to iN) != 2**C --> Y != C - log2(Pow2)
if (C.isPowerOf2())
- return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2()));
+ return new ICmpInst(
+ Pred, Y, ConstantInt::get(SrcTy, C.logBase2() - Pow2->logBase2()));
}
if (Cmp.isEquality() && (Trunc->hasOneUse() || Trunc->hasNoUnsignedWrap())) {