diff options
Diffstat (limited to 'flang-rt/lib/cuda/memory.cpp')
| -rw-r--r-- | flang-rt/lib/cuda/memory.cpp | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/flang-rt/lib/cuda/memory.cpp b/flang-rt/lib/cuda/memory.cpp index d830580e6a06..78270fef07c3 100644 --- a/flang-rt/lib/cuda/memory.cpp +++ b/flang-rt/lib/cuda/memory.cpp @@ -25,23 +25,22 @@ extern "C" { void *RTDEF(CUFMemAlloc)( std::size_t bytes, unsigned type, const char *sourceFile, int sourceLine) { void *ptr = nullptr; - if (bytes != 0) { - if (type == kMemTypeDevice) { - if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) { - CUDA_REPORT_IF_ERROR( - cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal)); - } else { - CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes)); - } - } else if (type == kMemTypeManaged || type == kMemTypeUnified) { + bytes = bytes ? bytes : 1; + if (type == kMemTypeDevice) { + if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) { CUDA_REPORT_IF_ERROR( cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal)); - } else if (type == kMemTypePinned) { - CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes)); } else { - Terminator terminator{sourceFile, sourceLine}; - terminator.Crash("unsupported memory type"); + CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes)); } + } else if (type == kMemTypeManaged || type == kMemTypeUnified) { + CUDA_REPORT_IF_ERROR( + cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal)); + } else if (type == kMemTypePinned) { + CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes)); + } else { + Terminator terminator{sourceFile, sourceLine}; + terminator.Crash("unsupported memory type"); } return ptr; } |
