diff options
| author | Mingming Liu <mingmingl@google.com> | 2025-09-10 15:25:31 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-10 15:25:31 -0700 |
| commit | 1417dafa1db9cb1b2b09438aa9f53ea5ab6e36e2 (patch) | |
| tree | 57f4b1f313c8cf74eed8819870f39c36ea263c68 /llvm/lib/Target/DirectX | |
| parent | 898b813bc8a6d0276bf0f4769f5f2f64b34e632d (diff) | |
| parent | b8cefcb601ddaa18482555c4ff363c01a270c2fe (diff) | |
Merge branch 'main' into users/mingmingl-llvm/samplefdo-profile-formatusers/mingmingl-llvm/samplefdo-profile-format
Diffstat (limited to 'llvm/lib/Target/DirectX')
17 files changed, 440 insertions, 108 deletions
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt index 8100f941c8d9..6c079517e22d 100644 --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -41,6 +41,7 @@ add_llvm_target(DirectXCodeGen LINK_COMPONENTS Analysis AsmPrinter + BinaryFormat CodeGen CodeGenTypes Core diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index a1ef2578f00a..ca81d30473c0 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -158,12 +158,15 @@ void DXContainerGlobals::addRootSignature(Module &M, if (MMI.ShaderProfile == llvm::Triple::Library) return; - assert(MMI.EntryPropertyVec.size() == 1); - auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo(); - const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry; - const mcdxbc::RootSignatureDesc *RS = RSA.getDescForFunction(EntryFunction); + const Function *EntryFunction = nullptr; + if (MMI.ShaderProfile != llvm::Triple::RootSignature) { + assert(MMI.EntryPropertyVec.size() == 1); + EntryFunction = MMI.EntryPropertyVec[0].Entry; + } + + const mcdxbc::RootSignatureDesc *RS = RSA.getDescForFunction(EntryFunction); if (!RS) return; @@ -258,7 +261,8 @@ void DXContainerGlobals::addPipelineStateValidationInfo( dxil::ModuleMetadataInfo &MMI = getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); assert(MMI.EntryPropertyVec.size() == 1 || - MMI.ShaderProfile == Triple::Library); + MMI.ShaderProfile == Triple::Library || + MMI.ShaderProfile == Triple::RootSignature); PSV.BaseData.ShaderStage = static_cast<uint8_t>(MMI.ShaderProfile - Triple::Pixel); @@ -279,7 +283,8 @@ void DXContainerGlobals::addPipelineStateValidationInfo( break; } - if (MMI.ShaderProfile != Triple::Library) + if (MMI.ShaderProfile != Triple::Library && + MMI.ShaderProfile != Triple::RootSignature) PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName(); PSV.finalize(MMI.ShaderProfile); diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index feecfc0880e2..d507d71b99fc 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -343,9 +343,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { GOp->replaceAllUsesWith(NewGEP); - if (auto *CE = dyn_cast<ConstantExpr>(GOp)) - CE->destroyConstant(); - else if (auto *OldGEPI = dyn_cast<GetElementPtrInst>(GOp)) + if (auto *OldGEPI = dyn_cast<GetElementPtrInst>(GOp)) OldGEPI->eraseFromParent(); return true; diff --git a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp index 13e3408815bb..aa16e795dc76 100644 --- a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp +++ b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp @@ -22,11 +22,13 @@ static bool finalizeLinkage(Module &M) { // Convert private globals and external globals with no usage to internal // linkage. - for (GlobalVariable &GV : M.globals()) + for (GlobalVariable &GV : M.globals()) { + GV.removeDeadConstantUsers(); if (GV.hasPrivateLinkage() || (GV.hasExternalLinkage() && GV.use_empty())) { GV.setLinkage(GlobalValue::InternalLinkage); MadeChange = true; } + } SmallVector<Function *> Funcs; diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index ee1db54446cb..e2469d8df957 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -51,6 +51,150 @@ static bool resourceAccessNeeds64BitExpansion(Module *M, Type *OverloadTy, return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64); } +static Value *expand16BitIsInf(CallInst *Orig) { + Module *M = Orig->getModule(); + if (M->getTargetTriple().getDXILVersion() >= VersionTuple(1, 9)) + return nullptr; + + Value *Val = Orig->getOperand(0); + Type *ValTy = Val->getType(); + if (!ValTy->getScalarType()->isHalfTy()) + return nullptr; + + IRBuilder<> Builder(Orig); + Type *IType = Type::getInt16Ty(M->getContext()); + Constant *PosInf = + ValTy->isVectorTy() + ? ConstantVector::getSplat( + ElementCount::getFixed( + cast<FixedVectorType>(ValTy)->getNumElements()), + ConstantInt::get(IType, 0x7c00)) + : ConstantInt::get(IType, 0x7c00); + + Constant *NegInf = + ValTy->isVectorTy() + ? ConstantVector::getSplat( + ElementCount::getFixed( + cast<FixedVectorType>(ValTy)->getNumElements()), + ConstantInt::get(IType, 0xfc00)) + : ConstantInt::get(IType, 0xfc00); + + Value *IVal = Builder.CreateBitCast(Val, PosInf->getType()); + Value *B1 = Builder.CreateICmpEQ(IVal, PosInf); + Value *B2 = Builder.CreateICmpEQ(IVal, NegInf); + Value *B3 = Builder.CreateOr(B1, B2); + return B3; +} + +static Value *expand16BitIsNaN(CallInst *Orig) { + Module *M = Orig->getModule(); + if (M->getTargetTriple().getDXILVersion() >= VersionTuple(1, 9)) + return nullptr; + + Value *Val = Orig->getOperand(0); + Type *ValTy = Val->getType(); + if (!ValTy->getScalarType()->isHalfTy()) + return nullptr; + + IRBuilder<> Builder(Orig); + Type *IType = Type::getInt16Ty(M->getContext()); + + Constant *ExpBitMask = + ValTy->isVectorTy() + ? ConstantVector::getSplat( + ElementCount::getFixed( + cast<FixedVectorType>(ValTy)->getNumElements()), + ConstantInt::get(IType, 0x7c00)) + : ConstantInt::get(IType, 0x7c00); + Constant *SigBitMask = + ValTy->isVectorTy() + ? ConstantVector::getSplat( + ElementCount::getFixed( + cast<FixedVectorType>(ValTy)->getNumElements()), + ConstantInt::get(IType, 0x3ff)) + : ConstantInt::get(IType, 0x3ff); + + Constant *Zero = + ValTy->isVectorTy() + ? ConstantVector::getSplat( + ElementCount::getFixed( + cast<FixedVectorType>(ValTy)->getNumElements()), + ConstantInt::get(IType, 0)) + : ConstantInt::get(IType, 0); + + Value *IVal = Builder.CreateBitCast(Val, ExpBitMask->getType()); + Value *Exp = Builder.CreateAnd(IVal, ExpBitMask); + Value *B1 = Builder.CreateICmpEQ(Exp, ExpBitMask); + + Value *Sig = Builder.CreateAnd(IVal, SigBitMask); + Value *B2 = Builder.CreateICmpNE(Sig, Zero); + Value *B3 = Builder.CreateAnd(B1, B2); + return B3; +} + +static Value *expand16BitIsFinite(CallInst *Orig) { + Module *M = Orig->getModule(); + if (M->getTargetTriple().getDXILVersion() >= VersionTuple(1, 9)) + return nullptr; + + Value *Val = Orig->getOperand(0); + Type *ValTy = Val->getType(); + if (!ValTy->getScalarType()->isHalfTy()) + return nullptr; + + IRBuilder<> Builder(Orig); + Type *IType = Type::getInt16Ty(M->getContext()); + + Constant *ExpBitMask = + ValTy->isVectorTy() + ? ConstantVector::getSplat( + ElementCount::getFixed( + cast<FixedVectorType>(ValTy)->getNumElements()), + ConstantInt::get(IType, 0x7c00)) + : ConstantInt::get(IType, 0x7c00); + + Value *IVal = Builder.CreateBitCast(Val, ExpBitMask->getType()); + Value *Exp = Builder.CreateAnd(IVal, ExpBitMask); + Value *B1 = Builder.CreateICmpNE(Exp, ExpBitMask); + return B1; +} + +static Value *expand16BitIsNormal(CallInst *Orig) { + Module *M = Orig->getModule(); + if (M->getTargetTriple().getDXILVersion() >= VersionTuple(1, 9)) + return nullptr; + + Value *Val = Orig->getOperand(0); + Type *ValTy = Val->getType(); + if (!ValTy->getScalarType()->isHalfTy()) + return nullptr; + + IRBuilder<> Builder(Orig); + Type *IType = Type::getInt16Ty(M->getContext()); + + Constant *ExpBitMask = + ValTy->isVectorTy() + ? ConstantVector::getSplat( + ElementCount::getFixed( + cast<FixedVectorType>(ValTy)->getNumElements()), + ConstantInt::get(IType, 0x7c00)) + : ConstantInt::get(IType, 0x7c00); + Constant *Zero = + ValTy->isVectorTy() + ? ConstantVector::getSplat( + ElementCount::getFixed( + cast<FixedVectorType>(ValTy)->getNumElements()), + ConstantInt::get(IType, 0)) + : ConstantInt::get(IType, 0); + + Value *IVal = Builder.CreateBitCast(Val, ExpBitMask->getType()); + Value *Exp = Builder.CreateAnd(IVal, ExpBitMask); + Value *NotAllZeroes = Builder.CreateICmpNE(Exp, Zero); + Value *NotAllOnes = Builder.CreateICmpNE(Exp, ExpBitMask); + Value *B1 = Builder.CreateAnd(NotAllZeroes, NotAllOnes); + return B1; +} + static bool isIntrinsicExpansion(Function &F) { switch (F.getIntrinsicID()) { case Intrinsic::abs: @@ -68,6 +212,7 @@ static bool isIntrinsicExpansion(Function &F) { case Intrinsic::dx_sclamp: case Intrinsic::dx_nclamp: case Intrinsic::dx_degrees: + case Intrinsic::dx_isinf: case Intrinsic::dx_lerp: case Intrinsic::dx_normalize: case Intrinsic::dx_fdot: @@ -301,13 +446,16 @@ static Value *expandIsFPClass(CallInst *Orig) { auto *TCI = dyn_cast<ConstantInt>(T); // These FPClassTest cases have DXIL opcodes, so they will be handled in - // DXIL Op Lowering instead. + // DXIL Op Lowering instead for all non f16 cases. switch (TCI->getZExtValue()) { case FPClassTest::fcInf: + return expand16BitIsInf(Orig); case FPClassTest::fcNan: + return expand16BitIsNaN(Orig); case FPClassTest::fcNormal: + return expand16BitIsNormal(Orig); case FPClassTest::fcFinite: - return nullptr; + return expand16BitIsFinite(Orig); } IRBuilder<> Builder(Orig); @@ -873,6 +1021,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) { case Intrinsic::dx_degrees: Result = expandDegreesIntrinsic(Orig); break; + case Intrinsic::dx_isinf: + Result = expand16BitIsInf(Orig); + break; case Intrinsic::dx_lerp: Result = expandLerpIntrinsic(Orig); break; diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index bd421771e8ed..577b4624458b 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -220,7 +220,7 @@ public: removeResourceGlobals(CI); - auto *NameGlobal = dyn_cast<llvm::GlobalVariable>(CI->getArgOperand(5)); + auto *NameGlobal = dyn_cast<llvm::GlobalVariable>(CI->getArgOperand(4)); CI->replaceAllUsesWith(Replacement); CI->eraseFromParent(); @@ -233,6 +233,7 @@ public: IRBuilder<> &IRB = OpBuilder.getIRB(); Type *Int8Ty = IRB.getInt8Ty(); Type *Int32Ty = IRB.getInt32Ty(); + Type *Int1Ty = IRB.getInt1Ty(); return replaceFunction(F, [&](CallInst *CI) -> Error { IRB.SetInsertPoint(CI); @@ -249,10 +250,13 @@ public: IndexOp = IRB.CreateAdd(IndexOp, ConstantInt::get(Int32Ty, Binding.LowerBound)); + // FIXME: The last argument is a NonUniform flag which needs to be set + // based on resource analysis. + // https://github.com/llvm/llvm-project/issues/155701 std::array<Value *, 4> Args{ ConstantInt::get(Int8Ty, llvm::to_underlying(RC)), ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp, - CI->getArgOperand(4)}; + ConstantInt::get(Int1Ty, false)}; Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName()); if (Error E = OpCall.takeError()) @@ -267,6 +271,7 @@ public: [[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) { IRBuilder<> &IRB = OpBuilder.getIRB(); Type *Int32Ty = IRB.getInt32Ty(); + Type *Int1Ty = IRB.getInt1Ty(); return replaceFunction(F, [&](CallInst *CI) -> Error { IRB.SetInsertPoint(CI); @@ -295,7 +300,11 @@ public: : Binding.LowerBound + Binding.Size - 1; Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound, Binding.Space, RC); - std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)}; + // FIXME: The last argument is a NonUniform flag which needs to be set + // based on resource analysis. + // https://github.com/llvm/llvm-project/issues/155701 + Constant *NonUniform = ConstantInt::get(Int1Ty, false); + std::array<Value *, 3> BindArgs{ResBind, IndexOp, NonUniform}; Expected<CallInst *> OpBind = OpBuilder.tryCreateOp( OpCode::CreateHandleFromBinding, BindArgs, CI->getName()); if (Error E = OpBind.takeError()) diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp index be2c7d1ddff3..d02f4b9f7ebc 100644 --- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp +++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp @@ -25,21 +25,6 @@ using namespace llvm; using namespace llvm::dxil; -static ResourceClass toResourceClass(dxbc::DescriptorRangeType RangeType) { - using namespace dxbc; - switch (RangeType) { - case DescriptorRangeType::SRV: - return ResourceClass::SRV; - case DescriptorRangeType::UAV: - return ResourceClass::UAV; - case DescriptorRangeType::CBV: - return ResourceClass::CBuffer; - case DescriptorRangeType::Sampler: - return ResourceClass::Sampler; - } - llvm_unreachable("Unknown DescriptorRangeType"); -} - static ResourceClass toResourceClass(dxbc::RootParameterType Type) { using namespace dxbc; switch (Type) { @@ -95,7 +80,7 @@ static void reportOverlappingError(Module &M, ResourceInfo R1, } static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) { - bool ErrorFound = false; + [[maybe_unused]] bool ErrorFound = false; for (const auto &ResList : {DRM.srvs(), DRM.uavs(), DRM.cbuffers(), DRM.samplers()}) { if (ResList.empty()) @@ -118,10 +103,8 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) { "true, yet no overlapping binding was found"); } -static void -reportOverlappingRegisters(Module &M, - const llvm::hlsl::BindingInfoBuilder::Binding &R1, - const llvm::hlsl::BindingInfoBuilder::Binding &R2) { +static void reportOverlappingRegisters(Module &M, const llvm::hlsl::Binding &R1, + const llvm::hlsl::Binding &R2) { SmallString<128> Message; raw_svector_ostream OS(Message); @@ -133,6 +116,17 @@ reportOverlappingRegisters(Module &M, M.getContext().diagnose(DiagnosticInfoGeneric(Message)); } +static void +reportRegNotBound(Module &M, ResourceClass Class, + const llvm::dxil::ResourceInfo::ResourceBinding &Unbound) { + SmallString<128> Message; + raw_svector_ostream OS(Message); + OS << getResourceClassName(Class) << " register " << Unbound.LowerBound + << " in space " << Unbound.Space + << " does not have a binding in the Root Signature"; + M.getContext().diagnose(DiagnosticInfoGeneric(Message)); +} + static dxbc::ShaderVisibility tripleToVisibility(llvm::Triple::EnvironmentType ET) { switch (ET) { @@ -157,22 +151,23 @@ tripleToVisibility(llvm::Triple::EnvironmentType ET) { static void validateRootSignature(Module &M, const mcdxbc::RootSignatureDesc &RSD, - dxil::ModuleMetadataInfo &MMI) { + dxil::ModuleMetadataInfo &MMI, + DXILResourceMap &DRM, + DXILResourceTypeMap &DRTM) { hlsl::BindingInfoBuilder Builder; dxbc::ShaderVisibility Visibility = tripleToVisibility(MMI.ShaderProfile); for (const mcdxbc::RootParameterInfo &ParamInfo : RSD.ParametersContainer) { dxbc::ShaderVisibility ParamVisibility = - static_cast<dxbc::ShaderVisibility>(ParamInfo.Header.ShaderVisibility); + dxbc::ShaderVisibility(ParamInfo.Visibility); if (ParamVisibility != dxbc::ShaderVisibility::All && ParamVisibility != Visibility) continue; - dxbc::RootParameterType ParamType = - static_cast<dxbc::RootParameterType>(ParamInfo.Header.ParameterType); + dxbc::RootParameterType ParamType = dxbc::RootParameterType(ParamInfo.Type); switch (ParamType) { case dxbc::RootParameterType::Constants32Bit: { - dxbc::RTS0::v1::RootConstants Const = + mcdxbc::RootConstants Const = RSD.ParametersContainer.getConstant(ParamInfo.Location); Builder.trackBinding(dxil::ResourceClass::CBuffer, Const.RegisterSpace, Const.ShaderRegister, Const.ShaderRegister, @@ -183,12 +178,11 @@ static void validateRootSignature(Module &M, case dxbc::RootParameterType::SRV: case dxbc::RootParameterType::UAV: case dxbc::RootParameterType::CBV: { - dxbc::RTS0::v2::RootDescriptor Desc = + mcdxbc::RootDescriptor Desc = RSD.ParametersContainer.getRootDescriptor(ParamInfo.Location); - Builder.trackBinding(toResourceClass(static_cast<dxbc::RootParameterType>( - ParamInfo.Header.ParameterType)), - Desc.RegisterSpace, Desc.ShaderRegister, - Desc.ShaderRegister, &ParamInfo); + Builder.trackBinding(toResourceClass(ParamInfo.Type), Desc.RegisterSpace, + Desc.ShaderRegister, Desc.ShaderRegister, + &ParamInfo); break; } @@ -196,16 +190,13 @@ static void validateRootSignature(Module &M, const mcdxbc::DescriptorTable &Table = RSD.ParametersContainer.getDescriptorTable(ParamInfo.Location); - for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) { + for (const mcdxbc::DescriptorRange &Range : Table.Ranges) { uint32_t UpperBound = Range.NumDescriptors == ~0U ? Range.BaseShaderRegister : Range.BaseShaderRegister + Range.NumDescriptors - 1; - Builder.trackBinding( - toResourceClass( - static_cast<dxbc::DescriptorRangeType>(Range.RangeType)), - Range.RegisterSpace, Range.BaseShaderRegister, UpperBound, - &ParamInfo); + Builder.trackBinding(Range.RangeType, Range.RegisterSpace, + Range.BaseShaderRegister, UpperBound, &ParamInfo); } break; } @@ -218,11 +209,19 @@ static void validateRootSignature(Module &M, Builder.calculateBindingInfo( [&M](const llvm::hlsl::BindingInfoBuilder &Builder, - const llvm::hlsl::BindingInfoBuilder::Binding &ReportedBinding) { - const llvm::hlsl::BindingInfoBuilder::Binding &Overlaping = + const llvm::hlsl::Binding &ReportedBinding) { + const llvm::hlsl::Binding &Overlaping = Builder.findOverlapping(ReportedBinding); reportOverlappingRegisters(M, ReportedBinding, Overlaping); }); + const hlsl::BoundRegs &BoundRegs = Builder.takeBoundRegs(); + for (const ResourceInfo &RI : DRM) { + const ResourceInfo::ResourceBinding &Binding = RI.getBinding(); + ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass(); + if (!BoundRegs.isBound(RC, Binding.Space, Binding.LowerBound, + Binding.LowerBound + Binding.Size - 1)) + reportRegNotBound(M, RC, Binding); + } } static mcdxbc::RootSignatureDesc * @@ -236,7 +235,8 @@ getRootSignature(RootSignatureBindingInfo &RSBI, static void reportErrors(Module &M, DXILResourceMap &DRM, DXILResourceBindingInfo &DRBI, RootSignatureBindingInfo &RSBI, - dxil::ModuleMetadataInfo &MMI) { + dxil::ModuleMetadataInfo &MMI, + DXILResourceTypeMap &DRTM) { if (DRM.hasInvalidCounterDirection()) reportInvalidDirection(M, DRM); @@ -247,7 +247,7 @@ static void reportErrors(Module &M, DXILResourceMap &DRM, "DXILResourceImplicitBinding pass"); if (mcdxbc::RootSignatureDesc *RSD = getRootSignature(RSBI, MMI)) - validateRootSignature(M, *RSD, MMI); + validateRootSignature(M, *RSD, MMI, DRM, DRTM); } PreservedAnalyses @@ -256,8 +256,9 @@ DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) { DXILResourceBindingInfo &DRBI = MAM.getResult<DXILResourceBindingAnalysis>(M); RootSignatureBindingInfo &RSBI = MAM.getResult<RootSignatureAnalysis>(M); ModuleMetadataInfo &MMI = MAM.getResult<DXILMetadataAnalysis>(M); + DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M); - reportErrors(M, DRM, DRBI, RSBI, MMI); + reportErrors(M, DRM, DRBI, RSBI, MMI, DRTM); return PreservedAnalyses::all(); } @@ -273,8 +274,10 @@ public: getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo(); dxil::ModuleMetadataInfo &MMI = getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); + DXILResourceTypeMap &DRTM = + getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); - reportErrors(M, DRM, DRBI, RSBI, MMI); + reportErrors(M, DRM, DRBI, RSBI, MMI, DRTM); return false; } StringRef getPassName() const override { @@ -288,6 +291,7 @@ public: AU.addRequired<DXILResourceBindingWrapperPass>(); AU.addRequired<DXILMetadataAnalysisWrapperPass>(); AU.addRequired<RootSignatureAnalysisWrapper>(); + AU.addRequired<DXILResourceTypeWrapperPass>(); AU.addPreserved<DXILResourceWrapperPass>(); AU.addPreserved<DXILResourceBindingWrapperPass>(); AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); @@ -305,6 +309,7 @@ INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) INITIALIZE_PASS_DEPENDENCY(RootSignatureAnalysisWrapper) +INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) INITIALIZE_PASS_END(DXILPostOptimizationValidationLegacy, DEBUG_TYPE, "DXIL Post Optimization Validation", false, false) diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp index c33ec0efd73c..6579d3405cf3 100644 --- a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp +++ b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp @@ -8,14 +8,19 @@ #include "DXILResourceAccess.h" #include "DirectX.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Analysis/DXILResource.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsDirectX.h" +#include "llvm/IR/User.h" #include "llvm/InitializePasses.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #define DEBUG_TYPE "dxil-resource-access" @@ -198,6 +203,112 @@ static void createLoadIntrinsic(IntrinsicInst *II, LoadInst *LI, Value *Offset, llvm_unreachable("Unhandled case in switch"); } +static SmallVector<Instruction *> collectBlockUseDef(Instruction *Start) { + SmallPtrSet<Instruction *, 32> Visited; + SmallVector<Instruction *, 32> Worklist; + SmallVector<Instruction *> Out; + auto *BB = Start->getParent(); + + // Seed with direct users in this block. + for (User *U : Start->users()) { + if (auto *I = dyn_cast<Instruction>(U)) { + if (I->getParent() == BB) + Worklist.push_back(I); + } + } + + // BFS over transitive users, constrained to the same block. + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + if (!Visited.insert(I).second) + continue; + Out.push_back(I); + + for (User *U : I->users()) { + if (auto *J = dyn_cast<Instruction>(U)) { + if (J->getParent() == BB) + Worklist.push_back(J); + } + } + for (Use &V : I->operands()) { + if (auto *J = dyn_cast<Instruction>(V)) { + if (J->getParent() == BB && V != Start) + Worklist.push_back(J); + } + } + } + + // Order results in program order. + DenseMap<const Instruction *, unsigned> Ord; + unsigned Idx = 0; + for (Instruction &I : *BB) + Ord[&I] = Idx++; + + llvm::sort(Out, [&](Instruction *A, Instruction *B) { + return Ord.lookup(A) < Ord.lookup(B); + }); + + return Out; +} + +static void phiNodeRemapHelper(PHINode *Phi, BasicBlock *BB, + IRBuilder<> &Builder, + SmallVector<Instruction *> &UsesInBlock) { + + ValueToValueMapTy VMap; + Value *Val = Phi->getIncomingValueForBlock(BB); + VMap[Phi] = Val; + Builder.SetInsertPoint(&BB->back()); + for (Instruction *I : UsesInBlock) { + // don't clone over the Phi just remap them + if (auto *PhiNested = dyn_cast<PHINode>(I)) { + VMap[PhiNested] = PhiNested->getIncomingValueForBlock(BB); + continue; + } + Instruction *Clone = I->clone(); + RemapInstruction(Clone, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + Builder.Insert(Clone); + VMap[I] = Clone; + } +} + +static void phiNodeReplacement(IntrinsicInst *II, + SmallVectorImpl<Instruction *> &PrevBBDeadInsts, + SetVector<BasicBlock *> &DeadBB) { + SmallVector<Instruction *> CurrBBDeadInsts; + for (User *U : II->users()) { + auto *Phi = dyn_cast<PHINode>(U); + if (!Phi) + continue; + + IRBuilder<> Builder(Phi); + SmallVector<Instruction *> UsesInBlock = collectBlockUseDef(Phi); + bool HasReturnUse = isa<ReturnInst>(UsesInBlock.back()); + + for (unsigned I = 0, E = Phi->getNumIncomingValues(); I < E; I++) { + auto *CurrIncomingBB = Phi->getIncomingBlock(I); + phiNodeRemapHelper(Phi, CurrIncomingBB, Builder, UsesInBlock); + if (HasReturnUse) + PrevBBDeadInsts.push_back(&CurrIncomingBB->back()); + } + + CurrBBDeadInsts.push_back(Phi); + + for (Instruction *I : UsesInBlock) { + CurrBBDeadInsts.push_back(I); + } + if (HasReturnUse) { + BasicBlock *PhiBB = Phi->getParent(); + DeadBB.insert(PhiBB); + } + } + // Traverse the now-dead instructions in RPO and remove them. + for (Instruction *Dead : llvm::reverse(CurrBBDeadInsts)) + Dead->eraseFromParent(); + CurrBBDeadInsts.clear(); +} + static void replaceAccess(IntrinsicInst *II, dxil::ResourceTypeInfo &RTI) { // Process users keeping track of indexing accumulated from GEPs. struct AccessAndOffset { @@ -229,7 +340,6 @@ static void replaceAccess(IntrinsicInst *II, dxil::ResourceTypeInfo &RTI) { } else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) { createLoadIntrinsic(II, LI, Current.Offset, RTI); DeadInsts.push_back(LI); - } else llvm_unreachable("Unhandled instruction - pointer escaped?"); } @@ -242,13 +352,27 @@ static void replaceAccess(IntrinsicInst *II, dxil::ResourceTypeInfo &RTI) { static bool transformResourcePointers(Function &F, DXILResourceTypeMap &DRTM) { SmallVector<std::pair<IntrinsicInst *, dxil::ResourceTypeInfo>> Resources; - for (BasicBlock &BB : F) + SetVector<BasicBlock *> DeadBB; + SmallVector<Instruction *> PrevBBDeadInsts; + for (BasicBlock &BB : make_early_inc_range(F)) { + for (Instruction &I : make_early_inc_range(BB)) + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer) + phiNodeReplacement(II, PrevBBDeadInsts, DeadBB); + for (Instruction &I : BB) if (auto *II = dyn_cast<IntrinsicInst>(&I)) if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer) { auto *HandleTy = cast<TargetExtType>(II->getArgOperand(0)->getType()); Resources.emplace_back(II, DRTM[HandleTy]); } + } + for (auto *Dead : PrevBBDeadInsts) + Dead->eraseFromParent(); + PrevBBDeadInsts.clear(); + for (auto *Dead : DeadBB) + Dead->eraseFromParent(); + DeadBB.clear(); for (auto &[II, RI] : Resources) replaceAccess(II, RI); @@ -279,7 +403,6 @@ public: bool runOnFunction(Function &F) override { DXILResourceTypeMap &DRTM = getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); - return transformResourcePointers(F, DRTM); } StringRef getPassName() const override { return "DXIL Resource Access"; } diff --git a/llvm/lib/Target/DirectX/DXILResourceImplicitBinding.cpp b/llvm/lib/Target/DirectX/DXILResourceImplicitBinding.cpp index 6e69c5ac1d63..b0d9ad8da10e 100644 --- a/llvm/lib/Target/DirectX/DXILResourceImplicitBinding.cpp +++ b/llvm/lib/Target/DirectX/DXILResourceImplicitBinding.cpp @@ -111,8 +111,7 @@ static bool assignBindings(Module &M, DXILResourceBindingInfo &DRBI, RegSlotOp, /* register slot */ IB.Call->getOperand(2), /* size */ IB.Call->getOperand(3), /* index */ - IB.Call->getOperand(4), /* non-uniform flag */ - IB.Call->getOperand(5)}); /* name */ + IB.Call->getOperand(4)}); /* name */ IB.Call->replaceAllUsesWith(NewCall); IB.Call->eraseFromParent(); Changed = true; diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index a4f5086c2f42..ac3c7dde6b89 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -24,9 +24,11 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" +#include "llvm/MC/DXContainerRootSignature.h" #include "llvm/Pass.h" #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" #include <cstdint> @@ -70,6 +72,13 @@ analyzeModule(Module &M) { if (RootSignatureNode == nullptr) return RSDMap; + bool AllowNullFunctions = false; + if (M.getTargetTriple().getEnvironment() == + Triple::EnvironmentType::RootSignature) { + assert(RootSignatureNode->getNumOperands() == 1); + AllowNullFunctions = true; + } + for (const auto &RSDefNode : RootSignatureNode->operands()) { if (RSDefNode->getNumOperands() != 3) { reportError(Ctx, "Invalid Root Signature metadata - expected function, " @@ -78,24 +87,28 @@ analyzeModule(Module &M) { } // Function was pruned during compilation. - const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0); - if (FunctionPointerMdNode == nullptr) { - reportError( - Ctx, "Function associated with Root Signature definition is null."); - continue; - } + Function *F = nullptr; + + if (!AllowNullFunctions) { + const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0); + if (FunctionPointerMdNode == nullptr) { + reportError( + Ctx, "Function associated with Root Signature definition is null."); + continue; + } - ValueAsMetadata *VAM = - llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get()); - if (VAM == nullptr) { - reportError(Ctx, "First element of root signature is not a Value"); - continue; - } + ValueAsMetadata *VAM = + llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get()); + if (VAM == nullptr) { + reportError(Ctx, "First element of root signature is not a Value"); + continue; + } - Function *F = dyn_cast<Function>(VAM->getValue()); - if (F == nullptr) { - reportError(Ctx, "First element of root signature is not a Function"); - continue; + F = dyn_cast<Function>(VAM->getValue()); + if (F == nullptr) { + reportError(Ctx, "First element of root signature is not a Function"); + continue; + } } Metadata *RootElementListOperand = RSDefNode->getOperand(1).get(); @@ -171,41 +184,41 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, << "RootParametersOffset: " << RS.RootParameterOffset << "\n" << "NumParameters: " << RS.ParametersContainer.size() << "\n"; for (size_t I = 0; I < RS.ParametersContainer.size(); I++) { - const auto &[Type, Loc] = - RS.ParametersContainer.getTypeAndLocForParameter(I); - const dxbc::RTS0::v1::RootParameterHeader Header = - RS.ParametersContainer.getHeader(I); - - OS << "- Parameter Type: " << Type << "\n" - << " Shader Visibility: " << Header.ShaderVisibility << "\n"; - - switch (Type) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): { - const dxbc::RTS0::v1::RootConstants &Constants = - RS.ParametersContainer.getConstant(Loc); + const mcdxbc::RootParameterInfo &Info = RS.ParametersContainer.getInfo(I); + + OS << "- Parameter Type: " + << enumToStringRef(Info.Type, dxbc::getRootParameterTypes()) << "\n" + << " Shader Visibility: " + << enumToStringRef(Info.Visibility, dxbc::getShaderVisibility()) + << "\n"; + switch (Info.Type) { + case dxbc::RootParameterType::Constants32Bit: { + const mcdxbc::RootConstants &Constants = + RS.ParametersContainer.getConstant(Info.Location); OS << " Register Space: " << Constants.RegisterSpace << "\n" << " Shader Register: " << Constants.ShaderRegister << "\n" << " Num 32 Bit Values: " << Constants.Num32BitValues << "\n"; break; } - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): { - const dxbc::RTS0::v2::RootDescriptor &Descriptor = - RS.ParametersContainer.getRootDescriptor(Loc); + case dxbc::RootParameterType::CBV: + case dxbc::RootParameterType::UAV: + case dxbc::RootParameterType::SRV: { + const mcdxbc::RootDescriptor &Descriptor = + RS.ParametersContainer.getRootDescriptor(Info.Location); OS << " Register Space: " << Descriptor.RegisterSpace << "\n" << " Shader Register: " << Descriptor.ShaderRegister << "\n"; if (RS.Version > 1) OS << " Flags: " << Descriptor.Flags << "\n"; break; } - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + case dxbc::RootParameterType::DescriptorTable: { const mcdxbc::DescriptorTable &Table = - RS.ParametersContainer.getDescriptorTable(Loc); + RS.ParametersContainer.getDescriptorTable(Info.Location); OS << " NumRanges: " << Table.Ranges.size() << "\n"; - for (const dxbc::RTS0::v2::DescriptorRange Range : Table) { - OS << " - Range Type: " << Range.RangeType << "\n" + for (const mcdxbc::DescriptorRange &Range : Table) { + OS << " - Range Type: " + << dxil::getResourceClassName(Range.RangeType) << "\n" << " Register Space: " << Range.RegisterSpace << "\n" << " Base Shader Register: " << Range.BaseShaderRegister << "\n" << " Num Descriptors: " << Range.NumDescriptors << "\n" diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index 82bcacee7a6d..9eebcc9b1306 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -127,6 +127,8 @@ static StringRef getShortShaderStage(Triple::EnvironmentType Env) { return "ms"; case Triple::Amplification: return "as"; + case Triple::RootSignature: + return "rootsig"; default: break; } diff --git a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp index 1d79c3018439..bc1a3a7995bd 100644 --- a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp +++ b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp @@ -2113,7 +2113,7 @@ void DXILBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal, } break; case Instruction::GetElementPtr: { - Code = bitc::CST_CODE_CE_GEP; + Code = bitc::CST_CODE_CE_GEP_OLD; const auto *GO = cast<GEPOperator>(C); if (GO->isInBounds()) Code = bitc::CST_CODE_CE_INBOUNDS_GEP; diff --git a/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp b/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp index f99bb4f4eaee..c2e139edc6bd 100644 --- a/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp +++ b/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp @@ -15,25 +15,39 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" using namespace llvm; using namespace llvm::dxil; namespace { +Type *classifyFunctionType(const Function &F, PointerTypeMap &Map); + // Classifies the type of the value passed in by walking the value's users to // find a typed instruction to materialize a type from. Type *classifyPointerType(const Value *V, PointerTypeMap &Map) { assert(V->getType()->isPointerTy() && "classifyPointerType called with non-pointer"); + + // A CallInst will trigger this case, and we want to classify its Function + // operand as a Function rather than a generic Value. + if (const Function *F = dyn_cast<Function>(V)) + return classifyFunctionType(*F, Map); + + // There can potentially be dead constants hanging off of the globals we do + // not want to deal with. So we remove them here. + if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) + GV->removeDeadConstantUsers(); + auto It = Map.find(V); if (It != Map.end()) return It->second; Type *PointeeTy = nullptr; - if (auto *Inst = dyn_cast<GetElementPtrInst>(V)) { - if (!Inst->getResultElementType()->isPointerTy()) - PointeeTy = Inst->getResultElementType(); + if (auto *GEP = dyn_cast<GEPOperator>(V)) { + if (!GEP->getResultElementType()->isPointerTy()) + PointeeTy = GEP->getResultElementType(); } else if (auto *Inst = dyn_cast<AllocaInst>(V)) { PointeeTy = Inst->getAllocatedType(); } else if (auto *GV = dyn_cast<GlobalVariable>(V)) { @@ -49,8 +63,8 @@ Type *classifyPointerType(const Value *V, PointerTypeMap &Map) { // When store value is ptr type, cannot get more type info. if (NewPointeeTy->isPointerTy()) continue; - } else if (const auto *Inst = dyn_cast<GetElementPtrInst>(User)) { - NewPointeeTy = Inst->getSourceElementType(); + } else if (const auto *GEP = dyn_cast<GEPOperator>(User)) { + NewPointeeTy = GEP->getSourceElementType(); } if (NewPointeeTy) { // HLSL doesn't support pointers, so it is unlikely to get more than one @@ -204,6 +218,9 @@ PointerTypeMap PointerTypeAnalysis::run(const Module &M) { for (const auto &I : B) { if (I.getType()->isPointerTy()) classifyPointerType(&I, Map); + for (const auto &O : I.operands()) + if (O.get()->getType()->isPointerTy()) + classifyPointerType(O.get(), Map); } } } diff --git a/llvm/lib/Target/DirectX/DirectXInstrInfo.cpp b/llvm/lib/Target/DirectX/DirectXInstrInfo.cpp index 07b68648f16c..bb2efa43d818 100644 --- a/llvm/lib/Target/DirectX/DirectXInstrInfo.cpp +++ b/llvm/lib/Target/DirectX/DirectXInstrInfo.cpp @@ -11,10 +11,14 @@ //===----------------------------------------------------------------------===// #include "DirectXInstrInfo.h" +#include "DirectXSubtarget.h" #define GET_INSTRINFO_CTOR_DTOR #include "DirectXGenInstrInfo.inc" using namespace llvm; +DirectXInstrInfo::DirectXInstrInfo(const DirectXSubtarget &STI) + : DirectXGenInstrInfo(STI) {} + DirectXInstrInfo::~DirectXInstrInfo() {} diff --git a/llvm/lib/Target/DirectX/DirectXInstrInfo.h b/llvm/lib/Target/DirectX/DirectXInstrInfo.h index e2c7036fc74a..57ede28030b2 100644 --- a/llvm/lib/Target/DirectX/DirectXInstrInfo.h +++ b/llvm/lib/Target/DirectX/DirectXInstrInfo.h @@ -20,9 +20,11 @@ #include "DirectXGenInstrInfo.inc" namespace llvm { +class DirectXSubtarget; + struct DirectXInstrInfo : public DirectXGenInstrInfo { const DirectXRegisterInfo RI; - explicit DirectXInstrInfo() : DirectXGenInstrInfo() {} + explicit DirectXInstrInfo(const DirectXSubtarget &STI); const DirectXRegisterInfo &getRegisterInfo() const { return RI; } ~DirectXInstrInfo() override; }; diff --git a/llvm/lib/Target/DirectX/DirectXSubtarget.cpp b/llvm/lib/Target/DirectX/DirectXSubtarget.cpp index 526b7d29fb13..f8519177cc2d 100644 --- a/llvm/lib/Target/DirectX/DirectXSubtarget.cpp +++ b/llvm/lib/Target/DirectX/DirectXSubtarget.cpp @@ -24,6 +24,7 @@ using namespace llvm; DirectXSubtarget::DirectXSubtarget(const Triple &TT, StringRef CPU, StringRef FS, const DirectXTargetMachine &TM) - : DirectXGenSubtargetInfo(TT, CPU, CPU, FS), FL(*this), TL(TM, *this) {} + : DirectXGenSubtargetInfo(TT, CPU, CPU, FS), InstrInfo(*this), FL(*this), + TL(TM, *this) {} void DirectXSubtarget::anchor() {} diff --git a/llvm/lib/Target/DirectX/DirectXSubtarget.h b/llvm/lib/Target/DirectX/DirectXSubtarget.h index b2374caaf3cd..f3d71c4c4e3b 100644 --- a/llvm/lib/Target/DirectX/DirectXSubtarget.h +++ b/llvm/lib/Target/DirectX/DirectXSubtarget.h @@ -28,9 +28,9 @@ namespace llvm { class DirectXTargetMachine; class DirectXSubtarget : public DirectXGenSubtargetInfo { + DirectXInstrInfo InstrInfo; DirectXFrameLowering FL; DirectXTargetLowering TL; - DirectXInstrInfo InstrInfo; virtual void anchor(); // virtual anchor method |
