summaryrefslogtreecommitdiff
path: root/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp')
-rw-r--r--llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp83
1 files changed, 66 insertions, 17 deletions
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index aa1f1957d9cb..765a3bcbed7e 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -15,7 +15,6 @@
#include "llvm/ADT/bit.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
-#include "llvm/IR/Module.h"
namespace llvm {
namespace hlsl {
@@ -169,25 +168,44 @@ void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) {
OS << "}";
}
+namespace {
+
+// We use the OverloadBuild with std::visit to ensure the compiler catches if a
+// new RootElement variant type is added but it's metadata generation isn't
+// handled.
+template <class... Ts> struct OverloadedBuild : Ts... {
+ using Ts::operator()...;
+};
+template <class... Ts> OverloadedBuild(Ts...) -> OverloadedBuild<Ts...>;
+
+} // namespace
+
MDNode *MetadataBuilder::BuildRootSignature() {
+ const auto Visitor = OverloadedBuild{
+ [this](const RootFlags &Flags) -> MDNode * {
+ return BuildRootFlags(Flags);
+ },
+ [this](const RootConstants &Constants) -> MDNode * {
+ return BuildRootConstants(Constants);
+ },
+ [this](const RootDescriptor &Descriptor) -> MDNode * {
+ return BuildRootDescriptor(Descriptor);
+ },
+ [this](const DescriptorTableClause &Clause) -> MDNode * {
+ return BuildDescriptorTableClause(Clause);
+ },
+ [this](const DescriptorTable &Table) -> MDNode * {
+ return BuildDescriptorTable(Table);
+ },
+ [this](const StaticSampler &Sampler) -> MDNode * {
+ return BuildStaticSampler(Sampler);
+ },
+ };
+
for (const RootElement &Element : Elements) {
- MDNode *ElementMD = nullptr;
- if (const auto &Flags = std::get_if<RootFlags>(&Element))
- ElementMD = BuildRootFlags(*Flags);
- else if (const auto &Constants = std::get_if<RootConstants>(&Element))
- ElementMD = BuildRootConstants(*Constants);
- else if (const auto &Descriptor = std::get_if<RootDescriptor>(&Element))
- ElementMD = BuildRootDescriptor(*Descriptor);
- else if (const auto &Clause = std::get_if<DescriptorTableClause>(&Element))
- ElementMD = BuildDescriptorTableClause(*Clause);
- else if (const auto &Table = std::get_if<DescriptorTable>(&Element))
- ElementMD = BuildDescriptorTable(*Table);
-
- // FIXME(#126586): remove once all RootElemnt variants are handled in a
- // visit or otherwise
+ MDNode *ElementMD = std::visit(Visitor, Element);
assert(ElementMD != nullptr &&
- "Constructed an unhandled root element type.");
-
+ "Root Element must be initialized and validated");
GeneratedMetadata.push_back(ElementMD);
}
@@ -274,6 +292,37 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause(
});
}
+MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
+ IRBuilder<> Builder(Ctx);
+ Metadata *Operands[] = {
+ MDString::get(Ctx, "StaticSampler"),
+ ConstantAsMetadata::get(
+ Builder.getInt32(llvm::to_underlying(Sampler.Filter))),
+ ConstantAsMetadata::get(
+ Builder.getInt32(llvm::to_underlying(Sampler.AddressU))),
+ ConstantAsMetadata::get(
+ Builder.getInt32(llvm::to_underlying(Sampler.AddressV))),
+ ConstantAsMetadata::get(
+ Builder.getInt32(llvm::to_underlying(Sampler.AddressW))),
+ ConstantAsMetadata::get(llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx),
+ Sampler.MipLODBias)),
+ ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)),
+ ConstantAsMetadata::get(
+ Builder.getInt32(llvm::to_underlying(Sampler.CompFunc))),
+ ConstantAsMetadata::get(
+ Builder.getInt32(llvm::to_underlying(Sampler.BorderColor))),
+ ConstantAsMetadata::get(
+ llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MinLOD)),
+ ConstantAsMetadata::get(
+ llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MaxLOD)),
+ ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)),
+ ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)),
+ ConstantAsMetadata::get(
+ Builder.getInt32(llvm::to_underlying(Sampler.Visibility))),
+ };
+ return MDNode::get(Ctx, Operands);
+}
+
} // namespace rootsig
} // namespace hlsl
} // namespace llvm