summaryrefslogtreecommitdiff
path: root/clang/lib/Sema/SemaHLSL.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'clang/lib/Sema/SemaHLSL.cpp')
-rw-r--r--clang/lib/Sema/SemaHLSL.cpp169
1 files changed, 114 insertions, 55 deletions
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index f87715950c74..6062f81d0aed 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -236,9 +236,8 @@ static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
QualType T) {
constexpr unsigned CBufferAlign = 16;
- if (const RecordType *RT = T->getAs<RecordType>()) {
+ if (const auto *RD = T->getAsRecordDecl()) {
unsigned Size = 0;
- const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
for (const FieldDecl *Field : RD->fields()) {
QualType Ty = Field->getType();
unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
@@ -351,8 +350,8 @@ getResourceArrayHandleType(VarDecl *VD) {
assert(VD->getType()->isHLSLResourceRecordArray() &&
"expected array of resource records");
const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
- while (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(Ty))
- Ty = CAT->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType();
+ while (const ArrayType *AT = dyn_cast<ArrayType>(Ty))
+ Ty = AT->getArrayElementTypeNoTypeQual()->getUnqualifiedDesugaredType();
return HLSLAttributedResourceType::findHandleTypeOnResource(Ty);
}
@@ -364,8 +363,8 @@ static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
Ty = Ty->getUnqualifiedDesugaredType();
if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray())
return true;
- if (Ty->isRecordType())
- return Ty->getAsCXXRecordDecl()->isEmpty();
+ if (const auto *RD = Ty->getAsCXXRecordDecl())
+ return RD->isEmpty();
if (Ty->isConstantArrayType() &&
isZeroSizedArray(cast<ConstantArrayType>(Ty)))
return true;
@@ -387,14 +386,14 @@ static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
QualType Ty = Field->getType();
if (isInvalidConstantBufferLeafElementType(Ty.getTypePtr()))
return true;
- if (Ty->isRecordType() &&
- requiresImplicitBufferLayoutStructure(Ty->getAsCXXRecordDecl()))
+ if (const auto *RD = Ty->getAsCXXRecordDecl();
+ RD && requiresImplicitBufferLayoutStructure(RD))
return true;
}
// check bases
for (const CXXBaseSpecifier &Base : RD->bases())
if (requiresImplicitBufferLayoutStructure(
- Base.getType()->getAsCXXRecordDecl()))
+ Base.getType()->castAsCXXRecordDecl()))
return true;
return false;
}
@@ -459,8 +458,7 @@ static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
if (isInvalidConstantBufferLeafElementType(Ty))
return nullptr;
- if (Ty->isRecordType()) {
- CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
+ if (auto *RD = Ty->getAsCXXRecordDecl()) {
if (requiresImplicitBufferLayoutStructure(RD)) {
RD = createHostLayoutStruct(S, RD);
if (!RD)
@@ -511,7 +509,7 @@ static CXXRecordDecl *createHostLayoutStruct(Sema &S,
assert(NumBases == 1 && "HLSL supports only one base type");
(void)NumBases;
CXXBaseSpecifier Base = *StructDecl->bases_begin();
- CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
+ CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
if (requiresImplicitBufferLayoutStructure(BaseDecl)) {
BaseDecl = createHostLayoutStruct(S, BaseDecl);
if (BaseDecl) {
@@ -729,6 +727,19 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
return;
+ // If we have specified a root signature to override the entry function then
+ // attach it now
+ HLSLRootSignatureDecl *SignatureDecl =
+ lookupRootSignatureOverrideDecl(FD->getDeclContext());
+ if (SignatureDecl) {
+ FD->dropAttr<RootSignatureAttr>();
+ // We could look up the SourceRange of the macro here as well
+ AttributeCommonInfo AL(RootSigOverrideIdent, AttributeScopeInfo(),
+ SourceRange(), ParsedAttr::Form::Microsoft());
+ FD->addAttr(::new (getASTContext()) RootSignatureAttr(
+ getASTContext(), AL, RootSigOverrideIdent, SignatureDecl));
+ }
+
llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
@@ -750,6 +761,8 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
case llvm::Triple::UnknownEnvironment:
case llvm::Triple::Library:
break;
+ case llvm::Triple::RootSignature:
+ llvm_unreachable("rootsig environment has no functions");
default:
llvm_unreachable("Unhandled environment in triple");
}
@@ -812,6 +825,8 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
}
}
break;
+ case llvm::Triple::RootSignature:
+ llvm_unreachable("rootsig environment has no function entry point");
default:
llvm_unreachable("Unhandled environment in triple");
}
@@ -1092,6 +1107,18 @@ void SemaHLSL::ActOnFinishRootSignatureDecl(
SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
}
+HLSLRootSignatureDecl *
+SemaHLSL::lookupRootSignatureOverrideDecl(DeclContext *DC) const {
+ if (RootSigOverrideIdent) {
+ LookupResult R(SemaRef, RootSigOverrideIdent, SourceLocation(),
+ Sema::LookupOrdinaryName);
+ if (SemaRef.LookupQualifiedName(R, DC))
+ return dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl());
+ }
+
+ return nullptr;
+}
+
namespace {
struct PerVisibilityBindingChecker {
@@ -1144,15 +1171,14 @@ struct PerVisibilityBindingChecker {
bool HadOverlap = false;
using llvm::hlsl::BindingInfoBuilder;
- auto ReportOverlap = [this, &HadOverlap](
- const BindingInfoBuilder &Builder,
- const BindingInfoBuilder::Binding &Reported) {
+ auto ReportOverlap = [this,
+ &HadOverlap](const BindingInfoBuilder &Builder,
+ const llvm::hlsl::Binding &Reported) {
HadOverlap = true;
const auto *Elem =
static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
- const BindingInfoBuilder::Binding &Previous =
- Builder.findOverlapping(Reported);
+ const llvm::hlsl::Binding &Previous = Builder.findOverlapping(Reported);
const auto *PrevElem =
static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
@@ -1269,9 +1295,8 @@ bool SemaHLSL::handleRootSignatureElements(
ReportError(Loc, 1, 0xfffffffe);
}
- if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
- Version, llvm::to_underlying(Clause->Type),
- llvm::to_underlying(Clause->Flags)))
+ if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type,
+ Clause->Flags))
ReportFlagError(Loc);
}
}
@@ -1317,12 +1342,48 @@ bool SemaHLSL::handleRootSignatureElements(
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
assert(UnboundClauses.size() == Table->NumClauses &&
"Number of unbound elements must match the number of clauses");
+ bool HasAnySampler = false;
+ bool HasAnyNonSampler = false;
+ uint32_t Offset = 0;
for (const auto &[Clause, ClauseElem] : UnboundClauses) {
- uint32_t LowerBound(Clause->Reg.Number);
+ SourceLocation Loc = ClauseElem->getLocation();
+ if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
+ HasAnySampler = true;
+ else
+ HasAnyNonSampler = true;
+
+ if (HasAnySampler && HasAnyNonSampler)
+ Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
+
// Relevant error will have already been reported above and needs to be
- // fixed before we can conduct range analysis, so shortcut error return
+ // fixed before we can conduct further analysis, so shortcut error
+ // return
if (Clause->NumDescriptors == 0)
return true;
+
+ if (Clause->Offset !=
+ llvm::hlsl::rootsig::DescriptorTableOffsetAppend) {
+ // Manually specified the offset
+ Offset = Clause->Offset;
+ }
+
+ uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
+ Offset, Clause->NumDescriptors);
+
+ if (!llvm::hlsl::rootsig::verifyBoundOffset(Offset)) {
+ // Trying to append onto unbound offset
+ Diag(Loc, diag::err_hlsl_appending_onto_unbound);
+ } else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound)) {
+ // Upper bound overflows maximum offset
+ Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound;
+ }
+
+ Offset = RangeBound == llvm::hlsl::rootsig::NumDescriptorsUnbounded
+ ? uint32_t(RangeBound)
+ : uint32_t(RangeBound + 1);
+
+ // Compute the register bounds and track resource binding
+ uint32_t LowerBound(Clause->Reg.Number);
uint32_t UpperBound = Clause->NumDescriptors == ~0u
? ~0u
: LowerBound + Clause->NumDescriptors - 1;
@@ -2008,9 +2069,11 @@ static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
}
void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
- if (isa<VarDecl>(TheDecl)) {
- if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(),
- cast<ValueDecl>(TheDecl)->getType(),
+ if (VarDecl *VD = dyn_cast<VarDecl>(TheDecl)) {
+ QualType Ty = VD->getType();
+ if (const auto *IAT = dyn_cast<IncompleteArrayType>(Ty))
+ Ty = IAT->getElementType();
+ if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), Ty,
diag::err_incomplete_type))
return;
}
@@ -2838,8 +2901,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
if (SemaRef.checkArgCount(TheCall, 6) ||
CheckResourceHandle(&SemaRef, TheCall, 0) ||
CheckArgTypeMatches(&SemaRef, TheCall->getArg(1), AST.UnsignedIntTy) ||
- CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.IntTy) ||
- CheckArgTypeMatches(&SemaRef, TheCall->getArg(3), AST.UnsignedIntTy) ||
+ CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), AST.UnsignedIntTy) ||
+ CheckArgTypeMatches(&SemaRef, TheCall->getArg(3), AST.IntTy) ||
CheckArgTypeMatches(&SemaRef, TheCall->getArg(4), AST.UnsignedIntTy) ||
CheckArgTypeMatches(&SemaRef, TheCall->getArg(5),
AST.getPointerType(AST.CharTy.withConst())))
@@ -3194,10 +3257,7 @@ static void BuildFlattenedTypeList(QualType BaseTy,
List.insert(List.end(), VT->getNumElements(), VT->getElementType());
continue;
}
- if (const auto *RT = dyn_cast<RecordType>(T)) {
- const CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
- assert(RD && "HLSL record types should all be CXXRecordDecls!");
-
+ if (const auto *RD = T->getAsCXXRecordDecl()) {
if (RD->isStandardLayout())
RD = RD->getStandardLayoutBaseWithFields();
@@ -3820,9 +3880,9 @@ void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
// Unwrap arrays
// FIXME: Calculate array size while unwrapping
const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
- while (Ty->isConstantArrayType()) {
- const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
- Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
+ while (Ty->isArrayType()) {
+ const ArrayType *AT = cast<ArrayType>(Ty);
+ Ty = AT->getElementType()->getUnqualifiedDesugaredType();
}
// Resource (or array of resources)
@@ -3967,19 +4027,19 @@ class InitListTransformer {
return true;
}
- if (auto *RTy = Ty->getAs<RecordType>()) {
- llvm::SmallVector<const RecordType *> RecordTypes;
- RecordTypes.push_back(RTy);
- while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
- CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
+ if (auto *RD = Ty->getAsCXXRecordDecl()) {
+ llvm::SmallVector<CXXRecordDecl *> RecordDecls;
+ RecordDecls.push_back(RD);
+ while (RecordDecls.back()->getNumBases()) {
+ CXXRecordDecl *D = RecordDecls.back();
assert(D->getNumBases() == 1 &&
"HLSL doesn't support multiple inheritance");
- RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
+ RecordDecls.push_back(
+ D->bases_begin()->getType()->castAsCXXRecordDecl());
}
- while (!RecordTypes.empty()) {
- const RecordType *RT = RecordTypes.pop_back_val();
- for (auto *FD :
- RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+ while (!RecordDecls.empty()) {
+ CXXRecordDecl *RD = RecordDecls.pop_back_val();
+ for (auto *FD : RD->fields()) {
DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
ExprResult Res = S.BuildFieldReferenceExpr(
@@ -4016,21 +4076,20 @@ class InitListTransformer {
for (uint64_t I = 0; I < Size; ++I)
Inits.push_back(generateInitListsImpl(ElTy));
}
- if (auto *RTy = Ty->getAs<RecordType>()) {
- llvm::SmallVector<const RecordType *> RecordTypes;
- RecordTypes.push_back(RTy);
- while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
- CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
+ if (auto *RD = Ty->getAsCXXRecordDecl()) {
+ llvm::SmallVector<CXXRecordDecl *> RecordDecls;
+ RecordDecls.push_back(RD);
+ while (RecordDecls.back()->getNumBases()) {
+ CXXRecordDecl *D = RecordDecls.back();
assert(D->getNumBases() == 1 &&
"HLSL doesn't support multiple inheritance");
- RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
+ RecordDecls.push_back(
+ D->bases_begin()->getType()->castAsCXXRecordDecl());
}
- while (!RecordTypes.empty()) {
- const RecordType *RT = RecordTypes.pop_back_val();
- for (auto *FD :
- RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+ while (!RecordDecls.empty()) {
+ CXXRecordDecl *RD = RecordDecls.pop_back_val();
+ for (auto *FD : RD->fields())
Inits.push_back(generateInitListsImpl(FD->getType()));
- }
}
}
auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),