summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/DirectX')
-rw-r--r--llvm/lib/Target/DirectX/DXILDataScalarization.cpp9
-rw-r--r--llvm/lib/Target/DirectX/DXILFlattenArrays.cpp308
-rw-r--r--llvm/lib/Target/DirectX/DXILLegalizePass.cpp74
-rw-r--r--llvm/lib/Target/DirectX/DirectXTargetMachine.cpp2
4 files changed, 230 insertions, 163 deletions
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index c97c604fdbf7..d9d9b36d0b73 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -202,7 +202,7 @@ DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
// original vector's defining instruction if available, else immediately after
// the alloca
if (auto *Instr = dyn_cast<Instruction>(Vec))
- Builder.SetInsertPoint(Instr->getNextNonDebugInstruction());
+ Builder.SetInsertPoint(Instr->getNextNode());
SmallVector<Value *, 4> GEPs(ArrNumElems);
for (unsigned I = 0; I < ArrNumElems; ++I) {
Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract");
@@ -302,7 +302,7 @@ bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
Value *PtrOperand = GEPI.getPointerOperand();
- Type *OrigGEPType = GEPI.getPointerOperandType();
+ Type *OrigGEPType = GEPI.getSourceElementType();
Type *NewGEPType = OrigGEPType;
bool NeedsTransform = false;
@@ -319,6 +319,11 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
}
}
+ // Scalar geps should remain scalars geps. The dxil-flatten-arrays pass will
+ // convert these scalar geps into flattened array geps
+ if (!isa<ArrayType>(OrigGEPType))
+ NewGEPType = OrigGEPType;
+
// Note: We bail if this isn't a gep touched via alloca or global
// transformations
if (!NeedsTransform)
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index 0b7cf2f97017..f0e2e786dfaf 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -20,6 +20,7 @@
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
#include <cstddef>
@@ -40,18 +41,19 @@ public:
static char ID; // Pass identification.
};
-struct GEPData {
- ArrayType *ParentArrayType;
- Value *ParentOperand;
- SmallVector<Value *> Indices;
- SmallVector<uint64_t> Dims;
- bool AllIndicesAreConstInt;
+struct GEPInfo {
+ ArrayType *RootFlattenedArrayType;
+ Value *RootPointerOperand;
+ SmallMapVector<Value *, APInt, 4> VariableOffsets;
+ APInt ConstantOffset;
};
class DXILFlattenArraysVisitor
: public InstVisitor<DXILFlattenArraysVisitor, bool> {
public:
- DXILFlattenArraysVisitor() {}
+ DXILFlattenArraysVisitor(
+ SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
+ : GlobalMap(GlobalMap) {}
bool visit(Function &F);
// InstVisitor methods. They return true if the instruction was scalarized,
// false if nothing changed.
@@ -78,7 +80,8 @@ public:
private:
SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
- DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
+ SmallDenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
+ SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
bool finish();
ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
ArrayRef<uint64_t> Dims,
@@ -86,27 +89,11 @@ private:
Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
ArrayRef<uint64_t> Dims,
IRBuilder<> &Builder);
-
- // Helper function to collect indices and dimensions from a GEP instruction
- void collectIndicesAndDimsFromGEP(GetElementPtrInst &GEP,
- SmallVectorImpl<Value *> &Indices,
- SmallVectorImpl<uint64_t> &Dims,
- bool &AllIndicesAreConstInt);
-
- void
- recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
- ArrayType *FlattenedArrayType, Value *PtrOperand,
- unsigned &GEPChainUseCount,
- SmallVector<Value *> Indices = SmallVector<Value *>(),
- SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
- bool AllIndicesAreConstInt = true);
- bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
- bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
- GetElementPtrInst &GEP);
};
} // namespace
bool DXILFlattenArraysVisitor::finish() {
+ GEPChainInfoMap.clear();
RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
return true;
}
@@ -225,131 +212,159 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
return true;
}
-void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP(
- GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices,
- SmallVectorImpl<uint64_t> &Dims, bool &AllIndicesAreConstInt) {
-
- Type *CurrentType = GEP.getSourceElementType();
-
- // Note index 0 is the ptr index.
- for (Value *Index : llvm::drop_begin(GEP.indices(), 1)) {
- Indices.push_back(Index);
- AllIndicesAreConstInt &= isa<ConstantInt>(Index);
+bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+ // Do not visit GEPs more than once
+ if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))
+ return false;
- if (auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) {
- Dims.push_back(ArrayTy->getNumElements());
- CurrentType = ArrayTy->getElementType();
- } else {
- assert(false && "Expected array type in GEP chain");
- }
+ Value *PtrOperand = GEP.getPointerOperand();
+ // It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI
+ // node unless HLSL has pointers. If this assumption is incorrect or HLSL gets
+ // pointer types, then the handling of this case can be implemented later.
+ assert(!isa<PHINode>(PtrOperand) &&
+ "Pointer operand of GEP should not be a PHI Node");
+
+ // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
+ // it can be visited
+ if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
+ PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
+ GetElementPtrInst *OldGEPI =
+ cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());
+ OldGEPI->insertBefore(GEP.getIterator());
+
+ IRBuilder<> Builder(&GEP);
+ SmallVector<Value *> Indices(GEP.indices());
+ Value *NewGEP =
+ Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
+ GEP.getName(), GEP.getNoWrapFlags());
+ assert(isa<GetElementPtrInst>(NewGEP) &&
+ "Expected newly-created GEP to be an instruction");
+ GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
+
+ GEP.replaceAllUsesWith(NewGEPI);
+ GEP.eraseFromParent();
+ visitGetElementPtrInst(*OldGEPI);
+ visitGetElementPtrInst(*NewGEPI);
+ return true;
}
-}
-
-void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
- GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
- Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
- SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
- // Check if this GEP is already in the map to avoid circular references
- if (GEPChainMap.count(&CurrGEP) > 0)
- return;
- // Collect indices and dimensions from the current GEP
- collectIndicesAndDimsFromGEP(CurrGEP, Indices, Dims, AllIndicesAreConstInt);
- bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
- if (!IsMultiDimArr) {
- assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
- GEPChainMap.insert(
- {&CurrGEP,
- {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
- std::move(Dims), AllIndicesAreConstInt}});
- return;
- }
- bool GepUses = false;
- for (auto *User : CurrGEP.users()) {
- if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
- recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
- ++GEPChainUseCount, Indices, Dims,
- AllIndicesAreConstInt);
- GepUses = true;
- }
- }
- // This case is just incase the gep chain doesn't end with a 1d array.
- if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
- GEPChainMap.insert(
- {&CurrGEP,
- {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
- std::move(Dims), AllIndicesAreConstInt}});
+ // Construct GEPInfo for this GEP
+ GEPInfo Info;
+
+ // Obtain the variable and constant byte offsets computed by this GEP
+ const DataLayout &DL = GEP.getDataLayout();
+ unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
+ Info.ConstantOffset = {BitWidth, 0};
+ [[maybe_unused]] bool Success = GEP.collectOffset(
+ DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset);
+ assert(Success && "Failed to collect offsets for GEP");
+
+ // If there is a parent GEP, inherit the root array type and pointer, and
+ // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
+ // chain and we need to deterine the root array type
+ if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
+ assert(GEPChainInfoMap.contains(PtrOpGEP) &&
+ "Expected parent GEP to be visited before this GEP");
+ GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
+ Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
+ Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
+ for (auto &VariableOffset : PGEPInfo.VariableOffsets)
+ Info.VariableOffsets.insert(VariableOffset);
+ Info.ConstantOffset += PGEPInfo.ConstantOffset;
+ } else {
+ Info.RootPointerOperand = PtrOperand;
+
+ // We should try to determine the type of the root from the pointer rather
+ // than the GEP's source element type because this could be a scalar GEP
+ // into an array-typed pointer from an Alloca or Global Variable.
+ Type *RootTy = GEP.getSourceElementType();
+ if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
+ if (GlobalMap.contains(GlobalVar))
+ GlobalVar = GlobalMap[GlobalVar];
+ Info.RootPointerOperand = GlobalVar;
+ RootTy = GlobalVar->getValueType();
+ } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand))
+ RootTy = Alloca->getAllocatedType();
+ assert(!isMultiDimensionalArray(RootTy) &&
+ "Expected root array type to be flattened");
+
+ // If the root type is not an array, we don't need to do any flattening
+ if (!isa<ArrayType>(RootTy))
+ return false;
+
+ Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
}
-}
-bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
- GetElementPtrInst &GEP) {
- GEPData GEPInfo = GEPChainMap.at(&GEP);
- return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
-}
-bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
- GEPData &GEPInfo, GetElementPtrInst &GEP) {
- IRBuilder<> Builder(&GEP);
- Value *FlatIndex;
- if (GEPInfo.AllIndicesAreConstInt)
- FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
- else
- FlatIndex =
- genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
-
- ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
-
- // Don't append '.flat' to an empty string. If the SSA name isn't available
- // it could conflict with the ParentOperand's name.
- std::string FlatName = GEP.hasName() ? GEP.getName().str() + ".flat" : "";
-
- Value *FlatGEP = Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParentOperand,
- {Builder.getInt32(0), FlatIndex}, FlatName,
- GEP.getNoWrapFlags());
-
- // Note: Old gep will become an invalid instruction after replaceAllUsesWith.
- // Erase the old GEP in the map before to avoid invalid instructions
- // and circular references.
- GEPChainMap.erase(&GEP);
-
- GEP.replaceAllUsesWith(FlatGEP);
- GEP.eraseFromParent();
- return true;
-}
-
-bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
- auto It = GEPChainMap.find(&GEP);
- if (It != GEPChainMap.end())
- return visitGetElementPtrInstInGEPChain(GEP);
- if (!isMultiDimensionalArray(GEP.getSourceElementType()))
- return false;
-
- ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
- IRBuilder<> Builder(&GEP);
- auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
- ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
-
- Value *PtrOperand = GEP.getPointerOperand();
+ // GEPs without users or GEPs with non-GEP users should be replaced such that
+ // the chain of GEPs they are a part of are collapsed to a single GEP into a
+ // flattened array.
+ bool ReplaceThisGEP = GEP.users().empty();
+ for (Value *User : GEP.users())
+ if (!isa<GetElementPtrInst>(User))
+ ReplaceThisGEP = true;
+
+ if (ReplaceThisGEP) {
+ unsigned BytesPerElem =
+ DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType());
+ assert(isPowerOf2_32(BytesPerElem) &&
+ "Bytes per element should be a power of 2");
+
+ // Compute the 32-bit index for this flattened GEP from the constant and
+ // variable byte offsets in the GEPInfo
+ IRBuilder<> Builder(&GEP);
+ Value *ZeroIndex = Builder.getInt32(0);
+ uint64_t ConstantOffset =
+ Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();
+ assert(ConstantOffset < UINT32_MAX &&
+ "Constant byte offset for flat GEP index must fit within 32 bits");
+ Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
+ for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
+ assert(Multiplier.getActiveBits() <= 32 &&
+ "The multiplier for a flat GEP index must fit within 32 bits");
+ assert(VarIndex->getType()->isIntegerTy(32) &&
+ "Expected i32-typed GEP indices");
+ Value *VI;
+ if (Multiplier.getZExtValue() % BytesPerElem != 0) {
+ // This can happen, e.g., with i8 GEPs. To handle this we just divide
+ // by BytesPerElem using an instruction after multiplying VarIndex by
+ // Multiplier.
+ VI = Builder.CreateMul(VarIndex,
+ Builder.getInt32(Multiplier.getZExtValue()));
+ VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem)));
+ } else
+ VI = Builder.CreateMul(
+ VarIndex,
+ Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem));
+ FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI);
+ }
- unsigned GEPChainUseCount = 0;
- recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
-
- // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
- // Here recursion is used to get the length of the GEP chain.
- // Handle zero uses here because there won't be an update via
- // a child in the chain later.
- if (GEPChainUseCount == 0) {
- SmallVector<Value *> Indices;
- SmallVector<uint64_t> Dims;
- bool AllIndicesAreConstInt = true;
-
- // Collect indices and dimensions from the GEP
- collectIndicesAndDimsFromGEP(GEP, Indices, Dims, AllIndicesAreConstInt);
- GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
- std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
- return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+ // Construct a new GEP for the flattened array to replace the current GEP
+ Value *NewGEP = Builder.CreateGEP(
+ Info.RootFlattenedArrayType, Info.RootPointerOperand,
+ {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());
+
+ // If the pointer operand is a global variable and all indices are 0,
+ // IRBuilder::CreateGEP will return the global variable instead of creating
+ // a GEP instruction or GEP ConstantExpr. In this case we have to create and
+ // insert our own GEP instruction.
+ if (!isa<GEPOperator>(NewGEP))
+ NewGEP = GetElementPtrInst::Create(
+ Info.RootFlattenedArrayType, Info.RootPointerOperand,
+ {ZeroIndex, FlattenedIndex}, GEP.getNoWrapFlags(), GEP.getName(),
+ Builder.GetInsertPoint());
+
+ // Replace the current GEP with the new GEP. Store GEPInfo into the map
+ // for later use in case this GEP was not the end of the chain
+ GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});
+ GEP.replaceAllUsesWith(NewGEP);
+ GEP.eraseFromParent();
+ return true;
}
+ // This GEP is potentially dead at the end of the pass since it may not have
+ // any users anymore after GEP chains have been collapsed. We retain store
+ // GEPInfo for GEPs down the chain to use to compute their indices.
+ GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)});
PotentiallyDeadInstrs.emplace_back(&GEP);
return false;
}
@@ -416,9 +431,8 @@ static Constant *transformInitializer(Constant *Init, Type *OrigType,
return ConstantArray::get(FlattenedType, FlattenedElements);
}
-static void
-flattenGlobalArrays(Module &M,
- DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
+static void flattenGlobalArrays(
+ Module &M, SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
LLVMContext &Ctx = M.getContext();
for (GlobalVariable &G : M.globals()) {
Type *OrigType = G.getValueType();
@@ -456,9 +470,9 @@ flattenGlobalArrays(Module &M,
static bool flattenArrays(Module &M) {
bool MadeChange = false;
- DXILFlattenArraysVisitor Impl;
- DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
+ SmallDenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
flattenGlobalArrays(M, GlobalMap);
+ DXILFlattenArraysVisitor Impl(GlobalMap);
for (auto &F : make_early_inc_range(M.functions())) {
if (F.isDeclaration())
continue;
diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
index 76a46c7a2b76..c73648f21e8d 100644
--- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
+++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp
@@ -98,9 +98,9 @@ static void fixI8UseChain(Instruction &I,
ElementType = AI->getAllocatedType();
if (auto *GEP = dyn_cast<GetElementPtrInst>(NewOperands[0])) {
ElementType = GEP->getSourceElementType();
- if (ElementType->isArrayTy())
- ElementType = ElementType->getArrayElementType();
}
+ if (ElementType->isArrayTy())
+ ElementType = ElementType->getArrayElementType();
LoadInst *NewLoad = Builder.CreateLoad(ElementType, NewOperands[0]);
ReplacedValues[Load] = NewLoad;
ToRemove.push_back(Load);
@@ -347,7 +347,6 @@ static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src,
if (ByteLength == 0)
return;
- LLVMContext &Ctx = Builder.getContext();
const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout();
auto GetArrTyFromVal = [](Value *Val) -> ArrayType * {
@@ -392,10 +391,11 @@ static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src,
assert(ByteLength % DstElemByteSize == 0 &&
"memcpy length must be divisible by array element type");
for (uint64_t I = 0; I < NumElemsToCopy; ++I) {
- Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
- Value *SrcPtr = Builder.CreateInBoundsGEP(SrcElemTy, Src, Offset, "gep");
+ SmallVector<Value *, 2> Indices = {Builder.getInt32(0),
+ Builder.getInt32(I)};
+ Value *SrcPtr = Builder.CreateInBoundsGEP(SrcArrTy, Src, Indices, "gep");
Value *SrcVal = Builder.CreateLoad(SrcElemTy, SrcPtr);
- Value *DstPtr = Builder.CreateInBoundsGEP(DstElemTy, Dst, Offset, "gep");
+ Value *DstPtr = Builder.CreateInBoundsGEP(DstArrTy, Dst, Indices, "gep");
Builder.CreateStore(SrcVal, DstPtr);
}
}
@@ -403,7 +403,6 @@ static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src,
static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
ConstantInt *SizeCI,
DenseMap<Value *, Value *> &ReplacedValues) {
- LLVMContext &Ctx = Builder.getContext();
[[maybe_unused]] const DataLayout &DL =
Builder.GetInsertBlock()->getModule()->getDataLayout();
[[maybe_unused]] uint64_t OrigSize = SizeCI->getZExtValue();
@@ -444,8 +443,9 @@ static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
}
for (uint64_t I = 0; I < Size; ++I) {
- Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
- Value *Ptr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep");
+ Value *Zero = Builder.getInt32(0);
+ Value *Offset = Builder.getInt32(I);
+ Value *Ptr = Builder.CreateGEP(ArrTy, Dst, {Zero, Offset}, "gep");
Builder.CreateStore(TypedVal, Ptr);
}
}
@@ -478,9 +478,9 @@ static void legalizeMemCpy(Instruction &I,
ToRemove.push_back(CI);
}
-static void removeMemSet(Instruction &I,
- SmallVectorImpl<Instruction *> &ToRemove,
- DenseMap<Value *, Value *> &ReplacedValues) {
+static void legalizeMemSet(Instruction &I,
+ SmallVectorImpl<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &ReplacedValues) {
CallInst *CI = dyn_cast<CallInst>(&I);
if (!CI)
@@ -562,6 +562,53 @@ legalizeGetHighLowi64Bytes(Instruction &I,
}
}
+static void
+legalizeScalarLoadStoreOnArrays(Instruction &I,
+ SmallVectorImpl<Instruction *> &ToRemove,
+ DenseMap<Value *, Value *> &) {
+
+ Value *PtrOp;
+ unsigned PtrOpIndex;
+ [[maybe_unused]] Type *LoadStoreTy;
+ if (auto *LI = dyn_cast<LoadInst>(&I)) {
+ PtrOp = LI->getPointerOperand();
+ PtrOpIndex = LI->getPointerOperandIndex();
+ LoadStoreTy = LI->getType();
+ } else if (auto *SI = dyn_cast<StoreInst>(&I)) {
+ PtrOp = SI->getPointerOperand();
+ PtrOpIndex = SI->getPointerOperandIndex();
+ LoadStoreTy = SI->getValueOperand()->getType();
+ } else
+ return;
+
+ // If the load/store is not of a single-value type (i.e., scalar or vector)
+ // then we do not modify it. It shouldn't be a vector either because the
+ // dxil-data-scalarization pass is expected to run before this, but it's not
+ // incorrect to apply this transformation to vector load/stores.
+ if (!LoadStoreTy->isSingleValueType())
+ return;
+
+ Type *ArrayTy;
+ if (auto *GlobalVarPtrOp = dyn_cast<GlobalVariable>(PtrOp))
+ ArrayTy = GlobalVarPtrOp->getValueType();
+ else if (auto *AllocaPtrOp = dyn_cast<AllocaInst>(PtrOp))
+ ArrayTy = AllocaPtrOp->getAllocatedType();
+ else
+ return;
+
+ if (!isa<ArrayType>(ArrayTy))
+ return;
+
+ assert(ArrayTy->getArrayElementType() == LoadStoreTy &&
+ "Expected array element type to be the same as to the scalar load or "
+ "store type");
+
+ Value *Zero = ConstantInt::get(Type::getInt32Ty(I.getContext()), 0);
+ Value *GEP = GetElementPtrInst::Create(
+ ArrayTy, PtrOp, {Zero, Zero}, GEPNoWrapFlags::all(), "", I.getIterator());
+ I.setOperand(PtrOpIndex, GEP);
+}
+
namespace {
class DXILLegalizationPipeline {
@@ -603,7 +650,7 @@ private:
LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes);
LegalizationPipeline[Stage1].push_back(legalizeFreeze);
LegalizationPipeline[Stage1].push_back(legalizeMemCpy);
- LegalizationPipeline[Stage1].push_back(removeMemSet);
+ LegalizationPipeline[Stage1].push_back(legalizeMemSet);
LegalizationPipeline[Stage1].push_back(updateFnegToFsub);
// Note: legalizeGetHighLowi64Bytes and
// downcastI64toI32InsertExtractElements both modify extractelement, so they
@@ -612,6 +659,7 @@ private:
// downcastI64toI32InsertExtractElements needs to handle.
LegalizationPipeline[Stage2].push_back(
downcastI64toI32InsertExtractElements);
+ LegalizationPipeline[Stage2].push_back(legalizeScalarLoadStoreOnArrays);
}
};
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 40fe6c6e639e..84751d2db226 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -107,10 +107,10 @@ public:
addPass(createDXILIntrinsicExpansionLegacyPass());
addPass(createDXILCBufferAccessLegacyPass());
addPass(createDXILDataScalarizationLegacyPass());
- addPass(createDXILFlattenArraysLegacyPass());
ScalarizerPassOptions DxilScalarOptions;
DxilScalarOptions.ScalarizeLoadStore = true;
addPass(createScalarizerPass(DxilScalarOptions));
+ addPass(createDXILFlattenArraysLegacyPass());
addPass(createDXILForwardHandleAccessesLegacyPass());
addPass(createDXILLegalizeLegacyPass());
addPass(createDXILResourceImplicitBindingLegacyPass());