diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/Scalarizer.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/Scalarizer.cpp | 104 |
1 files changed, 103 insertions, 1 deletions
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp index b1e4c7e52d99..772f4c6c35dd 100644 --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -197,6 +197,24 @@ struct VectorLayout { uint64_t SplitSize = 0; }; +static bool isStructOfMatchingFixedVectors(Type *Ty) { + if (!isa<StructType>(Ty)) + return false; + unsigned StructSize = Ty->getNumContainedTypes(); + if (StructSize < 1) + return false; + FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(0)); + if (!VecTy) + return false; + unsigned VecSize = VecTy->getNumElements(); + for (unsigned I = 1; I < StructSize; I++) { + VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(I)); + if (!VecTy || VecSize != VecTy->getNumElements()) + return false; + } + return true; +} + /// Concatenate the given fragments to a single vector value of the type /// described in @p VS. static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments, @@ -276,6 +294,7 @@ public: bool visitBitCastInst(BitCastInst &BCI); bool visitInsertElementInst(InsertElementInst &IEI); bool visitExtractElementInst(ExtractElementInst &EEI); + bool visitExtractValueInst(ExtractValueInst &EVI); bool visitShuffleVectorInst(ShuffleVectorInst &SVI); bool visitPHINode(PHINode &PHI); bool visitLoadInst(LoadInst &LI); @@ -667,6 +686,12 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) { bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) { if (isTriviallyVectorizable(ID)) return true; + // TODO: Move frexp to isTriviallyVectorizable. + // https://github.com/llvm/llvm-project/issues/112408 + switch (ID) { + case Intrinsic::frexp: + return true; + } return Intrinsic::isTargetIntrinsic(ID) && TTI->isTargetIntrinsicTriviallyScalarizable(ID); } @@ -674,7 +699,13 @@ bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) { /// If a call to a vector typed intrinsic function, split into a scalar call per /// element if possible for the intrinsic. bool ScalarizerVisitor::splitCall(CallInst &CI) { - std::optional<VectorSplit> VS = getVectorSplit(CI.getType()); + Type *CallType = CI.getType(); + bool AreAllVectorsOfMatchingSize = isStructOfMatchingFixedVectors(CallType); + std::optional<VectorSplit> VS; + if (AreAllVectorsOfMatchingSize) + VS = getVectorSplit(CallType->getContainedType(0)); + else + VS = getVectorSplit(CallType); if (!VS) return false; @@ -699,6 +730,23 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)) Tys.push_back(VS->SplitTy); + if (AreAllVectorsOfMatchingSize) { + for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) { + std::optional<VectorSplit> CurrVS = + getVectorSplit(cast<FixedVectorType>(CallType->getContainedType(I))); + // This case does not seem to happen, but it is possible for + // VectorSplit.NumPacked >= NumElems. If that happens a VectorSplit + // is not returned and we will bailout of handling this call. + // The secondary bailout case is if NumPacked does not match. + // This can happen if ScalarizeMinBits is not set to the default. + // This means with certain ScalarizeMinBits intrinsics like frexp + // will only scalarize when the struct elements have the same bitness. + if (!CurrVS || CurrVS->NumPacked != VS->NumPacked) + return false; + if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I)) + Tys.push_back(CurrVS->SplitTy); + } + } // Assumes that any vector type has the same number of elements as the return // vector type, which is true for all current intrinsics. for (unsigned I = 0; I != NumArgs; ++I) { @@ -1030,6 +1078,31 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) { return true; } +bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) { + Value *Op = EVI.getOperand(0); + Type *OpTy = Op->getType(); + ValueVector Res; + if (!isStructOfMatchingFixedVectors(OpTy)) + return false; + Type *VecType = cast<FixedVectorType>(OpTy->getContainedType(0)); + std::optional<VectorSplit> VS = getVectorSplit(VecType); + if (!VS) + return false; + IRBuilder<> Builder(&EVI); + Scatterer Op0 = scatter(&EVI, Op, *VS); + assert(!EVI.getIndices().empty() && "Make sure an index exists"); + // Note for our use case we only care about the top level index. + unsigned Index = EVI.getIndices()[0]; + for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) { + Value *ResElem = Builder.CreateExtractValue( + Op0[OpIdx], Index, EVI.getName() + ".elem" + Twine(Index)); + Res.push_back(ResElem); + } + + gather(&EVI, Res, *VS); + return true; +} + bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType()); if (!VS) @@ -1209,6 +1282,35 @@ bool ScalarizerVisitor::finish() { Res = concatenate(Builder, CV, VS, Op->getName()); Res->takeName(Op); + } else if (auto *Ty = dyn_cast<StructType>(Op->getType())) { + BasicBlock *BB = Op->getParent(); + IRBuilder<> Builder(Op); + if (isa<PHINode>(Op)) + Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); + + // Iterate over each element in the struct + unsigned NumOfStructElements = Ty->getNumElements(); + SmallVector<ValueVector, 4> ElemCV(NumOfStructElements); + for (unsigned I = 0; I < NumOfStructElements; ++I) { + for (auto *CVelem : CV) { + Value *Elem = Builder.CreateExtractValue( + CVelem, I, Op->getName() + ".elem" + Twine(I)); + ElemCV[I].push_back(Elem); + } + } + Res = PoisonValue::get(Ty); + for (unsigned I = 0; I < NumOfStructElements; ++I) { + Type *ElemTy = Ty->getElementType(I); + assert(isa<FixedVectorType>(ElemTy) && + "Only Structs of all FixedVectorType supported"); + VectorSplit VS = *getVectorSplit(ElemTy); + assert(VS.NumFragments == CV.size()); + + Value *ConcatenatedVector = + concatenate(Builder, ElemCV[I], VS, Op->getName()); + Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I, + Op->getName() + ".insert"); + } } else { assert(CV.size() == 1 && Op->getType() == CV[0]->getType()); Res = CV[0]; |
