diff options
Diffstat (limited to 'clang/lib/Sema/SemaHLSL.cpp')
| -rw-r--r-- | clang/lib/Sema/SemaHLSL.cpp | 169 |
1 files changed, 114 insertions, 55 deletions
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index f87715950c74..6062f81d0aed 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -236,9 +236,8 @@ static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context, static unsigned calculateLegacyCbufferSize(const ASTContext &Context, QualType T) { constexpr unsigned CBufferAlign = 16; - if (const RecordType *RT = T->getAs<RecordType>()) { + if (const auto *RD = T->getAsRecordDecl()) { unsigned Size = 0; - const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf(); for (const FieldDecl *Field : RD->fields()) { QualType Ty = Field->getType(); unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty); @@ -351,8 +350,8 @@ getResourceArrayHandleType(VarDecl *VD) { assert(VD->getType()->isHLSLResourceRecordArray() && "expected array of resource records"); const Type *Ty = VD->getType()->getUnqualifiedDesugaredType(); - while (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(Ty)) - Ty = CAT->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType(); + while (const ArrayType *AT = dyn_cast<ArrayType>(Ty)) + Ty = AT->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType(); return HLSLAttributedResourceType::findHandleTypeOnResource(Ty); } @@ -364,8 +363,8 @@ static bool isInvalidConstantBufferLeafElementType(const Type *Ty) { Ty = Ty->getUnqualifiedDesugaredType(); if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray()) return true; - if (Ty->isRecordType()) - return Ty->getAsCXXRecordDecl()->isEmpty(); + if (const auto *RD = Ty->getAsCXXRecordDecl()) + return RD->isEmpty(); if (Ty->isConstantArrayType() && isZeroSizedArray(cast<ConstantArrayType>(Ty))) return true; @@ -387,14 +386,14 @@ static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) { QualType Ty = Field->getType(); if (isInvalidConstantBufferLeafElementType(Ty.getTypePtr())) return true; - if (Ty->isRecordType() && - requiresImplicitBufferLayoutStructure(Ty->getAsCXXRecordDecl())) + if (const auto *RD = Ty->getAsCXXRecordDecl(); + RD && requiresImplicitBufferLayoutStructure(RD)) return true; } // check bases for (const CXXBaseSpecifier &Base : RD->bases()) if (requiresImplicitBufferLayoutStructure( - Base.getType()->getAsCXXRecordDecl())) + Base.getType()->castAsCXXRecordDecl())) return true; return false; } @@ -459,8 +458,7 @@ static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty, if (isInvalidConstantBufferLeafElementType(Ty)) return nullptr; - if (Ty->isRecordType()) { - CXXRecordDecl *RD = Ty->getAsCXXRecordDecl(); + if (auto *RD = Ty->getAsCXXRecordDecl()) { if (requiresImplicitBufferLayoutStructure(RD)) { RD = createHostLayoutStruct(S, RD); if (!RD) @@ -511,7 +509,7 @@ static CXXRecordDecl *createHostLayoutStruct(Sema &S, assert(NumBases == 1 && "HLSL supports only one base type"); (void)NumBases; CXXBaseSpecifier Base = *StructDecl->bases_begin(); - CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl(); + CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl(); if (requiresImplicitBufferLayoutStructure(BaseDecl)) { BaseDecl = createHostLayoutStruct(S, BaseDecl); if (BaseDecl) { @@ -729,6 +727,19 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) return; + // If we have specified a root signature to override the entry function then + // attach it now + HLSLRootSignatureDecl *SignatureDecl = + lookupRootSignatureOverrideDecl(FD->getDeclContext()); + if (SignatureDecl) { + FD->dropAttr<RootSignatureAttr>(); + // We could look up the SourceRange of the macro here as well + AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(), + SourceRange(), ParsedAttr::Form::Microsoft()); + FD->addAttr(::new (getASTContext()) RootSignatureAttr( + getASTContext(), AL, RootSigOverrideIdent, SignatureDecl)); + } + llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment(); if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) { if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) { @@ -750,6 +761,8 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { case llvm::Triple::UnknownEnvironment: case llvm::Triple::Library: break; + case llvm::Triple::RootSignature: + llvm_unreachable("rootsig environment has no functions"); default: llvm_unreachable("Unhandled environment in triple"); } @@ -812,6 +825,8 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { } } break; + case llvm::Triple::RootSignature: + llvm_unreachable("rootsig environment has no function entry point"); default: llvm_unreachable("Unhandled environment in triple"); } @@ -1092,6 +1107,18 @@ void SemaHLSL::ActOnFinishRootSignatureDecl( SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope()); } +HLSLRootSignatureDecl * +SemaHLSL::lookupRootSignatureOverrideDecl(DeclContext *DC) const { + if (RootSigOverrideIdent) { + LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(), + Sema::LookupOrdinaryName); + if (SemaRef.LookupQualifiedName(R, DC)) + return dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl()); + } + + return nullptr; +} + namespace { struct PerVisibilityBindingChecker { @@ -1144,15 +1171,14 @@ struct PerVisibilityBindingChecker { bool HadOverlap = false; using llvm::hlsl::BindingInfoBuilder; - auto ReportOverlap = [this, &HadOverlap]( - const BindingInfoBuilder &Builder, - const BindingInfoBuilder::Binding &Reported) { + auto ReportOverlap = [this, + &HadOverlap](const BindingInfoBuilder &Builder, + const llvm::hlsl::Binding &Reported) { HadOverlap = true; const auto *Elem = static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie); - const BindingInfoBuilder::Binding &Previous = - Builder.findOverlapping(Reported); + const llvm::hlsl::Binding &Previous = Builder.findOverlapping(Reported); const auto *PrevElem = static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie); @@ -1269,9 +1295,8 @@ bool SemaHLSL::handleRootSignatureElements( ReportError(Loc, 1, 0xfffffffe); } - if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( - Version, llvm::to_underlying(Clause->Type), - llvm::to_underlying(Clause->Flags))) + if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type, + Clause->Flags)) ReportFlagError(Loc); } } @@ -1317,12 +1342,48 @@ bool SemaHLSL::handleRootSignatureElements( std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) { assert(UnboundClauses.size() == Table->NumClauses && "Number of unbound elements must match the number of clauses"); + bool HasAnySampler = false; + bool HasAnyNonSampler = false; + uint32_t Offset = 0; for (const auto &[Clause, ClauseElem] : UnboundClauses) { - uint32_t LowerBound(Clause->Reg.Number); + SourceLocation Loc = ClauseElem->getLocation(); + if (Clause->Type == llvm::dxil::ResourceClass::Sampler) + HasAnySampler = true; + else + HasAnyNonSampler = true; + + if (HasAnySampler && HasAnyNonSampler) + Diag(Loc, diag::err_hlsl_invalid_mixed_resources); + // Relevant error will have already been reported above and needs to be - // fixed before we can conduct range analysis, so shortcut error return + // fixed before we can conduct further analysis, so shortcut error + // return if (Clause->NumDescriptors == 0) return true; + + if (Clause->Offset != + llvm::hlsl::rootsig::DescriptorTableOffsetAppend) { + // Manually specified the offset + Offset = Clause->Offset; + } + + uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound( + Offset, Clause->NumDescriptors); + + if (!llvm::hlsl::rootsig::verifyBoundOffset(Offset)) { + // Trying to append onto unbound offset + Diag(Loc, diag::err_hlsl_appending_onto_unbound); + } else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound)) { + // Upper bound overflows maximum offset + Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound; + } + + Offset = RangeBound == llvm::hlsl::rootsig::NumDescriptorsUnbounded + ? uint32_t(RangeBound) + : uint32_t(RangeBound + 1); + + // Compute the register bounds and track resource binding + uint32_t LowerBound(Clause->Reg.Number); uint32_t UpperBound = Clause->NumDescriptors == ~0u ? ~0u : LowerBound + Clause->NumDescriptors - 1; @@ -2008,9 +2069,11 @@ static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, } void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { - if (isa<VarDecl>(TheDecl)) { - if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), - cast<ValueDecl>(TheDecl)->getType(), + if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) { + QualType Ty = VD->getType(); + if (const auto *IAT = dyn_cast<IncompleteArrayType>(Ty)) + Ty = IAT->getElementType(); + if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), Ty, diag::err_incomplete_type)) return; } @@ -2838,8 +2901,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { if (SemaRef.checkArgCount(TheCall, 6) || CheckResourceHandle(&SemaRef, TheCall, 0) || CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), AST.UnsignedIntTy) || - CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.IntTy) || - CheckArgTypeMatches(&SemaRef, TheCall->getArg(3), AST.UnsignedIntTy) || + CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.UnsignedIntTy) || + CheckArgTypeMatches(&SemaRef, TheCall->getArg(3), AST.IntTy) || CheckArgTypeMatches(&SemaRef, TheCall->getArg(4), AST.UnsignedIntTy) || CheckArgTypeMatches(&SemaRef, TheCall->getArg(5), AST.getPointerType(AST.CharTy.withConst()))) @@ -3194,10 +3257,7 @@ static void BuildFlattenedTypeList(QualType BaseTy, List.insert(List.end(), VT->getNumElements(), VT->getElementType()); continue; } - if (const auto *RT = dyn_cast<RecordType>(T)) { - const CXXRecordDecl *RD = RT->getAsCXXRecordDecl(); - assert(RD && "HLSL record types should all be CXXRecordDecls!"); - + if (const auto *RD = T->getAsCXXRecordDecl()) { if (RD->isStandardLayout()) RD = RD->getStandardLayoutBaseWithFields(); @@ -3820,9 +3880,9 @@ void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) { // Unwrap arrays // FIXME: Calculate array size while unwrapping const Type *Ty = VD->getType()->getUnqualifiedDesugaredType(); - while (Ty->isConstantArrayType()) { - const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); - Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); + while (Ty->isArrayType()) { + const ArrayType *AT = cast<ArrayType>(Ty); + Ty = AT->getElementType()->getUnqualifiedDesugaredType(); } // Resource (or array of resources) @@ -3967,19 +4027,19 @@ class InitListTransformer { return true; } - if (auto *RTy = Ty->getAs<RecordType>()) { - llvm::SmallVector<const RecordType *> RecordTypes; - RecordTypes.push_back(RTy); - while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) { - CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl(); + if (auto *RD = Ty->getAsCXXRecordDecl()) { + llvm::SmallVector<CXXRecordDecl *> RecordDecls; + RecordDecls.push_back(RD); + while (RecordDecls.back()->getNumBases()) { + CXXRecordDecl *D = RecordDecls.back(); assert(D->getNumBases() == 1 && "HLSL doesn't support multiple inheritance"); - RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>()); + RecordDecls.push_back( + D->bases_begin()->getType()->castAsCXXRecordDecl()); } - while (!RecordTypes.empty()) { - const RecordType *RT = RecordTypes.pop_back_val(); - for (auto *FD : - RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) { + while (!RecordDecls.empty()) { + CXXRecordDecl *RD = RecordDecls.pop_back_val(); + for (auto *FD : RD->fields()) { DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess()); DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc()); ExprResult Res = S.BuildFieldReferenceExpr( @@ -4016,21 +4076,20 @@ class InitListTransformer { for (uint64_t I = 0; I < Size; ++I) Inits.push_back(generateInitListsImpl(ElTy)); } - if (auto *RTy = Ty->getAs<RecordType>()) { - llvm::SmallVector<const RecordType *> RecordTypes; - RecordTypes.push_back(RTy); - while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) { - CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl(); + if (auto *RD = Ty->getAsCXXRecordDecl()) { + llvm::SmallVector<CXXRecordDecl *> RecordDecls; + RecordDecls.push_back(RD); + while (RecordDecls.back()->getNumBases()) { + CXXRecordDecl *D = RecordDecls.back(); assert(D->getNumBases() == 1 && "HLSL doesn't support multiple inheritance"); - RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>()); + RecordDecls.push_back( + D->bases_begin()->getType()->castAsCXXRecordDecl()); } - while (!RecordTypes.empty()) { - const RecordType *RT = RecordTypes.pop_back_val(); - for (auto *FD : - RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) { + while (!RecordDecls.empty()) { + CXXRecordDecl *RD = RecordDecls.pop_back_val(); + for (auto *FD : RD->fields()) Inits.push_back(generateInitListsImpl(FD->getType())); - } } } auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(), |
