summaryrefslogtreecommitdiff
path: root/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp')
-rw-r--r--llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp118
1 files changed, 56 insertions, 62 deletions
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index dece8f197aaf..31605e390034 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -52,6 +52,17 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
return NodeText->getString();
}
+static Expected<dxbc::ShaderVisibility>
+extractShaderVisibility(MDNode *Node, unsigned int OpId) {
+ if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
+ if (!dxbc::isValidShaderVisibility(*Val))
+ return make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderVisibility", *Val);
+ return dxbc::ShaderVisibility(*Val);
+ }
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+}
+
namespace {
// We use the OverloadVisit with std::visit to ensure the compiler catches if a
@@ -221,17 +232,12 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
if (RootConstantNode->getNumOperands() != 5)
return make_error<InvalidRSMetadataFormat>("RootConstants Element");
- dxbc::RTS0::v1::RootParameterHeader Header;
- // The parameter offset doesn't matter here - we recalculate it during
- // serialization Header.ParameterOffset = 0;
- Header.ParameterType = to_underlying(dxbc::RootParameterType::Constants32Bit);
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+ Expected<dxbc::ShaderVisibility> Visibility =
+ extractShaderVisibility(RootConstantNode, 1);
+ if (auto E = Visibility.takeError())
+ return Error(std::move(E));
- dxbc::RTS0::v1::RootConstants Constants;
+ mcdxbc::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
@@ -247,7 +253,8 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
else
return make_error<InvalidRSMetadataValue>("Num32BitValues");
- RSD.ParametersContainer.addParameter(Header, Constants);
+ RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit,
+ *Visibility, Constants);
return Error::success();
}
@@ -263,28 +270,28 @@ Error MetadataParser::parseRootDescriptors(
if (RootDescriptorNode->getNumOperands() != 5)
return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
- dxbc::RTS0::v1::RootParameterHeader Header;
+ dxbc::RootParameterType Type;
switch (ElementKind) {
case RootSignatureElementKind::SRV:
- Header.ParameterType = to_underlying(dxbc::RootParameterType::SRV);
+ Type = dxbc::RootParameterType::SRV;
break;
case RootSignatureElementKind::UAV:
- Header.ParameterType = to_underlying(dxbc::RootParameterType::UAV);
+ Type = dxbc::RootParameterType::UAV;
break;
case RootSignatureElementKind::CBV:
- Header.ParameterType = to_underlying(dxbc::RootParameterType::CBV);
+ Type = dxbc::RootParameterType::CBV;
break;
default:
llvm_unreachable("invalid Root Descriptor kind");
break;
}
- if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+ Expected<dxbc::ShaderVisibility> Visibility =
+ extractShaderVisibility(RootDescriptorNode, 1);
+ if (auto E = Visibility.takeError())
+ return Error(std::move(E));
- dxbc::RTS0::v2::RootDescriptor Descriptor;
+ mcdxbc::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
@@ -296,7 +303,7 @@ Error MetadataParser::parseRootDescriptors(
return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (RSD.Version == 1) {
- RSD.ParametersContainer.addParameter(Header, Descriptor);
+ RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
return Error::success();
}
assert(RSD.Version > 1);
@@ -306,7 +313,7 @@ Error MetadataParser::parseRootDescriptors(
else
return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
- RSD.ParametersContainer.addParameter(Header, Descriptor);
+ RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
return Error::success();
}
@@ -315,7 +322,7 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
if (RangeDescriptorNode->getNumOperands() != 6)
return make_error<InvalidRSMetadataFormat>("Descriptor Range");
- dxbc::RTS0::v2::DescriptorRange Range;
+ mcdxbc::DescriptorRange Range;
std::optional<StringRef> ElementText =
extractMdStringValue(RangeDescriptorNode, 0);
@@ -323,15 +330,15 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
if (!ElementText.has_value())
return make_error<InvalidRSMetadataFormat>("Descriptor Range");
- Range.RangeType =
- StringSwitch<uint32_t>(*ElementText)
- .Case("CBV", to_underlying(dxbc::DescriptorRangeType::CBV))
- .Case("SRV", to_underlying(dxbc::DescriptorRangeType::SRV))
- .Case("UAV", to_underlying(dxbc::DescriptorRangeType::UAV))
- .Case("Sampler", to_underlying(dxbc::DescriptorRangeType::Sampler))
- .Default(~0U);
-
- if (Range.RangeType == ~0U)
+ if (*ElementText == "CBV")
+ Range.RangeType = dxil::ResourceClass::CBuffer;
+ else if (*ElementText == "SRV")
+ Range.RangeType = dxil::ResourceClass::SRV;
+ else if (*ElementText == "UAV")
+ Range.RangeType = dxil::ResourceClass::UAV;
+ else if (*ElementText == "Sampler")
+ Range.RangeType = dxil::ResourceClass::Sampler;
+ else
return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.",
RangeDescriptorNode);
@@ -372,15 +379,12 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
if (NumOperands < 2)
return make_error<InvalidRSMetadataFormat>("Descriptor Table");
- dxbc::RTS0::v1::RootParameterHeader Header;
- if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+ Expected<dxbc::ShaderVisibility> Visibility =
+ extractShaderVisibility(DescriptorTableNode, 1);
+ if (auto E = Visibility.takeError())
+ return Error(std::move(E));
mcdxbc::DescriptorTable Table;
- Header.ParameterType =
- to_underlying(dxbc::RootParameterType::DescriptorTable);
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
@@ -392,7 +396,8 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
return Err;
}
- RSD.ParametersContainer.addParameter(Header, Table);
+ RSD.ParametersContainer.addParameter(dxbc::RootParameterType::DescriptorTable,
+ *Visibility, Table);
return Error::success();
}
@@ -528,21 +533,15 @@ Error MetadataParser::validateRootSignature(
}
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
- if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "ShaderVisibility", Info.Header.ShaderVisibility));
-
- assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
- "Invalid value for ParameterType");
- switch (Info.Header.ParameterType) {
+ switch (Info.Type) {
+ case dxbc::RootParameterType::Constants32Bit:
+ break;
- case to_underlying(dxbc::RootParameterType::CBV):
- case to_underlying(dxbc::RootParameterType::UAV):
- case to_underlying(dxbc::RootParameterType::SRV): {
- const dxbc::RTS0::v2::RootDescriptor &Descriptor =
+ case dxbc::RootParameterType::CBV:
+ case dxbc::RootParameterType::UAV:
+ case dxbc::RootParameterType::SRV: {
+ const mcdxbc::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
DeferredErrs =
@@ -566,16 +565,10 @@ Error MetadataParser::validateRootSignature(
}
break;
}
- case to_underlying(dxbc::RootParameterType::DescriptorTable): {
+ case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
- for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
- if (!hlsl::rootsig::verifyRangeType(Range.RangeType))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "RangeType", Range.RangeType));
-
+ for (const mcdxbc::DescriptorRange &Range : Table) {
if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
@@ -589,7 +582,8 @@ Error MetadataParser::validateRootSignature(
"NumDescriptors", Range.NumDescriptors));
if (!hlsl::rootsig::verifyDescriptorRangeFlag(
- RSD.Version, Range.RangeType, Range.Flags))
+ RSD.Version, Range.RangeType,
+ dxbc::DescriptorRangeFlags(Range.Flags)))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(