summaryrefslogtreecommitdiff
path: root/flang/runtime/pointer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/runtime/pointer.cpp')
-rw-r--r--flang/runtime/pointer.cpp67
1 files changed, 42 insertions, 25 deletions
diff --git a/flang/runtime/pointer.cpp b/flang/runtime/pointer.cpp
index 08a1223764f3..aeed879f1a2e 100644
--- a/flang/runtime/pointer.cpp
+++ b/flang/runtime/pointer.cpp
@@ -124,6 +124,23 @@ void RTDEF(PointerAssociateRemapping)(Descriptor &pointer,
}
}
+RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t byteSize) {
+ // Add space for a footer to validate during deallocation.
+ constexpr std::size_t align{sizeof(std::uintptr_t)};
+ byteSize = ((byteSize / align) + 1) * align;
+ std::size_t total{byteSize + sizeof(std::uintptr_t)};
+ void *p{std::malloc(total)};
+ if (p) {
+ // Fill the footer word with the XOR of the ones' complement of
+ // the base address, which is a value that would be highly unlikely
+ // to appear accidentally at the right spot.
+ std::uintptr_t *footer{
+ reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
+ *footer = ~reinterpret_cast<std::uintptr_t>(p);
+ }
+ return p;
+}
+
int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
@@ -137,22 +154,12 @@ int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
elementBytes = pointer.raw().elem_len = 0;
}
std::size_t byteSize{pointer.Elements() * elementBytes};
- // Add space for a footer to validate during DEALLOCATE.
- constexpr std::size_t align{sizeof(std::uintptr_t)};
- byteSize = ((byteSize + align - 1) / align) * align;
- std::size_t total{byteSize + sizeof(std::uintptr_t)};
- void *p{std::malloc(total)};
+ void *p{AllocateValidatedPointerPayload(byteSize)};
if (!p) {
return ReturnError(terminator, CFI_ERROR_MEM_ALLOCATION, errMsg, hasStat);
}
pointer.set_base_addr(p);
pointer.SetByteStrides();
- // Fill the footer word with the XOR of the ones' complement of
- // the base address, which is a value that would be highly unlikely
- // to appear accidentally at the right spot.
- std::uintptr_t *footer{
- reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
- *footer = ~reinterpret_cast<std::uintptr_t>(p);
int stat{StatOk};
if (const DescriptorAddendum * addendum{pointer.Addendum()}) {
if (const auto *derived{addendum->derivedType()}) {
@@ -176,6 +183,27 @@ int RTDEF(PointerAllocateSource)(Descriptor &pointer, const Descriptor &source,
return stat;
}
+static RT_API_ATTRS std::size_t GetByteSize(
+ const ISO::CFI_cdesc_t &descriptor) {
+ std::size_t rank{descriptor.rank};
+ const ISO::CFI_dim_t *dim{descriptor.dim};
+ std::size_t byteSize{descriptor.elem_len};
+ for (std::size_t j{0}; j < rank; ++j) {
+ byteSize *= dim[j].extent;
+ }
+ return byteSize;
+}
+
+bool RT_API_ATTRS ValidatePointerPayload(const ISO::CFI_cdesc_t &desc) {
+ std::size_t byteSize{GetByteSize(desc)};
+ constexpr std::size_t align{sizeof(std::uintptr_t)};
+ byteSize = ((byteSize / align) + 1) * align;
+ const void *p{desc.base_addr};
+ const std::uintptr_t *footer{reinterpret_cast<const std::uintptr_t *>(
+ static_cast<const char *>(p) + byteSize)};
+ return *footer == ~reinterpret_cast<std::uintptr_t>(p);
+}
+
int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
@@ -185,20 +213,9 @@ int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat,
if (!pointer.IsAllocated()) {
return ReturnError(terminator, StatBaseNull, errMsg, hasStat);
}
- if (executionEnvironment.checkPointerDeallocation) {
- // Validate the footer. This should fail if the pointer doesn't
- // span the entire object, or the object was not allocated as a
- // pointer.
- std::size_t byteSize{pointer.Elements() * pointer.ElementBytes()};
- constexpr std::size_t align{sizeof(std::uintptr_t)};
- byteSize = ((byteSize + align - 1) / align) * align;
- void *p{pointer.raw().base_addr};
- std::uintptr_t *footer{
- reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
- if (*footer != ~reinterpret_cast<std::uintptr_t>(p)) {
- return ReturnError(
- terminator, StatBadPointerDeallocation, errMsg, hasStat);
- }
+ if (executionEnvironment.checkPointerDeallocation &&
+ !ValidatePointerPayload(pointer.raw())) {
+ return ReturnError(terminator, StatBadPointerDeallocation, errMsg, hasStat);
}
return ReturnError(terminator,
pointer.Destroy(/*finalize=*/true, /*destroyPointers=*/true, &terminator),