summaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX/DXILRootSignature.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/DirectX/DXILRootSignature.cpp')
-rw-r--r--llvm/lib/Target/DirectX/DXILRootSignature.cpp132
1 files changed, 126 insertions, 6 deletions
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 40277899b5c7..3d195acb19e1 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -55,6 +55,14 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
return std::nullopt;
}
+static std::optional<StringRef> extractMdStringValue(MDNode *Node,
+ unsigned int OpId) {
+ MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
+ if (NodeText == nullptr)
+ return std::nullopt;
+ return NodeText->getString();
+}
+
static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
MDNode *RootFlagNode) {
@@ -107,17 +115,79 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
return false;
}
+static bool parseRootDescriptors(LLVMContext *Ctx,
+ mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind) {
+ assert(ElementKind == RootSignatureElementKind::SRV ||
+ ElementKind == RootSignatureElementKind::UAV ||
+ ElementKind == RootSignatureElementKind::CBV &&
+ "parseRootDescriptors should only be called with RootDescriptor "
+ "element kind.");
+ if (RootDescriptorNode->getNumOperands() != 5)
+ return reportError(Ctx, "Invalid format for Root Descriptor Element");
+
+ dxbc::RTS0::v1::RootParameterHeader Header;
+ switch (ElementKind) {
+ case RootSignatureElementKind::SRV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
+ break;
+ case RootSignatureElementKind::UAV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
+ break;
+ case RootSignatureElementKind::CBV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
+ break;
+ default:
+ llvm_unreachable("invalid Root Descriptor kind");
+ break;
+ }
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
+ Header.ShaderVisibility = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ShaderVisibility");
+
+ dxbc::RTS0::v2::RootDescriptor Descriptor;
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
+ Descriptor.ShaderRegister = *Val;
+ else
+ return reportError(Ctx, "Invalid value for ShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
+ Descriptor.RegisterSpace = *Val;
+ else
+ return reportError(Ctx, "Invalid value for RegisterSpace");
+
+ if (RSD.Version == 1) {
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return false;
+ }
+ assert(RSD.Version > 1);
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
+ Descriptor.Flags = *Val;
+ else
+ return reportError(Ctx, "Invalid value for Root Descriptor Flags");
+
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return false;
+}
+
static bool parseRootSignatureElement(LLVMContext *Ctx,
mcdxbc::RootSignatureDesc &RSD,
MDNode *Element) {
- MDString *ElementText = cast<MDString>(Element->getOperand(0));
- if (ElementText == nullptr)
+ std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
+ if (!ElementText.has_value())
return reportError(Ctx, "Invalid format for Root Element");
RootSignatureElementKind ElementKind =
- StringSwitch<RootSignatureElementKind>(ElementText->getString())
+ StringSwitch<RootSignatureElementKind>(*ElementText)
.Case("RootFlags", RootSignatureElementKind::RootFlags)
.Case("RootConstants", RootSignatureElementKind::RootConstants)
+ .Case("RootCBV", RootSignatureElementKind::CBV)
+ .Case("RootSRV", RootSignatureElementKind::SRV)
+ .Case("RootUAV", RootSignatureElementKind::UAV)
.Default(RootSignatureElementKind::Error);
switch (ElementKind) {
@@ -126,10 +196,12 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
return parseRootFlags(Ctx, RSD, Element);
case RootSignatureElementKind::RootConstants:
return parseRootConstants(Ctx, RSD, Element);
- break;
+ case RootSignatureElementKind::CBV:
+ case RootSignatureElementKind::SRV:
+ case RootSignatureElementKind::UAV:
+ return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
case RootSignatureElementKind::Error:
- return reportError(Ctx, "Invalid Root Signature Element: " +
- ElementText->getString());
+ return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
}
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
@@ -157,6 +229,18 @@ static bool verifyVersion(uint32_t Version) {
return (Version == 1 || Version == 2);
}
+static bool verifyRegisterValue(uint32_t RegisterValue) {
+ return RegisterValue != ~0U;
+}
+
+// This Range is reserverved, therefore invalid, according to the spec
+// https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
+static bool verifyRegisterSpace(uint32_t RegisterSpace) {
+ return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF);
+}
+
+static bool verifyDescriptorFlag(uint32_t Flags) { return (Flags & ~0xE) == 0; }
+
static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
if (!verifyVersion(RSD.Version)) {
@@ -174,6 +258,28 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");
+
+ switch (Info.Header.ParameterType) {
+
+ case llvm::to_underlying(dxbc::RootParameterType::CBV):
+ case llvm::to_underlying(dxbc::RootParameterType::UAV):
+ case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
+ RSD.ParametersContainer.getRootDescriptor(Info.Location);
+ if (!verifyRegisterValue(Descriptor.ShaderRegister))
+ return reportValueError(Ctx, "ShaderRegister",
+ Descriptor.ShaderRegister);
+
+ if (!verifyRegisterSpace(Descriptor.RegisterSpace))
+ return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
+
+ if (RSD.Version > 1) {
+ if (!verifyDescriptorFlag(Descriptor.Flags))
+ return reportValueError(Ctx, "DescriptorFlag", Descriptor.Flags);
+ }
+ break;
+ }
+ }
}
return false;
@@ -313,6 +419,20 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
<< "Shader Register: " << Constants.ShaderRegister << "\n";
OS << indent(Space + 2)
<< "Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
+ break;
+ }
+ case llvm::to_underlying(dxbc::RootParameterType::CBV):
+ case llvm::to_underlying(dxbc::RootParameterType::UAV):
+ case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
+ RS.ParametersContainer.getRootDescriptor(Loc);
+ OS << indent(Space + 2)
+ << "Register Space: " << Descriptor.RegisterSpace << "\n";
+ OS << indent(Space + 2)
+ << "Shader Register: " << Descriptor.ShaderRegister << "\n";
+ if (RS.Version > 1)
+ OS << indent(Space + 2) << "Flags: " << Descriptor.Flags << "\n";
+ break;
}
}
Space--;