diff options
| author | Bill Wendling <morbo@google.com> | 2025-05-13 16:01:36 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-13 16:01:36 -0700 |
| commit | 9ae3bce17543f92ce0237597cc66503d58cce317 (patch) | |
| tree | 22967a72c40b010c52ab71af525830ca81760511 /clang/lib/CodeGen/CGExpr.cpp | |
| parent | 4e604d46681f722b1def10ce72c89046dac39e63 (diff) | |
[Clang][counted_by] Add support for 'counted_by' on struct pointers (#137250)
The 'counted_by' attribute is now available for pointers in structs.
It generates code for sanity checks as well as
__builtin_dynamic_object_size()
calculations. For example:
struct annotated_ptr {
int count;
char *buf __attribute__((counted_by(count)));
};
If the pointer's type is 'void *', use the 'sized_by' attribute, which
works similarly to 'counted_by', but can handle the 'void' base type:
struct annotated_ptr {
int count;
void *buf __attribute__((sized_by(count)));
};
If the 'count' field member occurs after the pointer, use the
'-fexperimental-late-parse-attributes' flag during compilation.
Note that 'counted_by' cannot be applied to a pointer to an incomplete
type, because the size isn't known.
struct foo;
struct annotated_ptr {
int count;
struct foo *buf __attribute__((counted_by(count))); /* invalid */
};
Signed-off-by: Bill Wendling <morbo@google.com>
Diffstat (limited to 'clang/lib/CodeGen/CGExpr.cpp')
| -rw-r--r-- | clang/lib/CodeGen/CGExpr.cpp | 137 |
1 files changed, 93 insertions, 44 deletions
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 0d03923951a1..ec01c87c13b1 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -4274,6 +4274,24 @@ static Address emitArraySubscriptGEP(CodeGenFunction &CGF, Address addr, return Address(eltPtr, CGF.ConvertTypeForMem(eltType), eltAlign); } +namespace { + +/// StructFieldAccess is a simple visitor class to grab the first l-value to +/// r-value cast Expr. +struct StructFieldAccess + : public ConstStmtVisitor<StructFieldAccess, const Expr *> { + const Expr *VisitCastExpr(const CastExpr *E) { + if (E->getCastKind() == CK_LValueToRValue) + return E; + return Visit(E->getSubExpr()); + } + const Expr *VisitParenExpr(const ParenExpr *E) { + return Visit(E->getSubExpr()); + } +}; + +} // end anonymous namespace + /// The offset of a field from the beginning of the record. static bool getFieldOffsetInBits(CodeGenFunction &CGF, const RecordDecl *RD, const FieldDecl *Field, int64_t &Offset) { @@ -4329,6 +4347,60 @@ static std::optional<int64_t> getOffsetDifferenceInBits(CodeGenFunction &CGF, return std::make_optional<int64_t>(FD1Offset - FD2Offset); } +/// EmitCountedByBoundsChecking - If the array being accessed has a "counted_by" +/// attribute, generate bounds checking code. The "count" field is at the top +/// level of the struct or in an anonymous struct, that's also at the top level. +/// Future expansions may allow the "count" to reside at any place in the +/// struct, but the value of "counted_by" will be a "simple" path to the count, +/// i.e. "a.b.count", so we shouldn't need the full force of EmitLValue or +/// similar to emit the correct GEP. +void CodeGenFunction::EmitCountedByBoundsChecking( + const Expr *E, llvm::Value *Idx, Address Addr, QualType IdxTy, + QualType ArrayTy, bool Accessed, bool FlexibleArray) { + const auto *ME = dyn_cast<MemberExpr>(E->IgnoreImpCasts()); + if (!ME || !ME->getMemberDecl()->getType()->isCountAttributedType()) + return; + + const LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel = + getLangOpts().getStrictFlexArraysLevel(); + if (FlexibleArray && + !ME->isFlexibleArrayMemberLike(getContext(), StrictFlexArraysLevel)) + return; + + const FieldDecl *FD = cast<FieldDecl>(ME->getMemberDecl()); + const FieldDecl *CountFD = FD->findCountedByField(); + if (!CountFD) + return; + + if (std::optional<int64_t> Diff = + getOffsetDifferenceInBits(*this, CountFD, FD)) { + if (!Addr.isValid()) { + // An invalid Address indicates we're checking a pointer array access. + // Emit the checked L-Value here. + LValue LV = EmitCheckedLValue(E, TCK_MemberAccess); + Addr = LV.getAddress(); + } + + // FIXME: The 'static_cast' is necessary, otherwise the result turns into a + // uint64_t, which messes things up if we have a negative offset difference. + Diff = *Diff / static_cast<int64_t>(CGM.getContext().getCharWidth()); + + // Create a GEP with the byte offset between the counted object and the + // count and use that to load the count value. + Addr = Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Int8PtrTy, Int8Ty); + + llvm::Type *CountTy = ConvertType(CountFD->getType()); + llvm::Value *Res = + Builder.CreateInBoundsGEP(Int8Ty, Addr.emitRawPointer(*this), + Builder.getInt32(*Diff), ".counted_by.gep"); + Res = Builder.CreateAlignedLoad(CountTy, Res, getIntAlign(), + ".counted_by.load"); + + // Now emit the bounds checking. + EmitBoundsCheckImpl(E, Res, Idx, IdxTy, ArrayTy, Accessed); + } +} + LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E, bool Accessed) { // The index must always be an integer, which is not an aggregate. Emit it @@ -4455,46 +4527,10 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E, ArrayLV = EmitLValue(Array); auto *Idx = EmitIdxAfterBase(/*Promote*/true); - if (SanOpts.has(SanitizerKind::ArrayBounds)) { - // If the array being accessed has a "counted_by" attribute, generate - // bounds checking code. The "count" field is at the top level of the - // struct or in an anonymous struct, that's also at the top level. Future - // expansions may allow the "count" to reside at any place in the struct, - // but the value of "counted_by" will be a "simple" path to the count, - // i.e. "a.b.count", so we shouldn't need the full force of EmitLValue or - // similar to emit the correct GEP. - const LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel = - getLangOpts().getStrictFlexArraysLevel(); - - if (const auto *ME = dyn_cast<MemberExpr>(Array); - ME && - ME->isFlexibleArrayMemberLike(getContext(), StrictFlexArraysLevel) && - ME->getMemberDecl()->getType()->isCountAttributedType()) { - const FieldDecl *FAMDecl = cast<FieldDecl>(ME->getMemberDecl()); - if (const FieldDecl *CountFD = FAMDecl->findCountedByField()) { - if (std::optional<int64_t> Diff = - getOffsetDifferenceInBits(*this, CountFD, FAMDecl)) { - CharUnits OffsetDiff = CGM.getContext().toCharUnitsFromBits(*Diff); - - // Create a GEP with a byte offset between the FAM and count and - // use that to load the count value. - Addr = Builder.CreatePointerBitCastOrAddrSpaceCast( - ArrayLV.getAddress(), Int8PtrTy, Int8Ty); - - llvm::Type *CountTy = ConvertType(CountFD->getType()); - llvm::Value *Res = Builder.CreateInBoundsGEP( - Int8Ty, Addr.emitRawPointer(*this), - Builder.getInt32(OffsetDiff.getQuantity()), ".counted_by.gep"); - Res = Builder.CreateAlignedLoad(CountTy, Res, getIntAlign(), - ".counted_by.load"); - - // Now emit the bounds checking. - EmitBoundsCheckImpl(E, Res, Idx, E->getIdx()->getType(), - Array->getType(), Accessed); - } - } - } - } + if (SanOpts.has(SanitizerKind::ArrayBounds)) + EmitCountedByBoundsChecking(Array, Idx, ArrayLV.getAddress(), + E->getIdx()->getType(), Array->getType(), + Accessed, /*FlexibleArray=*/true); // Propagate the alignment from the array itself to the result. QualType arrayType = Array->getType(); @@ -4506,12 +4542,25 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E, EltTBAAInfo = CGM.getTBAAInfoForSubobject(ArrayLV, E->getType()); } else { // The base must be a pointer; emit it with an estimate of its alignment. - Addr = EmitPointerWithAlignment(E->getBase(), &EltBaseInfo, &EltTBAAInfo); + Address BaseAddr = + EmitPointerWithAlignment(E->getBase(), &EltBaseInfo, &EltTBAAInfo); auto *Idx = EmitIdxAfterBase(/*Promote*/true); QualType ptrType = E->getBase()->getType(); - Addr = emitArraySubscriptGEP( - *this, Addr, Idx, E->getType(), !getLangOpts().PointerOverflowDefined, - SignedIndices, E->getExprLoc(), &ptrType, E->getBase()); + Addr = emitArraySubscriptGEP(*this, BaseAddr, Idx, E->getType(), + !getLangOpts().PointerOverflowDefined, + SignedIndices, E->getExprLoc(), &ptrType, + E->getBase()); + + if (SanOpts.has(SanitizerKind::ArrayBounds)) { + StructFieldAccess Visitor; + const Expr *Base = Visitor.Visit(E->getBase()); + + if (const auto *CE = dyn_cast_if_present<CastExpr>(Base); + CE && CE->getCastKind() == CK_LValueToRValue) + EmitCountedByBoundsChecking(CE, Idx, Address::invalid(), + E->getIdx()->getType(), ptrType, Accessed, + /*FlexibleArray=*/false); + } } LValue LV = MakeAddrLValue(Addr, E->getType(), EltBaseInfo, EltTBAAInfo); |
