diff options
| author | Valentin Clement (バレンタイン クレメン) <clementval@gmail.com> | 2025-09-26 08:17:24 -1000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-26 18:17:24 +0000 |
| commit | d48bda5421c5af9baa5bc98ba4e3a453937ff96a (patch) | |
| tree | fceb5383217acf5d57a1905e17f41544ef44326b /flang-rt | |
| parent | 24bc1a60978cf6871d3381dcf92211509f658c76 (diff) | |
[flang][cuda] Handle zero sized allocation correctly (#160929)
Like on the host allocate 1 byte when zero size is requested.
Diffstat (limited to 'flang-rt')
| -rw-r--r-- | flang-rt/lib/cuda/memory.cpp | 25 | ||||
| -rw-r--r-- | flang-rt/unittests/Runtime/CUDA/Memory.cpp | 6 |
2 files changed, 18 insertions, 13 deletions
diff --git a/flang-rt/lib/cuda/memory.cpp b/flang-rt/lib/cuda/memory.cpp index d830580e6a06..78270fef07c3 100644 --- a/flang-rt/lib/cuda/memory.cpp +++ b/flang-rt/lib/cuda/memory.cpp @@ -25,23 +25,22 @@ extern "C" { void *RTDEF(CUFMemAlloc)( std::size_t bytes, unsigned type, const char *sourceFile, int sourceLine) { void *ptr = nullptr; - if (bytes != 0) { - if (type == kMemTypeDevice) { - if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) { - CUDA_REPORT_IF_ERROR( - cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal)); - } else { - CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes)); - } - } else if (type == kMemTypeManaged || type == kMemTypeUnified) { + bytes = bytes ? bytes : 1; + if (type == kMemTypeDevice) { + if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) { CUDA_REPORT_IF_ERROR( cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal)); - } else if (type == kMemTypePinned) { - CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes)); } else { - Terminator terminator{sourceFile, sourceLine}; - terminator.Crash("unsupported memory type"); + CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes)); } + } else if (type == kMemTypeManaged || type == kMemTypeUnified) { + CUDA_REPORT_IF_ERROR( + cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal)); + } else if (type == kMemTypePinned) { + CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes)); + } else { + Terminator terminator{sourceFile, sourceLine}; + terminator.Crash("unsupported memory type"); } return ptr; } diff --git a/flang-rt/unittests/Runtime/CUDA/Memory.cpp b/flang-rt/unittests/Runtime/CUDA/Memory.cpp index f2e17870f799..c84c54a1376e 100644 --- a/flang-rt/unittests/Runtime/CUDA/Memory.cpp +++ b/flang-rt/unittests/Runtime/CUDA/Memory.cpp @@ -35,6 +35,12 @@ TEST(MemoryCUFTest, SimpleAllocTramsferFree) { RTNAME(CUFMemFree)((void *)dev, kMemTypeDevice, __FILE__, __LINE__); } +TEST(MemoryCUFTest, AllocZero) { + int *dev = (int *)RTNAME(CUFMemAlloc)(0, kMemTypeDevice, __FILE__, __LINE__); + EXPECT_TRUE(dev != 0); + RTNAME(CUFMemFree)((void *)dev, kMemTypeDevice, __FILE__, __LINE__); +} + static OwningPtr<Descriptor> createAllocatable( Fortran::common::TypeCategory tc, int kind, int rank = 1) { return Descriptor::Create(TypeCode{tc, kind}, kind, nullptr, rank, nullptr, |
