diff options
Diffstat (limited to 'clang/lib/Sema/SemaHLSL.cpp')
| -rw-r--r-- | clang/lib/Sema/SemaHLSL.cpp | 185 |
1 files changed, 165 insertions, 20 deletions
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index cc50d23e0b99..9276554bebf9 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -43,7 +43,9 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/DXILABI.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/TargetParser/Triple.h" +#include <cmath> #include <cstddef> #include <iterator> #include <utility> @@ -1064,27 +1066,119 @@ SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) { void SemaHLSL::ActOnFinishRootSignatureDecl( SourceLocation Loc, IdentifierInfo *DeclIdent, - SmallVector<llvm::hlsl::rootsig::RootElement> &Elements) { + ArrayRef<hlsl::RootSignatureElement> RootElements) { + + if (handleRootSignatureElements(RootElements)) + return; + + SmallVector<llvm::hlsl::rootsig::RootElement> Elements; + for (auto &RootSigElement : RootElements) + Elements.push_back(RootSigElement.getElement()); auto *SignatureDecl = HLSLRootSignatureDecl::Create( SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc, DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements); - if (handleRootSignatureDecl(SignatureDecl, Loc)) - return; - SignatureDecl->setImplicit(); SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope()); } -bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, - SourceLocation Loc) { +bool SemaHLSL::handleRootSignatureElements( + ArrayRef<hlsl::RootSignatureElement> Elements) { + // Define some common error handling functions + bool HadError = false; + auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound, + uint32_t UpperBound) { + HadError = true; + this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value) + << LowerBound << UpperBound; + }; + + auto ReportFloatError = [this, &HadError](SourceLocation Loc, + float LowerBound, + float UpperBound) { + HadError = true; + this->Diag(Loc, diag::err_hlsl_invalid_rootsig_value) + << llvm::formatv("{0:f}", LowerBound).sstr<6>() + << llvm::formatv("{0:f}", UpperBound).sstr<6>(); + }; + + auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) { + if (!llvm::hlsl::rootsig::verifyRegisterValue(Register)) + ReportError(Loc, 0, 0xfffffffe); + }; + + auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) { + if (!llvm::hlsl::rootsig::verifyRegisterSpace(Space)) + ReportError(Loc, 0, 0xffffffef); + }; + + const uint32_t Version = + llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer); + const uint32_t VersionEnum = Version - 1; + auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) { + HadError = true; + this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag) + << /*version minor*/ VersionEnum; + }; + + // Iterate through the elements and do basic validations + for (const hlsl::RootSignatureElement &RootSigElem : Elements) { + SourceLocation Loc = RootSigElem.getLocation(); + const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement(); + if (const auto *Descriptor = + std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) { + VerifyRegister(Loc, Descriptor->Reg.Number); + VerifySpace(Loc, Descriptor->Space); + + if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag( + Version, llvm::to_underlying(Descriptor->Flags))) + ReportFlagError(Loc); + } else if (const auto *Constants = + std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) { + VerifyRegister(Loc, Constants->Reg.Number); + VerifySpace(Loc, Constants->Space); + } else if (const auto *Sampler = + std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) { + VerifyRegister(Loc, Sampler->Reg.Number); + VerifySpace(Loc, Sampler->Space); + + assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) && + "By construction, parseFloatParam can't produce a NaN from a " + "float_literal token"); + + if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler->MaxAnisotropy)) + ReportError(Loc, 0, 16); + if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler->MipLODBias)) + ReportFloatError(Loc, -16.f, 15.99f); + } else if (const auto *Clause = + std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>( + &Elem)) { + VerifyRegister(Loc, Clause->Reg.Number); + VerifySpace(Loc, Clause->Space); + + if (!llvm::hlsl::rootsig::verifyNumDescriptors(Clause->NumDescriptors)) { + // NumDescriptor could techincally be ~0u but that is reserved for + // unbounded, so the diagnostic will not report that as a valid int + // value + ReportError(Loc, 1, 0xfffffffe); + } + + if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( + Version, llvm::to_underlying(Clause->Type), + llvm::to_underlying(Clause->Flags))) + ReportFlagError(Loc); + } + } + using RangeInfo = llvm::hlsl::rootsig::RangeInfo; using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges; + using InfoPairT = std::pair<RangeInfo, const hlsl::RootSignatureElement *>; // 1. Collect RangeInfos - llvm::SmallVector<RangeInfo> Infos; - for (const llvm::hlsl::rootsig::RootElement &Elem : D->getRootElements()) { + llvm::SmallVector<InfoPairT> InfoPairs; + for (const hlsl::RootSignatureElement &RootSigElem : Elements) { + const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement(); if (const auto *Descriptor = std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) { RangeInfo Info; @@ -1095,7 +1189,8 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type)); Info.Space = Descriptor->Space; Info.Visibility = Descriptor->Visibility; - Infos.push_back(Info); + + InfoPairs.push_back({Info, &RootSigElem}); } else if (const auto *Constants = std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) { RangeInfo Info; @@ -1105,7 +1200,8 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, Info.Class = llvm::dxil::ResourceClass::CBuffer; Info.Space = Constants->Space; Info.Visibility = Constants->Visibility; - Infos.push_back(Info); + + InfoPairs.push_back({Info, &RootSigElem}); } else if (const auto *Sampler = std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) { RangeInfo Info; @@ -1115,13 +1211,17 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, Info.Class = llvm::dxil::ResourceClass::Sampler; Info.Space = Sampler->Space; Info.Visibility = Sampler->Visibility; - Infos.push_back(Info); + + InfoPairs.push_back({Info, &RootSigElem}); } else if (const auto *Clause = std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>( &Elem)) { RangeInfo Info; Info.LowerBound = Clause->Reg.Number; - assert(0 < Clause->NumDescriptors && "Verified as part of TODO(#129940)"); + // Relevant error will have already been reported above and needs to be + // fixed before we can conduct range analysis, so shortcut error return + if (Clause->NumDescriptors == 0) + return true; Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded ? RangeInfo::Unbounded : Info.LowerBound + Clause->NumDescriptors - @@ -1129,38 +1229,83 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, Info.Class = Clause->Type; Info.Space = Clause->Space; + // Note: Clause does not hold the visibility this will need to - Infos.push_back(Info); + InfoPairs.push_back({Info, &RootSigElem}); } else if (const auto *Table = std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) { // Table holds the Visibility of all owned Clauses in Table, so iterate // owned Clauses and update their corresponding RangeInfo - assert(Table->NumClauses <= Infos.size() && "RootElement"); + assert(Table->NumClauses <= InfoPairs.size() && "RootElement"); // The last Table->NumClauses elements of Infos are the owned Clauses // generated RangeInfo auto TableInfos = - MutableArrayRef<RangeInfo>(Infos).take_back(Table->NumClauses); - for (RangeInfo &Info : TableInfos) - Info.Visibility = Table->Visibility; + MutableArrayRef<InfoPairT>(InfoPairs).take_back(Table->NumClauses); + for (InfoPairT &Pair : TableInfos) + Pair.first.Visibility = Table->Visibility; } } - // Helper to report diagnostics - auto ReportOverlap = [this, Loc](OverlappingRanges Overlap) { + // 2. Sort with the RangeInfo <operator to prepare it for findOverlapping + llvm::sort(InfoPairs, + [](InfoPairT A, InfoPairT B) { return A.first < B.first; }); + + llvm::SmallVector<RangeInfo> Infos; + for (const InfoPairT &Pair : InfoPairs) + Infos.push_back(Pair.first); + + // Helpers to report diagnostics + uint32_t DuplicateCounter = 0; + using ElemPair = std::pair<const hlsl::RootSignatureElement *, + const hlsl::RootSignatureElement *>; + auto GetElemPair = [&Infos, &InfoPairs, &DuplicateCounter]( + OverlappingRanges Overlap) -> ElemPair { + // Given we sorted the InfoPairs (and by implication) Infos, and, + // that Overlap.B is the item retrieved from the ResourceRange. Then it is + // guarenteed that Overlap.B <= Overlap.A. + // + // So we will find Overlap.B first and then continue to find Overlap.A + // after + auto InfoB = std::lower_bound(Infos.begin(), Infos.end(), *Overlap.B); + auto DistB = std::distance(Infos.begin(), InfoB); + auto PairB = InfoPairs.begin(); + std::advance(PairB, DistB); + + auto InfoA = std::lower_bound(InfoB, Infos.end(), *Overlap.A); + // Similarily, from the property that we have sorted the RangeInfos, + // all duplicates will be processed one after the other. So + // DuplicateCounter can be re-used for each set of duplicates we + // encounter as we handle incoming errors + DuplicateCounter = InfoA == InfoB ? DuplicateCounter + 1 : 0; + auto DistA = std::distance(InfoB, InfoA) + DuplicateCounter; + auto PairA = PairB; + std::advance(PairA, DistA); + + return {PairA->second, PairB->second}; + }; + + auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) { + auto Pair = GetElemPair(Overlap); const RangeInfo *Info = Overlap.A; + const hlsl::RootSignatureElement *Elem = Pair.first; const RangeInfo *OInfo = Overlap.B; + auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All ? OInfo->Visibility : Info->Visibility; - this->Diag(Loc, diag::err_hlsl_resource_range_overlap) + this->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap) << llvm::to_underlying(Info->Class) << Info->LowerBound << /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded) << Info->UpperBound << llvm::to_underlying(OInfo->Class) << OInfo->LowerBound << /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded) << OInfo->UpperBound << Info->Space << CommonVis; + + const hlsl::RootSignatureElement *OElem = Pair.second; + this->Diag(OElem->getLocation(), diag::note_hlsl_resource_range_here); }; + // 3. Invoke find overlapping ranges llvm::SmallVector<OverlappingRanges> Overlaps = llvm::hlsl::rootsig::findOverlappingRanges(Infos); for (OverlappingRanges Overlap : Overlaps) |
