summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/DirectX')
-rw-r--r--llvm/lib/Target/DirectX/DXILDataScalarization.cpp101
-rw-r--r--llvm/lib/Target/DirectX/DXILFlattenArrays.cpp42
-rw-r--r--llvm/lib/Target/DirectX/DXILOpLowering.cpp8
-rw-r--r--llvm/lib/Target/DirectX/DXILShaderFlags.cpp54
-rw-r--r--llvm/lib/Target/DirectX/DXILShaderFlags.h7
-rw-r--r--llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp1
-rw-r--r--llvm/lib/Target/DirectX/DirectXTargetMachine.cpp2
7 files changed, 144 insertions, 71 deletions
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 1783e4a54631..2ab2daaff5b5 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -40,7 +40,7 @@ static bool findAndReplaceVectors(Module &M);
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
public:
DataScalarizerVisitor() : GlobalMap() {}
- bool visit(Function &F);
+ bool visit(Instruction &I);
// InstVisitor methods. They return true if the instruction was scalarized,
// false if nothing changed.
bool visitInstruction(Instruction &I) { return false; }
@@ -65,28 +65,11 @@ public:
private:
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
- SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
- bool finish();
};
-bool DataScalarizerVisitor::visit(Function &F) {
+bool DataScalarizerVisitor::visit(Instruction &I) {
assert(!GlobalMap.empty());
- ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
- for (BasicBlock *BB : RPOT) {
- for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
- Instruction *I = &*II;
- bool Done = InstVisitor::visit(I);
- ++II;
- if (Done && I->getType()->isVoidTy())
- I->eraseFromParent();
- }
- }
- return finish();
-}
-
-bool DataScalarizerVisitor::finish() {
- RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
- return true;
+ return InstVisitor::visit(I);
}
GlobalVariable *
@@ -104,6 +87,20 @@ bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
unsigned NumOperands = LI.getNumOperands();
for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = LI.getOperand(I);
+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+ if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+ GetElementPtrInst *OldGEP =
+ cast<GetElementPtrInst>(CE->getAsInstruction());
+ OldGEP->insertBefore(&LI);
+ IRBuilder<> Builder(&LI);
+ LoadInst *NewLoad =
+ Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
+ NewLoad->setAlignment(LI.getAlign());
+ LI.replaceAllUsesWith(NewLoad);
+ LI.eraseFromParent();
+ visitGetElementPtrInst(*OldGEP);
+ return true;
+ }
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
LI.setOperand(I, NewGlobal);
}
@@ -114,32 +111,48 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
unsigned NumOperands = SI.getNumOperands();
for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = SI.getOperand(I);
- if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) {
- SI.setOperand(I, NewGlobal);
+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
+ if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+ GetElementPtrInst *OldGEP =
+ cast<GetElementPtrInst>(CE->getAsInstruction());
+ OldGEP->insertBefore(&SI);
+ IRBuilder<> Builder(&SI);
+ StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
+ NewStore->setAlignment(SI.getAlign());
+ SI.replaceAllUsesWith(NewStore);
+ SI.eraseFromParent();
+ visitGetElementPtrInst(*OldGEP);
+ return true;
}
+ if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
+ SI.setOperand(I, NewGlobal);
}
return false;
}
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
+
unsigned NumOperands = GEPI.getNumOperands();
+ GlobalVariable *NewGlobal = nullptr;
for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = GEPI.getOperand(I);
- GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand);
- if (!NewGlobal)
- continue;
- IRBuilder<> Builder(&GEPI);
-
- SmallVector<Value *, MaxVecSize> Indices;
- for (auto &Index : GEPI.indices())
- Indices.push_back(Index);
-
- Value *NewGEP =
- Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
-
- GEPI.replaceAllUsesWith(NewGEP);
- PotentiallyDeadInstrs.emplace_back(&GEPI);
+ NewGlobal = lookupReplacementGlobal(CurrOpperand);
+ if (NewGlobal)
+ break;
}
+ if (!NewGlobal)
+ return false;
+
+ IRBuilder<> Builder(&GEPI);
+ SmallVector<Value *, MaxVecSize> Indices;
+ for (auto &Index : GEPI.indices())
+ Indices.push_back(Index);
+
+ Value *NewGEP =
+ Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
+ GEPI.getName(), GEPI.getNoWrapFlags());
+ GEPI.replaceAllUsesWith(NewGEP);
+ GEPI.eraseFromParent();
return true;
}
@@ -245,17 +258,13 @@ static bool findAndReplaceVectors(Module &M) {
for (User *U : make_early_inc_range(G.users())) {
if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
ConstantExpr *CE = cast<ConstantExpr>(U);
- convertUsersOfConstantsToInstructions(CE,
- /*RestrictToFunc=*/nullptr,
- /*RemoveDeadConstants=*/false,
- /*IncludeSelf=*/true);
- }
- if (isa<Instruction>(U)) {
- Instruction *Inst = cast<Instruction>(U);
- Function *F = Inst->getFunction();
- if (F)
- Impl.visit(*F);
+ for (User *UCE : make_early_inc_range(CE->users())) {
+ if (Instruction *Inst = dyn_cast<Instruction>(UCE))
+ Impl.visit(*Inst);
+ }
}
+ if (Instruction *Inst = dyn_cast<Instruction>(U))
+ Impl.visit(*Inst);
}
}
}
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index 6077af997212..53fc1c713a8c 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -162,11 +162,18 @@ bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
Value *CurrOpperand = LI.getOperand(I);
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
- convertUsersOfConstantsToInstructions(CE,
- /*RestrictToFunc=*/nullptr,
- /*RemoveDeadConstants=*/false,
- /*IncludeSelf=*/true);
- return false;
+ GetElementPtrInst *OldGEP =
+ cast<GetElementPtrInst>(CE->getAsInstruction());
+ OldGEP->insertBefore(&LI);
+
+ IRBuilder<> Builder(&LI);
+ LoadInst *NewLoad =
+ Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
+ NewLoad->setAlignment(LI.getAlign());
+ LI.replaceAllUsesWith(NewLoad);
+ LI.eraseFromParent();
+ visitGetElementPtrInst(*OldGEP);
+ return true;
}
}
return false;
@@ -178,11 +185,17 @@ bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
Value *CurrOpperand = SI.getOperand(I);
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
- convertUsersOfConstantsToInstructions(CE,
- /*RestrictToFunc=*/nullptr,
- /*RemoveDeadConstants=*/false,
- /*IncludeSelf=*/true);
- return false;
+ GetElementPtrInst *OldGEP =
+ cast<GetElementPtrInst>(CE->getAsInstruction());
+ OldGEP->insertBefore(&SI);
+
+ IRBuilder<> Builder(&SI);
+ StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
+ NewStore->setAlignment(SI.getAlign());
+ SI.replaceAllUsesWith(NewStore);
+ SI.eraseFromParent();
+ visitGetElementPtrInst(*OldGEP);
+ return true;
}
}
return false;
@@ -315,10 +328,17 @@ bool DXILFlattenArraysVisitor::visit(Function &F) {
static void collectElements(Constant *Init,
SmallVectorImpl<Constant *> &Elements) {
// Base case: If Init is not an array, add it directly to the vector.
- if (!isa<ArrayType>(Init->getType())) {
+ auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());
+ if (!ArrayTy) {
Elements.push_back(Init);
return;
}
+ unsigned ArrSize = ArrayTy->getNumElements();
+ if (isa<ConstantAggregateZero>(Init)) {
+ for (unsigned I = 0; I < ArrSize; ++I)
+ Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));
+ return;
+ }
// Recursive case: Process each element in the array.
if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index a898b6a5047d..c283b9081e08 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -10,8 +10,11 @@
#include "DXILConstants.h"
#include "DXILIntrinsicExpansion.h"
#include "DXILOpBuilder.h"
+#include "DXILResourceAnalysis.h"
+#include "DXILShaderFlags.h"
#include "DirectX.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/DXILMetadataAnalysis.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/DiagnosticInfo.h"
@@ -763,6 +766,8 @@ PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) {
return PreservedAnalyses::all();
PreservedAnalyses PA;
PA.preserve<DXILResourceBindingAnalysis>();
+ PA.preserve<DXILMetadataAnalysis>();
+ PA.preserve<ShaderFlagsAnalysis>();
return PA;
}
@@ -785,6 +790,9 @@ public:
AU.addRequired<DXILResourceTypeWrapperPass>();
AU.addRequired<DXILResourceBindingWrapperPass>();
AU.addPreserved<DXILResourceBindingWrapperPass>();
+ AU.addPreserved<DXILResourceMDWrapper>();
+ AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
+ AU.addPreserved<ShaderFlagsAnalysisWrapper>();
}
};
char DXILOpLoweringLegacy::ID = 0;
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index d6917dce98ab..1e8896334576 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -14,16 +14,21 @@
#include "DXILShaderFlags.h"
#include "DirectX.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
using namespace llvm::dxil;
-static void updateFunctionFlags(ComputedShaderFlags &CSF,
- const Instruction &I) {
+static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
+ DXILResourceTypeMap &DRTM) {
if (!CSF.Doubles)
CSF.Doubles = I.getType()->isDoubleTy();
@@ -44,17 +49,34 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
break;
}
}
+
+ if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
+ switch (II->getIntrinsicID()) {
+ default:
+ break;
+ case Intrinsic::dx_typedBufferLoad: {
+ dxil::ResourceTypeInfo &RTI =
+ DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())];
+ if (RTI.isTyped())
+ CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1;
+ }
+ }
+ }
}
-void ModuleShaderFlags::initialize(const Module &M) {
+void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) {
+
// Collect shader flags for each of the functions
for (const auto &F : M.getFunctionList()) {
- if (F.isDeclaration())
+ if (F.isDeclaration()) {
+ assert(!F.getName().starts_with("dx.op.") &&
+ "DXIL Shader Flag analysis should not be run post-lowering.");
continue;
+ }
ComputedShaderFlags CSF;
for (const auto &BB : F)
for (const auto &I : BB)
- updateFunctionFlags(CSF, I);
+ updateFunctionFlags(CSF, I, DRTM);
// Insert shader flag mask for function F
FunctionFlags.push_back({&F, CSF});
// Update combined shader flags mask
@@ -101,8 +123,11 @@ AnalysisKey ShaderFlagsAnalysis::Key;
ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
ModuleAnalysisManager &AM) {
+ DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M);
+
ModuleShaderFlags MSFI;
- MSFI.initialize(M);
+ MSFI.initialize(M, DRTM);
+
return MSFI;
}
@@ -129,11 +154,22 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
- MSFI.initialize(M);
+ DXILResourceTypeMap &DRTM =
+ getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
+
+ MSFI.initialize(M, DRTM);
return false;
}
+void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ AU.addRequiredTransitive<DXILResourceTypeWrapperPass>();
+}
+
char ShaderFlagsAnalysisWrapper::ID = 0;
-INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
- "DXIL Shader Flag Analysis", true, true)
+INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
+ "DXIL Shader Flag Analysis", true, true)
+INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
+INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
+ "DXIL Shader Flag Analysis", true, true)
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 2d60137f8b19..67ddab39d0f3 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -26,6 +26,7 @@
namespace llvm {
class Module;
class GlobalVariable;
+class DXILResourceTypeMap;
namespace dxil {
@@ -84,7 +85,7 @@ struct ComputedShaderFlags {
};
struct ModuleShaderFlags {
- void initialize(const Module &);
+ void initialize(const Module &, DXILResourceTypeMap &DRTM);
const ComputedShaderFlags &getFunctionFlags(const Function *) const;
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
@@ -135,9 +136,7 @@ public:
bool runOnModule(Module &M) override;
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.setPreservesAll();
- }
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
};
} // namespace dxil
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index f175f169f35a..5afe6b2d2883 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -393,6 +393,7 @@ public:
AU.addPreserved<DXILResourceBindingWrapperPass>();
AU.addPreserved<DXILResourceMDWrapper>();
AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
+ AU.addPreserved<ShaderFlagsAnalysisWrapper>();
}
bool runOnModule(Module &M) override {
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index d4e35fb75031..ecb1bf775f85 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -99,8 +99,8 @@ public:
ScalarizerPassOptions DxilScalarOptions;
DxilScalarOptions.ScalarizeLoadStore = true;
addPass(createScalarizerPass(DxilScalarOptions));
- addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILTranslateMetadataLegacyPass());
+ addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILPrepareModulePass());
}
};