summaryrefslogtreecommitdiff
path: root/flang-rt/lib/cuda/memory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang-rt/lib/cuda/memory.cpp')
-rw-r--r--flang-rt/lib/cuda/memory.cpp25
1 files changed, 12 insertions, 13 deletions
diff --git a/flang-rt/lib/cuda/memory.cpp b/flang-rt/lib/cuda/memory.cpp
index d830580e6a06..78270fef07c3 100644
--- a/flang-rt/lib/cuda/memory.cpp
+++ b/flang-rt/lib/cuda/memory.cpp
@@ -25,23 +25,22 @@ extern "C" {
void *RTDEF(CUFMemAlloc)(
std::size_t bytes, unsigned type, const char *sourceFile, int sourceLine) {
void *ptr = nullptr;
- if (bytes != 0) {
- if (type == kMemTypeDevice) {
- if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) {
- CUDA_REPORT_IF_ERROR(
- cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal));
- } else {
- CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes));
- }
- } else if (type == kMemTypeManaged || type == kMemTypeUnified) {
+ bytes = bytes ? bytes : 1;
+ if (type == kMemTypeDevice) {
+ if (Fortran::runtime::executionEnvironment.cudaDeviceIsManaged) {
CUDA_REPORT_IF_ERROR(
cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal));
- } else if (type == kMemTypePinned) {
- CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes));
} else {
- Terminator terminator{sourceFile, sourceLine};
- terminator.Crash("unsupported memory type");
+ CUDA_REPORT_IF_ERROR(cudaMalloc((void **)&ptr, bytes));
}
+ } else if (type == kMemTypeManaged || type == kMemTypeUnified) {
+ CUDA_REPORT_IF_ERROR(
+ cudaMallocManaged((void **)&ptr, bytes, cudaMemAttachGlobal));
+ } else if (type == kMemTypePinned) {
+ CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&ptr, bytes));
+ } else {
+ Terminator terminator{sourceFile, sourceLine};
+ terminator.Crash("unsupported memory type");
}
return ptr;
}