summaryrefslogtreecommitdiff
path: root/offload/liboffload/src/OffloadImpl.cpp
diff options
context:
space:
mode:
authorAiden Grossman <aidengrossman@google.com>2025-09-26 22:48:22 +0000
committerAiden Grossman <aidengrossman@google.com>2025-09-26 22:48:22 +0000
commit76533872e149395812a6d1651aa49dbf53fb4921 (patch)
tree199a669fa57a4effc3116705d2ec89c07ff36c65 /offload/liboffload/src/OffloadImpl.cpp
parent54f5c1b2e17a9be61609d70dbbc8354ad41bb931 (diff)
parent37e7ad184d002db15f72771938755580433cf96d (diff)
[𝘀𝗽𝗿] changes introduced through rebaseusers/boomanaiden154/main.lit-remove-t-from-tests
Created using spr 1.3.6 [skip ci]
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
-rw-r--r--offload/liboffload/src/OffloadImpl.cpp134
1 files changed, 118 insertions, 16 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index c5d083db7522..08a2e25b97d8 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -182,6 +182,9 @@ namespace offload {
struct AllocInfo {
ol_device_handle_t Device;
ol_alloc_type_t Type;
+ void *Start;
+ // One byte past the end
+ void *End;
};
// Global shared state for liboffload
@@ -200,6 +203,9 @@ struct OffloadContext {
bool ValidationEnabled = true;
DenseMap<void *, AllocInfo> AllocInfoMap{};
std::mutex AllocInfoMapMutex{};
+ // Partitioned list of memory base addresses. Each element in this list is a
+ // key in AllocInfoMap
+ llvm::SmallVector<void *> AllocBases{};
SmallVector<ol_platform_impl_t, 4> Platforms{};
size_t RefCount;
@@ -244,17 +250,15 @@ Error initPlugins(OffloadContext &Context) {
// Attempt to create an instance of each supported plugin.
#define PLUGIN_TARGET(Name) \
do { \
- Context.Platforms.emplace_back(ol_platform_impl_t{ \
- std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
- pluginNameToBackend(#Name)}); \
+ if (StringRef(#Name) != "host") \
+ Context.Platforms.emplace_back(ol_platform_impl_t{ \
+ std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
+ pluginNameToBackend(#Name)}); \
} while (false);
#include "Shared/Targets.def"
// Preemptively initialize all devices in the plugin
for (auto &Platform : Context.Platforms) {
- // Do not use the host plugin - it isn't supported.
- if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
- continue;
auto Err = Platform.Plugin->init();
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
@@ -613,20 +617,61 @@ TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {
}
}
+constexpr size_t MAX_ALLOC_TRIES = 50;
Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
size_t Size, void **AllocationOut) {
- auto Alloc =
- Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type));
- if (!Alloc)
- return Alloc.takeError();
+ SmallVector<void *> Rejects;
+
+ // Repeat the allocation up to a certain amount of times. If it happens to
+ // already be allocated (e.g. by a device from another vendor) throw it away
+ // and try again.
+ for (size_t Count = 0; Count < MAX_ALLOC_TRIES; Count++) {
+ auto NewAlloc = Device->Device->dataAlloc(Size, nullptr,
+ convertOlToPluginAllocTy(Type));
+ if (!NewAlloc)
+ return NewAlloc.takeError();
+
+ void *NewEnd = &static_cast<char *>(*NewAlloc)[Size];
+ auto &AllocBases = OffloadContext::get().AllocBases;
+ auto &AllocInfoMap = OffloadContext::get().AllocInfoMap;
+ {
+ std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
+
+ // Check that this memory region doesn't overlap another one
+ // That is, the start of this allocation needs to be after another
+ // allocation's end point, and the end of this allocation needs to be
+ // before the next one's start.
+ // `Gap` is the first alloc who ends after the new alloc's start point.
+ auto Gap =
+ std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc,
+ [&](const void *Iter, const void *Val) {
+ return AllocInfoMap.at(Iter).End <= Val;
+ });
+ if (Gap == AllocBases.end() || NewEnd <= AllocInfoMap.at(*Gap).Start) {
+ // Success, no conflict
+ AllocInfoMap.insert_or_assign(
+ *NewAlloc, AllocInfo{Device, Type, *NewAlloc, NewEnd});
+ AllocBases.insert(
+ std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc),
+ *NewAlloc);
+ *AllocationOut = *NewAlloc;
+
+ for (void *R : Rejects)
+ if (auto Err =
+ Device->Device->dataDelete(R, convertOlToPluginAllocTy(Type)))
+ return Err;
+ return Error::success();
+ }
- *AllocationOut = *Alloc;
- {
- std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
- OffloadContext::get().AllocInfoMap.insert_or_assign(
- *Alloc, AllocInfo{Device, Type});
+ // To avoid the next attempt allocating the same memory we just freed, we
+ // hold onto it until we complete the allocation
+ Rejects.push_back(*NewAlloc);
+ }
}
- return Error::success();
+
+ // We've tried multiple times, and can't allocate a non-overlapping region.
+ return createOffloadError(ErrorCode::BACKEND_FAILURE,
+ "failed to allocate non-overlapping memory");
}
Error olMemFree_impl(void *Address) {
@@ -642,6 +687,9 @@ Error olMemFree_impl(void *Address) {
Device = AllocInfo.Device;
Type = AllocInfo.Type;
OffloadContext::get().AllocInfoMap.erase(Address);
+
+ auto &Bases = OffloadContext::get().AllocBases;
+ Bases.erase(std::lower_bound(Bases.begin(), Bases.end(), Address));
}
if (auto Res =
@@ -651,6 +699,60 @@ Error olMemFree_impl(void *Address) {
return Error::success();
}
+Error olGetMemInfoImplDetail(const void *Ptr, ol_mem_info_t PropName,
+ size_t PropSize, void *PropValue,
+ size_t *PropSizeRet) {
+ InfoWriter Info(PropSize, PropValue, PropSizeRet);
+ std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
+
+ auto &AllocBases = OffloadContext::get().AllocBases;
+ auto &AllocInfoMap = OffloadContext::get().AllocInfoMap;
+ const AllocInfo *Alloc = nullptr;
+ if (AllocInfoMap.contains(Ptr)) {
+ // Fast case, we have been given the base pointer directly
+ Alloc = &AllocInfoMap.at(Ptr);
+ } else {
+ // Slower case, we need to look up the base pointer first
+ // Find the first memory allocation whose end is after the target pointer,
+ // and then check to see if it is in range
+ auto Loc = std::lower_bound(AllocBases.begin(), AllocBases.end(), Ptr,
+ [&](const void *Iter, const void *Val) {
+ return AllocInfoMap.at(Iter).End <= Val;
+ });
+ if (Loc == AllocBases.end() || Ptr < AllocInfoMap.at(*Loc).Start)
+ return Plugin::error(ErrorCode::NOT_FOUND,
+ "allocated memory information not found");
+ Alloc = &AllocInfoMap.at(*Loc);
+ }
+
+ switch (PropName) {
+ case OL_MEM_INFO_DEVICE:
+ return Info.write<ol_device_handle_t>(Alloc->Device);
+ case OL_MEM_INFO_BASE:
+ return Info.write<void *>(Alloc->Start);
+ case OL_MEM_INFO_SIZE:
+ return Info.write<size_t>(static_cast<char *>(Alloc->End) -
+ static_cast<char *>(Alloc->Start));
+ case OL_MEM_INFO_TYPE:
+ return Info.write<ol_alloc_type_t>(Alloc->Type);
+ default:
+ return createOffloadError(ErrorCode::INVALID_ENUMERATION,
+ "olGetMemInfo enum '%i' is invalid", PropName);
+ }
+
+ return Error::success();
+}
+
+Error olGetMemInfo_impl(const void *Ptr, ol_mem_info_t PropName,
+ size_t PropSize, void *PropValue) {
+ return olGetMemInfoImplDetail(Ptr, PropName, PropSize, PropValue, nullptr);
+}
+
+Error olGetMemInfoSize_impl(const void *Ptr, ol_mem_info_t PropName,
+ size_t *PropSizeRet) {
+ return olGetMemInfoImplDetail(Ptr, PropName, 0, nullptr, PropSizeRet);
+}
+
Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) {
auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device);