diff options
Diffstat (limited to 'llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp')
| -rw-r--r-- | llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp | 83 |
1 files changed, 66 insertions, 17 deletions
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp index aa1f1957d9cb..765a3bcbed7e 100644 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp @@ -15,7 +15,6 @@ #include "llvm/ADT/bit.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Metadata.h" -#include "llvm/IR/Module.h" namespace llvm { namespace hlsl { @@ -169,25 +168,44 @@ void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) { OS << "}"; } +namespace { + +// We use the OverloadBuild with std::visit to ensure the compiler catches if a +// new RootElement variant type is added but it's metadata generation isn't +// handled. +template <class... Ts> struct OverloadedBuild : Ts... { + using Ts::operator()...; +}; +template <class... Ts> OverloadedBuild(Ts...) -> OverloadedBuild<Ts...>; + +} // namespace + MDNode *MetadataBuilder::BuildRootSignature() { + const auto Visitor = OverloadedBuild{ + [this](const RootFlags &Flags) -> MDNode * { + return BuildRootFlags(Flags); + }, + [this](const RootConstants &Constants) -> MDNode * { + return BuildRootConstants(Constants); + }, + [this](const RootDescriptor &Descriptor) -> MDNode * { + return BuildRootDescriptor(Descriptor); + }, + [this](const DescriptorTableClause &Clause) -> MDNode * { + return BuildDescriptorTableClause(Clause); + }, + [this](const DescriptorTable &Table) -> MDNode * { + return BuildDescriptorTable(Table); + }, + [this](const StaticSampler &Sampler) -> MDNode * { + return BuildStaticSampler(Sampler); + }, + }; + for (const RootElement &Element : Elements) { - MDNode *ElementMD = nullptr; - if (const auto &Flags = std::get_if<RootFlags>(&Element)) - ElementMD = BuildRootFlags(*Flags); - else if (const auto &Constants = std::get_if<RootConstants>(&Element)) - ElementMD = BuildRootConstants(*Constants); - else if (const auto &Descriptor = std::get_if<RootDescriptor>(&Element)) - ElementMD = BuildRootDescriptor(*Descriptor); - else if (const auto &Clause = std::get_if<DescriptorTableClause>(&Element)) - ElementMD = BuildDescriptorTableClause(*Clause); - else if (const auto &Table = std::get_if<DescriptorTable>(&Element)) - ElementMD = BuildDescriptorTable(*Table); - - // FIXME(#126586): remove once all RootElemnt variants are handled in a - // visit or otherwise + MDNode *ElementMD = std::visit(Visitor, Element); assert(ElementMD != nullptr && - "Constructed an unhandled root element type."); - + "Root Element must be initialized and validated"); GeneratedMetadata.push_back(ElementMD); } @@ -274,6 +292,37 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause( }); } +MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { + IRBuilder<> Builder(Ctx); + Metadata *Operands[] = { + MDString::get(Ctx, "StaticSampler"), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.Filter))), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.AddressU))), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.AddressV))), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.AddressW))), + ConstantAsMetadata::get(llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), + Sampler.MipLODBias)), + ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.CompFunc))), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.BorderColor))), + ConstantAsMetadata::get( + llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MinLOD)), + ConstantAsMetadata::get( + llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MaxLOD)), + ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)), + ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Sampler.Visibility))), + }; + return MDNode::get(Ctx, Operands); +} + } // namespace rootsig } // namespace hlsl } // namespace llvm |
