summaryrefslogtreecommitdiff
path: root/clang/lib/CodeGen/CGExpr.cpp
diff options
context:
space:
mode:
authorBill Wendling <morbo@google.com>2025-05-13 16:01:36 -0700
committerGitHub <noreply@github.com>2025-05-13 16:01:36 -0700
commit9ae3bce17543f92ce0237597cc66503d58cce317 (patch)
tree22967a72c40b010c52ab71af525830ca81760511 /clang/lib/CodeGen/CGExpr.cpp
parent4e604d46681f722b1def10ce72c89046dac39e63 (diff)
[Clang][counted_by] Add support for 'counted_by' on struct pointers (#137250)
The 'counted_by' attribute is now available for pointers in structs. It generates code for sanity checks as well as __builtin_dynamic_object_size() calculations. For example: struct annotated_ptr { int count; char *buf __attribute__((counted_by(count))); }; If the pointer's type is 'void *', use the 'sized_by' attribute, which works similarly to 'counted_by', but can handle the 'void' base type: struct annotated_ptr { int count; void *buf __attribute__((sized_by(count))); }; If the 'count' field member occurs after the pointer, use the '-fexperimental-late-parse-attributes' flag during compilation. Note that 'counted_by' cannot be applied to a pointer to an incomplete type, because the size isn't known. struct foo; struct annotated_ptr { int count; struct foo *buf __attribute__((counted_by(count))); /* invalid */ }; Signed-off-by: Bill Wendling <morbo@google.com>
Diffstat (limited to 'clang/lib/CodeGen/CGExpr.cpp')
-rw-r--r--clang/lib/CodeGen/CGExpr.cpp137
1 files changed, 93 insertions, 44 deletions
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 0d03923951a1..ec01c87c13b1 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -4274,6 +4274,24 @@ static Address emitArraySubscriptGEP(CodeGenFunction &CGF, Address addr,
return Address(eltPtr, CGF.ConvertTypeForMem(eltType), eltAlign);
}
+namespace {
+
+/// StructFieldAccess is a simple visitor class to grab the first l-value to
+/// r-value cast Expr.
+struct StructFieldAccess
+ : public ConstStmtVisitor<StructFieldAccess, const Expr *> {
+ const Expr *VisitCastExpr(const CastExpr *E) {
+ if (E->getCastKind() == CK_LValueToRValue)
+ return E;
+ return Visit(E->getSubExpr());
+ }
+ const Expr *VisitParenExpr(const ParenExpr *E) {
+ return Visit(E->getSubExpr());
+ }
+};
+
+} // end anonymous namespace
+
/// The offset of a field from the beginning of the record.
static bool getFieldOffsetInBits(CodeGenFunction &CGF, const RecordDecl *RD,
const FieldDecl *Field, int64_t &Offset) {
@@ -4329,6 +4347,60 @@ static std::optional<int64_t> getOffsetDifferenceInBits(CodeGenFunction &CGF,
return std::make_optional<int64_t>(FD1Offset - FD2Offset);
}
+/// EmitCountedByBoundsChecking - If the array being accessed has a "counted_by"
+/// attribute, generate bounds checking code. The "count" field is at the top
+/// level of the struct or in an anonymous struct, that's also at the top level.
+/// Future expansions may allow the "count" to reside at any place in the
+/// struct, but the value of "counted_by" will be a "simple" path to the count,
+/// i.e. "a.b.count", so we shouldn't need the full force of EmitLValue or
+/// similar to emit the correct GEP.
+void CodeGenFunction::EmitCountedByBoundsChecking(
+ const Expr *E, llvm::Value *Idx, Address Addr, QualType IdxTy,
+ QualType ArrayTy, bool Accessed, bool FlexibleArray) {
+ const auto *ME = dyn_cast<MemberExpr>(E->IgnoreImpCasts());
+ if (!ME || !ME->getMemberDecl()->getType()->isCountAttributedType())
+ return;
+
+ const LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel =
+ getLangOpts().getStrictFlexArraysLevel();
+ if (FlexibleArray &&
+ !ME->isFlexibleArrayMemberLike(getContext(), StrictFlexArraysLevel))
+ return;
+
+ const FieldDecl *FD = cast<FieldDecl>(ME->getMemberDecl());
+ const FieldDecl *CountFD = FD->findCountedByField();
+ if (!CountFD)
+ return;
+
+ if (std::optional<int64_t> Diff =
+ getOffsetDifferenceInBits(*this, CountFD, FD)) {
+ if (!Addr.isValid()) {
+ // An invalid Address indicates we're checking a pointer array access.
+ // Emit the checked L-Value here.
+ LValue LV = EmitCheckedLValue(E, TCK_MemberAccess);
+ Addr = LV.getAddress();
+ }
+
+ // FIXME: The 'static_cast' is necessary, otherwise the result turns into a
+ // uint64_t, which messes things up if we have a negative offset difference.
+ Diff = *Diff / static_cast<int64_t>(CGM.getContext().getCharWidth());
+
+ // Create a GEP with the byte offset between the counted object and the
+ // count and use that to load the count value.
+ Addr = Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Int8PtrTy, Int8Ty);
+
+ llvm::Type *CountTy = ConvertType(CountFD->getType());
+ llvm::Value *Res =
+ Builder.CreateInBoundsGEP(Int8Ty, Addr.emitRawPointer(*this),
+ Builder.getInt32(*Diff), ".counted_by.gep");
+ Res = Builder.CreateAlignedLoad(CountTy, Res, getIntAlign(),
+ ".counted_by.load");
+
+ // Now emit the bounds checking.
+ EmitBoundsCheckImpl(E, Res, Idx, IdxTy, ArrayTy, Accessed);
+ }
+}
+
LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
bool Accessed) {
// The index must always be an integer, which is not an aggregate. Emit it
@@ -4455,46 +4527,10 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
ArrayLV = EmitLValue(Array);
auto *Idx = EmitIdxAfterBase(/*Promote*/true);
- if (SanOpts.has(SanitizerKind::ArrayBounds)) {
- // If the array being accessed has a "counted_by" attribute, generate
- // bounds checking code. The "count" field is at the top level of the
- // struct or in an anonymous struct, that's also at the top level. Future
- // expansions may allow the "count" to reside at any place in the struct,
- // but the value of "counted_by" will be a "simple" path to the count,
- // i.e. "a.b.count", so we shouldn't need the full force of EmitLValue or
- // similar to emit the correct GEP.
- const LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel =
- getLangOpts().getStrictFlexArraysLevel();
-
- if (const auto *ME = dyn_cast<MemberExpr>(Array);
- ME &&
- ME->isFlexibleArrayMemberLike(getContext(), StrictFlexArraysLevel) &&
- ME->getMemberDecl()->getType()->isCountAttributedType()) {
- const FieldDecl *FAMDecl = cast<FieldDecl>(ME->getMemberDecl());
- if (const FieldDecl *CountFD = FAMDecl->findCountedByField()) {
- if (std::optional<int64_t> Diff =
- getOffsetDifferenceInBits(*this, CountFD, FAMDecl)) {
- CharUnits OffsetDiff = CGM.getContext().toCharUnitsFromBits(*Diff);
-
- // Create a GEP with a byte offset between the FAM and count and
- // use that to load the count value.
- Addr = Builder.CreatePointerBitCastOrAddrSpaceCast(
- ArrayLV.getAddress(), Int8PtrTy, Int8Ty);
-
- llvm::Type *CountTy = ConvertType(CountFD->getType());
- llvm::Value *Res = Builder.CreateInBoundsGEP(
- Int8Ty, Addr.emitRawPointer(*this),
- Builder.getInt32(OffsetDiff.getQuantity()), ".counted_by.gep");
- Res = Builder.CreateAlignedLoad(CountTy, Res, getIntAlign(),
- ".counted_by.load");
-
- // Now emit the bounds checking.
- EmitBoundsCheckImpl(E, Res, Idx, E->getIdx()->getType(),
- Array->getType(), Accessed);
- }
- }
- }
- }
+ if (SanOpts.has(SanitizerKind::ArrayBounds))
+ EmitCountedByBoundsChecking(Array, Idx, ArrayLV.getAddress(),
+ E->getIdx()->getType(), Array->getType(),
+ Accessed, /*FlexibleArray=*/true);
// Propagate the alignment from the array itself to the result.
QualType arrayType = Array->getType();
@@ -4506,12 +4542,25 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
EltTBAAInfo = CGM.getTBAAInfoForSubobject(ArrayLV, E->getType());
} else {
// The base must be a pointer; emit it with an estimate of its alignment.
- Addr = EmitPointerWithAlignment(E->getBase(), &EltBaseInfo, &EltTBAAInfo);
+ Address BaseAddr =
+ EmitPointerWithAlignment(E->getBase(), &EltBaseInfo, &EltTBAAInfo);
auto *Idx = EmitIdxAfterBase(/*Promote*/true);
QualType ptrType = E->getBase()->getType();
- Addr = emitArraySubscriptGEP(
- *this, Addr, Idx, E->getType(), !getLangOpts().PointerOverflowDefined,
- SignedIndices, E->getExprLoc(), &ptrType, E->getBase());
+ Addr = emitArraySubscriptGEP(*this, BaseAddr, Idx, E->getType(),
+ !getLangOpts().PointerOverflowDefined,
+ SignedIndices, E->getExprLoc(), &ptrType,
+ E->getBase());
+
+ if (SanOpts.has(SanitizerKind::ArrayBounds)) {
+ StructFieldAccess Visitor;
+ const Expr *Base = Visitor.Visit(E->getBase());
+
+ if (const auto *CE = dyn_cast_if_present<CastExpr>(Base);
+ CE && CE->getCastKind() == CK_LValueToRValue)
+ EmitCountedByBoundsChecking(CE, Idx, Address::invalid(),
+ E->getIdx()->getType(), ptrType, Accessed,
+ /*FlexibleArray=*/false);
+ }
}
LValue LV = MakeAddrLValue(Addr, E->getType(), EltBaseInfo, EltTBAAInfo);