diff options
Diffstat (limited to 'flang/runtime/pointer.cpp')
| -rw-r--r-- | flang/runtime/pointer.cpp | 67 |
1 files changed, 42 insertions, 25 deletions
diff --git a/flang/runtime/pointer.cpp b/flang/runtime/pointer.cpp index 08a1223764f3..aeed879f1a2e 100644 --- a/flang/runtime/pointer.cpp +++ b/flang/runtime/pointer.cpp @@ -124,6 +124,23 @@ void RTDEF(PointerAssociateRemapping)(Descriptor &pointer, } } +RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t byteSize) { + // Add space for a footer to validate during deallocation. + constexpr std::size_t align{sizeof(std::uintptr_t)}; + byteSize = ((byteSize / align) + 1) * align; + std::size_t total{byteSize + sizeof(std::uintptr_t)}; + void *p{std::malloc(total)}; + if (p) { + // Fill the footer word with the XOR of the ones' complement of + // the base address, which is a value that would be highly unlikely + // to appear accidentally at the right spot. + std::uintptr_t *footer{ + reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)}; + *footer = ~reinterpret_cast<std::uintptr_t>(p); + } + return p; +} + int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat, const Descriptor *errMsg, const char *sourceFile, int sourceLine) { Terminator terminator{sourceFile, sourceLine}; @@ -137,22 +154,12 @@ int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat, elementBytes = pointer.raw().elem_len = 0; } std::size_t byteSize{pointer.Elements() * elementBytes}; - // Add space for a footer to validate during DEALLOCATE. - constexpr std::size_t align{sizeof(std::uintptr_t)}; - byteSize = ((byteSize + align - 1) / align) * align; - std::size_t total{byteSize + sizeof(std::uintptr_t)}; - void *p{std::malloc(total)}; + void *p{AllocateValidatedPointerPayload(byteSize)}; if (!p) { return ReturnError(terminator, CFI_ERROR_MEM_ALLOCATION, errMsg, hasStat); } pointer.set_base_addr(p); pointer.SetByteStrides(); - // Fill the footer word with the XOR of the ones' complement of - // the base address, which is a value that would be highly unlikely - // to appear accidentally at the right spot. - std::uintptr_t *footer{ - reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)}; - *footer = ~reinterpret_cast<std::uintptr_t>(p); int stat{StatOk}; if (const DescriptorAddendum * addendum{pointer.Addendum()}) { if (const auto *derived{addendum->derivedType()}) { @@ -176,6 +183,27 @@ int RTDEF(PointerAllocateSource)(Descriptor &pointer, const Descriptor &source, return stat; } +static RT_API_ATTRS std::size_t GetByteSize( + const ISO::CFI_cdesc_t &descriptor) { + std::size_t rank{descriptor.rank}; + const ISO::CFI_dim_t *dim{descriptor.dim}; + std::size_t byteSize{descriptor.elem_len}; + for (std::size_t j{0}; j < rank; ++j) { + byteSize *= dim[j].extent; + } + return byteSize; +} + +bool RT_API_ATTRS ValidatePointerPayload(const ISO::CFI_cdesc_t &desc) { + std::size_t byteSize{GetByteSize(desc)}; + constexpr std::size_t align{sizeof(std::uintptr_t)}; + byteSize = ((byteSize / align) + 1) * align; + const void *p{desc.base_addr}; + const std::uintptr_t *footer{reinterpret_cast<const std::uintptr_t *>( + static_cast<const char *>(p) + byteSize)}; + return *footer == ~reinterpret_cast<std::uintptr_t>(p); +} + int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat, const Descriptor *errMsg, const char *sourceFile, int sourceLine) { Terminator terminator{sourceFile, sourceLine}; @@ -185,20 +213,9 @@ int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat, if (!pointer.IsAllocated()) { return ReturnError(terminator, StatBaseNull, errMsg, hasStat); } - if (executionEnvironment.checkPointerDeallocation) { - // Validate the footer. This should fail if the pointer doesn't - // span the entire object, or the object was not allocated as a - // pointer. - std::size_t byteSize{pointer.Elements() * pointer.ElementBytes()}; - constexpr std::size_t align{sizeof(std::uintptr_t)}; - byteSize = ((byteSize + align - 1) / align) * align; - void *p{pointer.raw().base_addr}; - std::uintptr_t *footer{ - reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)}; - if (*footer != ~reinterpret_cast<std::uintptr_t>(p)) { - return ReturnError( - terminator, StatBadPointerDeallocation, errMsg, hasStat); - } + if (executionEnvironment.checkPointerDeallocation && + !ValidatePointerPayload(pointer.raw())) { + return ReturnError(terminator, StatBadPointerDeallocation, errMsg, hasStat); } return ReturnError(terminator, pointer.Destroy(/*finalize=*/true, /*destroyPointers=*/true, &terminator), |
