diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VectorCombine.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 227 |
1 files changed, 142 insertions, 85 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index e608c7fb6046..7fa1b433ef11 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1669,63 +1669,109 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { return true; } -using InstLane = std::pair<Value *, int>; +using InstLane = std::pair<Use *, int>; -static InstLane lookThroughShuffles(Value *V, int Lane) { - while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) { +static InstLane lookThroughShuffles(Use *U, int Lane) { + while (auto *SV = dyn_cast<ShuffleVectorInst>(U->get())) { unsigned NumElts = cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements(); int M = SV->getMaskValue(Lane); if (M < 0) return {nullptr, PoisonMaskElem}; if (static_cast<unsigned>(M) < NumElts) { - V = SV->getOperand(0); + U = &SV->getOperandUse(0); Lane = M; } else { - V = SV->getOperand(1); + U = &SV->getOperandUse(1); Lane = M - NumElts; } } - return InstLane{V, Lane}; + return InstLane{U, Lane}; } static SmallVector<InstLane> generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) { SmallVector<InstLane> NItem; for (InstLane IL : Item) { - auto [V, Lane] = IL; + auto [U, Lane] = IL; InstLane OpLane = - V ? lookThroughShuffles(cast<Instruction>(V)->getOperand(Op), Lane) + U ? lookThroughShuffles(&cast<Instruction>(U->get())->getOperandUse(Op), + Lane) : InstLane{nullptr, PoisonMaskElem}; NItem.emplace_back(OpLane); } return NItem; } +/// Detect concat of multiple values into a vector +static bool isFreeConcat(ArrayRef<InstLane> Item, + const TargetTransformInfo &TTI) { + auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType()); + unsigned NumElts = Ty->getNumElements(); + if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0) + return false; + + // Check that the concat is free, usually meaning that the type will be split + // during legalization. + SmallVector<int, 16> ConcatMask(NumElts * 2); + std::iota(ConcatMask.begin(), ConcatMask.end(), 0); + if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask, + TTI::TCK_RecipThroughput) != 0) + return false; + + unsigned NumSlices = Item.size() / NumElts; + // Currently we generate a tree of shuffles for the concats, which limits us + // to a power2. + if (!isPowerOf2_32(NumSlices)) + return false; + for (unsigned Slice = 0; Slice < NumSlices; ++Slice) { + Use *SliceV = Item[Slice * NumElts].first; + if (!SliceV || SliceV->get()->getType() != Ty) + return false; + for (unsigned Elt = 0; Elt < NumElts; ++Elt) { + auto [V, Lane] = Item[Slice * NumElts + Elt]; + if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get()) + return false; + } + } + return true; +} + static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, - const SmallPtrSet<Value *, 4> &IdentityLeafs, - const SmallPtrSet<Value *, 4> &SplatLeafs, + const SmallPtrSet<Use *, 4> &IdentityLeafs, + const SmallPtrSet<Use *, 4> &SplatLeafs, + const SmallPtrSet<Use *, 4> &ConcatLeafs, IRBuilder<> &Builder) { - auto [FrontV, FrontLane] = Item.front(); - - if (IdentityLeafs.contains(FrontV) && - all_of(drop_begin(enumerate(Item)), [Item](const auto &E) { - Value *FrontV = Item.front().first; - auto [V, Lane] = E.value(); - return !V || (V == FrontV && Lane == (int)E.index()); - })) { - return FrontV; + auto [FrontU, FrontLane] = Item.front(); + + if (IdentityLeafs.contains(FrontU)) { + return FrontU->get(); } - if (SplatLeafs.contains(FrontV)) { - if (auto *ILI = dyn_cast<Instruction>(FrontV)) - Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef()); - else if (auto *Arg = dyn_cast<Argument>(FrontV)) - Builder.SetInsertPointPastAllocas(Arg->getParent()); + if (SplatLeafs.contains(FrontU)) { SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane); - return Builder.CreateShuffleVector(FrontV, Mask); + return Builder.CreateShuffleVector(FrontU->get(), Mask); + } + if (ConcatLeafs.contains(FrontU)) { + unsigned NumElts = + cast<FixedVectorType>(FrontU->get()->getType())->getNumElements(); + SmallVector<Value *> Values(Item.size() / NumElts, nullptr); + for (unsigned S = 0; S < Values.size(); ++S) + Values[S] = Item[S * NumElts].first->get(); + + while (Values.size() > 1) { + NumElts *= 2; + SmallVector<int, 16> Mask(NumElts, 0); + std::iota(Mask.begin(), Mask.end(), 0); + SmallVector<Value *> NewValues(Values.size() / 2, nullptr); + for (unsigned S = 0; S < NewValues.size(); ++S) + NewValues[S] = + Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask); + Values = NewValues; + } + return Values[0]; } - auto *I = cast<Instruction>(FrontV); + auto *I = cast<Instruction>(FrontU->get()); auto *II = dyn_cast<IntrinsicInst>(I); unsigned NumOps = I->getNumOperands() - (II ? 1 : 0); SmallVector<Value *> Ops(NumOps); @@ -1734,16 +1780,16 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, Ops[Idx] = II->getOperand(Idx); continue; } - Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), - Ty, IdentityLeafs, SplatLeafs, Builder); + Ops[Idx] = + generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), Ty, + IdentityLeafs, SplatLeafs, ConcatLeafs, Builder); } SmallVector<Value *, 8> ValueList; for (const auto &Lane : Item) if (Lane.first) - ValueList.push_back(Lane.first); + ValueList.push_back(Lane.first->get()); - Builder.SetInsertPoint(I); Type *DstTy = FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements()); if (auto *BI = dyn_cast<BinaryOperator>(I)) { @@ -1785,16 +1831,16 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, // do so. bool VectorCombine::foldShuffleToIdentity(Instruction &I) { auto *Ty = dyn_cast<FixedVectorType>(I.getType()); - if (!Ty) + if (!Ty || I.use_empty()) return false; SmallVector<InstLane> Start(Ty->getNumElements()); for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M) - Start[M] = lookThroughShuffles(&I, M); + Start[M] = lookThroughShuffles(&*I.use_begin(), M); SmallVector<SmallVector<InstLane>> Worklist; Worklist.push_back(Start); - SmallPtrSet<Value *, 4> IdentityLeafs, SplatLeafs; + SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs; unsigned NumVisited = 0; while (!Worklist.empty()) { @@ -1802,52 +1848,52 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) { return false; SmallVector<InstLane> Item = Worklist.pop_back_val(); - auto [FrontV, FrontLane] = Item.front(); + auto [FrontU, FrontLane] = Item.front(); // If we found an undef first lane then bail out to keep things simple. - if (!FrontV) + if (!FrontU) return false; // Look for an identity value. - if (!FrontLane && - cast<FixedVectorType>(FrontV->getType())->getNumElements() == + if (FrontLane == 0 && + cast<FixedVectorType>(FrontU->get()->getType())->getNumElements() == Ty->getNumElements() && all_of(drop_begin(enumerate(Item)), [Item](const auto &E) { - Value *FrontV = Item.front().first; - return !E.value().first || (E.value().first == FrontV && + Value *FrontV = Item.front().first->get(); + return !E.value().first || (E.value().first->get() == FrontV && E.value().second == (int)E.index()); })) { - IdentityLeafs.insert(FrontV); + IdentityLeafs.insert(FrontU); continue; } // Look for constants, for the moment only supporting constant splats. - if (auto *C = dyn_cast<Constant>(FrontV); + if (auto *C = dyn_cast<Constant>(FrontU); C && C->getSplatValue() && all_of(drop_begin(Item), [Item](InstLane &IL) { - Value *FrontV = Item.front().first; - Value *V = IL.first; - return !V || V == FrontV; + Value *FrontV = Item.front().first->get(); + Use *U = IL.first; + return !U || U->get() == FrontV; })) { - SplatLeafs.insert(FrontV); + SplatLeafs.insert(FrontU); continue; } // Look for a splat value. if (all_of(drop_begin(Item), [Item](InstLane &IL) { - auto [FrontV, FrontLane] = Item.front(); - auto [V, Lane] = IL; - return !V || (V == FrontV && Lane == FrontLane); + auto [FrontU, FrontLane] = Item.front(); + auto [U, Lane] = IL; + return !U || (U->get() == FrontU->get() && Lane == FrontLane); })) { - SplatLeafs.insert(FrontV); + SplatLeafs.insert(FrontU); continue; } // We need each element to be the same type of value, and check that each // element has a single use. - if (!all_of(drop_begin(Item), [Item](InstLane IL) { - Value *FrontV = Item.front().first; - Value *V = IL.first; - if (!V) + if (all_of(drop_begin(Item), [Item](InstLane IL) { + Value *FrontV = Item.front().first->get(); + if (!IL.first) return true; + Value *V = IL.first->get(); if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse()) return false; if (V->getValueID() != FrontV->getValueID()) @@ -1864,40 +1910,49 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) { return !II || (isa<IntrinsicInst>(FrontV) && II->getIntrinsicID() == cast<IntrinsicInst>(FrontV)->getIntrinsicID()); - })) - return false; - - // Check the operator is one that we support. We exclude div/rem in case - // they hit UB from poison lanes. - if ((isa<BinaryOperator>(FrontV) && - !cast<BinaryOperator>(FrontV)->isIntDivRem()) || - isa<CmpInst>(FrontV)) { - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); - } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontV)) { - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); - } else if (isa<SelectInst>(FrontV)) { - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); - Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2)); - } else if (auto *II = dyn_cast<IntrinsicInst>(FrontV); - II && isTriviallyVectorizable(II->getIntrinsicID())) { - for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) { - if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) { - if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) { - Value *FrontV = Item.front().first; - Value *V = IL.first; - return !V || (cast<Instruction>(V)->getOperand(Op) == - cast<Instruction>(FrontV)->getOperand(Op)); - })) - return false; - continue; + })) { + // Check the operator is one that we support. + if (isa<BinaryOperator, CmpInst>(FrontU)) { + // We exclude div/rem in case they hit UB from poison lanes. + if (auto *BO = dyn_cast<BinaryOperator>(FrontU); + BO && BO->isIntDivRem()) + return false; + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); + continue; + } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontU)) { + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); + continue; + } else if (isa<SelectInst>(FrontU)) { + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2)); + continue; + } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU); + II && isTriviallyVectorizable(II->getIntrinsicID())) { + for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) { + if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) { + if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) { + Value *FrontV = Item.front().first->get(); + Use *U = IL.first; + return !U || (cast<Instruction>(U->get())->getOperand(Op) == + cast<Instruction>(FrontV)->getOperand(Op)); + })) + return false; + continue; + } + Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op)); } - Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op)); + continue; } - } else { - return false; } + + if (isFreeConcat(Item, TTI)) { + ConcatLeafs.insert(FrontU); + continue; + } + + return false; } if (NumVisited <= 1) @@ -1905,7 +1960,9 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) { // If we got this far, we know the shuffles are superfluous and can be // removed. Scan through again and generate the new tree of instructions. - Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, Builder); + Builder.SetInsertPoint(&I); + Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, + ConcatLeafs, Builder); replaceValue(I, *V); return true; } |
