summaryrefslogtreecommitdiff
path: root/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp')
-rw-r--r--llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp78
1 files changed, 76 insertions, 2 deletions
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index f29f2c7602fc..5785505ce2b0 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -27,6 +27,11 @@ namespace rootsig {
char GenericRSMetadataError::ID;
char InvalidRSMetadataFormat::ID;
char InvalidRSMetadataValue::ID;
+char TableSamplerMixinError::ID;
+char ShaderRegisterOverflowError::ID;
+char OffsetOverflowError::ID;
+char OffsetAppendAfterOverflow::ID;
+
template <typename T> char RootSignatureValidationError<T>::ID;
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
@@ -55,8 +60,9 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
template <typename T, typename = std::enable_if_t<
std::is_enum_v<T> &&
std::is_same_v<std::underlying_type_t<T>, uint32_t>>>
-Expected<T> extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText,
- llvm::function_ref<bool(uint32_t)> VerifyFn) {
+static Expected<T>
+extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText,
+ llvm::function_ref<bool(uint32_t)> VerifyFn) {
if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
if (!VerifyFn(*Val))
return make_error<RootSignatureValidationError<uint32_t>>(ErrText, *Val);
@@ -538,6 +544,60 @@ Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
}
+static Error
+validateDescriptorTableSamplerMixin(const mcdxbc::DescriptorTable &Table,
+ uint32_t Location) {
+ dxil::ResourceClass CurrRC = dxil::ResourceClass::Sampler;
+ for (const mcdxbc::DescriptorRange &Range : Table.Ranges) {
+ if (Range.RangeType == dxil::ResourceClass::Sampler &&
+ CurrRC != dxil::ResourceClass::Sampler)
+ return make_error<TableSamplerMixinError>(CurrRC, Location);
+ CurrRC = Range.RangeType;
+ }
+ return Error::success();
+}
+
+static Error
+validateDescriptorTableRegisterOverflow(const mcdxbc::DescriptorTable &Table,
+ uint32_t Location) {
+ uint64_t Offset = 0;
+ bool IsPrevUnbound = false;
+ for (const mcdxbc::DescriptorRange &Range : Table.Ranges) {
+ // Validation of NumDescriptors should have happened by this point.
+ if (Range.NumDescriptors == 0)
+ continue;
+
+ const uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
+ Range.BaseShaderRegister, Range.NumDescriptors);
+
+ if (!verifyNoOverflowedOffset(RangeBound))
+ return make_error<ShaderRegisterOverflowError>(
+ Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace);
+
+ bool IsAppending =
+ Range.OffsetInDescriptorsFromTableStart == DescriptorTableOffsetAppend;
+ if (!IsAppending)
+ Offset = Range.OffsetInDescriptorsFromTableStart;
+
+ if (IsPrevUnbound && IsAppending)
+ return make_error<OffsetAppendAfterOverflow>(
+ Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace);
+
+ const uint64_t OffsetBound =
+ llvm::hlsl::rootsig::computeRangeBound(Offset, Range.NumDescriptors);
+
+ if (!verifyNoOverflowedOffset(OffsetBound))
+ return make_error<OffsetOverflowError>(
+ Range.RangeType, Range.BaseShaderRegister, Range.RegisterSpace);
+
+ Offset = OffsetBound + 1;
+ IsPrevUnbound =
+ Range.NumDescriptors == llvm::hlsl::rootsig::NumDescriptorsUnbounded;
+ }
+
+ return Error::success();
+}
+
Error MetadataParser::validateRootSignature(
const mcdxbc::RootSignatureDesc &RSD) {
Error DeferredErrs = Error::success();
@@ -611,6 +671,14 @@ Error MetadataParser::validateRootSignature(
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
"DescriptorFlag", Range.Flags));
+
+ if (Error Err =
+ validateDescriptorTableSamplerMixin(Table, Info.Location))
+ DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
+
+ if (Error Err =
+ validateDescriptorTableRegisterOverflow(Table, Info.Location))
+ DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
}
break;
}
@@ -651,6 +719,12 @@ Error MetadataParser::validateRootSignature(
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
"RegisterSpace", Sampler.RegisterSpace));
+
+ if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "Static Sampler Flag", Sampler.Flags));
}
return DeferredErrs;