diff options
Diffstat (limited to 'llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp')
| -rw-r--r-- | llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 173 |
1 files changed, 75 insertions, 98 deletions
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index a5a92cbd2d61..f29f2c7602fc 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -52,13 +52,15 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node, return NodeText->getString(); } -static Expected<dxbc::ShaderVisibility> -extractShaderVisibility(MDNode *Node, unsigned int OpId) { +template <typename T, typename = std::enable_if_t< + std::is_enum_v<T> && + std::is_same_v<std::underlying_type_t<T>, uint32_t>>> +Expected<T> extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText, + llvm::function_ref<bool(uint32_t)> VerifyFn) { if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) { - if (!dxbc::isValidShaderVisibility(*Val)) - return make_error<RootSignatureValidationError<uint32_t>>( - "ShaderVisibility", *Val); - return dxbc::ShaderVisibility(*Val); + if (!VerifyFn(*Val)) + return make_error<RootSignatureValidationError<uint32_t>>(ErrText, *Val); + return static_cast<T>(*Val); } return make_error<InvalidRSMetadataValue>("ShaderVisibility"); } @@ -233,7 +235,9 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD, return make_error<InvalidRSMetadataFormat>("RootConstants Element"); Expected<dxbc::ShaderVisibility> Visibility = - extractShaderVisibility(RootConstantNode, 1); + extractEnumValue<dxbc::ShaderVisibility>(RootConstantNode, 1, + "ShaderVisibility", + dxbc::isValidShaderVisibility); if (auto E = Visibility.takeError()) return Error(std::move(E)); @@ -287,7 +291,9 @@ Error MetadataParser::parseRootDescriptors( } Expected<dxbc::ShaderVisibility> Visibility = - extractShaderVisibility(RootDescriptorNode, 1); + extractEnumValue<dxbc::ShaderVisibility>(RootDescriptorNode, 1, + "ShaderVisibility", + dxbc::isValidShaderVisibility); if (auto E = Visibility.takeError()) return Error(std::move(E)); @@ -322,7 +328,7 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, if (RangeDescriptorNode->getNumOperands() != 6) return make_error<InvalidRSMetadataFormat>("Descriptor Range"); - dxbc::RTS0::v2::DescriptorRange Range; + mcdxbc::DescriptorRange Range; std::optional<StringRef> ElementText = extractMdStringValue(RangeDescriptorNode, 0); @@ -330,15 +336,15 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, if (!ElementText.has_value()) return make_error<InvalidRSMetadataFormat>("Descriptor Range"); - Range.RangeType = - StringSwitch<uint32_t>(*ElementText) - .Case("CBV", to_underlying(dxbc::DescriptorRangeType::CBV)) - .Case("SRV", to_underlying(dxbc::DescriptorRangeType::SRV)) - .Case("UAV", to_underlying(dxbc::DescriptorRangeType::UAV)) - .Case("Sampler", to_underlying(dxbc::DescriptorRangeType::Sampler)) - .Default(~0U); - - if (Range.RangeType == ~0U) + if (*ElementText == "CBV") + Range.RangeType = dxil::ResourceClass::CBuffer; + else if (*ElementText == "SRV") + Range.RangeType = dxil::ResourceClass::SRV; + else if (*ElementText == "UAV") + Range.RangeType = dxil::ResourceClass::UAV; + else if (*ElementText == "Sampler") + Range.RangeType = dxil::ResourceClass::Sampler; + else return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.", RangeDescriptorNode); @@ -380,7 +386,9 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, return make_error<InvalidRSMetadataFormat>("Descriptor Table"); Expected<dxbc::ShaderVisibility> Visibility = - extractShaderVisibility(DescriptorTableNode, 1); + extractEnumValue<dxbc::ShaderVisibility>(DescriptorTableNode, 1, + "ShaderVisibility", + dxbc::isValidShaderVisibility); if (auto E = Visibility.takeError()) return Error(std::move(E)); @@ -406,26 +414,34 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, if (StaticSamplerNode->getNumOperands() != 14) return make_error<InvalidRSMetadataFormat>("Static Sampler"); - dxbc::RTS0::v1::StaticSampler Sampler; - if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1)) - Sampler.Filter = *Val; - else - return make_error<InvalidRSMetadataValue>("Filter"); + mcdxbc::StaticSampler Sampler; - if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2)) - Sampler.AddressU = *Val; - else - return make_error<InvalidRSMetadataValue>("AddressU"); + Expected<dxbc::SamplerFilter> Filter = extractEnumValue<dxbc::SamplerFilter>( + StaticSamplerNode, 1, "Filter", dxbc::isValidSamplerFilter); + if (auto E = Filter.takeError()) + return Error(std::move(E)); + Sampler.Filter = *Filter; - if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3)) - Sampler.AddressV = *Val; - else - return make_error<InvalidRSMetadataValue>("AddressV"); + Expected<dxbc::TextureAddressMode> AddressU = + extractEnumValue<dxbc::TextureAddressMode>( + StaticSamplerNode, 2, "AddressU", dxbc::isValidAddress); + if (auto E = AddressU.takeError()) + return Error(std::move(E)); + Sampler.AddressU = *AddressU; - if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4)) - Sampler.AddressW = *Val; - else - return make_error<InvalidRSMetadataValue>("AddressW"); + Expected<dxbc::TextureAddressMode> AddressV = + extractEnumValue<dxbc::TextureAddressMode>( + StaticSamplerNode, 3, "AddressV", dxbc::isValidAddress); + if (auto E = AddressV.takeError()) + return Error(std::move(E)); + Sampler.AddressV = *AddressV; + + Expected<dxbc::TextureAddressMode> AddressW = + extractEnumValue<dxbc::TextureAddressMode>( + StaticSamplerNode, 4, "AddressW", dxbc::isValidAddress); + if (auto E = AddressW.takeError()) + return Error(std::move(E)); + Sampler.AddressW = *AddressW; if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5)) Sampler.MipLODBias = *Val; @@ -437,15 +453,19 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, else return make_error<InvalidRSMetadataValue>("MaxAnisotropy"); - if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7)) - Sampler.ComparisonFunc = *Val; - else - return make_error<InvalidRSMetadataValue>("ComparisonFunc"); + Expected<dxbc::ComparisonFunc> ComparisonFunc = + extractEnumValue<dxbc::ComparisonFunc>( + StaticSamplerNode, 7, "ComparisonFunc", dxbc::isValidComparisonFunc); + if (auto E = ComparisonFunc.takeError()) + return Error(std::move(E)); + Sampler.ComparisonFunc = *ComparisonFunc; - if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8)) - Sampler.BorderColor = *Val; - else - return make_error<InvalidRSMetadataValue>("ComparisonFunc"); + Expected<dxbc::StaticBorderColor> BorderColor = + extractEnumValue<dxbc::StaticBorderColor>( + StaticSamplerNode, 8, "BorderColor", dxbc::isValidBorderColor); + if (auto E = BorderColor.takeError()) + return Error(std::move(E)); + Sampler.BorderColor = *BorderColor; if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9)) Sampler.MinLOD = *Val; @@ -467,10 +487,13 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, else return make_error<InvalidRSMetadataValue>("RegisterSpace"); - if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13)) - Sampler.ShaderVisibility = *Val; - else - return make_error<InvalidRSMetadataValue>("ShaderVisibility"); + Expected<dxbc::ShaderVisibility> Visibility = + extractEnumValue<dxbc::ShaderVisibility>(StaticSamplerNode, 13, + "ShaderVisibility", + dxbc::isValidShaderVisibility); + if (auto E = Visibility.takeError()) + return Error(std::move(E)); + Sampler.ShaderVisibility = *Visibility; RSD.StaticSamplers.push_back(Sampler); return Error::success(); @@ -568,13 +591,7 @@ Error MetadataParser::validateRootSignature( case dxbc::RootParameterType::DescriptorTable: { const mcdxbc::DescriptorTable &Table = RSD.ParametersContainer.getDescriptorTable(Info.Location); - for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { - if (!hlsl::rootsig::verifyRangeType(Range.RangeType)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "RangeType", Range.RangeType)); - + for (const mcdxbc::DescriptorRange &Range : Table) { if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) DeferredErrs = joinErrors(std::move(DeferredErrs), @@ -588,7 +605,8 @@ Error MetadataParser::validateRootSignature( "NumDescriptors", Range.NumDescriptors)); if (!hlsl::rootsig::verifyDescriptorRangeFlag( - RSD.Version, Range.RangeType, Range.Flags)) + RSD.Version, Range.RangeType, + dxbc::DescriptorRangeFlags(Range.Flags))) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( @@ -599,30 +617,7 @@ Error MetadataParser::validateRootSignature( } } - for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) { - if (!hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "Filter", Sampler.Filter)); - - if (!hlsl::rootsig::verifyAddress(Sampler.AddressU)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "AddressU", Sampler.AddressU)); - - if (!hlsl::rootsig::verifyAddress(Sampler.AddressV)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "AddressV", Sampler.AddressV)); - - if (!hlsl::rootsig::verifyAddress(Sampler.AddressW)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "AddressW", Sampler.AddressW)); + for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers) { if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) DeferredErrs = joinErrors(std::move(DeferredErrs), @@ -635,18 +630,6 @@ Error MetadataParser::validateRootSignature( make_error<RootSignatureValidationError<uint32_t>>( "MaxAnisotropy", Sampler.MaxAnisotropy)); - if (!hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "ComparisonFunc", Sampler.ComparisonFunc)); - - if (!hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "BorderColor", Sampler.BorderColor)); - if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD)) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<float>>( @@ -668,12 +651,6 @@ Error MetadataParser::validateRootSignature( joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( "RegisterSpace", Sampler.RegisterSpace)); - - if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "ShaderVisibility", Sampler.ShaderVisibility)); } return DeferredErrs; |
