summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp72
1 files changed, 58 insertions, 14 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 73ed611e8de8..3a98e257367b 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -1256,7 +1256,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
- Instruction *Fixup = nullptr);
+ Instruction *Fixup = nullptr,
+ int64_t ScalableOffset = 0);
static unsigned getSetupCost(const SCEV *Reg, unsigned Depth) {
if (isa<SCEVUnknown>(Reg) || isa<SCEVConstant>(Reg))
@@ -1675,16 +1676,18 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
- Instruction *Fixup/*= nullptr*/) {
+ Instruction *Fixup /* = nullptr */,
+ int64_t ScalableOffset) {
switch (Kind) {
case LSRUse::Address:
return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset,
- HasBaseReg, Scale, AccessTy.AddrSpace, Fixup);
+ HasBaseReg, Scale, AccessTy.AddrSpace,
+ Fixup, ScalableOffset);
case LSRUse::ICmpZero:
// There's not even a target hook for querying whether it would be legal to
// fold a GV into an ICmp.
- if (BaseGV)
+ if (BaseGV || ScalableOffset != 0)
return false;
// ICmp only has two operands; don't allow more than two non-trivial parts.
@@ -1715,11 +1718,12 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
case LSRUse::Basic:
// Only handle single-register values.
- return !BaseGV && Scale == 0 && BaseOffset == 0;
+ return !BaseGV && Scale == 0 && BaseOffset == 0 && ScalableOffset == 0;
case LSRUse::Special:
// Special case Basic to handle -1 scales.
- return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0;
+ return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0 &&
+ ScalableOffset == 0;
}
llvm_unreachable("Invalid LSRUse Kind!");
@@ -1843,7 +1847,7 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
- bool HasBaseReg) {
+ bool HasBaseReg, int64_t ScalableOffset = 0) {
// Fast-path: zero is always foldable.
if (BaseOffset == 0 && !BaseGV) return true;
@@ -1859,7 +1863,7 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
}
return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, BaseOffset,
- HasBaseReg, Scale);
+ HasBaseReg, Scale, nullptr, ScalableOffset);
}
static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
@@ -3165,16 +3169,30 @@ void LSRInstance::FinalizeChain(IVChain &Chain) {
static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
Value *Operand, const TargetTransformInfo &TTI) {
const SCEVConstant *IncConst = dyn_cast<SCEVConstant>(IncExpr);
- if (!IncConst || !isAddressUse(TTI, UserInst, Operand))
- return false;
+ int64_t IncOffset = 0;
+ int64_t ScalableOffset = 0;
+ if (IncConst) {
+ if (IncConst && IncConst->getAPInt().getSignificantBits() > 64)
+ return false;
+ IncOffset = IncConst->getValue()->getSExtValue();
+ } else {
+ // Look for mul(vscale, constant), to detect ScalableOffset.
+ auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
+ if (!IncVScale || IncVScale->getNumOperands() != 2 ||
+ !isa<SCEVVScale>(IncVScale->getOperand(1)))
+ return false;
+ auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
+ if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
+ return false;
+ ScalableOffset = Scale->getValue()->getSExtValue();
+ }
- if (IncConst->getAPInt().getSignificantBits() > 64)
+ if (!isAddressUse(TTI, UserInst, Operand))
return false;
MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand);
- int64_t IncOffset = IncConst->getValue()->getSExtValue();
if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr,
- IncOffset, /*HasBaseReg=*/false))
+ IncOffset, /*HasBaseReg=*/false, ScalableOffset))
return false;
return true;
@@ -3220,6 +3238,10 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
Type *IVTy = IVSrc->getType();
Type *IntTy = SE.getEffectiveSCEVType(IVTy);
const SCEV *LeftOverExpr = nullptr;
+ const SCEV *Accum = SE.getZero(IntTy);
+ SmallVector<std::pair<const SCEV *, Value *>> Bases;
+ Bases.emplace_back(Accum, IVSrc);
+
for (const IVInc &Inc : Chain) {
Instruction *InsertPt = Inc.UserInst;
if (isa<PHINode>(InsertPt))
@@ -3232,10 +3254,31 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
// IncExpr was the result of subtraction of two narrow values, so must
// be signed.
const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy);
+ Accum = SE.getAddExpr(Accum, IncExpr);
LeftOverExpr = LeftOverExpr ?
SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr;
}
- if (LeftOverExpr && !LeftOverExpr->isZero()) {
+
+ // Look through each base to see if any can produce a nice addressing mode.
+ bool FoundBase = false;
+ for (auto [MapScev, MapIVOper] : reverse(Bases)) {
+ const SCEV *Remainder = SE.getMinusSCEV(Accum, MapScev);
+ if (canFoldIVIncExpr(Remainder, Inc.UserInst, Inc.IVOperand, TTI)) {
+ if (!Remainder->isZero()) {
+ Rewriter.clearPostInc();
+ Value *IncV = Rewriter.expandCodeFor(Remainder, IntTy, InsertPt);
+ const SCEV *IVOperExpr =
+ SE.getAddExpr(SE.getUnknown(MapIVOper), SE.getUnknown(IncV));
+ IVOper = Rewriter.expandCodeFor(IVOperExpr, IVTy, InsertPt);
+ } else {
+ IVOper = MapIVOper;
+ }
+
+ FoundBase = true;
+ break;
+ }
+ }
+ if (!FoundBase && LeftOverExpr && !LeftOverExpr->isZero()) {
// Expand the IV increment.
Rewriter.clearPostInc();
Value *IncV = Rewriter.expandCodeFor(LeftOverExpr, IntTy, InsertPt);
@@ -3246,6 +3289,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
// If an IV increment can't be folded, use it as the next IV value.
if (!canFoldIVIncExpr(LeftOverExpr, Inc.UserInst, Inc.IVOperand, TTI)) {
assert(IVTy == IVOper->getType() && "inconsistent IV increment type");
+ Bases.emplace_back(Accum, IVOper);
IVSrc = IVOper;
LeftOverExpr = nullptr;
}