diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp | 209 |
1 files changed, 209 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 019536ca91ae..0895b4c2cb48 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -83,6 +83,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopPeel.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include <algorithm> @@ -104,6 +105,8 @@ STATISTIC( "Number of uncountable loops recognized as 'shift until bitttest' idiom"); STATISTIC(NumShiftUntilZero, "Number of uncountable loops recognized as 'shift until zero' idiom"); +STATISTIC(NumByteLoadsWidened, + "Number of loops with consecutive byte loads widened"); bool DisableLIRP::All; static cl::opt<bool, true> @@ -249,6 +252,7 @@ private: const SCEVAddRecExpr *StoreEv, const SCEVAddRecExpr *LoadEv, const SCEV *BECount); + bool processLoopSequentialByteLoads(BasicBlock *BB, const SCEV *BECount); bool avoidLIRForMultiBlockLoop(bool IsMemset = false, bool IsLoopMemset = false); bool optimizeCRCLoop(const PolynomialInfo &Info); @@ -625,6 +629,9 @@ bool LoopIdiomRecognize::runOnLoopBlock( MadeChange |= processLoopMemIntrinsic<MemSetInst>( BB, &LoopIdiomRecognize::processLoopMemSet, BECount); + // Check for sequential byte load patterns that can be widened + MadeChange |= processLoopSequentialByteLoads(BB, BECount); + return MadeChange; } @@ -1517,6 +1524,208 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( return true; } +/// Process sequential byte loads in a loop and widen them if profitable. +/// Recognizes patterns like: +/// %a = load i8, ptr %p +/// %b = load i8, ptr %p+1 +/// %c = load i8, ptr %p+2 +/// %p_next = getelementptr inbounds i8, ptr %p, i64 3 +/// +/// And transforms to (after peeling last iteration): +/// %wide = load i32, ptr %p +/// %a = trunc i32 %wide to i8 +/// %b = trunc i32 (lshr %wide, 8) to i8 +/// %c = trunc i32 (lshr %wide, 16) to i8 +bool LoopIdiomRecognize::processLoopSequentialByteLoads(BasicBlock *BB, + const SCEV *BECount) { + LLVM_DEBUG(dbgs() << " processLoopSequentialByteLoads called\n"); + + // Quick checks + if (!CurLoop->isLoopSimplifyForm()) { + LLVM_DEBUG(dbgs() << " Loop is not in simplified form\n"); + return false; + } + + LLVM_DEBUG(dbgs() << " Loop has " << CurLoop->getNumBlocks() << " blocks\n"); + + // Only process single-block loops for now + if (CurLoop->getNumBlocks() != 1) { + LLVM_DEBUG(dbgs() << " Skipping multi-block loop\n"); + return false; + } + + // Skip if we're compiling for code size + if (ApplyCodeSizeHeuristics) { + LLVM_DEBUG(dbgs() << " Skipping due to code size heuristics\n"); + return false; + } + + LLVM_DEBUG(dbgs() << " Checking for sequential byte load pattern\n"); + + // Collect all simple i8 loads in the loop + SmallVector<LoadInst *, 16> ByteLoads; + for (Instruction &I : *BB) { + if (auto *LI = dyn_cast<LoadInst>(&I)) { + if (LI->getType()->isIntegerTy(8) && LI->isSimple()) + ByteLoads.push_back(LI); + } + } + + if (ByteLoads.size() < 3) { + LLVM_DEBUG(dbgs() << " Not enough byte loads (" << ByteLoads.size() + << ")\n"); + return false; + } + + LLVM_DEBUG(dbgs() << " Found " << ByteLoads.size() << " byte loads\n"); + + // Find ALL groups of 3 consecutive loads + struct LoadGroup { + LoadInst *Load0; + LoadInst *Load1; + LoadInst *Load2; + }; + SmallVector<LoadGroup, 4> LoadGroups; + SmallPtrSet<LoadInst *, 16> ProcessedLoads; + + for (size_t i = 0; i + 2 < ByteLoads.size(); ++i) { + LoadInst *L0 = ByteLoads[i]; + + // Skip if already part of a group + if (ProcessedLoads.count(L0)) + continue; + + LoadInst *L1 = ByteLoads[i + 1]; + LoadInst *L2 = ByteLoads[i + 2]; + + // Check if they're consecutive using isConsecutiveAccess + if (isConsecutiveAccess(L0, L1, *DL, *SE, false) && + isConsecutiveAccess(L1, L2, *DL, *SE, false)) { + LoadGroups.push_back({L0, L1, L2}); + ProcessedLoads.insert(L0); + ProcessedLoads.insert(L1); + ProcessedLoads.insert(L2); + LLVM_DEBUG(dbgs() << " Found consecutive loads at indices " << i << ", " + << (i + 1) << ", " << (i + 2) << "\n"); + i += 2; // Skip the processed loads + } + } + + if (LoadGroups.empty()) { + LLVM_DEBUG(dbgs() << " No consecutive load groups found\n"); + return false; + } + + LLVM_DEBUG(dbgs() << " Found " << LoadGroups.size() + << " groups of consecutive loads\n"); + + // Use the first group for SCEV pattern checking + LoadInst *Load0 = LoadGroups[0].Load0; + LoadInst *Load1 = LoadGroups[0].Load1; + LoadInst *Load2 = LoadGroups[0].Load2; + + // Check if the pointer has the right SCEV pattern (stride on this loop) + const SCEV *Ptr0SCEV = SE->getSCEV(Load0->getPointerOperand()); + const SCEVAddRecExpr *Ptr0AR = dyn_cast<SCEVAddRecExpr>(Ptr0SCEV); + + if (!Ptr0AR || Ptr0AR->getLoop() != CurLoop) { + LLVM_DEBUG(dbgs() << " Pointer is not an AddRec on this loop\n"); + return false; + } + + // Check if we can peel the last iteration + // Note: canPeelLastIteration already verifies the loop executes at least 2 iterations + if (!canPeelLastIteration(*CurLoop, *SE)) { + LLVM_DEBUG(dbgs() << " Cannot peel last iteration\n"); + return false; + } + + LLVM_DEBUG(dbgs() << " Pattern detected and safety checks passed\n"); + LLVM_DEBUG(dbgs() << " Peeling last iteration and widening loads\n"); + + // Peel the last iteration + ValueToValueMapTy VMap; + if (!peelLoop(CurLoop, 1, /*PeelLast=*/true, LI, SE, *DT, /*AC=*/nullptr, + /*PreserveLCSSA=*/true, VMap)) { + LLVM_DEBUG(dbgs() << " Peeling failed\n"); + return false; + } + + LLVM_DEBUG(dbgs() << " Successfully peeled last iteration\n"); + + // Perform the widening transformation in the IR + // This is target-independent - we load 4 bytes and extract what we need + unsigned GroupsWidened = 0; + IRBuilder<> Builder(BB); + + for (const auto &Group : LoadGroups) { + LoadInst *L0 = Group.Load0; + LoadInst *L1 = Group.Load1; + LoadInst *L2 = Group.Load2; + + LLVM_DEBUG(dbgs() << " Widening load group " + << (GroupsWidened + 1) << "/" << LoadGroups.size() << "\n"); + + // Create a 32-bit load at the first load's position + Builder.SetInsertPoint(L0); + Type *I32Ty = Type::getInt32Ty(L0->getContext()); + + // Load 4 bytes as i32 + LoadInst *WideLoad = Builder.CreateAlignedLoad( + I32Ty, L0->getPointerOperand(), L0->getAlign(), "wide.load"); + + // Extract the bytes we need based on endianness + Value *Byte0, *Byte1, *Byte2; + if (DL->isLittleEndian()) { + // Little endian: byte 0 is at the lowest bits + Byte0 = Builder.CreateTrunc(WideLoad, Type::getInt8Ty(L0->getContext())); + Value *Shift8 = Builder.CreateLShr(WideLoad, 8); + Byte1 = Builder.CreateTrunc(Shift8, Type::getInt8Ty(L0->getContext())); + Value *Shift16 = Builder.CreateLShr(WideLoad, 16); + Byte2 = Builder.CreateTrunc(Shift16, Type::getInt8Ty(L0->getContext())); + } else { + // Big endian: byte 0 is at the highest bits + Value *Shift24 = Builder.CreateLShr(WideLoad, 24); + Byte0 = Builder.CreateTrunc(Shift24, Type::getInt8Ty(L0->getContext())); + Value *Shift16 = Builder.CreateLShr(WideLoad, 16); + Byte1 = Builder.CreateTrunc(Shift16, Type::getInt8Ty(L0->getContext())); + Value *Shift8 = Builder.CreateLShr(WideLoad, 8); + Byte2 = Builder.CreateTrunc(Shift8, Type::getInt8Ty(L0->getContext())); + } + + // Replace the original loads with the extracted values + L0->replaceAllUsesWith(Byte0); + L1->replaceAllUsesWith(Byte1); + L2->replaceAllUsesWith(Byte2); + + // Remove the original loads + L0->eraseFromParent(); + L1->eraseFromParent(); + L2->eraseFromParent(); + + GroupsWidened++; + } + + LLVM_DEBUG(dbgs() << " Widened " << GroupsWidened + << " groups of byte loads into 32-bit loads\n"); + + unsigned GroupsHinted = GroupsWidened; + + LLVM_DEBUG(dbgs() << " Successfully widened " << GroupsWidened + << " groups of byte loads\n"); + + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "WidenedByteLoads", + BB->getTerminator()->getDebugLoc(), BB) + << "Widened " << ore::NV("NumGroups", GroupsWidened) + << " groups of 3 consecutive byte loads into 32-bit loads in " + << ore::NV("Function", BB->getParent()); + }); + + ++NumByteLoadsWidened; + return true; +} + // When compiling for codesize we avoid idiom recognition for a multi-block loop // unless it is a loop_memset idiom or a memset/memcpy idiom in a nested loop. // |
