summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjoaosaffran <joao.saffran@microsoft.com>2025-06-16 19:50:29 +0000
committerjoaosaffran <joao.saffran@microsoft.com>2025-06-16 19:50:29 +0000
commite62419f82edd38bb027f3451dc350ecb01b0be2c (patch)
treee975970bf1c5084f0861d8c5c8be10548d3c6f5b
parent02f1f21b8ecc608341440c573483e69c161a06d4 (diff)
-rw-r--r--llvm/lib/Target/DirectX/DXILRootSignature.cpp65
1 files changed, 38 insertions, 27 deletions
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 3a27afc6c660..57d5ee8ac467 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//
#include "DXILRootSignature.h"
#include "DirectX.h"
-#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DXILMetadataAnalysis.h"
@@ -31,7 +30,6 @@
#include <cmath>
#include <cstdint>
#include <optional>
-#include <string>
#include <utility>
using namespace llvm;
@@ -290,32 +288,32 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 1);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 1);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 2);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 2);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 3);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 3);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 4);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 4);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 5);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 5);
Table.Ranges.push_back(Range);
return false;
@@ -332,8 +330,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportInvalidTypeError<MDString>(Ctx, "DescriptorTableNode",
- DescriptorTableNode, 1);
+ return reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode",
+ DescriptorTableNode, 1);
mcdxbc::DescriptorTable Table;
Header.ParameterType =
@@ -362,67 +360,80 @@ static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
Sampler.Filter = *Val;
else
- return reportError(Ctx, "Invalid value for Filter");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 1);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
Sampler.AddressU = *Val;
else
- return reportError(Ctx, "Invalid value for AddressU");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 2);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
Sampler.AddressV = *Val;
else
- return reportError(Ctx, "Invalid value for AddressV");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 3);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
Sampler.AddressW = *Val;
else
- return reportError(Ctx, "Invalid value for AddressW");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 4);
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = Val->convertToFloat();
else
- return reportError(Ctx, "Invalid value for MipLODBias");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 5);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
- return reportError(Ctx, "Invalid value for MaxAnisotropy");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 6);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
Sampler.ComparisonFunc = *Val;
else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 7);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
Sampler.BorderColor = *Val;
else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 8);
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = Val->convertToFloat();
else
- return reportError(Ctx, "Invalid value for MinLOD");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 9);
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = Val->convertToFloat();
else
- return reportError(Ctx, "Invalid value for MaxLOD");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 10);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 11);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 12);
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
Sampler.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 13);
RSD.StaticSamplers.push_back(Sampler);
return false;