summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX
diff options
context:
space:
mode:
authorMingming Liu <mingmingl@google.com>2025-09-10 15:25:31 -0700
committerGitHub <noreply@github.com>2025-09-10 15:25:31 -0700
commit1417dafa1db9cb1b2b09438aa9f53ea5ab6e36e2 (patch)
tree57f4b1f313c8cf74eed8819870f39c36ea263c68 /llvm/lib/Target/DirectX
parent898b813bc8a6d0276bf0f4769f5f2f64b34e632d (diff)
parentb8cefcb601ddaa18482555c4ff363c01a270c2fe (diff)
Merge branch 'main' into users/mingmingl-llvm/samplefdo-profile-formatusers/mingmingl-llvm/samplefdo-profile-format
Diffstat (limited to 'llvm/lib/Target/DirectX')
-rw-r--r--llvm/lib/Target/DirectX/CMakeLists.txt1
-rw-r--r--llvm/lib/Target/DirectX/DXContainerGlobals.cpp17
-rw-r--r--llvm/lib/Target/DirectX/DXILDataScalarization.cpp4
-rw-r--r--llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp4
-rw-r--r--llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp155
-rw-r--r--llvm/lib/Target/DirectX/DXILOpLowering.cpp15
-rw-r--r--llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp89
-rw-r--r--llvm/lib/Target/DirectX/DXILResourceAccess.cpp129
-rw-r--r--llvm/lib/Target/DirectX/DXILResourceImplicitBinding.cpp3
-rw-r--r--llvm/lib/Target/DirectX/DXILRootSignature.cpp87
-rw-r--r--llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp2
-rw-r--r--llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp2
-rw-r--r--llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp27
-rw-r--r--llvm/lib/Target/DirectX/DirectXInstrInfo.cpp4
-rw-r--r--llvm/lib/Target/DirectX/DirectXInstrInfo.h4
-rw-r--r--llvm/lib/Target/DirectX/DirectXSubtarget.cpp3
-rw-r--r--llvm/lib/Target/DirectX/DirectXSubtarget.h2
17 files changed, 440 insertions, 108 deletions
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 8100f941c8d9..6c079517e22d 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -41,6 +41,7 @@ add_llvm_target(DirectXCodeGen
LINK_COMPONENTS
Analysis
AsmPrinter
+ BinaryFormat
CodeGen
CodeGenTypes
Core
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index a1ef2578f00a..ca81d30473c0 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -158,12 +158,15 @@ void DXContainerGlobals::addRootSignature(Module &M,
if (MMI.ShaderProfile == llvm::Triple::Library)
return;
- assert(MMI.EntryPropertyVec.size() == 1);
-
auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
- const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry;
- const mcdxbc::RootSignatureDesc *RS = RSA.getDescForFunction(EntryFunction);
+ const Function *EntryFunction = nullptr;
+ if (MMI.ShaderProfile != llvm::Triple::RootSignature) {
+ assert(MMI.EntryPropertyVec.size() == 1);
+ EntryFunction = MMI.EntryPropertyVec[0].Entry;
+ }
+
+ const mcdxbc::RootSignatureDesc *RS = RSA.getDescForFunction(EntryFunction);
if (!RS)
return;
@@ -258,7 +261,8 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
dxil::ModuleMetadataInfo &MMI =
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
assert(MMI.EntryPropertyVec.size() == 1 ||
- MMI.ShaderProfile == Triple::Library);
+ MMI.ShaderProfile == Triple::Library ||
+ MMI.ShaderProfile == Triple::RootSignature);
PSV.BaseData.ShaderStage =
static_cast<uint8_t>(MMI.ShaderProfile - Triple::Pixel);
@@ -279,7 +283,8 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
break;
}
- if (MMI.ShaderProfile != Triple::Library)
+ if (MMI.ShaderProfile != Triple::Library &&
+ MMI.ShaderProfile != Triple::RootSignature)
PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
PSV.finalize(MMI.ShaderProfile);
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index feecfc0880e2..d507d71b99fc 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -343,9 +343,7 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
GOp->replaceAllUsesWith(NewGEP);
- if (auto *CE = dyn_cast<ConstantExpr>(GOp))
- CE->destroyConstant();
- else if (auto *OldGEPI = dyn_cast<GetElementPtrInst>(GOp))
+ if (auto *OldGEPI = dyn_cast<GetElementPtrInst>(GOp))
OldGEPI->eraseFromParent();
return true;
diff --git a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
index 13e3408815bb..aa16e795dc76 100644
--- a/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
+++ b/llvm/lib/Target/DirectX/DXILFinalizeLinkage.cpp
@@ -22,11 +22,13 @@ static bool finalizeLinkage(Module &M) {
// Convert private globals and external globals with no usage to internal
// linkage.
- for (GlobalVariable &GV : M.globals())
+ for (GlobalVariable &GV : M.globals()) {
+ GV.removeDeadConstantUsers();
if (GV.hasPrivateLinkage() || (GV.hasExternalLinkage() && GV.use_empty())) {
GV.setLinkage(GlobalValue::InternalLinkage);
MadeChange = true;
}
+ }
SmallVector<Function *> Funcs;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ee1db54446cb..e2469d8df957 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -51,6 +51,150 @@ static bool resourceAccessNeeds64BitExpansion(Module *M, Type *OverloadTy,
return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
}
+static Value *expand16BitIsInf(CallInst *Orig) {
+ Module *M = Orig->getModule();
+ if (M->getTargetTriple().getDXILVersion() >= VersionTuple(1, 9))
+ return nullptr;
+
+ Value *Val = Orig->getOperand(0);
+ Type *ValTy = Val->getType();
+ if (!ValTy->getScalarType()->isHalfTy())
+ return nullptr;
+
+ IRBuilder<> Builder(Orig);
+ Type *IType = Type::getInt16Ty(M->getContext());
+ Constant *PosInf =
+ ValTy->isVectorTy()
+ ? ConstantVector::getSplat(
+ ElementCount::getFixed(
+ cast<FixedVectorType>(ValTy)->getNumElements()),
+ ConstantInt::get(IType, 0x7c00))
+ : ConstantInt::get(IType, 0x7c00);
+
+ Constant *NegInf =
+ ValTy->isVectorTy()
+ ? ConstantVector::getSplat(
+ ElementCount::getFixed(
+ cast<FixedVectorType>(ValTy)->getNumElements()),
+ ConstantInt::get(IType, 0xfc00))
+ : ConstantInt::get(IType, 0xfc00);
+
+ Value *IVal = Builder.CreateBitCast(Val, PosInf->getType());
+ Value *B1 = Builder.CreateICmpEQ(IVal, PosInf);
+ Value *B2 = Builder.CreateICmpEQ(IVal, NegInf);
+ Value *B3 = Builder.CreateOr(B1, B2);
+ return B3;
+}
+
+static Value *expand16BitIsNaN(CallInst *Orig) {
+ Module *M = Orig->getModule();
+ if (M->getTargetTriple().getDXILVersion() >= VersionTuple(1, 9))
+ return nullptr;
+
+ Value *Val = Orig->getOperand(0);
+ Type *ValTy = Val->getType();
+ if (!ValTy->getScalarType()->isHalfTy())
+ return nullptr;
+
+ IRBuilder<> Builder(Orig);
+ Type *IType = Type::getInt16Ty(M->getContext());
+
+ Constant *ExpBitMask =
+ ValTy->isVectorTy()
+ ? ConstantVector::getSplat(
+ ElementCount::getFixed(
+ cast<FixedVectorType>(ValTy)->getNumElements()),
+ ConstantInt::get(IType, 0x7c00))
+ : ConstantInt::get(IType, 0x7c00);
+ Constant *SigBitMask =
+ ValTy->isVectorTy()
+ ? ConstantVector::getSplat(
+ ElementCount::getFixed(
+ cast<FixedVectorType>(ValTy)->getNumElements()),
+ ConstantInt::get(IType, 0x3ff))
+ : ConstantInt::get(IType, 0x3ff);
+
+ Constant *Zero =
+ ValTy->isVectorTy()
+ ? ConstantVector::getSplat(
+ ElementCount::getFixed(
+ cast<FixedVectorType>(ValTy)->getNumElements()),
+ ConstantInt::get(IType, 0))
+ : ConstantInt::get(IType, 0);
+
+ Value *IVal = Builder.CreateBitCast(Val, ExpBitMask->getType());
+ Value *Exp = Builder.CreateAnd(IVal, ExpBitMask);
+ Value *B1 = Builder.CreateICmpEQ(Exp, ExpBitMask);
+
+ Value *Sig = Builder.CreateAnd(IVal, SigBitMask);
+ Value *B2 = Builder.CreateICmpNE(Sig, Zero);
+ Value *B3 = Builder.CreateAnd(B1, B2);
+ return B3;
+}
+
+static Value *expand16BitIsFinite(CallInst *Orig) {
+ Module *M = Orig->getModule();
+ if (M->getTargetTriple().getDXILVersion() >= VersionTuple(1, 9))
+ return nullptr;
+
+ Value *Val = Orig->getOperand(0);
+ Type *ValTy = Val->getType();
+ if (!ValTy->getScalarType()->isHalfTy())
+ return nullptr;
+
+ IRBuilder<> Builder(Orig);
+ Type *IType = Type::getInt16Ty(M->getContext());
+
+ Constant *ExpBitMask =
+ ValTy->isVectorTy()
+ ? ConstantVector::getSplat(
+ ElementCount::getFixed(
+ cast<FixedVectorType>(ValTy)->getNumElements()),
+ ConstantInt::get(IType, 0x7c00))
+ : ConstantInt::get(IType, 0x7c00);
+
+ Value *IVal = Builder.CreateBitCast(Val, ExpBitMask->getType());
+ Value *Exp = Builder.CreateAnd(IVal, ExpBitMask);
+ Value *B1 = Builder.CreateICmpNE(Exp, ExpBitMask);
+ return B1;
+}
+
+static Value *expand16BitIsNormal(CallInst *Orig) {
+ Module *M = Orig->getModule();
+ if (M->getTargetTriple().getDXILVersion() >= VersionTuple(1, 9))
+ return nullptr;
+
+ Value *Val = Orig->getOperand(0);
+ Type *ValTy = Val->getType();
+ if (!ValTy->getScalarType()->isHalfTy())
+ return nullptr;
+
+ IRBuilder<> Builder(Orig);
+ Type *IType = Type::getInt16Ty(M->getContext());
+
+ Constant *ExpBitMask =
+ ValTy->isVectorTy()
+ ? ConstantVector::getSplat(
+ ElementCount::getFixed(
+ cast<FixedVectorType>(ValTy)->getNumElements()),
+ ConstantInt::get(IType, 0x7c00))
+ : ConstantInt::get(IType, 0x7c00);
+ Constant *Zero =
+ ValTy->isVectorTy()
+ ? ConstantVector::getSplat(
+ ElementCount::getFixed(
+ cast<FixedVectorType>(ValTy)->getNumElements()),
+ ConstantInt::get(IType, 0))
+ : ConstantInt::get(IType, 0);
+
+ Value *IVal = Builder.CreateBitCast(Val, ExpBitMask->getType());
+ Value *Exp = Builder.CreateAnd(IVal, ExpBitMask);
+ Value *NotAllZeroes = Builder.CreateICmpNE(Exp, Zero);
+ Value *NotAllOnes = Builder.CreateICmpNE(Exp, ExpBitMask);
+ Value *B1 = Builder.CreateAnd(NotAllZeroes, NotAllOnes);
+ return B1;
+}
+
static bool isIntrinsicExpansion(Function &F) {
switch (F.getIntrinsicID()) {
case Intrinsic::abs:
@@ -68,6 +212,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_sclamp:
case Intrinsic::dx_nclamp:
case Intrinsic::dx_degrees:
+ case Intrinsic::dx_isinf:
case Intrinsic::dx_lerp:
case Intrinsic::dx_normalize:
case Intrinsic::dx_fdot:
@@ -301,13 +446,16 @@ static Value *expandIsFPClass(CallInst *Orig) {
auto *TCI = dyn_cast<ConstantInt>(T);
// These FPClassTest cases have DXIL opcodes, so they will be handled in
- // DXIL Op Lowering instead.
+ // DXIL Op Lowering instead for all non f16 cases.
switch (TCI->getZExtValue()) {
case FPClassTest::fcInf:
+ return expand16BitIsInf(Orig);
case FPClassTest::fcNan:
+ return expand16BitIsNaN(Orig);
case FPClassTest::fcNormal:
+ return expand16BitIsNormal(Orig);
case FPClassTest::fcFinite:
- return nullptr;
+ return expand16BitIsFinite(Orig);
}
IRBuilder<> Builder(Orig);
@@ -873,6 +1021,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::dx_degrees:
Result = expandDegreesIntrinsic(Orig);
break;
+ case Intrinsic::dx_isinf:
+ Result = expand16BitIsInf(Orig);
+ break;
case Intrinsic::dx_lerp:
Result = expandLerpIntrinsic(Orig);
break;
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index bd421771e8ed..577b4624458b 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -220,7 +220,7 @@ public:
removeResourceGlobals(CI);
- auto *NameGlobal = dyn_cast<llvm::GlobalVariable>(CI->getArgOperand(5));
+ auto *NameGlobal = dyn_cast<llvm::GlobalVariable>(CI->getArgOperand(4));
CI->replaceAllUsesWith(Replacement);
CI->eraseFromParent();
@@ -233,6 +233,7 @@ public:
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int8Ty = IRB.getInt8Ty();
Type *Int32Ty = IRB.getInt32Ty();
+ Type *Int1Ty = IRB.getInt1Ty();
return replaceFunction(F, [&](CallInst *CI) -> Error {
IRB.SetInsertPoint(CI);
@@ -249,10 +250,13 @@ public:
IndexOp = IRB.CreateAdd(IndexOp,
ConstantInt::get(Int32Ty, Binding.LowerBound));
+ // FIXME: The last argument is a NonUniform flag which needs to be set
+ // based on resource analysis.
+ // https://github.com/llvm/llvm-project/issues/155701
std::array<Value *, 4> Args{
ConstantInt::get(Int8Ty, llvm::to_underlying(RC)),
ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp,
- CI->getArgOperand(4)};
+ ConstantInt::get(Int1Ty, false)};
Expected<CallInst *> OpCall =
OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName());
if (Error E = OpCall.takeError())
@@ -267,6 +271,7 @@ public:
[[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int32Ty = IRB.getInt32Ty();
+ Type *Int1Ty = IRB.getInt1Ty();
return replaceFunction(F, [&](CallInst *CI) -> Error {
IRB.SetInsertPoint(CI);
@@ -295,7 +300,11 @@ public:
: Binding.LowerBound + Binding.Size - 1;
Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound,
Binding.Space, RC);
- std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)};
+ // FIXME: The last argument is a NonUniform flag which needs to be set
+ // based on resource analysis.
+ // https://github.com/llvm/llvm-project/issues/155701
+ Constant *NonUniform = ConstantInt::get(Int1Ty, false);
+ std::array<Value *, 3> BindArgs{ResBind, IndexOp, NonUniform};
Expected<CallInst *> OpBind = OpBuilder.tryCreateOp(
OpCode::CreateHandleFromBinding, BindArgs, CI->getName());
if (Error E = OpBind.takeError())
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
index be2c7d1ddff3..d02f4b9f7ebc 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
@@ -25,21 +25,6 @@
using namespace llvm;
using namespace llvm::dxil;
-static ResourceClass toResourceClass(dxbc::DescriptorRangeType RangeType) {
- using namespace dxbc;
- switch (RangeType) {
- case DescriptorRangeType::SRV:
- return ResourceClass::SRV;
- case DescriptorRangeType::UAV:
- return ResourceClass::UAV;
- case DescriptorRangeType::CBV:
- return ResourceClass::CBuffer;
- case DescriptorRangeType::Sampler:
- return ResourceClass::Sampler;
- }
- llvm_unreachable("Unknown DescriptorRangeType");
-}
-
static ResourceClass toResourceClass(dxbc::RootParameterType Type) {
using namespace dxbc;
switch (Type) {
@@ -95,7 +80,7 @@ static void reportOverlappingError(Module &M, ResourceInfo R1,
}
static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
- bool ErrorFound = false;
+ [[maybe_unused]] bool ErrorFound = false;
for (const auto &ResList :
{DRM.srvs(), DRM.uavs(), DRM.cbuffers(), DRM.samplers()}) {
if (ResList.empty())
@@ -118,10 +103,8 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
"true, yet no overlapping binding was found");
}
-static void
-reportOverlappingRegisters(Module &M,
- const llvm::hlsl::BindingInfoBuilder::Binding &R1,
- const llvm::hlsl::BindingInfoBuilder::Binding &R2) {
+static void reportOverlappingRegisters(Module &M, const llvm::hlsl::Binding &R1,
+ const llvm::hlsl::Binding &R2) {
SmallString<128> Message;
raw_svector_ostream OS(Message);
@@ -133,6 +116,17 @@ reportOverlappingRegisters(Module &M,
M.getContext().diagnose(DiagnosticInfoGeneric(Message));
}
+static void
+reportRegNotBound(Module &M, ResourceClass Class,
+ const llvm::dxil::ResourceInfo::ResourceBinding &Unbound) {
+ SmallString<128> Message;
+ raw_svector_ostream OS(Message);
+ OS << getResourceClassName(Class) << " register " << Unbound.LowerBound
+ << " in space " << Unbound.Space
+ << " does not have a binding in the Root Signature";
+ M.getContext().diagnose(DiagnosticInfoGeneric(Message));
+}
+
static dxbc::ShaderVisibility
tripleToVisibility(llvm::Triple::EnvironmentType ET) {
switch (ET) {
@@ -157,22 +151,23 @@ tripleToVisibility(llvm::Triple::EnvironmentType ET) {
static void validateRootSignature(Module &M,
const mcdxbc::RootSignatureDesc &RSD,
- dxil::ModuleMetadataInfo &MMI) {
+ dxil::ModuleMetadataInfo &MMI,
+ DXILResourceMap &DRM,
+ DXILResourceTypeMap &DRTM) {
hlsl::BindingInfoBuilder Builder;
dxbc::ShaderVisibility Visibility = tripleToVisibility(MMI.ShaderProfile);
for (const mcdxbc::RootParameterInfo &ParamInfo : RSD.ParametersContainer) {
dxbc::ShaderVisibility ParamVisibility =
- static_cast<dxbc::ShaderVisibility>(ParamInfo.Header.ShaderVisibility);
+ dxbc::ShaderVisibility(ParamInfo.Visibility);
if (ParamVisibility != dxbc::ShaderVisibility::All &&
ParamVisibility != Visibility)
continue;
- dxbc::RootParameterType ParamType =
- static_cast<dxbc::RootParameterType>(ParamInfo.Header.ParameterType);
+ dxbc::RootParameterType ParamType = dxbc::RootParameterType(ParamInfo.Type);
switch (ParamType) {
case dxbc::RootParameterType::Constants32Bit: {
- dxbc::RTS0::v1::RootConstants Const =
+ mcdxbc::RootConstants Const =
RSD.ParametersContainer.getConstant(ParamInfo.Location);
Builder.trackBinding(dxil::ResourceClass::CBuffer, Const.RegisterSpace,
Const.ShaderRegister, Const.ShaderRegister,
@@ -183,12 +178,11 @@ static void validateRootSignature(Module &M,
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV:
case dxbc::RootParameterType::CBV: {
- dxbc::RTS0::v2::RootDescriptor Desc =
+ mcdxbc::RootDescriptor Desc =
RSD.ParametersContainer.getRootDescriptor(ParamInfo.Location);
- Builder.trackBinding(toResourceClass(static_cast<dxbc::RootParameterType>(
- ParamInfo.Header.ParameterType)),
- Desc.RegisterSpace, Desc.ShaderRegister,
- Desc.ShaderRegister, &ParamInfo);
+ Builder.trackBinding(toResourceClass(ParamInfo.Type), Desc.RegisterSpace,
+ Desc.ShaderRegister, Desc.ShaderRegister,
+ &ParamInfo);
break;
}
@@ -196,16 +190,13 @@ static void validateRootSignature(Module &M,
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(ParamInfo.Location);
- for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
+ for (const mcdxbc::DescriptorRange &Range : Table.Ranges) {
uint32_t UpperBound =
Range.NumDescriptors == ~0U
? Range.BaseShaderRegister
: Range.BaseShaderRegister + Range.NumDescriptors - 1;
- Builder.trackBinding(
- toResourceClass(
- static_cast<dxbc::DescriptorRangeType>(Range.RangeType)),
- Range.RegisterSpace, Range.BaseShaderRegister, UpperBound,
- &ParamInfo);
+ Builder.trackBinding(Range.RangeType, Range.RegisterSpace,
+ Range.BaseShaderRegister, UpperBound, &ParamInfo);
}
break;
}
@@ -218,11 +209,19 @@ static void validateRootSignature(Module &M,
Builder.calculateBindingInfo(
[&M](const llvm::hlsl::BindingInfoBuilder &Builder,
- const llvm::hlsl::BindingInfoBuilder::Binding &ReportedBinding) {
- const llvm::hlsl::BindingInfoBuilder::Binding &Overlaping =
+ const llvm::hlsl::Binding &ReportedBinding) {
+ const llvm::hlsl::Binding &Overlaping =
Builder.findOverlapping(ReportedBinding);
reportOverlappingRegisters(M, ReportedBinding, Overlaping);
});
+ const hlsl::BoundRegs &BoundRegs = Builder.takeBoundRegs();
+ for (const ResourceInfo &RI : DRM) {
+ const ResourceInfo::ResourceBinding &Binding = RI.getBinding();
+ ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass();
+ if (!BoundRegs.isBound(RC, Binding.Space, Binding.LowerBound,
+ Binding.LowerBound + Binding.Size - 1))
+ reportRegNotBound(M, RC, Binding);
+ }
}
static mcdxbc::RootSignatureDesc *
@@ -236,7 +235,8 @@ getRootSignature(RootSignatureBindingInfo &RSBI,
static void reportErrors(Module &M, DXILResourceMap &DRM,
DXILResourceBindingInfo &DRBI,
RootSignatureBindingInfo &RSBI,
- dxil::ModuleMetadataInfo &MMI) {
+ dxil::ModuleMetadataInfo &MMI,
+ DXILResourceTypeMap &DRTM) {
if (DRM.hasInvalidCounterDirection())
reportInvalidDirection(M, DRM);
@@ -247,7 +247,7 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
"DXILResourceImplicitBinding pass");
if (mcdxbc::RootSignatureDesc *RSD = getRootSignature(RSBI, MMI))
- validateRootSignature(M, *RSD, MMI);
+ validateRootSignature(M, *RSD, MMI, DRM, DRTM);
}
PreservedAnalyses
@@ -256,8 +256,9 @@ DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {
DXILResourceBindingInfo &DRBI = MAM.getResult<DXILResourceBindingAnalysis>(M);
RootSignatureBindingInfo &RSBI = MAM.getResult<RootSignatureAnalysis>(M);
ModuleMetadataInfo &MMI = MAM.getResult<DXILMetadataAnalysis>(M);
+ DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M);
- reportErrors(M, DRM, DRBI, RSBI, MMI);
+ reportErrors(M, DRM, DRBI, RSBI, MMI, DRTM);
return PreservedAnalyses::all();
}
@@ -273,8 +274,10 @@ public:
getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
dxil::ModuleMetadataInfo &MMI =
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
+ DXILResourceTypeMap &DRTM =
+ getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
- reportErrors(M, DRM, DRBI, RSBI, MMI);
+ reportErrors(M, DRM, DRBI, RSBI, MMI, DRTM);
return false;
}
StringRef getPassName() const override {
@@ -288,6 +291,7 @@ public:
AU.addRequired<DXILResourceBindingWrapperPass>();
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
AU.addRequired<RootSignatureAnalysisWrapper>();
+ AU.addRequired<DXILResourceTypeWrapperPass>();
AU.addPreserved<DXILResourceWrapperPass>();
AU.addPreserved<DXILResourceBindingWrapperPass>();
AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
@@ -305,6 +309,7 @@ INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
INITIALIZE_PASS_DEPENDENCY(RootSignatureAnalysisWrapper)
+INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
INITIALIZE_PASS_END(DXILPostOptimizationValidationLegacy, DEBUG_TYPE,
"DXIL Post Optimization Validation", false, false)
diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
index c33ec0efd73c..6579d3405cf3 100644
--- a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
+++ b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
@@ -8,14 +8,19 @@
#include "DXILResourceAccess.h"
#include "DirectX.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/IR/User.h"
#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/ValueMapper.h"
#define DEBUG_TYPE "dxil-resource-access"
@@ -198,6 +203,112 @@ static void createLoadIntrinsic(IntrinsicInst *II, LoadInst *LI, Value *Offset,
llvm_unreachable("Unhandled case in switch");
}
+static SmallVector<Instruction *> collectBlockUseDef(Instruction *Start) {
+ SmallPtrSet<Instruction *, 32> Visited;
+ SmallVector<Instruction *, 32> Worklist;
+ SmallVector<Instruction *> Out;
+ auto *BB = Start->getParent();
+
+ // Seed with direct users in this block.
+ for (User *U : Start->users()) {
+ if (auto *I = dyn_cast<Instruction>(U)) {
+ if (I->getParent() == BB)
+ Worklist.push_back(I);
+ }
+ }
+
+ // BFS over transitive users, constrained to the same block.
+ while (!Worklist.empty()) {
+ Instruction *I = Worklist.pop_back_val();
+ if (!Visited.insert(I).second)
+ continue;
+ Out.push_back(I);
+
+ for (User *U : I->users()) {
+ if (auto *J = dyn_cast<Instruction>(U)) {
+ if (J->getParent() == BB)
+ Worklist.push_back(J);
+ }
+ }
+ for (Use &V : I->operands()) {
+ if (auto *J = dyn_cast<Instruction>(V)) {
+ if (J->getParent() == BB && V != Start)
+ Worklist.push_back(J);
+ }
+ }
+ }
+
+ // Order results in program order.
+ DenseMap<const Instruction *, unsigned> Ord;
+ unsigned Idx = 0;
+ for (Instruction &I : *BB)
+ Ord[&I] = Idx++;
+
+ llvm::sort(Out, [&](Instruction *A, Instruction *B) {
+ return Ord.lookup(A) < Ord.lookup(B);
+ });
+
+ return Out;
+}
+
+static void phiNodeRemapHelper(PHINode *Phi, BasicBlock *BB,
+ IRBuilder<> &Builder,
+ SmallVector<Instruction *> &UsesInBlock) {
+
+ ValueToValueMapTy VMap;
+ Value *Val = Phi->getIncomingValueForBlock(BB);
+ VMap[Phi] = Val;
+ Builder.SetInsertPoint(&BB->back());
+ for (Instruction *I : UsesInBlock) {
+ // don't clone over the Phi just remap them
+ if (auto *PhiNested = dyn_cast<PHINode>(I)) {
+ VMap[PhiNested] = PhiNested->getIncomingValueForBlock(BB);
+ continue;
+ }
+ Instruction *Clone = I->clone();
+ RemapInstruction(Clone, VMap,
+ RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
+ Builder.Insert(Clone);
+ VMap[I] = Clone;
+ }
+}
+
+static void phiNodeReplacement(IntrinsicInst *II,
+ SmallVectorImpl<Instruction *> &PrevBBDeadInsts,
+ SetVector<BasicBlock *> &DeadBB) {
+ SmallVector<Instruction *> CurrBBDeadInsts;
+ for (User *U : II->users()) {
+ auto *Phi = dyn_cast<PHINode>(U);
+ if (!Phi)
+ continue;
+
+ IRBuilder<> Builder(Phi);
+ SmallVector<Instruction *> UsesInBlock = collectBlockUseDef(Phi);
+ bool HasReturnUse = isa<ReturnInst>(UsesInBlock.back());
+
+ for (unsigned I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
+ auto *CurrIncomingBB = Phi->getIncomingBlock(I);
+ phiNodeRemapHelper(Phi, CurrIncomingBB, Builder, UsesInBlock);
+ if (HasReturnUse)
+ PrevBBDeadInsts.push_back(&CurrIncomingBB->back());
+ }
+
+ CurrBBDeadInsts.push_back(Phi);
+
+ for (Instruction *I : UsesInBlock) {
+ CurrBBDeadInsts.push_back(I);
+ }
+ if (HasReturnUse) {
+ BasicBlock *PhiBB = Phi->getParent();
+ DeadBB.insert(PhiBB);
+ }
+ }
+ // Traverse the now-dead instructions in RPO and remove them.
+ for (Instruction *Dead : llvm::reverse(CurrBBDeadInsts))
+ Dead->eraseFromParent();
+ CurrBBDeadInsts.clear();
+}
+
static void replaceAccess(IntrinsicInst *II, dxil::ResourceTypeInfo &RTI) {
// Process users keeping track of indexing accumulated from GEPs.
struct AccessAndOffset {
@@ -229,7 +340,6 @@ static void replaceAccess(IntrinsicInst *II, dxil::ResourceTypeInfo &RTI) {
} else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
createLoadIntrinsic(II, LI, Current.Offset, RTI);
DeadInsts.push_back(LI);
-
} else
llvm_unreachable("Unhandled instruction - pointer escaped?");
}
@@ -242,13 +352,27 @@ static void replaceAccess(IntrinsicInst *II, dxil::ResourceTypeInfo &RTI) {
static bool transformResourcePointers(Function &F, DXILResourceTypeMap &DRTM) {
SmallVector<std::pair<IntrinsicInst *, dxil::ResourceTypeInfo>> Resources;
- for (BasicBlock &BB : F)
+ SetVector<BasicBlock *> DeadBB;
+ SmallVector<Instruction *> PrevBBDeadInsts;
+ for (BasicBlock &BB : make_early_inc_range(F)) {
+ for (Instruction &I : make_early_inc_range(BB))
+ if (auto *II = dyn_cast<IntrinsicInst>(&I))
+ if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer)
+ phiNodeReplacement(II, PrevBBDeadInsts, DeadBB);
+
for (Instruction &I : BB)
if (auto *II = dyn_cast<IntrinsicInst>(&I))
if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer) {
auto *HandleTy = cast<TargetExtType>(II->getArgOperand(0)->getType());
Resources.emplace_back(II, DRTM[HandleTy]);
}
+ }
+ for (auto *Dead : PrevBBDeadInsts)
+ Dead->eraseFromParent();
+ PrevBBDeadInsts.clear();
+ for (auto *Dead : DeadBB)
+ Dead->eraseFromParent();
+ DeadBB.clear();
for (auto &[II, RI] : Resources)
replaceAccess(II, RI);
@@ -279,7 +403,6 @@ public:
bool runOnFunction(Function &F) override {
DXILResourceTypeMap &DRTM =
getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
-
return transformResourcePointers(F, DRTM);
}
StringRef getPassName() const override { return "DXIL Resource Access"; }
diff --git a/llvm/lib/Target/DirectX/DXILResourceImplicitBinding.cpp b/llvm/lib/Target/DirectX/DXILResourceImplicitBinding.cpp
index 6e69c5ac1d63..b0d9ad8da10e 100644
--- a/llvm/lib/Target/DirectX/DXILResourceImplicitBinding.cpp
+++ b/llvm/lib/Target/DirectX/DXILResourceImplicitBinding.cpp
@@ -111,8 +111,7 @@ static bool assignBindings(Module &M, DXILResourceBindingInfo &DRBI,
RegSlotOp, /* register slot */
IB.Call->getOperand(2), /* size */
IB.Call->getOperand(3), /* index */
- IB.Call->getOperand(4), /* non-uniform flag */
- IB.Call->getOperand(5)}); /* name */
+ IB.Call->getOperand(4)}); /* name */
IB.Call->replaceAllUsesWith(NewCall);
IB.Call->eraseFromParent();
Changed = true;
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index a4f5086c2f42..ac3c7dde6b89 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -24,9 +24,11 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
+#include "llvm/MC/DXContainerRootSignature.h"
#include "llvm/Pass.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
@@ -70,6 +72,13 @@ analyzeModule(Module &M) {
if (RootSignatureNode == nullptr)
return RSDMap;
+ bool AllowNullFunctions = false;
+ if (M.getTargetTriple().getEnvironment() ==
+ Triple::EnvironmentType::RootSignature) {
+ assert(RootSignatureNode->getNumOperands() == 1);
+ AllowNullFunctions = true;
+ }
+
for (const auto &RSDefNode : RootSignatureNode->operands()) {
if (RSDefNode->getNumOperands() != 3) {
reportError(Ctx, "Invalid Root Signature metadata - expected function, "
@@ -78,24 +87,28 @@ analyzeModule(Module &M) {
}
// Function was pruned during compilation.
- const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0);
- if (FunctionPointerMdNode == nullptr) {
- reportError(
- Ctx, "Function associated with Root Signature definition is null.");
- continue;
- }
+ Function *F = nullptr;
+
+ if (!AllowNullFunctions) {
+ const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0);
+ if (FunctionPointerMdNode == nullptr) {
+ reportError(
+ Ctx, "Function associated with Root Signature definition is null.");
+ continue;
+ }
- ValueAsMetadata *VAM =
- llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get());
- if (VAM == nullptr) {
- reportError(Ctx, "First element of root signature is not a Value");
- continue;
- }
+ ValueAsMetadata *VAM =
+ llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get());
+ if (VAM == nullptr) {
+ reportError(Ctx, "First element of root signature is not a Value");
+ continue;
+ }
- Function *F = dyn_cast<Function>(VAM->getValue());
- if (F == nullptr) {
- reportError(Ctx, "First element of root signature is not a Function");
- continue;
+ F = dyn_cast<Function>(VAM->getValue());
+ if (F == nullptr) {
+ reportError(Ctx, "First element of root signature is not a Function");
+ continue;
+ }
}
Metadata *RootElementListOperand = RSDefNode->getOperand(1).get();
@@ -171,41 +184,41 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
<< "RootParametersOffset: " << RS.RootParameterOffset << "\n"
<< "NumParameters: " << RS.ParametersContainer.size() << "\n";
for (size_t I = 0; I < RS.ParametersContainer.size(); I++) {
- const auto &[Type, Loc] =
- RS.ParametersContainer.getTypeAndLocForParameter(I);
- const dxbc::RTS0::v1::RootParameterHeader Header =
- RS.ParametersContainer.getHeader(I);
-
- OS << "- Parameter Type: " << Type << "\n"
- << " Shader Visibility: " << Header.ShaderVisibility << "\n";
-
- switch (Type) {
- case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
- const dxbc::RTS0::v1::RootConstants &Constants =
- RS.ParametersContainer.getConstant(Loc);
+ const mcdxbc::RootParameterInfo &Info = RS.ParametersContainer.getInfo(I);
+
+ OS << "- Parameter Type: "
+ << enumToStringRef(Info.Type, dxbc::getRootParameterTypes()) << "\n"
+ << " Shader Visibility: "
+ << enumToStringRef(Info.Visibility, dxbc::getShaderVisibility())
+ << "\n";
+ switch (Info.Type) {
+ case dxbc::RootParameterType::Constants32Bit: {
+ const mcdxbc::RootConstants &Constants =
+ RS.ParametersContainer.getConstant(Info.Location);
OS << " Register Space: " << Constants.RegisterSpace << "\n"
<< " Shader Register: " << Constants.ShaderRegister << "\n"
<< " Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::CBV):
- case llvm::to_underlying(dxbc::RootParameterType::UAV):
- case llvm::to_underlying(dxbc::RootParameterType::SRV): {
- const dxbc::RTS0::v2::RootDescriptor &Descriptor =
- RS.ParametersContainer.getRootDescriptor(Loc);
+ case dxbc::RootParameterType::CBV:
+ case dxbc::RootParameterType::UAV:
+ case dxbc::RootParameterType::SRV: {
+ const mcdxbc::RootDescriptor &Descriptor =
+ RS.ParametersContainer.getRootDescriptor(Info.Location);
OS << " Register Space: " << Descriptor.RegisterSpace << "\n"
<< " Shader Register: " << Descriptor.ShaderRegister << "\n";
if (RS.Version > 1)
OS << " Flags: " << Descriptor.Flags << "\n";
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+ case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
- RS.ParametersContainer.getDescriptorTable(Loc);
+ RS.ParametersContainer.getDescriptorTable(Info.Location);
OS << " NumRanges: " << Table.Ranges.size() << "\n";
- for (const dxbc::RTS0::v2::DescriptorRange Range : Table) {
- OS << " - Range Type: " << Range.RangeType << "\n"
+ for (const mcdxbc::DescriptorRange &Range : Table) {
+ OS << " - Range Type: "
+ << dxil::getResourceClassName(Range.RangeType) << "\n"
<< " Register Space: " << Range.RegisterSpace << "\n"
<< " Base Shader Register: " << Range.BaseShaderRegister << "\n"
<< " Num Descriptors: " << Range.NumDescriptors << "\n"
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 82bcacee7a6d..9eebcc9b1306 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -127,6 +127,8 @@ static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
return "ms";
case Triple::Amplification:
return "as";
+ case Triple::RootSignature:
+ return "rootsig";
default:
break;
}
diff --git a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
index 1d79c3018439..bc1a3a7995bd 100644
--- a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
+++ b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.cpp
@@ -2113,7 +2113,7 @@ void DXILBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
}
break;
case Instruction::GetElementPtr: {
- Code = bitc::CST_CODE_CE_GEP;
+ Code = bitc::CST_CODE_CE_GEP_OLD;
const auto *GO = cast<GEPOperator>(C);
if (GO->isInBounds())
Code = bitc::CST_CODE_CE_INBOUNDS_GEP;
diff --git a/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp b/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp
index f99bb4f4eaee..c2e139edc6bd 100644
--- a/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp
+++ b/llvm/lib/Target/DirectX/DirectXIRPasses/PointerTypeAnalysis.cpp
@@ -15,25 +15,39 @@
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
using namespace llvm;
using namespace llvm::dxil;
namespace {
+Type *classifyFunctionType(const Function &F, PointerTypeMap &Map);
+
// Classifies the type of the value passed in by walking the value's users to
// find a typed instruction to materialize a type from.
Type *classifyPointerType(const Value *V, PointerTypeMap &Map) {
assert(V->getType()->isPointerTy() &&
"classifyPointerType called with non-pointer");
+
+ // A CallInst will trigger this case, and we want to classify its Function
+ // operand as a Function rather than a generic Value.
+ if (const Function *F = dyn_cast<Function>(V))
+ return classifyFunctionType(*F, Map);
+
+ // There can potentially be dead constants hanging off of the globals we do
+ // not want to deal with. So we remove them here.
+ if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
+ GV->removeDeadConstantUsers();
+
auto It = Map.find(V);
if (It != Map.end())
return It->second;
Type *PointeeTy = nullptr;
- if (auto *Inst = dyn_cast<GetElementPtrInst>(V)) {
- if (!Inst->getResultElementType()->isPointerTy())
- PointeeTy = Inst->getResultElementType();
+ if (auto *GEP = dyn_cast<GEPOperator>(V)) {
+ if (!GEP->getResultElementType()->isPointerTy())
+ PointeeTy = GEP->getResultElementType();
} else if (auto *Inst = dyn_cast<AllocaInst>(V)) {
PointeeTy = Inst->getAllocatedType();
} else if (auto *GV = dyn_cast<GlobalVariable>(V)) {
@@ -49,8 +63,8 @@ Type *classifyPointerType(const Value *V, PointerTypeMap &Map) {
// When store value is ptr type, cannot get more type info.
if (NewPointeeTy->isPointerTy())
continue;
- } else if (const auto *Inst = dyn_cast<GetElementPtrInst>(User)) {
- NewPointeeTy = Inst->getSourceElementType();
+ } else if (const auto *GEP = dyn_cast<GEPOperator>(User)) {
+ NewPointeeTy = GEP->getSourceElementType();
}
if (NewPointeeTy) {
// HLSL doesn't support pointers, so it is unlikely to get more than one
@@ -204,6 +218,9 @@ PointerTypeMap PointerTypeAnalysis::run(const Module &M) {
for (const auto &I : B) {
if (I.getType()->isPointerTy())
classifyPointerType(&I, Map);
+ for (const auto &O : I.operands())
+ if (O.get()->getType()->isPointerTy())
+ classifyPointerType(O.get(), Map);
}
}
}
diff --git a/llvm/lib/Target/DirectX/DirectXInstrInfo.cpp b/llvm/lib/Target/DirectX/DirectXInstrInfo.cpp
index 07b68648f16c..bb2efa43d818 100644
--- a/llvm/lib/Target/DirectX/DirectXInstrInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXInstrInfo.cpp
@@ -11,10 +11,14 @@
//===----------------------------------------------------------------------===//
#include "DirectXInstrInfo.h"
+#include "DirectXSubtarget.h"
#define GET_INSTRINFO_CTOR_DTOR
#include "DirectXGenInstrInfo.inc"
using namespace llvm;
+DirectXInstrInfo::DirectXInstrInfo(const DirectXSubtarget &STI)
+ : DirectXGenInstrInfo(STI) {}
+
DirectXInstrInfo::~DirectXInstrInfo() {}
diff --git a/llvm/lib/Target/DirectX/DirectXInstrInfo.h b/llvm/lib/Target/DirectX/DirectXInstrInfo.h
index e2c7036fc74a..57ede28030b2 100644
--- a/llvm/lib/Target/DirectX/DirectXInstrInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXInstrInfo.h
@@ -20,9 +20,11 @@
#include "DirectXGenInstrInfo.inc"
namespace llvm {
+class DirectXSubtarget;
+
struct DirectXInstrInfo : public DirectXGenInstrInfo {
const DirectXRegisterInfo RI;
- explicit DirectXInstrInfo() : DirectXGenInstrInfo() {}
+ explicit DirectXInstrInfo(const DirectXSubtarget &STI);
const DirectXRegisterInfo &getRegisterInfo() const { return RI; }
~DirectXInstrInfo() override;
};
diff --git a/llvm/lib/Target/DirectX/DirectXSubtarget.cpp b/llvm/lib/Target/DirectX/DirectXSubtarget.cpp
index 526b7d29fb13..f8519177cc2d 100644
--- a/llvm/lib/Target/DirectX/DirectXSubtarget.cpp
+++ b/llvm/lib/Target/DirectX/DirectXSubtarget.cpp
@@ -24,6 +24,7 @@ using namespace llvm;
DirectXSubtarget::DirectXSubtarget(const Triple &TT, StringRef CPU,
StringRef FS, const DirectXTargetMachine &TM)
- : DirectXGenSubtargetInfo(TT, CPU, CPU, FS), FL(*this), TL(TM, *this) {}
+ : DirectXGenSubtargetInfo(TT, CPU, CPU, FS), InstrInfo(*this), FL(*this),
+ TL(TM, *this) {}
void DirectXSubtarget::anchor() {}
diff --git a/llvm/lib/Target/DirectX/DirectXSubtarget.h b/llvm/lib/Target/DirectX/DirectXSubtarget.h
index b2374caaf3cd..f3d71c4c4e3b 100644
--- a/llvm/lib/Target/DirectX/DirectXSubtarget.h
+++ b/llvm/lib/Target/DirectX/DirectXSubtarget.h
@@ -28,9 +28,9 @@ namespace llvm {
class DirectXTargetMachine;
class DirectXSubtarget : public DirectXGenSubtargetInfo {
+ DirectXInstrInfo InstrInfo;
DirectXFrameLowering FL;
DirectXTargetLowering TL;
- DirectXInstrInfo InstrInfo;
virtual void anchor(); // virtual anchor method