summaryrefslogtreecommitdiff
path: root/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp')
-rw-r--r--llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp173
1 files changed, 75 insertions, 98 deletions
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index a5a92cbd2d61..f29f2c7602fc 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -52,13 +52,15 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
return NodeText->getString();
}
-static Expected<dxbc::ShaderVisibility>
-extractShaderVisibility(MDNode *Node, unsigned int OpId) {
+template <typename T, typename = std::enable_if_t<
+ std::is_enum_v<T> &&
+ std::is_same_v<std::underlying_type_t<T>, uint32_t>>>
+Expected<T> extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText,
+ llvm::function_ref<bool(uint32_t)> VerifyFn) {
if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
- if (!dxbc::isValidShaderVisibility(*Val))
- return make_error<RootSignatureValidationError<uint32_t>>(
- "ShaderVisibility", *Val);
- return dxbc::ShaderVisibility(*Val);
+ if (!VerifyFn(*Val))
+ return make_error<RootSignatureValidationError<uint32_t>>(ErrText, *Val);
+ return static_cast<T>(*Val);
}
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
}
@@ -233,7 +235,9 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
return make_error<InvalidRSMetadataFormat>("RootConstants Element");
Expected<dxbc::ShaderVisibility> Visibility =
- extractShaderVisibility(RootConstantNode, 1);
+ extractEnumValue<dxbc::ShaderVisibility>(RootConstantNode, 1,
+ "ShaderVisibility",
+ dxbc::isValidShaderVisibility);
if (auto E = Visibility.takeError())
return Error(std::move(E));
@@ -287,7 +291,9 @@ Error MetadataParser::parseRootDescriptors(
}
Expected<dxbc::ShaderVisibility> Visibility =
- extractShaderVisibility(RootDescriptorNode, 1);
+ extractEnumValue<dxbc::ShaderVisibility>(RootDescriptorNode, 1,
+ "ShaderVisibility",
+ dxbc::isValidShaderVisibility);
if (auto E = Visibility.takeError())
return Error(std::move(E));
@@ -322,7 +328,7 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
if (RangeDescriptorNode->getNumOperands() != 6)
return make_error<InvalidRSMetadataFormat>("Descriptor Range");
- dxbc::RTS0::v2::DescriptorRange Range;
+ mcdxbc::DescriptorRange Range;
std::optional<StringRef> ElementText =
extractMdStringValue(RangeDescriptorNode, 0);
@@ -330,15 +336,15 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
if (!ElementText.has_value())
return make_error<InvalidRSMetadataFormat>("Descriptor Range");
- Range.RangeType =
- StringSwitch<uint32_t>(*ElementText)
- .Case("CBV", to_underlying(dxbc::DescriptorRangeType::CBV))
- .Case("SRV", to_underlying(dxbc::DescriptorRangeType::SRV))
- .Case("UAV", to_underlying(dxbc::DescriptorRangeType::UAV))
- .Case("Sampler", to_underlying(dxbc::DescriptorRangeType::Sampler))
- .Default(~0U);
-
- if (Range.RangeType == ~0U)
+ if (*ElementText == "CBV")
+ Range.RangeType = dxil::ResourceClass::CBuffer;
+ else if (*ElementText == "SRV")
+ Range.RangeType = dxil::ResourceClass::SRV;
+ else if (*ElementText == "UAV")
+ Range.RangeType = dxil::ResourceClass::UAV;
+ else if (*ElementText == "Sampler")
+ Range.RangeType = dxil::ResourceClass::Sampler;
+ else
return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.",
RangeDescriptorNode);
@@ -380,7 +386,9 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
return make_error<InvalidRSMetadataFormat>("Descriptor Table");
Expected<dxbc::ShaderVisibility> Visibility =
- extractShaderVisibility(DescriptorTableNode, 1);
+ extractEnumValue<dxbc::ShaderVisibility>(DescriptorTableNode, 1,
+ "ShaderVisibility",
+ dxbc::isValidShaderVisibility);
if (auto E = Visibility.takeError())
return Error(std::move(E));
@@ -406,26 +414,34 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
if (StaticSamplerNode->getNumOperands() != 14)
return make_error<InvalidRSMetadataFormat>("Static Sampler");
- dxbc::RTS0::v1::StaticSampler Sampler;
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
- Sampler.Filter = *Val;
- else
- return make_error<InvalidRSMetadataValue>("Filter");
+ mcdxbc::StaticSampler Sampler;
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
- Sampler.AddressU = *Val;
- else
- return make_error<InvalidRSMetadataValue>("AddressU");
+ Expected<dxbc::SamplerFilter> Filter = extractEnumValue<dxbc::SamplerFilter>(
+ StaticSamplerNode, 1, "Filter", dxbc::isValidSamplerFilter);
+ if (auto E = Filter.takeError())
+ return Error(std::move(E));
+ Sampler.Filter = *Filter;
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
- Sampler.AddressV = *Val;
- else
- return make_error<InvalidRSMetadataValue>("AddressV");
+ Expected<dxbc::TextureAddressMode> AddressU =
+ extractEnumValue<dxbc::TextureAddressMode>(
+ StaticSamplerNode, 2, "AddressU", dxbc::isValidAddress);
+ if (auto E = AddressU.takeError())
+ return Error(std::move(E));
+ Sampler.AddressU = *AddressU;
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
- Sampler.AddressW = *Val;
- else
- return make_error<InvalidRSMetadataValue>("AddressW");
+ Expected<dxbc::TextureAddressMode> AddressV =
+ extractEnumValue<dxbc::TextureAddressMode>(
+ StaticSamplerNode, 3, "AddressV", dxbc::isValidAddress);
+ if (auto E = AddressV.takeError())
+ return Error(std::move(E));
+ Sampler.AddressV = *AddressV;
+
+ Expected<dxbc::TextureAddressMode> AddressW =
+ extractEnumValue<dxbc::TextureAddressMode>(
+ StaticSamplerNode, 4, "AddressW", dxbc::isValidAddress);
+ if (auto E = AddressW.takeError())
+ return Error(std::move(E));
+ Sampler.AddressW = *AddressW;
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = *Val;
@@ -437,15 +453,19 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
else
return make_error<InvalidRSMetadataValue>("MaxAnisotropy");
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
- Sampler.ComparisonFunc = *Val;
- else
- return make_error<InvalidRSMetadataValue>("ComparisonFunc");
+ Expected<dxbc::ComparisonFunc> ComparisonFunc =
+ extractEnumValue<dxbc::ComparisonFunc>(
+ StaticSamplerNode, 7, "ComparisonFunc", dxbc::isValidComparisonFunc);
+ if (auto E = ComparisonFunc.takeError())
+ return Error(std::move(E));
+ Sampler.ComparisonFunc = *ComparisonFunc;
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
- Sampler.BorderColor = *Val;
- else
- return make_error<InvalidRSMetadataValue>("ComparisonFunc");
+ Expected<dxbc::StaticBorderColor> BorderColor =
+ extractEnumValue<dxbc::StaticBorderColor>(
+ StaticSamplerNode, 8, "BorderColor", dxbc::isValidBorderColor);
+ if (auto E = BorderColor.takeError())
+ return Error(std::move(E));
+ Sampler.BorderColor = *BorderColor;
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = *Val;
@@ -467,10 +487,13 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
else
return make_error<InvalidRSMetadataValue>("RegisterSpace");
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
- Sampler.ShaderVisibility = *Val;
- else
- return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+ Expected<dxbc::ShaderVisibility> Visibility =
+ extractEnumValue<dxbc::ShaderVisibility>(StaticSamplerNode, 13,
+ "ShaderVisibility",
+ dxbc::isValidShaderVisibility);
+ if (auto E = Visibility.takeError())
+ return Error(std::move(E));
+ Sampler.ShaderVisibility = *Visibility;
RSD.StaticSamplers.push_back(Sampler);
return Error::success();
@@ -568,13 +591,7 @@ Error MetadataParser::validateRootSignature(
case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
- for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
- if (!hlsl::rootsig::verifyRangeType(Range.RangeType))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "RangeType", Range.RangeType));
-
+ for (const mcdxbc::DescriptorRange &Range : Table) {
if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
@@ -588,7 +605,8 @@ Error MetadataParser::validateRootSignature(
"NumDescriptors", Range.NumDescriptors));
if (!hlsl::rootsig::verifyDescriptorRangeFlag(
- RSD.Version, Range.RangeType, Range.Flags))
+ RSD.Version, Range.RangeType,
+ dxbc::DescriptorRangeFlags(Range.Flags)))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
@@ -599,30 +617,7 @@ Error MetadataParser::validateRootSignature(
}
}
- for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
- if (!hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "Filter", Sampler.Filter));
-
- if (!hlsl::rootsig::verifyAddress(Sampler.AddressU))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "AddressU", Sampler.AddressU));
-
- if (!hlsl::rootsig::verifyAddress(Sampler.AddressV))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "AddressV", Sampler.AddressV));
-
- if (!hlsl::rootsig::verifyAddress(Sampler.AddressW))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "AddressW", Sampler.AddressW));
+ for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers) {
if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
DeferredErrs = joinErrors(std::move(DeferredErrs),
@@ -635,18 +630,6 @@ Error MetadataParser::validateRootSignature(
make_error<RootSignatureValidationError<uint32_t>>(
"MaxAnisotropy", Sampler.MaxAnisotropy));
- if (!hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "ComparisonFunc", Sampler.ComparisonFunc));
-
- if (!hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "BorderColor", Sampler.BorderColor));
-
if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD))
DeferredErrs = joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<float>>(
@@ -668,12 +651,6 @@ Error MetadataParser::validateRootSignature(
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
"RegisterSpace", Sampler.RegisterSpace));
-
- if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
- DeferredErrs =
- joinErrors(std::move(DeferredErrs),
- make_error<RootSignatureValidationError<uint32_t>>(
- "ShaderVisibility", Sampler.ShaderVisibility));
}
return DeferredErrs;