summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuy David <guyda96@gmail.com>2025-10-23 10:38:34 +0300
committerGuy David <guyda96@gmail.com>2025-10-24 15:04:37 +0300
commit9e16046fae0d57381e2ea6c9f6279729dae56083 (patch)
tree2b0641a529ae54dc1630c1313cb04e48045cade3
parent89b18f0304c8a4f7e069fdba92a13d1b939a218f (diff)
-rw-r--r--llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp209
1 files changed, 209 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 019536ca91ae..0895b4c2cb48 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -83,6 +83,7 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/LoopPeel.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include <algorithm>
@@ -104,6 +105,8 @@ STATISTIC(
"Number of uncountable loops recognized as 'shift until bitttest' idiom");
STATISTIC(NumShiftUntilZero,
"Number of uncountable loops recognized as 'shift until zero' idiom");
+STATISTIC(NumByteLoadsWidened,
+ "Number of loops with consecutive byte loads widened");
bool DisableLIRP::All;
static cl::opt<bool, true>
@@ -249,6 +252,7 @@ private:
const SCEVAddRecExpr *StoreEv,
const SCEVAddRecExpr *LoadEv,
const SCEV *BECount);
+ bool processLoopSequentialByteLoads(BasicBlock *BB, const SCEV *BECount);
bool avoidLIRForMultiBlockLoop(bool IsMemset = false,
bool IsLoopMemset = false);
bool optimizeCRCLoop(const PolynomialInfo &Info);
@@ -625,6 +629,9 @@ bool LoopIdiomRecognize::runOnLoopBlock(
MadeChange |= processLoopMemIntrinsic<MemSetInst>(
BB, &LoopIdiomRecognize::processLoopMemSet, BECount);
+ // Check for sequential byte load patterns that can be widened
+ MadeChange |= processLoopSequentialByteLoads(BB, BECount);
+
return MadeChange;
}
@@ -1517,6 +1524,208 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
return true;
}
+/// Process sequential byte loads in a loop and widen them if profitable.
+/// Recognizes patterns like:
+/// %a = load i8, ptr %p
+/// %b = load i8, ptr %p+1
+/// %c = load i8, ptr %p+2
+/// %p_next = getelementptr inbounds i8, ptr %p, i64 3
+///
+/// And transforms to (after peeling last iteration):
+/// %wide = load i32, ptr %p
+/// %a = trunc i32 %wide to i8
+/// %b = trunc i32 (lshr %wide, 8) to i8
+/// %c = trunc i32 (lshr %wide, 16) to i8
+bool LoopIdiomRecognize::processLoopSequentialByteLoads(BasicBlock *BB,
+ const SCEV *BECount) {
+ LLVM_DEBUG(dbgs() << " processLoopSequentialByteLoads called\n");
+
+ // Quick checks
+ if (!CurLoop->isLoopSimplifyForm()) {
+ LLVM_DEBUG(dbgs() << " Loop is not in simplified form\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << " Loop has " << CurLoop->getNumBlocks() << " blocks\n");
+
+ // Only process single-block loops for now
+ if (CurLoop->getNumBlocks() != 1) {
+ LLVM_DEBUG(dbgs() << " Skipping multi-block loop\n");
+ return false;
+ }
+
+ // Skip if we're compiling for code size
+ if (ApplyCodeSizeHeuristics) {
+ LLVM_DEBUG(dbgs() << " Skipping due to code size heuristics\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << " Checking for sequential byte load pattern\n");
+
+ // Collect all simple i8 loads in the loop
+ SmallVector<LoadInst *, 16> ByteLoads;
+ for (Instruction &I : *BB) {
+ if (auto *LI = dyn_cast<LoadInst>(&I)) {
+ if (LI->getType()->isIntegerTy(8) && LI->isSimple())
+ ByteLoads.push_back(LI);
+ }
+ }
+
+ if (ByteLoads.size() < 3) {
+ LLVM_DEBUG(dbgs() << " Not enough byte loads (" << ByteLoads.size()
+ << ")\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << " Found " << ByteLoads.size() << " byte loads\n");
+
+ // Find ALL groups of 3 consecutive loads
+ struct LoadGroup {
+ LoadInst *Load0;
+ LoadInst *Load1;
+ LoadInst *Load2;
+ };
+ SmallVector<LoadGroup, 4> LoadGroups;
+ SmallPtrSet<LoadInst *, 16> ProcessedLoads;
+
+ for (size_t i = 0; i + 2 < ByteLoads.size(); ++i) {
+ LoadInst *L0 = ByteLoads[i];
+
+ // Skip if already part of a group
+ if (ProcessedLoads.count(L0))
+ continue;
+
+ LoadInst *L1 = ByteLoads[i + 1];
+ LoadInst *L2 = ByteLoads[i + 2];
+
+ // Check if they're consecutive using isConsecutiveAccess
+ if (isConsecutiveAccess(L0, L1, *DL, *SE, false) &&
+ isConsecutiveAccess(L1, L2, *DL, *SE, false)) {
+ LoadGroups.push_back({L0, L1, L2});
+ ProcessedLoads.insert(L0);
+ ProcessedLoads.insert(L1);
+ ProcessedLoads.insert(L2);
+ LLVM_DEBUG(dbgs() << " Found consecutive loads at indices " << i << ", "
+ << (i + 1) << ", " << (i + 2) << "\n");
+ i += 2; // Skip the processed loads
+ }
+ }
+
+ if (LoadGroups.empty()) {
+ LLVM_DEBUG(dbgs() << " No consecutive load groups found\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << " Found " << LoadGroups.size()
+ << " groups of consecutive loads\n");
+
+ // Use the first group for SCEV pattern checking
+ LoadInst *Load0 = LoadGroups[0].Load0;
+ LoadInst *Load1 = LoadGroups[0].Load1;
+ LoadInst *Load2 = LoadGroups[0].Load2;
+
+ // Check if the pointer has the right SCEV pattern (stride on this loop)
+ const SCEV *Ptr0SCEV = SE->getSCEV(Load0->getPointerOperand());
+ const SCEVAddRecExpr *Ptr0AR = dyn_cast<SCEVAddRecExpr>(Ptr0SCEV);
+
+ if (!Ptr0AR || Ptr0AR->getLoop() != CurLoop) {
+ LLVM_DEBUG(dbgs() << " Pointer is not an AddRec on this loop\n");
+ return false;
+ }
+
+ // Check if we can peel the last iteration
+ // Note: canPeelLastIteration already verifies the loop executes at least 2 iterations
+ if (!canPeelLastIteration(*CurLoop, *SE)) {
+ LLVM_DEBUG(dbgs() << " Cannot peel last iteration\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << " Pattern detected and safety checks passed\n");
+ LLVM_DEBUG(dbgs() << " Peeling last iteration and widening loads\n");
+
+ // Peel the last iteration
+ ValueToValueMapTy VMap;
+ if (!peelLoop(CurLoop, 1, /*PeelLast=*/true, LI, SE, *DT, /*AC=*/nullptr,
+ /*PreserveLCSSA=*/true, VMap)) {
+ LLVM_DEBUG(dbgs() << " Peeling failed\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << " Successfully peeled last iteration\n");
+
+ // Perform the widening transformation in the IR
+ // This is target-independent - we load 4 bytes and extract what we need
+ unsigned GroupsWidened = 0;
+ IRBuilder<> Builder(BB);
+
+ for (const auto &Group : LoadGroups) {
+ LoadInst *L0 = Group.Load0;
+ LoadInst *L1 = Group.Load1;
+ LoadInst *L2 = Group.Load2;
+
+ LLVM_DEBUG(dbgs() << " Widening load group "
+ << (GroupsWidened + 1) << "/" << LoadGroups.size() << "\n");
+
+ // Create a 32-bit load at the first load's position
+ Builder.SetInsertPoint(L0);
+ Type *I32Ty = Type::getInt32Ty(L0->getContext());
+
+ // Load 4 bytes as i32
+ LoadInst *WideLoad = Builder.CreateAlignedLoad(
+ I32Ty, L0->getPointerOperand(), L0->getAlign(), "wide.load");
+
+ // Extract the bytes we need based on endianness
+ Value *Byte0, *Byte1, *Byte2;
+ if (DL->isLittleEndian()) {
+ // Little endian: byte 0 is at the lowest bits
+ Byte0 = Builder.CreateTrunc(WideLoad, Type::getInt8Ty(L0->getContext()));
+ Value *Shift8 = Builder.CreateLShr(WideLoad, 8);
+ Byte1 = Builder.CreateTrunc(Shift8, Type::getInt8Ty(L0->getContext()));
+ Value *Shift16 = Builder.CreateLShr(WideLoad, 16);
+ Byte2 = Builder.CreateTrunc(Shift16, Type::getInt8Ty(L0->getContext()));
+ } else {
+ // Big endian: byte 0 is at the highest bits
+ Value *Shift24 = Builder.CreateLShr(WideLoad, 24);
+ Byte0 = Builder.CreateTrunc(Shift24, Type::getInt8Ty(L0->getContext()));
+ Value *Shift16 = Builder.CreateLShr(WideLoad, 16);
+ Byte1 = Builder.CreateTrunc(Shift16, Type::getInt8Ty(L0->getContext()));
+ Value *Shift8 = Builder.CreateLShr(WideLoad, 8);
+ Byte2 = Builder.CreateTrunc(Shift8, Type::getInt8Ty(L0->getContext()));
+ }
+
+ // Replace the original loads with the extracted values
+ L0->replaceAllUsesWith(Byte0);
+ L1->replaceAllUsesWith(Byte1);
+ L2->replaceAllUsesWith(Byte2);
+
+ // Remove the original loads
+ L0->eraseFromParent();
+ L1->eraseFromParent();
+ L2->eraseFromParent();
+
+ GroupsWidened++;
+ }
+
+ LLVM_DEBUG(dbgs() << " Widened " << GroupsWidened
+ << " groups of byte loads into 32-bit loads\n");
+
+ unsigned GroupsHinted = GroupsWidened;
+
+ LLVM_DEBUG(dbgs() << " Successfully widened " << GroupsWidened
+ << " groups of byte loads\n");
+
+ ORE.emit([&]() {
+ return OptimizationRemark(DEBUG_TYPE, "WidenedByteLoads",
+ BB->getTerminator()->getDebugLoc(), BB)
+ << "Widened " << ore::NV("NumGroups", GroupsWidened)
+ << " groups of 3 consecutive byte loads into 32-bit loads in "
+ << ore::NV("Function", BB->getParent());
+ });
+
+ ++NumByteLoadsWidened;
+ return true;
+}
+
// When compiling for codesize we avoid idiom recognition for a multi-block loop
// unless it is a loop_memset idiom or a memset/memcpy idiom in a nested loop.
//