summaryrefslogtreecommitdiff
path: root/flang-rt
diff options
context:
space:
mode:
authorValentin Clement (バレンタイン クレメン) <clementval@gmail.com>2025-09-26 08:17:24 -1000
committerGitHub <noreply@github.com>2025-09-26 18:17:24 +0000
commitd48bda5421c5af9baa5bc98ba4e3a453937ff96a (patch)
treefceb5383217acf5d57a1905e17f41544ef44326b /flang-rt
parent24bc1a60978cf6871d3381dcf92211509f658c76 (diff)
[flang][cuda] Handle zero sized allocation correctly (#160929)
Like on the host allocate 1 byte when zero size is requested.
Diffstat (limited to 'flang-rt')
-rw-r--r--flang-rt/lib/cuda/memory.cpp25
-rw-r--r--flang-rt/unittests/Runtime/CUDA/Memory.cpp6
2 files changed, 18 insertions, 13 deletions
diff --git a/flang-rt/lib/cuda/memory.cpp b/flang-rt/lib/cuda/memory.cpp
index d830580e6a06..78270fef07c3 100644
--- a/flang-rt/lib/cuda/memory.cpp
+++ b/flang-rt/lib/cuda/memory.cpp
@@ -25,23 +25,22 @@ extern "C" {
void *RTDEF(CUFMemAlloc)(
std::size_t bytes, unsigned type, const char *sourceFile, int sourceLine) {
void *ptr = nullptr;
- if (bytes != 0) {
- if (type == kMemTypeDevice) {
- if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) {
- CUDA_REPORT_IF_ERROR(
- cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal));
- } else {
- CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes));
- }
- } else if (type == kMemTypeManaged || type == kMemTypeUnified) {
+ bytes = bytes ? bytes : 1;
+ if (type == kMemTypeDevice) {
+ if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) {
CUDA_REPORT_IF_ERROR(
cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal));
- } else if (type == kMemTypePinned) {
- CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes));
} else {
- Terminator terminator{sourceFile, sourceLine};
- terminator.Crash("unsupported memory type");
+ CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes));
}
+ } else if (type == kMemTypeManaged || type == kMemTypeUnified) {
+ CUDA_REPORT_IF_ERROR(
+ cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal));
+ } else if (type == kMemTypePinned) {
+ CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes));
+ } else {
+ Terminator terminator{sourceFile, sourceLine};
+ terminator.Crash("unsupported memory type");
}
return ptr;
}
diff --git a/flang-rt/unittests/Runtime/CUDA/Memory.cpp b/flang-rt/unittests/Runtime/CUDA/Memory.cpp
index f2e17870f799..c84c54a1376e 100644
--- a/flang-rt/unittests/Runtime/CUDA/Memory.cpp
+++ b/flang-rt/unittests/Runtime/CUDA/Memory.cpp
@@ -35,6 +35,12 @@ TEST(MemoryCUFTest, SimpleAllocTramsferFree) {
RTNAME(CUFMemFree)((void *)dev, kMemTypeDevice, __FILE__, __LINE__);
}
+TEST(MemoryCUFTest, AllocZero) {
+ int *dev = (int *)RTNAME(CUFMemAlloc)(0, kMemTypeDevice, __FILE__, __LINE__);
+ EXPECT_TRUE(dev != 0);
+ RTNAME(CUFMemFree)((void *)dev, kMemTypeDevice, __FILE__, __LINE__);
+}
+
static OwningPtr<Descriptor> createAllocatable(
Fortran::common::TypeCategory tc, int kind, int rank = 1) {
return Descriptor::Create(TypeCode{tc, kind}, kind, nullptr, rank, nullptr,