diff options
Diffstat (limited to 'llvm/lib/Target/DirectX')
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILDataScalarization.cpp | 9 | ||||
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 308 | ||||
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILLegalizePass.cpp | 74 | ||||
| -rw-r--r-- | llvm/lib/Target/DirectX/DirectXTargetMachine.cpp | 2 |
4 files changed, 230 insertions, 163 deletions
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index c97c604fdbf7..d9d9b36d0b73 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -202,7 +202,7 @@ DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec, // original vector's defining instruction if available, else immediately after // the alloca if (auto *Instr = dyn_cast<Instruction>(Vec)) - Builder.SetInsertPoint(Instr->getNextNonDebugInstruction()); + Builder.SetInsertPoint(Instr->getNextNode()); SmallVector<Value *, 4> GEPs(ArrNumElems); for (unsigned I = 0; I < ArrNumElems; ++I) { Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract"); @@ -302,7 +302,7 @@ bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { Value *PtrOperand = GEPI.getPointerOperand(); - Type *OrigGEPType = GEPI.getPointerOperandType(); + Type *OrigGEPType = GEPI.getSourceElementType(); Type *NewGEPType = OrigGEPType; bool NeedsTransform = false; @@ -319,6 +319,11 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { } } + // Scalar geps should remain scalars geps. The dxil-flatten-arrays pass will + // convert these scalar geps into flattened array geps + if (!isa<ArrayType>(OrigGEPType)) + NewGEPType = OrigGEPType; + // Note: We bail if this isn't a gep touched via alloca or global // transformations if (!NeedsTransform) diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index 0b7cf2f97017..f0e2e786dfaf 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -20,6 +20,7 @@ #include "llvm/IR/InstVisitor.h" #include "llvm/IR/ReplaceConstant.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstddef> @@ -40,18 +41,19 @@ public: static char ID; // Pass identification. }; -struct GEPData { - ArrayType *ParentArrayType; - Value *ParentOperand; - SmallVector<Value *> Indices; - SmallVector<uint64_t> Dims; - bool AllIndicesAreConstInt; +struct GEPInfo { + ArrayType *RootFlattenedArrayType; + Value *RootPointerOperand; + SmallMapVector<Value *, APInt, 4> VariableOffsets; + APInt ConstantOffset; }; class DXILFlattenArraysVisitor : public InstVisitor<DXILFlattenArraysVisitor, bool> { public: - DXILFlattenArraysVisitor() {} + DXILFlattenArraysVisitor( + SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) + : GlobalMap(GlobalMap) {} bool visit(Function &F); // InstVisitor methods. They return true if the instruction was scalarized, // false if nothing changed. @@ -78,7 +80,8 @@ public: private: SmallVector<WeakTrackingVH> PotentiallyDeadInstrs; - DenseMap<GetElementPtrInst *, GEPData> GEPChainMap; + SmallDenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap; + SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap; bool finish(); ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, @@ -86,27 +89,11 @@ private: Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder); - - // Helper function to collect indices and dimensions from a GEP instruction - void collectIndicesAndDimsFromGEP(GetElementPtrInst &GEP, - SmallVectorImpl<Value *> &Indices, - SmallVectorImpl<uint64_t> &Dims, - bool &AllIndicesAreConstInt); - - void - recursivelyCollectGEPs(GetElementPtrInst &CurrGEP, - ArrayType *FlattenedArrayType, Value *PtrOperand, - unsigned &GEPChainUseCount, - SmallVector<Value *> Indices = SmallVector<Value *>(), - SmallVector<uint64_t> Dims = SmallVector<uint64_t>(), - bool AllIndicesAreConstInt = true); - bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP); - bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo, - GetElementPtrInst &GEP); }; } // namespace bool DXILFlattenArraysVisitor::finish() { + GEPChainInfoMap.clear(); RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); return true; } @@ -225,131 +212,159 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) { return true; } -void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP( - GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices, - SmallVectorImpl<uint64_t> &Dims, bool &AllIndicesAreConstInt) { - - Type *CurrentType = GEP.getSourceElementType(); - - // Note index 0 is the ptr index. - for (Value *Index : llvm::drop_begin(GEP.indices(), 1)) { - Indices.push_back(Index); - AllIndicesAreConstInt &= isa<ConstantInt>(Index); +bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { + // Do not visit GEPs more than once + if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP))) + return false; - if (auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) { - Dims.push_back(ArrayTy->getNumElements()); - CurrentType = ArrayTy->getElementType(); - } else { - assert(false && "Expected array type in GEP chain"); - } + Value *PtrOperand = GEP.getPointerOperand(); + // It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI + // node unless HLSL has pointers. If this assumption is incorrect or HLSL gets + // pointer types, then the handling of this case can be implemented later. + assert(!isa<PHINode>(PtrOperand) && + "Pointer operand of GEP should not be a PHI Node"); + + // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that + // it can be visited + if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand); + PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) { + GetElementPtrInst *OldGEPI = + cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction()); + OldGEPI->insertBefore(GEP.getIterator()); + + IRBuilder<> Builder(&GEP); + SmallVector<Value *> Indices(GEP.indices()); + Value *NewGEP = + Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices, + GEP.getName(), GEP.getNoWrapFlags()); + assert(isa<GetElementPtrInst>(NewGEP) && + "Expected newly-created GEP to be an instruction"); + GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP); + + GEP.replaceAllUsesWith(NewGEPI); + GEP.eraseFromParent(); + visitGetElementPtrInst(*OldGEPI); + visitGetElementPtrInst(*NewGEPI); + return true; } -} - -void DXILFlattenArraysVisitor::recursivelyCollectGEPs( - GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType, - Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices, - SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) { - // Check if this GEP is already in the map to avoid circular references - if (GEPChainMap.count(&CurrGEP) > 0) - return; - // Collect indices and dimensions from the current GEP - collectIndicesAndDimsFromGEP(CurrGEP, Indices, Dims, AllIndicesAreConstInt); - bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType()); - if (!IsMultiDimArr) { - assert(GEPChainUseCount < FlattenedArrayType->getNumElements()); - GEPChainMap.insert( - {&CurrGEP, - {std::move(FlattenedArrayType), PtrOperand, std::move(Indices), - std::move(Dims), AllIndicesAreConstInt}}); - return; - } - bool GepUses = false; - for (auto *User : CurrGEP.users()) { - if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) { - recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand, - ++GEPChainUseCount, Indices, Dims, - AllIndicesAreConstInt); - GepUses = true; - } - } - // This case is just incase the gep chain doesn't end with a 1d array. - if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) { - GEPChainMap.insert( - {&CurrGEP, - {std::move(FlattenedArrayType), PtrOperand, std::move(Indices), - std::move(Dims), AllIndicesAreConstInt}}); + // Construct GEPInfo for this GEP + GEPInfo Info; + + // Obtain the variable and constant byte offsets computed by this GEP + const DataLayout &DL = GEP.getDataLayout(); + unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType()); + Info.ConstantOffset = {BitWidth, 0}; + [[maybe_unused]] bool Success = GEP.collectOffset( + DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset); + assert(Success && "Failed to collect offsets for GEP"); + + // If there is a parent GEP, inherit the root array type and pointer, and + // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP + // chain and we need to deterine the root array type + if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) { + assert(GEPChainInfoMap.contains(PtrOpGEP) && + "Expected parent GEP to be visited before this GEP"); + GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP]; + Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType; + Info.RootPointerOperand = PGEPInfo.RootPointerOperand; + for (auto &VariableOffset : PGEPInfo.VariableOffsets) + Info.VariableOffsets.insert(VariableOffset); + Info.ConstantOffset += PGEPInfo.ConstantOffset; + } else { + Info.RootPointerOperand = PtrOperand; + + // We should try to determine the type of the root from the pointer rather + // than the GEP's source element type because this could be a scalar GEP + // into an array-typed pointer from an Alloca or Global Variable. + Type *RootTy = GEP.getSourceElementType(); + if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) { + if (GlobalMap.contains(GlobalVar)) + GlobalVar = GlobalMap[GlobalVar]; + Info.RootPointerOperand = GlobalVar; + RootTy = GlobalVar->getValueType(); + } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand)) + RootTy = Alloca->getAllocatedType(); + assert(!isMultiDimensionalArray(RootTy) && + "Expected root array type to be flattened"); + + // If the root type is not an array, we don't need to do any flattening + if (!isa<ArrayType>(RootTy)) + return false; + + Info.RootFlattenedArrayType = cast<ArrayType>(RootTy); } -} -bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain( - GetElementPtrInst &GEP) { - GEPData GEPInfo = GEPChainMap.at(&GEP); - return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP); -} -bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase( - GEPData &GEPInfo, GetElementPtrInst &GEP) { - IRBuilder<> Builder(&GEP); - Value *FlatIndex; - if (GEPInfo.AllIndicesAreConstInt) - FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); - else - FlatIndex = - genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); - - ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType; - - // Don't append '.flat' to an empty string. If the SSA name isn't available - // it could conflict with the ParentOperand's name. - std::string FlatName = GEP.hasName() ? GEP.getName().str() + ".flat" : ""; - - Value *FlatGEP = Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParentOperand, - {Builder.getInt32(0), FlatIndex}, FlatName, - GEP.getNoWrapFlags()); - - // Note: Old gep will become an invalid instruction after replaceAllUsesWith. - // Erase the old GEP in the map before to avoid invalid instructions - // and circular references. - GEPChainMap.erase(&GEP); - - GEP.replaceAllUsesWith(FlatGEP); - GEP.eraseFromParent(); - return true; -} - -bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { - auto It = GEPChainMap.find(&GEP); - if (It != GEPChainMap.end()) - return visitGetElementPtrInstInGEPChain(GEP); - if (!isMultiDimensionalArray(GEP.getSourceElementType())) - return false; - - ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType()); - IRBuilder<> Builder(&GEP); - auto [TotalElements, BaseType] = getElementCountAndType(ArrType); - ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements); - - Value *PtrOperand = GEP.getPointerOperand(); + // GEPs without users or GEPs with non-GEP users should be replaced such that + // the chain of GEPs they are a part of are collapsed to a single GEP into a + // flattened array. + bool ReplaceThisGEP = GEP.users().empty(); + for (Value *User : GEP.users()) + if (!isa<GetElementPtrInst>(User)) + ReplaceThisGEP = true; + + if (ReplaceThisGEP) { + unsigned BytesPerElem = + DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType()); + assert(isPowerOf2_32(BytesPerElem) && + "Bytes per element should be a power of 2"); + + // Compute the 32-bit index for this flattened GEP from the constant and + // variable byte offsets in the GEPInfo + IRBuilder<> Builder(&GEP); + Value *ZeroIndex = Builder.getInt32(0); + uint64_t ConstantOffset = + Info.ConstantOffset.udiv(BytesPerElem).getZExtValue(); + assert(ConstantOffset < UINT32_MAX && + "Constant byte offset for flat GEP index must fit within 32 bits"); + Value *FlattenedIndex = Builder.getInt32(ConstantOffset); + for (auto [VarIndex, Multiplier] : Info.VariableOffsets) { + assert(Multiplier.getActiveBits() <= 32 && + "The multiplier for a flat GEP index must fit within 32 bits"); + assert(VarIndex->getType()->isIntegerTy(32) && + "Expected i32-typed GEP indices"); + Value *VI; + if (Multiplier.getZExtValue() % BytesPerElem != 0) { + // This can happen, e.g., with i8 GEPs. To handle this we just divide + // by BytesPerElem using an instruction after multiplying VarIndex by + // Multiplier. + VI = Builder.CreateMul(VarIndex, + Builder.getInt32(Multiplier.getZExtValue())); + VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem))); + } else + VI = Builder.CreateMul( + VarIndex, + Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem)); + FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI); + } - unsigned GEPChainUseCount = 0; - recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount); - - // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0. - // Here recursion is used to get the length of the GEP chain. - // Handle zero uses here because there won't be an update via - // a child in the chain later. - if (GEPChainUseCount == 0) { - SmallVector<Value *> Indices; - SmallVector<uint64_t> Dims; - bool AllIndicesAreConstInt = true; - - // Collect indices and dimensions from the GEP - collectIndicesAndDimsFromGEP(GEP, Indices, Dims, AllIndicesAreConstInt); - GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand, - std::move(Indices), std::move(Dims), AllIndicesAreConstInt}; - return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP); + // Construct a new GEP for the flattened array to replace the current GEP + Value *NewGEP = Builder.CreateGEP( + Info.RootFlattenedArrayType, Info.RootPointerOperand, + {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags()); + + // If the pointer operand is a global variable and all indices are 0, + // IRBuilder::CreateGEP will return the global variable instead of creating + // a GEP instruction or GEP ConstantExpr. In this case we have to create and + // insert our own GEP instruction. + if (!isa<GEPOperator>(NewGEP)) + NewGEP = GetElementPtrInst::Create( + Info.RootFlattenedArrayType, Info.RootPointerOperand, + {ZeroIndex, FlattenedIndex}, GEP.getNoWrapFlags(), GEP.getName(), + Builder.GetInsertPoint()); + + // Replace the current GEP with the new GEP. Store GEPInfo into the map + // for later use in case this GEP was not the end of the chain + GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)}); + GEP.replaceAllUsesWith(NewGEP); + GEP.eraseFromParent(); + return true; } + // This GEP is potentially dead at the end of the pass since it may not have + // any users anymore after GEP chains have been collapsed. We retain store + // GEPInfo for GEPs down the chain to use to compute their indices. + GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)}); PotentiallyDeadInstrs.emplace_back(&GEP); return false; } @@ -416,9 +431,8 @@ static Constant *transformInitializer(Constant *Init, Type *OrigType, return ConstantArray::get(FlattenedType, FlattenedElements); } -static void -flattenGlobalArrays(Module &M, - DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) { +static void flattenGlobalArrays( + Module &M, SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) { LLVMContext &Ctx = M.getContext(); for (GlobalVariable &G : M.globals()) { Type *OrigType = G.getValueType(); @@ -456,9 +470,9 @@ flattenGlobalArrays(Module &M, static bool flattenArrays(Module &M) { bool MadeChange = false; - DXILFlattenArraysVisitor Impl; - DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap; + SmallDenseMap<GlobalVariable *, GlobalVariable *> GlobalMap; flattenGlobalArrays(M, GlobalMap); + DXILFlattenArraysVisitor Impl(GlobalMap); for (auto &F : make_early_inc_range(M.functions())) { if (F.isDeclaration()) continue; diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index 76a46c7a2b76..c73648f21e8d 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -98,9 +98,9 @@ static void fixI8UseChain(Instruction &I, ElementType = AI->getAllocatedType(); if (auto *GEP = dyn_cast<GetElementPtrInst>(NewOperands[0])) { ElementType = GEP->getSourceElementType(); - if (ElementType->isArrayTy()) - ElementType = ElementType->getArrayElementType(); } + if (ElementType->isArrayTy()) + ElementType = ElementType->getArrayElementType(); LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]); ReplacedValues[Load] = NewLoad; ToRemove.push_back(Load); @@ -347,7 +347,6 @@ static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src, if (ByteLength == 0) return; - LLVMContext &Ctx = Builder.getContext(); const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout(); auto GetArrTyFromVal = [](Value *Val) -> ArrayType * { @@ -392,10 +391,11 @@ static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src, assert(ByteLength % DstElemByteSize == 0 && "memcpy length must be divisible by array element type"); for (uint64_t I = 0; I < NumElemsToCopy; ++I) { - Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I); - Value *SrcPtr = Builder.CreateInBoundsGEP(SrcElemTy, Src, Offset, "gep"); + SmallVector<Value *, 2> Indices = {Builder.getInt32(0), + Builder.getInt32(I)}; + Value *SrcPtr = Builder.CreateInBoundsGEP(SrcArrTy, Src, Indices, "gep"); Value *SrcVal = Builder.CreateLoad(SrcElemTy, SrcPtr); - Value *DstPtr = Builder.CreateInBoundsGEP(DstElemTy, Dst, Offset, "gep"); + Value *DstPtr = Builder.CreateInBoundsGEP(DstArrTy, Dst, Indices, "gep"); Builder.CreateStore(SrcVal, DstPtr); } } @@ -403,7 +403,6 @@ static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src, static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val, ConstantInt *SizeCI, DenseMap<Value *, Value *> &ReplacedValues) { - LLVMContext &Ctx = Builder.getContext(); [[maybe_unused]] const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout(); [[maybe_unused]] uint64_t OrigSize = SizeCI->getZExtValue(); @@ -444,8 +443,9 @@ static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val, } for (uint64_t I = 0; I < Size; ++I) { - Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I); - Value *Ptr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep"); + Value *Zero = Builder.getInt32(0); + Value *Offset = Builder.getInt32(I); + Value *Ptr = Builder.CreateGEP(ArrTy, Dst, {Zero, Offset}, "gep"); Builder.CreateStore(TypedVal, Ptr); } } @@ -478,9 +478,9 @@ static void legalizeMemCpy(Instruction &I, ToRemove.push_back(CI); } -static void removeMemSet(Instruction &I, - SmallVectorImpl<Instruction *> &ToRemove, - DenseMap<Value *, Value *> &ReplacedValues) { +static void legalizeMemSet(Instruction &I, + SmallVectorImpl<Instruction *> &ToRemove, + DenseMap<Value *, Value *> &ReplacedValues) { CallInst *CI = dyn_cast<CallInst>(&I); if (!CI) @@ -562,6 +562,53 @@ legalizeGetHighLowi64Bytes(Instruction &I, } } +static void +legalizeScalarLoadStoreOnArrays(Instruction &I, + SmallVectorImpl<Instruction *> &ToRemove, + DenseMap<Value *, Value *> &) { + + Value *PtrOp; + unsigned PtrOpIndex; + [[maybe_unused]] Type *LoadStoreTy; + if (auto *LI = dyn_cast<LoadInst>(&I)) { + PtrOp = LI->getPointerOperand(); + PtrOpIndex = LI->getPointerOperandIndex(); + LoadStoreTy = LI->getType(); + } else if (auto *SI = dyn_cast<StoreInst>(&I)) { + PtrOp = SI->getPointerOperand(); + PtrOpIndex = SI->getPointerOperandIndex(); + LoadStoreTy = SI->getValueOperand()->getType(); + } else + return; + + // If the load/store is not of a single-value type (i.e., scalar or vector) + // then we do not modify it. It shouldn't be a vector either because the + // dxil-data-scalarization pass is expected to run before this, but it's not + // incorrect to apply this transformation to vector load/stores. + if (!LoadStoreTy->isSingleValueType()) + return; + + Type *ArrayTy; + if (auto *GlobalVarPtrOp = dyn_cast<GlobalVariable>(PtrOp)) + ArrayTy = GlobalVarPtrOp->getValueType(); + else if (auto *AllocaPtrOp = dyn_cast<AllocaInst>(PtrOp)) + ArrayTy = AllocaPtrOp->getAllocatedType(); + else + return; + + if (!isa<ArrayType>(ArrayTy)) + return; + + assert(ArrayTy->getArrayElementType() == LoadStoreTy && + "Expected array element type to be the same as to the scalar load or " + "store type"); + + Value *Zero = ConstantInt::get(Type::getInt32Ty(I.getContext()), 0); + Value *GEP = GetElementPtrInst::Create( + ArrayTy, PtrOp, {Zero, Zero}, GEPNoWrapFlags::all(), "", I.getIterator()); + I.setOperand(PtrOpIndex, GEP); +} + namespace { class DXILLegalizationPipeline { @@ -603,7 +650,7 @@ private: LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes); LegalizationPipeline[Stage1].push_back(legalizeFreeze); LegalizationPipeline[Stage1].push_back(legalizeMemCpy); - LegalizationPipeline[Stage1].push_back(removeMemSet); + LegalizationPipeline[Stage1].push_back(legalizeMemSet); LegalizationPipeline[Stage1].push_back(updateFnegToFsub); // Note: legalizeGetHighLowi64Bytes and // downcastI64toI32InsertExtractElements both modify extractelement, so they @@ -612,6 +659,7 @@ private: // downcastI64toI32InsertExtractElements needs to handle. LegalizationPipeline[Stage2].push_back( downcastI64toI32InsertExtractElements); + LegalizationPipeline[Stage2].push_back(legalizeScalarLoadStoreOnArrays); } }; diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index 40fe6c6e639e..84751d2db226 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -107,10 +107,10 @@ public: addPass(createDXILIntrinsicExpansionLegacyPass()); addPass(createDXILCBufferAccessLegacyPass()); addPass(createDXILDataScalarizationLegacyPass()); - addPass(createDXILFlattenArraysLegacyPass()); ScalarizerPassOptions DxilScalarOptions; DxilScalarOptions.ScalarizeLoadStore = true; addPass(createScalarizerPass(DxilScalarOptions)); + addPass(createDXILFlattenArraysLegacyPass()); addPass(createDXILForwardHandleAccessesLegacyPass()); addPass(createDXILLegalizeLegacyPass()); addPass(createDXILResourceImplicitBindingLegacyPass()); |
