summaryrefslogtreecommitdiff
path: root/offload/liboffload/src/OffloadImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
-rw-r--r--offload/liboffload/src/OffloadImpl.cpp154
1 files changed, 138 insertions, 16 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 9d342e06127a..7e8e297831f4 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -47,10 +47,59 @@ struct ol_device_impl_t {
ol_platform_handle_t Platform, InfoTreeNode &&DevInfo)
: DeviceNum(DeviceNum), Device(Device), Platform(Platform),
Info(std::forward<InfoTreeNode>(DevInfo)) {}
+
+ ~ol_device_impl_t() {
+ assert(!OutstandingQueues.size() &&
+ "Device object dropped with outstanding queues");
+ }
+
int DeviceNum;
GenericDeviceTy *Device;
ol_platform_handle_t Platform;
InfoTreeNode Info;
+
+ llvm::SmallVector<__tgt_async_info *> OutstandingQueues;
+ std::mutex OutstandingQueuesMutex;
+
+ /// If the device has any outstanding queues that are now complete, remove it
+ /// from the list and return it.
+ ///
+ /// Queues may be added to the outstanding queue list by olDestroyQueue if
+ /// they are destroyed but not completed.
+ __tgt_async_info *getOutstandingQueue() {
+ // Not locking the `size()` access is fine here - In the worst case we
+ // either miss a queue that exists or loop through an empty array after
+ // taking the lock. Both are sub-optimal but not that bad.
+ if (OutstandingQueues.size()) {
+ std::lock_guard<std::mutex> Lock(OutstandingQueuesMutex);
+
+ // As queues are pulled and popped from this list, longer running queues
+ // naturally bubble to the start of the array. Hence looping backwards.
+ for (auto Q = OutstandingQueues.rbegin(); Q != OutstandingQueues.rend();
+ Q++) {
+ if (!Device->hasPendingWork(*Q)) {
+ auto OutstandingQueue = *Q;
+ *Q = OutstandingQueues.back();
+ OutstandingQueues.pop_back();
+ return OutstandingQueue;
+ }
+ }
+ }
+ return nullptr;
+ }
+
+ /// Complete all pending work for this device and perform any needed cleanup.
+ ///
+ /// After calling this function, no liboffload functions should be called with
+ /// this device handle.
+ llvm::Error destroy() {
+ llvm::Error Result = Plugin::success();
+ for (auto Q : OutstandingQueues)
+ if (auto Err = Device->synchronize(Q, /*Release=*/true))
+ Result = llvm::joinErrors(std::move(Result), std::move(Err));
+ OutstandingQueues.clear();
+ return Result;
+ }
};
struct ol_platform_impl_t {
@@ -58,23 +107,51 @@ struct ol_platform_impl_t {
ol_platform_backend_t BackendType)
: Plugin(std::move(Plugin)), BackendType(BackendType) {}
std::unique_ptr<GenericPluginTy> Plugin;
- std::vector<ol_device_impl_t> Devices;
+ llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
ol_platform_backend_t BackendType;
+
+ /// Complete all pending work for this platform and perform any needed
+ /// cleanup.
+ ///
+ /// After calling this function, no liboffload functions should be called with
+ /// this platform handle.
+ llvm::Error destroy() {
+ llvm::Error Result = Plugin::success();
+ for (auto &D : Devices)
+ if (auto Err = D->destroy())
+ Result = llvm::joinErrors(std::move(Result), std::move(Err));
+
+ if (auto Res = Plugin->deinit())
+ Result = llvm::joinErrors(std::move(Result), std::move(Res));
+
+ return Result;
+ }
};
struct ol_queue_impl_t {
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
- : AsyncInfo(AsyncInfo), Device(Device) {}
+ : AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {}
__tgt_async_info *AsyncInfo;
ol_device_handle_t Device;
+ // A unique identifier for the queue
+ size_t Id;
+ static std::atomic<size_t> IdCounter;
};
+std::atomic<size_t> ol_queue_impl_t::IdCounter(0);
struct ol_event_impl_t {
- ol_event_impl_t(void *EventInfo, ol_queue_handle_t Queue)
- : EventInfo(EventInfo), Queue(Queue) {}
+ ol_event_impl_t(void *EventInfo, ol_device_handle_t Device,
+ ol_queue_handle_t Queue)
+ : EventInfo(EventInfo), Device(Device), QueueId(Queue->Id), Queue(Queue) {
+ }
// EventInfo may be null, in which case the event should be considered always
// complete
void *EventInfo;
+ ol_device_handle_t Device;
+ size_t QueueId;
+ // Events may outlive the queue - don't assume this is always valid.
+ // It is provided only to implement OL_EVENT_INFO_QUEUE. Use QueueId to check
+ // for queue equality instead.
ol_queue_handle_t Queue;
};
@@ -131,7 +208,7 @@ struct OffloadContext {
ol_device_handle_t HostDevice() {
// The host platform is always inserted last
- return &Platforms.back().Devices[0];
+ return Platforms.back().Devices[0].get();
}
static OffloadContext &get() {
@@ -190,8 +267,8 @@ Error initPlugins(OffloadContext &Context) {
auto Info = Device->obtainInfoImpl();
if (auto Err = Info.takeError())
return Err;
- Platform.Devices.emplace_back(DevNum, Device, &Platform,
- std::move(*Info));
+ Platform.Devices.emplace_back(std::make_unique<ol_device_impl_t>(
+ DevNum, Device, &Platform, std::move(*Info)));
}
}
}
@@ -199,7 +276,8 @@ Error initPlugins(OffloadContext &Context) {
// Add the special host device
auto &HostPlatform = Context.Platforms.emplace_back(
ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
- HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{});
+ HostPlatform.Devices.emplace_back(
+ std::make_unique<ol_device_impl_t>(-1, nullptr, nullptr, InfoTreeNode{}));
Context.HostDevice()->Platform = &HostPlatform;
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
@@ -240,7 +318,7 @@ Error olShutDown_impl() {
if (!P.Plugin || !P.Plugin->is_initialized())
continue;
- if (auto Res = P.Plugin->deinit())
+ if (auto Res = P.destroy())
Result = llvm::joinErrors(std::move(Result), std::move(Res));
}
@@ -367,6 +445,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
// Retrieve properties from the plugin interface
switch (PropName) {
case OL_DEVICE_INFO_NAME:
+ case OL_DEVICE_INFO_PRODUCT_NAME:
case OL_DEVICE_INFO_VENDOR:
case OL_DEVICE_INFO_DRIVER_VERSION: {
// String values
@@ -377,6 +456,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
}
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
+ case OL_DEVICE_INFO_MAX_WORK_SIZE:
case OL_DEVICE_INFO_VENDOR_ID:
case OL_DEVICE_INFO_NUM_COMPUTE_UNITS:
case OL_DEVICE_INFO_ADDRESS_BITS:
@@ -393,6 +473,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
return Info.write(static_cast<uint32_t>(Value));
}
+ case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION:
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: {
// {x, y, z} triples
ol_dimensions_t Out{0, 0, 0};
@@ -431,6 +512,8 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
assert(Device == OffloadContext::get().HostDevice());
InfoWriter Info(PropSize, PropValue, PropSizeRet);
+ constexpr auto uint32_max = std::numeric_limits<uint32_t>::max();
+
switch (PropName) {
case OL_DEVICE_INFO_PLATFORM:
return Info.write<void *>(Device->Platform);
@@ -438,6 +521,8 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
case OL_DEVICE_INFO_NAME:
return Info.writeString("Virtual Host Device");
+ case OL_DEVICE_INFO_PRODUCT_NAME:
+ return Info.writeString("Virtual Host Device");
case OL_DEVICE_INFO_VENDOR:
return Info.writeString("Liboffload");
case OL_DEVICE_INFO_DRIVER_VERSION:
@@ -446,6 +531,11 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
return Info.write<uint32_t>(1);
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION:
return Info.write<ol_dimensions_t>(ol_dimensions_t{1, 1, 1});
+ case OL_DEVICE_INFO_MAX_WORK_SIZE:
+ return Info.write<uint32_t>(uint32_max);
+ case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION:
+ return Info.write<ol_dimensions_t>(
+ ol_dimensions_t{uint32_max, uint32_max, uint32_max});
case OL_DEVICE_INFO_VENDOR_ID:
return Info.write<uint32_t>(0);
case OL_DEVICE_INFO_NUM_COMPUTE_UNITS:
@@ -505,7 +595,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
for (auto &Platform : OffloadContext::get().Platforms) {
for (auto &Device : Platform.Devices) {
- if (!Callback(&Device, UserData)) {
+ if (!Callback(Device.get(), UserData)) {
break;
}
}
@@ -566,14 +656,46 @@ Error olMemFree_impl(void *Address) {
Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) {
auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device);
- if (auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo)))
+
+ auto OutstandingQueue = Device->getOutstandingQueue();
+ if (OutstandingQueue) {
+ // The queue is empty, but we still need to sync it to release any temporary
+ // memory allocations or do other cleanup.
+ if (auto Err =
+ Device->Device->synchronize(OutstandingQueue, /*Release=*/false))
+ return Err;
+ CreatedQueue->AsyncInfo = OutstandingQueue;
+ } else if (auto Err =
+ Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) {
return Err;
+ }
*Queue = CreatedQueue.release();
return Error::success();
}
-Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); }
+Error olDestroyQueue_impl(ol_queue_handle_t Queue) {
+ auto *Device = Queue->Device;
+ // This is safe; as soon as olDestroyQueue is called it is not possible to add
+ // any more work to the queue, so if it's finished now it will remain finished
+ // forever.
+ auto Res = Device->Device->hasPendingWork(Queue->AsyncInfo);
+ if (!Res)
+ return Res.takeError();
+
+ if (!*Res) {
+ // The queue is complete, so sync it and throw it back into the pool.
+ if (auto Err = Device->Device->synchronize(Queue->AsyncInfo,
+ /*Release=*/true))
+ return Err;
+ } else {
+ // The queue still has outstanding work. Store it so we can check it later.
+ std::lock_guard<std::mutex> Lock(Device->OutstandingQueuesMutex);
+ Device->OutstandingQueues.push_back(Queue->AsyncInfo);
+ }
+
+ return olDestroy(Queue);
+}
Error olSyncQueue_impl(ol_queue_handle_t Queue) {
// Host plugin doesn't have a queue set so it's not safe to call synchronize
@@ -601,7 +723,7 @@ Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events,
"olWaitEvents asked to wait on a NULL event");
// Do nothing if the event is for this queue or the event is always complete
- if (Event->Queue == Queue || !Event->EventInfo)
+ if (Event->QueueId == Queue->Id || !Event->EventInfo)
continue;
if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo))
@@ -649,7 +771,7 @@ Error olSyncEvent_impl(ol_event_handle_t Event) {
if (!Event->EventInfo)
return Plugin::success();
- if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo))
+ if (auto Res = Event->Device->Device->syncEvent(Event->EventInfo))
return Res;
return Error::success();
@@ -657,7 +779,7 @@ Error olSyncEvent_impl(ol_event_handle_t Event) {
Error olDestroyEvent_impl(ol_event_handle_t Event) {
if (Event->EventInfo)
- if (auto Res = Event->Queue->Device->Device->destroyEvent(Event->EventInfo))
+ if (auto Res = Event->Device->Device->destroyEvent(Event->EventInfo))
return Res;
return olDestroy(Event);
@@ -708,7 +830,7 @@ Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
if (auto Err = Pending.takeError())
return Err;
- *EventOut = new ol_event_impl_t(nullptr, Queue);
+ *EventOut = new ol_event_impl_t(nullptr, Queue->Device, Queue);
if (!*Pending)
// Queue is empty, don't record an event and consider the event always
// complete