summaryrefslogtreecommitdiff
path: root/clang/lib/Sema/SemaHLSL.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'clang/lib/Sema/SemaHLSL.cpp')
-rw-r--r--clang/lib/Sema/SemaHLSL.cpp185
1 files changed, 165 insertions, 20 deletions
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index cc50d23e0b99..9276554bebf9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -43,7 +43,9 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/DXILABI.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/TargetParser/Triple.h"
+#include <cmath>
#include <cstddef>
#include <iterator>
#include <utility>
@@ -1064,27 +1066,119 @@ SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) {
void SemaHLSL::ActOnFinishRootSignatureDecl(
SourceLocation Loc, IdentifierInfo *DeclIdent,
- SmallVector<llvm::hlsl::rootsig::RootElement> &Elements) {
+ ArrayRef<hlsl::RootSignatureElement> RootElements) {
+
+ if (handleRootSignatureElements(RootElements))
+ return;
+
+ SmallVector<llvm::hlsl::rootsig::RootElement> Elements;
+ for (auto &RootSigElement : RootElements)
+ Elements.push_back(RootSigElement.getElement());
auto *SignatureDecl = HLSLRootSignatureDecl::Create(
SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc,
DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements);
- if (handleRootSignatureDecl(SignatureDecl, Loc))
- return;
-
SignatureDecl->setImplicit();
SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
}
-bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D,
- SourceLocation Loc) {
+bool SemaHLSL::handleRootSignatureElements(
+ ArrayRef<hlsl::RootSignatureElement> Elements) {
+ // Define some common error handling functions
+ bool HadError = false;
+ auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
+ uint32_t UpperBound) {
+ HadError = true;
+ this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
+ << LowerBound << UpperBound;
+ };
+
+ auto ReportFloatError = [this, &HadError](SourceLocation Loc,
+ float LowerBound,
+ float UpperBound) {
+ HadError = true;
+ this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value)
+ << llvm::formatv("{0:f}", LowerBound).sstr<6>()
+ << llvm::formatv("{0:f}", UpperBound).sstr<6>();
+ };
+
+ auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
+ if (!llvm::hlsl::rootsig::verifyRegisterValue(Register))
+ ReportError(Loc, 0, 0xfffffffe);
+ };
+
+ auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
+ if (!llvm::hlsl::rootsig::verifyRegisterSpace(Space))
+ ReportError(Loc, 0, 0xffffffef);
+ };
+
+ const uint32_t Version =
+ llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer);
+ const uint32_t VersionEnum = Version - 1;
+ auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
+ HadError = true;
+ this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag)
+ << /*version minor*/ VersionEnum;
+ };
+
+ // Iterate through the elements and do basic validations
+ for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
+ SourceLocation Loc = RootSigElem.getLocation();
+ const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
+ if (const auto *Descriptor =
+ std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
+ VerifyRegister(Loc, Descriptor->Reg.Number);
+ VerifySpace(Loc, Descriptor->Space);
+
+ if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(
+ Version, llvm::to_underlying(Descriptor->Flags)))
+ ReportFlagError(Loc);
+ } else if (const auto *Constants =
+ std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
+ VerifyRegister(Loc, Constants->Reg.Number);
+ VerifySpace(Loc, Constants->Space);
+ } else if (const auto *Sampler =
+ std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
+ VerifyRegister(Loc, Sampler->Reg.Number);
+ VerifySpace(Loc, Sampler->Space);
+
+ assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
+ "By construction, parseFloatParam can't produce a NaN from a "
+ "float_literal token");
+
+ if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler->MaxAnisotropy))
+ ReportError(Loc, 0, 16);
+ if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler->MipLODBias))
+ ReportFloatError(Loc, -16.f, 15.99f);
+ } else if (const auto *Clause =
+ std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
+ &Elem)) {
+ VerifyRegister(Loc, Clause->Reg.Number);
+ VerifySpace(Loc, Clause->Space);
+
+ if (!llvm::hlsl::rootsig::verifyNumDescriptors(Clause->NumDescriptors)) {
+ // NumDescriptor could techincally be ~0u but that is reserved for
+ // unbounded, so the diagnostic will not report that as a valid int
+ // value
+ ReportError(Loc, 1, 0xfffffffe);
+ }
+
+ if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
+ Version, llvm::to_underlying(Clause->Type),
+ llvm::to_underlying(Clause->Flags)))
+ ReportFlagError(Loc);
+ }
+ }
+
using RangeInfo = llvm::hlsl::rootsig::RangeInfo;
using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges;
+ using InfoPairT = std::pair<RangeInfo, const hlsl::RootSignatureElement *>;
// 1. Collect RangeInfos
- llvm::SmallVector<RangeInfo> Infos;
- for (const llvm::hlsl::rootsig::RootElement &Elem : D->getRootElements()) {
+ llvm::SmallVector<InfoPairT> InfoPairs;
+ for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
+ const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
if (const auto *Descriptor =
std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
RangeInfo Info;
@@ -1095,7 +1189,8 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D,
llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type));
Info.Space = Descriptor->Space;
Info.Visibility = Descriptor->Visibility;
- Infos.push_back(Info);
+
+ InfoPairs.push_back({Info, &RootSigElem});
} else if (const auto *Constants =
std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
RangeInfo Info;
@@ -1105,7 +1200,8 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D,
Info.Class = llvm::dxil::ResourceClass::CBuffer;
Info.Space = Constants->Space;
Info.Visibility = Constants->Visibility;
- Infos.push_back(Info);
+
+ InfoPairs.push_back({Info, &RootSigElem});
} else if (const auto *Sampler =
std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
RangeInfo Info;
@@ -1115,13 +1211,17 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D,
Info.Class = llvm::dxil::ResourceClass::Sampler;
Info.Space = Sampler->Space;
Info.Visibility = Sampler->Visibility;
- Infos.push_back(Info);
+
+ InfoPairs.push_back({Info, &RootSigElem});
} else if (const auto *Clause =
std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
&Elem)) {
RangeInfo Info;
Info.LowerBound = Clause->Reg.Number;
- assert(0 < Clause->NumDescriptors && "Verified as part of TODO(#129940)");
+ // Relevant error will have already been reported above and needs to be
+ // fixed before we can conduct range analysis, so shortcut error return
+ if (Clause->NumDescriptors == 0)
+ return true;
Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded
? RangeInfo::Unbounded
: Info.LowerBound + Clause->NumDescriptors -
@@ -1129,38 +1229,83 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D,
Info.Class = Clause->Type;
Info.Space = Clause->Space;
+
// Note: Clause does not hold the visibility this will need to
- Infos.push_back(Info);
+ InfoPairs.push_back({Info, &RootSigElem});
} else if (const auto *Table =
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
// Table holds the Visibility of all owned Clauses in Table, so iterate
// owned Clauses and update their corresponding RangeInfo
- assert(Table->NumClauses <= Infos.size() && "RootElement");
+ assert(Table->NumClauses <= InfoPairs.size() && "RootElement");
// The last Table->NumClauses elements of Infos are the owned Clauses
// generated RangeInfo
auto TableInfos =
- MutableArrayRef<RangeInfo>(Infos).take_back(Table->NumClauses);
- for (RangeInfo &Info : TableInfos)
- Info.Visibility = Table->Visibility;
+ MutableArrayRef<InfoPairT>(InfoPairs).take_back(Table->NumClauses);
+ for (InfoPairT &Pair : TableInfos)
+ Pair.first.Visibility = Table->Visibility;
}
}
- // Helper to report diagnostics
- auto ReportOverlap = [this, Loc](OverlappingRanges Overlap) {
+ // 2. Sort with the RangeInfo <operator to prepare it for findOverlapping
+ llvm::sort(InfoPairs,
+ [](InfoPairT A, InfoPairT B) { return A.first < B.first; });
+
+ llvm::SmallVector<RangeInfo> Infos;
+ for (const InfoPairT &Pair : InfoPairs)
+ Infos.push_back(Pair.first);
+
+ // Helpers to report diagnostics
+ uint32_t DuplicateCounter = 0;
+ using ElemPair = std::pair<const hlsl::RootSignatureElement *,
+ const hlsl::RootSignatureElement *>;
+ auto GetElemPair = [&Infos, &InfoPairs, &DuplicateCounter](
+ OverlappingRanges Overlap) -> ElemPair {
+ // Given we sorted the InfoPairs (and by implication) Infos, and,
+ // that Overlap.B is the item retrieved from the ResourceRange. Then it is
+ // guarenteed that Overlap.B <= Overlap.A.
+ //
+ // So we will find Overlap.B first and then continue to find Overlap.A
+ // after
+ auto InfoB = std::lower_bound(Infos.begin(), Infos.end(), *Overlap.B);
+ auto DistB = std::distance(Infos.begin(), InfoB);
+ auto PairB = InfoPairs.begin();
+ std::advance(PairB, DistB);
+
+ auto InfoA = std::lower_bound(InfoB, Infos.end(), *Overlap.A);
+ // Similarily, from the property that we have sorted the RangeInfos,
+ // all duplicates will be processed one after the other. So
+ // DuplicateCounter can be re-used for each set of duplicates we
+ // encounter as we handle incoming errors
+ DuplicateCounter = InfoA == InfoB ? DuplicateCounter + 1 : 0;
+ auto DistA = std::distance(InfoB, InfoA) + DuplicateCounter;
+ auto PairA = PairB;
+ std::advance(PairA, DistA);
+
+ return {PairA->second, PairB->second};
+ };
+
+ auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) {
+ auto Pair = GetElemPair(Overlap);
const RangeInfo *Info = Overlap.A;
+ const hlsl::RootSignatureElement *Elem = Pair.first;
const RangeInfo *OInfo = Overlap.B;
+
auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All
? OInfo->Visibility
: Info->Visibility;
- this->Diag(Loc, diag::err_hlsl_resource_range_overlap)
+ this->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
<< llvm::to_underlying(Info->Class) << Info->LowerBound
<< /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded)
<< Info->UpperBound << llvm::to_underlying(OInfo->Class)
<< OInfo->LowerBound
<< /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded)
<< OInfo->UpperBound << Info->Space << CommonVis;
+
+ const hlsl::RootSignatureElement *OElem = Pair.second;
+ this->Diag(OElem->getLocation(), diag::note_hlsl_resource_range_here);
};
+ // 3. Invoke find overlapping ranges
llvm::SmallVector<OverlappingRanges> Overlaps =
llvm::hlsl::rootsig::findOverlappingRanges(Infos);
for (OverlappingRanges Overlap : Overlaps)