summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
diff options
context:
space:
mode:
authorOliver Hunt <oliver@apple.com>2025-10-20 01:38:07 -0700
committerGitHub <noreply@github.com>2025-10-20 01:38:07 -0700
commit7de01aa5d0418bd4e8db2917f831e7383c6863bb (patch)
tree1db866f57c2236573cd4b4c2d141d6d420f87a92 /llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
parent6bc540043d4c3fed8f44c8f6de86be0d1740582e (diff)
parent46a866ab7735aaa0f89fde209d516271c4825c49 (diff)
Merge branch 'main' into users/ojhunt/ptrauth-additionsusers/ojhunt/ptrauth-additions
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp64
1 files changed, 62 insertions, 2 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index aa030294ff1e..a330bb7b2fc3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -16,6 +16,7 @@
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
@@ -60,6 +61,61 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
return true;
}
+/// Let N = 2 * M.
+/// Given an N-bit integer representing a pack of two M-bit integers,
+/// we can select one of the packed integers by right-shifting by either
+/// zero or M (which is the most straightforward to check if M is a power
+/// of 2), and then isolating the lower M bits. In this case, we can
+/// represent the shift as a select on whether the shr amount is nonzero.
+static Value *simplifyShiftSelectingPackedElement(Instruction *I,
+ const APInt &DemandedMask,
+ InstCombinerImpl &IC,
+ unsigned Depth) {
+ assert(I->getOpcode() == Instruction::LShr &&
+ "Only lshr instruction supported");
+
+ uint64_t ShlAmt;
+ Value *Upper, *Lower;
+ if (!match(I->getOperand(0),
+ m_OneUse(m_c_DisjointOr(
+ m_OneUse(m_Shl(m_Value(Upper), m_ConstantInt(ShlAmt))),
+ m_Value(Lower)))))
+ return nullptr;
+
+ if (!isPowerOf2_64(ShlAmt))
+ return nullptr;
+
+ const uint64_t DemandedBitWidth = DemandedMask.getActiveBits();
+ if (DemandedBitWidth > ShlAmt)
+ return nullptr;
+
+ // Check that upper demanded bits are not lost from lshift.
+ if (Upper->getType()->getScalarSizeInBits() < ShlAmt + DemandedBitWidth)
+ return nullptr;
+
+ KnownBits KnownLowerBits = IC.computeKnownBits(Lower, I, Depth);
+ if (!KnownLowerBits.getMaxValue().isIntN(ShlAmt))
+ return nullptr;
+
+ Value *ShrAmt = I->getOperand(1);
+ KnownBits KnownShrBits = IC.computeKnownBits(ShrAmt, I, Depth);
+
+ // Verify that ShrAmt is either exactly ShlAmt (which is a power of 2) or
+ // zero.
+ if (~KnownShrBits.Zero != ShlAmt)
+ return nullptr;
+
+ Value *ShrAmtZ =
+ IC.Builder.CreateICmpEQ(ShrAmt, Constant::getNullValue(ShrAmt->getType()),
+ ShrAmt->getName() + ".z");
+ // There is no existing !prof metadata we can derive the !prof metadata for
+ // this select.
+ Value *Select = IC.Builder.CreateSelectWithUnknownProfile(ShrAmtZ, Lower,
+ Upper, DEBUG_TYPE);
+ Select->takeName(I);
+ return Select;
+}
+
/// Returns the bitwidth of the given scalar or pointer type. For vector types,
/// returns the element type's bitwidth.
static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
@@ -798,9 +854,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I,
Known >>= ShiftAmt;
if (ShiftAmt)
Known.Zero.setHighBits(ShiftAmt); // high bits known zero.
- } else {
- llvm::computeKnownBits(I, Known, Q, Depth);
+ break;
}
+ if (Value *V =
+ simplifyShiftSelectingPackedElement(I, DemandedMask, *this, Depth))
+ return V;
+
+ llvm::computeKnownBits(I, Known, Q, Depth);
break;
}
case Instruction::AShr: {