diff options
Diffstat (limited to 'llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp')
| -rw-r--r-- | llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp | 78 |
1 files changed, 76 insertions, 2 deletions
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index f29f2c7602fc..5785505ce2b0 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -27,6 +27,11 @@ namespace rootsig { char GenericRSMetadataError::ID; char InvalidRSMetadataFormat::ID; char InvalidRSMetadataValue::ID; +char TableSamplerMixinError::ID; +char ShaderRegisterOverflowError::ID; +char OffsetOverflowError::ID; +char OffsetAppendAfterOverflow::ID; + template <typename T> char RootSignatureValidationError<T>::ID; static std::optional<uint32_t> extractMdIntValue(MDNode *Node, @@ -55,8 +60,9 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node, template <typename T, typename = std::enable_if_t< std::is_enum_v<T> && std::is_same_v<std::underlying_type_t<T>, uint32_t>>> -Expected<T> extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText, - llvm::function_ref<bool(uint32_t)> VerifyFn) { +static Expected<T> +extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText, + llvm::function_ref<bool(uint32_t)> VerifyFn) { if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) { if (!VerifyFn(*Val)) return make_error<RootSignatureValidationError<uint32_t>>(ErrText, *Val); @@ -538,6 +544,60 @@ Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD, llvm_unreachable("Unhandled RootSignatureElementKind enum."); } +static Error +validateDescriptorTableSamplerMixin(const mcdxbc::DescriptorTable &Table, + uint32_t Location) { + dxil::ResourceClass CurrRC = dxil::ResourceClass::Sampler; + for (const mcdxbc::DescriptorRange &Range : Table.Ranges) { + if (Range.RangeType == dxil::ResourceClass::Sampler && + CurrRC != dxil::ResourceClass::Sampler) + return make_error<TableSamplerMixinError>(CurrRC, Location); + CurrRC = Range.RangeType; + } + return Error::success(); +} + +static Error +validateDescriptorTableRegisterOverflow(const mcdxbc::DescriptorTable &Table, + uint32_t Location) { + uint64_t Offset = 0; + bool IsPrevUnbound = false; + for (const mcdxbc::DescriptorRange &Range : Table.Ranges) { + // Validation of NumDescriptors should have happened by this point. + if (Range.NumDescriptors == 0) + continue; + + const uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound( + Range.BaseShaderRegister, Range.NumDescriptors); + + if (!verifyNoOverflowedOffset(RangeBound)) + return make_error<ShaderRegisterOverflowError>( + Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace); + + bool IsAppending = + Range.OffsetInDescriptorsFromTableStart == DescriptorTableOffsetAppend; + if (!IsAppending) + Offset = Range.OffsetInDescriptorsFromTableStart; + + if (IsPrevUnbound && IsAppending) + return make_error<OffsetAppendAfterOverflow>( + Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace); + + const uint64_t OffsetBound = + llvm::hlsl::rootsig::computeRangeBound(Offset, Range.NumDescriptors); + + if (!verifyNoOverflowedOffset(OffsetBound)) + return make_error<OffsetOverflowError>( + Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace); + + Offset = OffsetBound + 1; + IsPrevUnbound = + Range.NumDescriptors == llvm::hlsl::rootsig::NumDescriptorsUnbounded; + } + + return Error::success(); +} + Error MetadataParser::validateRootSignature( const mcdxbc::RootSignatureDesc &RSD) { Error DeferredErrs = Error::success(); @@ -611,6 +671,14 @@ Error MetadataParser::validateRootSignature( joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( "DescriptorFlag", Range.Flags)); + + if (Error Err = + validateDescriptorTableSamplerMixin(Table, Info.Location)) + DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); + + if (Error Err = + validateDescriptorTableRegisterOverflow(Table, Info.Location)) + DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err)); } break; } @@ -651,6 +719,12 @@ Error MetadataParser::validateRootSignature( joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( "RegisterSpace", Sampler.RegisterSpace)); + + if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags)) + DeferredErrs = + joinErrors(std::move(DeferredErrs), + make_error<RootSignatureValidationError<uint32_t>>( + "Static Sampler Flag", Sampler.Flags)); } return DeferredErrs; |
