diff options
| author | Aiden Grossman <aidengrossman@google.com> | 2025-09-26 22:48:22 +0000 |
|---|---|---|
| committer | Aiden Grossman <aidengrossman@google.com> | 2025-09-26 22:48:22 +0000 |
| commit | 76533872e149395812a6d1651aa49dbf53fb4921 (patch) | |
| tree | 199a669fa57a4effc3116705d2ec89c07ff36c65 /offload/liboffload/src/OffloadImpl.cpp | |
| parent | 54f5c1b2e17a9be61609d70dbbc8354ad41bb931 (diff) | |
| parent | 37e7ad184d002db15f72771938755580433cf96d (diff) | |
[𝘀𝗽𝗿] changes introduced through rebaseusers/boomanaiden154/main.lit-remove-t-from-tests
Created using spr 1.3.6
[skip ci]
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
| -rw-r--r-- | offload/liboffload/src/OffloadImpl.cpp | 134 |
1 files changed, 118 insertions, 16 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index c5d083db7522..08a2e25b97d8 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -182,6 +182,9 @@ namespace offload { struct AllocInfo { ol_device_handle_t Device; ol_alloc_type_t Type; + void *Start; + // One byte past the end + void *End; }; // Global shared state for liboffload @@ -200,6 +203,9 @@ struct OffloadContext { bool ValidationEnabled = true; DenseMap<void *, AllocInfo> AllocInfoMap{}; std::mutex AllocInfoMapMutex{}; + // Partitioned list of memory base addresses. Each element in this list is a + // key in AllocInfoMap + llvm::SmallVector<void *> AllocBases{}; SmallVector<ol_platform_impl_t, 4> Platforms{}; size_t RefCount; @@ -244,17 +250,15 @@ Error initPlugins(OffloadContext &Context) { // Attempt to create an instance of each supported plugin. #define PLUGIN_TARGET(Name) \ do { \ - Context.Platforms.emplace_back(ol_platform_impl_t{ \ - std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \ - pluginNameToBackend(#Name)}); \ + if (StringRef(#Name) != "host") \ + Context.Platforms.emplace_back(ol_platform_impl_t{ \ + std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \ + pluginNameToBackend(#Name)}); \ } while (false); #include "Shared/Targets.def" // Preemptively initialize all devices in the plugin for (auto &Platform : Context.Platforms) { - // Do not use the host plugin - it isn't supported. - if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN) - continue; auto Err = Platform.Plugin->init(); [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices(); @@ -613,20 +617,61 @@ TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) { } } +constexpr size_t MAX_ALLOC_TRIES = 50; Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type, size_t Size, void **AllocationOut) { - auto Alloc = - Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type)); - if (!Alloc) - return Alloc.takeError(); + SmallVector<void *> Rejects; + + // Repeat the allocation up to a certain amount of times. If it happens to + // already be allocated (e.g. by a device from another vendor) throw it away + // and try again. + for (size_t Count = 0; Count < MAX_ALLOC_TRIES; Count++) { + auto NewAlloc = Device->Device->dataAlloc(Size, nullptr, + convertOlToPluginAllocTy(Type)); + if (!NewAlloc) + return NewAlloc.takeError(); + + void *NewEnd = &static_cast<char *>(*NewAlloc)[Size]; + auto &AllocBases = OffloadContext::get().AllocBases; + auto &AllocInfoMap = OffloadContext::get().AllocInfoMap; + { + std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex); + + // Check that this memory region doesn't overlap another one + // That is, the start of this allocation needs to be after another + // allocation's end point, and the end of this allocation needs to be + // before the next one's start. + // `Gap` is the first alloc who ends after the new alloc's start point. + auto Gap = + std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc, + [&](const void *Iter, const void *Val) { + return AllocInfoMap.at(Iter).End <= Val; + }); + if (Gap == AllocBases.end() || NewEnd <= AllocInfoMap.at(*Gap).Start) { + // Success, no conflict + AllocInfoMap.insert_or_assign( + *NewAlloc, AllocInfo{Device, Type, *NewAlloc, NewEnd}); + AllocBases.insert( + std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc), + *NewAlloc); + *AllocationOut = *NewAlloc; + + for (void *R : Rejects) + if (auto Err = + Device->Device->dataDelete(R, convertOlToPluginAllocTy(Type))) + return Err; + return Error::success(); + } - *AllocationOut = *Alloc; - { - std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex); - OffloadContext::get().AllocInfoMap.insert_or_assign( - *Alloc, AllocInfo{Device, Type}); + // To avoid the next attempt allocating the same memory we just freed, we + // hold onto it until we complete the allocation + Rejects.push_back(*NewAlloc); + } } - return Error::success(); + + // We've tried multiple times, and can't allocate a non-overlapping region. + return createOffloadError(ErrorCode::BACKEND_FAILURE, + "failed to allocate non-overlapping memory"); } Error olMemFree_impl(void *Address) { @@ -642,6 +687,9 @@ Error olMemFree_impl(void *Address) { Device = AllocInfo.Device; Type = AllocInfo.Type; OffloadContext::get().AllocInfoMap.erase(Address); + + auto &Bases = OffloadContext::get().AllocBases; + Bases.erase(std::lower_bound(Bases.begin(), Bases.end(), Address)); } if (auto Res = @@ -651,6 +699,60 @@ Error olMemFree_impl(void *Address) { return Error::success(); } +Error olGetMemInfoImplDetail(const void *Ptr, ol_mem_info_t PropName, + size_t PropSize, void *PropValue, + size_t *PropSizeRet) { + InfoWriter Info(PropSize, PropValue, PropSizeRet); + std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex); + + auto &AllocBases = OffloadContext::get().AllocBases; + auto &AllocInfoMap = OffloadContext::get().AllocInfoMap; + const AllocInfo *Alloc = nullptr; + if (AllocInfoMap.contains(Ptr)) { + // Fast case, we have been given the base pointer directly + Alloc = &AllocInfoMap.at(Ptr); + } else { + // Slower case, we need to look up the base pointer first + // Find the first memory allocation whose end is after the target pointer, + // and then check to see if it is in range + auto Loc = std::lower_bound(AllocBases.begin(), AllocBases.end(), Ptr, + [&](const void *Iter, const void *Val) { + return AllocInfoMap.at(Iter).End <= Val; + }); + if (Loc == AllocBases.end() || Ptr < AllocInfoMap.at(*Loc).Start) + return Plugin::error(ErrorCode::NOT_FOUND, + "allocated memory information not found"); + Alloc = &AllocInfoMap.at(*Loc); + } + + switch (PropName) { + case OL_MEM_INFO_DEVICE: + return Info.write<ol_device_handle_t>(Alloc->Device); + case OL_MEM_INFO_BASE: + return Info.write<void *>(Alloc->Start); + case OL_MEM_INFO_SIZE: + return Info.write<size_t>(static_cast<char *>(Alloc->End) - + static_cast<char *>(Alloc->Start)); + case OL_MEM_INFO_TYPE: + return Info.write<ol_alloc_type_t>(Alloc->Type); + default: + return createOffloadError(ErrorCode::INVALID_ENUMERATION, + "olGetMemInfo enum '%i' is invalid", PropName); + } + + return Error::success(); +} + +Error olGetMemInfo_impl(const void *Ptr, ol_mem_info_t PropName, + size_t PropSize, void *PropValue) { + return olGetMemInfoImplDetail(Ptr, PropName, PropSize, PropValue, nullptr); +} + +Error olGetMemInfoSize_impl(const void *Ptr, ol_mem_info_t PropName, + size_t *PropSizeRet) { + return olGetMemInfoImplDetail(Ptr, PropName, 0, nullptr, PropSizeRet); +} + Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) { auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device); |
