diff options
Diffstat (limited to 'llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp')
| -rw-r--r-- | llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 118 |
1 files changed, 56 insertions, 62 deletions
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index dece8f197aaf..31605e390034 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -52,6 +52,17 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node, return NodeText->getString(); } +static Expected<dxbc::ShaderVisibility> +extractShaderVisibility(MDNode *Node, unsigned int OpId) { + if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) { + if (!dxbc::isValidShaderVisibility(*Val)) + return make_error<RootSignatureValidationError<uint32_t>>( + "ShaderVisibility", *Val); + return dxbc::ShaderVisibility(*Val); + } + return make_error<InvalidRSMetadataValue>("ShaderVisibility"); +} + namespace { // We use the OverloadVisit with std::visit to ensure the compiler catches if a @@ -221,17 +232,12 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD, if (RootConstantNode->getNumOperands() != 5) return make_error<InvalidRSMetadataFormat>("RootConstants Element"); - dxbc::RTS0::v1::RootParameterHeader Header; - // The parameter offset doesn't matter here - we recalculate it during - // serialization Header.ParameterOffset = 0; - Header.ParameterType = to_underlying(dxbc::RootParameterType::Constants32Bit); - - if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1)) - Header.ShaderVisibility = *Val; - else - return make_error<InvalidRSMetadataValue>("ShaderVisibility"); + Expected<dxbc::ShaderVisibility> Visibility = + extractShaderVisibility(RootConstantNode, 1); + if (auto E = Visibility.takeError()) + return Error(std::move(E)); - dxbc::RTS0::v1::RootConstants Constants; + mcdxbc::RootConstants Constants; if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2)) Constants.ShaderRegister = *Val; else @@ -247,7 +253,8 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD, else return make_error<InvalidRSMetadataValue>("Num32BitValues"); - RSD.ParametersContainer.addParameter(Header, Constants); + RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit, + *Visibility, Constants); return Error::success(); } @@ -263,28 +270,28 @@ Error MetadataParser::parseRootDescriptors( if (RootDescriptorNode->getNumOperands() != 5) return make_error<InvalidRSMetadataFormat>("Root Descriptor Element"); - dxbc::RTS0::v1::RootParameterHeader Header; + dxbc::RootParameterType Type; switch (ElementKind) { case RootSignatureElementKind::SRV: - Header.ParameterType = to_underlying(dxbc::RootParameterType::SRV); + Type = dxbc::RootParameterType::SRV; break; case RootSignatureElementKind::UAV: - Header.ParameterType = to_underlying(dxbc::RootParameterType::UAV); + Type = dxbc::RootParameterType::UAV; break; case RootSignatureElementKind::CBV: - Header.ParameterType = to_underlying(dxbc::RootParameterType::CBV); + Type = dxbc::RootParameterType::CBV; break; default: llvm_unreachable("invalid Root Descriptor kind"); break; } - if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1)) - Header.ShaderVisibility = *Val; - else - return make_error<InvalidRSMetadataValue>("ShaderVisibility"); + Expected<dxbc::ShaderVisibility> Visibility = + extractShaderVisibility(RootDescriptorNode, 1); + if (auto E = Visibility.takeError()) + return Error(std::move(E)); - dxbc::RTS0::v2::RootDescriptor Descriptor; + mcdxbc::RootDescriptor Descriptor; if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2)) Descriptor.ShaderRegister = *Val; else @@ -296,7 +303,7 @@ Error MetadataParser::parseRootDescriptors( return make_error<InvalidRSMetadataValue>("RegisterSpace"); if (RSD.Version == 1) { - RSD.ParametersContainer.addParameter(Header, Descriptor); + RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor); return Error::success(); } assert(RSD.Version > 1); @@ -306,7 +313,7 @@ Error MetadataParser::parseRootDescriptors( else return make_error<InvalidRSMetadataValue>("Root Descriptor Flags"); - RSD.ParametersContainer.addParameter(Header, Descriptor); + RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor); return Error::success(); } @@ -315,7 +322,7 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, if (RangeDescriptorNode->getNumOperands() != 6) return make_error<InvalidRSMetadataFormat>("Descriptor Range"); - dxbc::RTS0::v2::DescriptorRange Range; + mcdxbc::DescriptorRange Range; std::optional<StringRef> ElementText = extractMdStringValue(RangeDescriptorNode, 0); @@ -323,15 +330,15 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table, if (!ElementText.has_value()) return make_error<InvalidRSMetadataFormat>("Descriptor Range"); - Range.RangeType = - StringSwitch<uint32_t>(*ElementText) - .Case("CBV", to_underlying(dxbc::DescriptorRangeType::CBV)) - .Case("SRV", to_underlying(dxbc::DescriptorRangeType::SRV)) - .Case("UAV", to_underlying(dxbc::DescriptorRangeType::UAV)) - .Case("Sampler", to_underlying(dxbc::DescriptorRangeType::Sampler)) - .Default(~0U); - - if (Range.RangeType == ~0U) + if (*ElementText == "CBV") + Range.RangeType = dxil::ResourceClass::CBuffer; + else if (*ElementText == "SRV") + Range.RangeType = dxil::ResourceClass::SRV; + else if (*ElementText == "UAV") + Range.RangeType = dxil::ResourceClass::UAV; + else if (*ElementText == "Sampler") + Range.RangeType = dxil::ResourceClass::Sampler; + else return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.", RangeDescriptorNode); @@ -372,15 +379,12 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, if (NumOperands < 2) return make_error<InvalidRSMetadataFormat>("Descriptor Table"); - dxbc::RTS0::v1::RootParameterHeader Header; - if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1)) - Header.ShaderVisibility = *Val; - else - return make_error<InvalidRSMetadataValue>("ShaderVisibility"); + Expected<dxbc::ShaderVisibility> Visibility = + extractShaderVisibility(DescriptorTableNode, 1); + if (auto E = Visibility.takeError()) + return Error(std::move(E)); mcdxbc::DescriptorTable Table; - Header.ParameterType = - to_underlying(dxbc::RootParameterType::DescriptorTable); for (unsigned int I = 2; I < NumOperands; I++) { MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I)); @@ -392,7 +396,8 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, return Err; } - RSD.ParametersContainer.addParameter(Header, Table); + RSD.ParametersContainer.addParameter(dxbc::RootParameterType::DescriptorTable, + *Visibility, Table); return Error::success(); } @@ -528,21 +533,15 @@ Error MetadataParser::validateRootSignature( } for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { - if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "ShaderVisibility", Info.Header.ShaderVisibility)); - - assert(dxbc::isValidParameterType(Info.Header.ParameterType) && - "Invalid value for ParameterType"); - switch (Info.Header.ParameterType) { + switch (Info.Type) { + case dxbc::RootParameterType::Constants32Bit: + break; - case to_underlying(dxbc::RootParameterType::CBV): - case to_underlying(dxbc::RootParameterType::UAV): - case to_underlying(dxbc::RootParameterType::SRV): { - const dxbc::RTS0::v2::RootDescriptor &Descriptor = + case dxbc::RootParameterType::CBV: + case dxbc::RootParameterType::UAV: + case dxbc::RootParameterType::SRV: { + const mcdxbc::RootDescriptor &Descriptor = RSD.ParametersContainer.getRootDescriptor(Info.Location); if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) DeferredErrs = @@ -566,16 +565,10 @@ Error MetadataParser::validateRootSignature( } break; } - case to_underlying(dxbc::RootParameterType::DescriptorTable): { + case dxbc::RootParameterType::DescriptorTable: { const mcdxbc::DescriptorTable &Table = RSD.ParametersContainer.getDescriptorTable(Info.Location); - for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { - if (!hlsl::rootsig::verifyRangeType(Range.RangeType)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error<RootSignatureValidationError<uint32_t>>( - "RangeType", Range.RangeType)); - + for (const mcdxbc::DescriptorRange &Range : Table) { if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) DeferredErrs = joinErrors(std::move(DeferredErrs), @@ -589,7 +582,8 @@ Error MetadataParser::validateRootSignature( "NumDescriptors", Range.NumDescriptors)); if (!hlsl::rootsig::verifyDescriptorRangeFlag( - RSD.Version, Range.RangeType, Range.Flags)) + RSD.Version, Range.RangeType, + dxbc::DescriptorRangeFlags(Range.Flags))) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( |
