diff options
Diffstat (limited to 'llvm/lib/Target/DirectX')
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILDataScalarization.cpp | 101 | ||||
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 42 | ||||
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILOpLowering.cpp | 8 | ||||
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 54 | ||||
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILShaderFlags.h | 7 | ||||
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp | 1 | ||||
| -rw-r--r-- | llvm/lib/Target/DirectX/DirectXTargetMachine.cpp | 2 |
7 files changed, 144 insertions, 71 deletions
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index 1783e4a54631..2ab2daaff5b5 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -40,7 +40,7 @@ static bool findAndReplaceVectors(Module &M); class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> { public: DataScalarizerVisitor() : GlobalMap() {} - bool visit(Function &F); + bool visit(Instruction &I); // InstVisitor methods. They return true if the instruction was scalarized, // false if nothing changed. bool visitInstruction(Instruction &I) { return false; } @@ -65,28 +65,11 @@ public: private: GlobalVariable *lookupReplacementGlobal(Value *CurrOperand); DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap; - SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs; - bool finish(); }; -bool DataScalarizerVisitor::visit(Function &F) { +bool DataScalarizerVisitor::visit(Instruction &I) { assert(!GlobalMap.empty()); - ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock()); - for (BasicBlock *BB : RPOT) { - for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { - Instruction *I = &*II; - bool Done = InstVisitor::visit(I); - ++II; - if (Done && I->getType()->isVoidTy()) - I->eraseFromParent(); - } - } - return finish(); -} - -bool DataScalarizerVisitor::finish() { - RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); - return true; + return InstVisitor::visit(I); } GlobalVariable * @@ -104,6 +87,20 @@ bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { unsigned NumOperands = LI.getNumOperands(); for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = LI.getOperand(I); + ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); + if (CE && CE->getOpcode() == Instruction::GetElementPtr) { + GetElementPtrInst *OldGEP = + cast<GetElementPtrInst>(CE->getAsInstruction()); + OldGEP->insertBefore(&LI); + IRBuilder<> Builder(&LI); + LoadInst *NewLoad = + Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); + NewLoad->setAlignment(LI.getAlign()); + LI.replaceAllUsesWith(NewLoad); + LI.eraseFromParent(); + visitGetElementPtrInst(*OldGEP); + return true; + } if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) LI.setOperand(I, NewGlobal); } @@ -114,32 +111,48 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) { unsigned NumOperands = SI.getNumOperands(); for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = SI.getOperand(I); - if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) { - SI.setOperand(I, NewGlobal); + ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); + if (CE && CE->getOpcode() == Instruction::GetElementPtr) { + GetElementPtrInst *OldGEP = + cast<GetElementPtrInst>(CE->getAsInstruction()); + OldGEP->insertBefore(&SI); + IRBuilder<> Builder(&SI); + StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); + NewStore->setAlignment(SI.getAlign()); + SI.replaceAllUsesWith(NewStore); + SI.eraseFromParent(); + visitGetElementPtrInst(*OldGEP); + return true; } + if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) + SI.setOperand(I, NewGlobal); } return false; } bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { + unsigned NumOperands = GEPI.getNumOperands(); + GlobalVariable *NewGlobal = nullptr; for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = GEPI.getOperand(I); - GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand); - if (!NewGlobal) - continue; - IRBuilder<> Builder(&GEPI); - - SmallVector<Value *, MaxVecSize> Indices; - for (auto &Index : GEPI.indices()) - Indices.push_back(Index); - - Value *NewGEP = - Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices); - - GEPI.replaceAllUsesWith(NewGEP); - PotentiallyDeadInstrs.emplace_back(&GEPI); + NewGlobal = lookupReplacementGlobal(CurrOpperand); + if (NewGlobal) + break; } + if (!NewGlobal) + return false; + + IRBuilder<> Builder(&GEPI); + SmallVector<Value *, MaxVecSize> Indices; + for (auto &Index : GEPI.indices()) + Indices.push_back(Index); + + Value *NewGEP = + Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices, + GEPI.getName(), GEPI.getNoWrapFlags()); + GEPI.replaceAllUsesWith(NewGEP); + GEPI.eraseFromParent(); return true; } @@ -245,17 +258,13 @@ static bool findAndReplaceVectors(Module &M) { for (User *U : make_early_inc_range(G.users())) { if (isa<ConstantExpr>(U) && isa<Operator>(U)) { ConstantExpr *CE = cast<ConstantExpr>(U); - convertUsersOfConstantsToInstructions(CE, - /*RestrictToFunc=*/nullptr, - /*RemoveDeadConstants=*/false, - /*IncludeSelf=*/true); - } - if (isa<Instruction>(U)) { - Instruction *Inst = cast<Instruction>(U); - Function *F = Inst->getFunction(); - if (F) - Impl.visit(*F); + for (User *UCE : make_early_inc_range(CE->users())) { + if (Instruction *Inst = dyn_cast<Instruction>(UCE)) + Impl.visit(*Inst); + } } + if (Instruction *Inst = dyn_cast<Instruction>(U)) + Impl.visit(*Inst); } } } diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index 6077af997212..53fc1c713a8c 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -162,11 +162,18 @@ bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) { Value *CurrOpperand = LI.getOperand(I); ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); if (CE && CE->getOpcode() == Instruction::GetElementPtr) { - convertUsersOfConstantsToInstructions(CE, - /*RestrictToFunc=*/nullptr, - /*RemoveDeadConstants=*/false, - /*IncludeSelf=*/true); - return false; + GetElementPtrInst *OldGEP = + cast<GetElementPtrInst>(CE->getAsInstruction()); + OldGEP->insertBefore(&LI); + + IRBuilder<> Builder(&LI); + LoadInst *NewLoad = + Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); + NewLoad->setAlignment(LI.getAlign()); + LI.replaceAllUsesWith(NewLoad); + LI.eraseFromParent(); + visitGetElementPtrInst(*OldGEP); + return true; } } return false; @@ -178,11 +185,17 @@ bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) { Value *CurrOpperand = SI.getOperand(I); ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); if (CE && CE->getOpcode() == Instruction::GetElementPtr) { - convertUsersOfConstantsToInstructions(CE, - /*RestrictToFunc=*/nullptr, - /*RemoveDeadConstants=*/false, - /*IncludeSelf=*/true); - return false; + GetElementPtrInst *OldGEP = + cast<GetElementPtrInst>(CE->getAsInstruction()); + OldGEP->insertBefore(&SI); + + IRBuilder<> Builder(&SI); + StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); + NewStore->setAlignment(SI.getAlign()); + SI.replaceAllUsesWith(NewStore); + SI.eraseFromParent(); + visitGetElementPtrInst(*OldGEP); + return true; } } return false; @@ -315,10 +328,17 @@ bool DXILFlattenArraysVisitor::visit(Function &F) { static void collectElements(Constant *Init, SmallVectorImpl<Constant *> &Elements) { // Base case: If Init is not an array, add it directly to the vector. - if (!isa<ArrayType>(Init->getType())) { + auto *ArrayTy = dyn_cast<ArrayType>(Init->getType()); + if (!ArrayTy) { Elements.push_back(Init); return; } + unsigned ArrSize = ArrayTy->getNumElements(); + if (isa<ConstantAggregateZero>(Init)) { + for (unsigned I = 0; I < ArrSize; ++I) + Elements.push_back(Constant::getNullValue(ArrayTy->getElementType())); + return; + } // Recursive case: Process each element in the array. if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) { diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index a898b6a5047d..c283b9081e08 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -10,8 +10,11 @@ #include "DXILConstants.h" #include "DXILIntrinsicExpansion.h" #include "DXILOpBuilder.h" +#include "DXILResourceAnalysis.h" +#include "DXILShaderFlags.h" #include "DirectX.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/Analysis/DXILResource.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/DiagnosticInfo.h" @@ -763,6 +766,8 @@ PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) { return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<DXILResourceBindingAnalysis>(); + PA.preserve<DXILMetadataAnalysis>(); + PA.preserve<ShaderFlagsAnalysis>(); return PA; } @@ -785,6 +790,9 @@ public: AU.addRequired<DXILResourceTypeWrapperPass>(); AU.addRequired<DXILResourceBindingWrapperPass>(); AU.addPreserved<DXILResourceBindingWrapperPass>(); + AU.addPreserved<DXILResourceMDWrapper>(); + AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); + AU.addPreserved<ShaderFlagsAnalysisWrapper>(); } }; char DXILOpLoweringLegacy::ID = 0; diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index d6917dce98ab..1e8896334576 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -14,16 +14,21 @@ #include "DXILShaderFlags.h" #include "DirectX.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/DXILResource.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; using namespace llvm::dxil; -static void updateFunctionFlags(ComputedShaderFlags &CSF, - const Instruction &I) { +static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I, + DXILResourceTypeMap &DRTM) { if (!CSF.Doubles) CSF.Doubles = I.getType()->isDoubleTy(); @@ -44,17 +49,34 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, break; } } + + if (auto *II = dyn_cast<IntrinsicInst>(&I)) { + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::dx_typedBufferLoad: { + dxil::ResourceTypeInfo &RTI = + DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())]; + if (RTI.isTyped()) + CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1; + } + } + } } -void ModuleShaderFlags::initialize(const Module &M) { +void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) { + // Collect shader flags for each of the functions for (const auto &F : M.getFunctionList()) { - if (F.isDeclaration()) + if (F.isDeclaration()) { + assert(!F.getName().starts_with("dx.op.") && + "DXIL Shader Flag analysis should not be run post-lowering."); continue; + } ComputedShaderFlags CSF; for (const auto &BB : F) for (const auto &I : BB) - updateFunctionFlags(CSF, I); + updateFunctionFlags(CSF, I, DRTM); // Insert shader flag mask for function F FunctionFlags.push_back({&F, CSF}); // Update combined shader flags mask @@ -101,8 +123,11 @@ AnalysisKey ShaderFlagsAnalysis::Key; ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M, ModuleAnalysisManager &AM) { + DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M); + ModuleShaderFlags MSFI; - MSFI.initialize(M); + MSFI.initialize(M, DRTM); + return MSFI; } @@ -129,11 +154,22 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M, // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) { - MSFI.initialize(M); + DXILResourceTypeMap &DRTM = + getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); + + MSFI.initialize(M, DRTM); return false; } +void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<DXILResourceTypeWrapperPass>(); +} + char ShaderFlagsAnalysisWrapper::ID = 0; -INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", - "DXIL Shader Flag Analysis", true, true) +INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", + "DXIL Shader Flag Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) +INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", + "DXIL Shader Flag Analysis", true, true) diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h index 2d60137f8b19..67ddab39d0f3 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.h +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h @@ -26,6 +26,7 @@ namespace llvm { class Module; class GlobalVariable; +class DXILResourceTypeMap; namespace dxil { @@ -84,7 +85,7 @@ struct ComputedShaderFlags { }; struct ModuleShaderFlags { - void initialize(const Module &); + void initialize(const Module &, DXILResourceTypeMap &DRTM); const ComputedShaderFlags &getFunctionFlags(const Function *) const; const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; } @@ -135,9 +136,7 @@ public: bool runOnModule(Module &M) override; - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } + void getAnalysisUsage(AnalysisUsage &AU) const override; }; } // namespace dxil diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index f175f169f35a..5afe6b2d2883 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -393,6 +393,7 @@ public: AU.addPreserved<DXILResourceBindingWrapperPass>(); AU.addPreserved<DXILResourceMDWrapper>(); AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); + AU.addPreserved<ShaderFlagsAnalysisWrapper>(); } bool runOnModule(Module &M) override { diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index d4e35fb75031..ecb1bf775f85 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -99,8 +99,8 @@ public: ScalarizerPassOptions DxilScalarOptions; DxilScalarOptions.ScalarizeLoadStore = true; addPass(createScalarizerPass(DxilScalarOptions)); - addPass(createDXILOpLoweringLegacyPass()); addPass(createDXILTranslateMetadataLegacyPass()); + addPass(createDXILOpLoweringLegacyPass()); addPass(createDXILPrepareModulePass()); } }; |
