diff options
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
| -rw-r--r-- | offload/liboffload/src/OffloadImpl.cpp | 154 |
1 files changed, 138 insertions, 16 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 9d342e06127a..7e8e297831f4 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -47,10 +47,59 @@ struct ol_device_impl_t { ol_platform_handle_t Platform, InfoTreeNode &&DevInfo) : DeviceNum(DeviceNum), Device(Device), Platform(Platform), Info(std::forward<InfoTreeNode>(DevInfo)) {} + + ~ol_device_impl_t() { + assert(!OutstandingQueues.size() && + "Device object dropped with outstanding queues"); + } + int DeviceNum; GenericDeviceTy *Device; ol_platform_handle_t Platform; InfoTreeNode Info; + + llvm::SmallVector<__tgt_async_info *> OutstandingQueues; + std::mutex OutstandingQueuesMutex; + + /// If the device has any outstanding queues that are now complete, remove it + /// from the list and return it. + /// + /// Queues may be added to the outstanding queue list by olDestroyQueue if + /// they are destroyed but not completed. + __tgt_async_info *getOutstandingQueue() { + // Not locking the `size()` access is fine here - In the worst case we + // either miss a queue that exists or loop through an empty array after + // taking the lock. Both are sub-optimal but not that bad. + if (OutstandingQueues.size()) { + std::lock_guard<std::mutex> Lock(OutstandingQueuesMutex); + + // As queues are pulled and popped from this list, longer running queues + // naturally bubble to the start of the array. Hence looping backwards. + for (auto Q = OutstandingQueues.rbegin(); Q != OutstandingQueues.rend(); + Q++) { + if (!Device->hasPendingWork(*Q)) { + auto OutstandingQueue = *Q; + *Q = OutstandingQueues.back(); + OutstandingQueues.pop_back(); + return OutstandingQueue; + } + } + } + return nullptr; + } + + /// Complete all pending work for this device and perform any needed cleanup. + /// + /// After calling this function, no liboffload functions should be called with + /// this device handle. + llvm::Error destroy() { + llvm::Error Result = Plugin::success(); + for (auto Q : OutstandingQueues) + if (auto Err = Device->synchronize(Q, /*Release=*/true)) + Result = llvm::joinErrors(std::move(Result), std::move(Err)); + OutstandingQueues.clear(); + return Result; + } }; struct ol_platform_impl_t { @@ -58,23 +107,51 @@ struct ol_platform_impl_t { ol_platform_backend_t BackendType) : Plugin(std::move(Plugin)), BackendType(BackendType) {} std::unique_ptr<GenericPluginTy> Plugin; - std::vector<ol_device_impl_t> Devices; + llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices; ol_platform_backend_t BackendType; + + /// Complete all pending work for this platform and perform any needed + /// cleanup. + /// + /// After calling this function, no liboffload functions should be called with + /// this platform handle. + llvm::Error destroy() { + llvm::Error Result = Plugin::success(); + for (auto &D : Devices) + if (auto Err = D->destroy()) + Result = llvm::joinErrors(std::move(Result), std::move(Err)); + + if (auto Res = Plugin->deinit()) + Result = llvm::joinErrors(std::move(Result), std::move(Res)); + + return Result; + } }; struct ol_queue_impl_t { ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device) - : AsyncInfo(AsyncInfo), Device(Device) {} + : AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {} __tgt_async_info *AsyncInfo; ol_device_handle_t Device; + // A unique identifier for the queue + size_t Id; + static std::atomic<size_t> IdCounter; }; +std::atomic<size_t> ol_queue_impl_t::IdCounter(0); struct ol_event_impl_t { - ol_event_impl_t(void *EventInfo, ol_queue_handle_t Queue) - : EventInfo(EventInfo), Queue(Queue) {} + ol_event_impl_t(void *EventInfo, ol_device_handle_t Device, + ol_queue_handle_t Queue) + : EventInfo(EventInfo), Device(Device), QueueId(Queue->Id), Queue(Queue) { + } // EventInfo may be null, in which case the event should be considered always // complete void *EventInfo; + ol_device_handle_t Device; + size_t QueueId; + // Events may outlive the queue - don't assume this is always valid. + // It is provided only to implement OL_EVENT_INFO_QUEUE. Use QueueId to check + // for queue equality instead. ol_queue_handle_t Queue; }; @@ -131,7 +208,7 @@ struct OffloadContext { ol_device_handle_t HostDevice() { // The host platform is always inserted last - return &Platforms.back().Devices[0]; + return Platforms.back().Devices[0].get(); } static OffloadContext &get() { @@ -190,8 +267,8 @@ Error initPlugins(OffloadContext &Context) { auto Info = Device->obtainInfoImpl(); if (auto Err = Info.takeError()) return Err; - Platform.Devices.emplace_back(DevNum, Device, &Platform, - std::move(*Info)); + Platform.Devices.emplace_back(std::make_unique<ol_device_impl_t>( + DevNum, Device, &Platform, std::move(*Info))); } } } @@ -199,7 +276,8 @@ Error initPlugins(OffloadContext &Context) { // Add the special host device auto &HostPlatform = Context.Platforms.emplace_back( ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST}); - HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{}); + HostPlatform.Devices.emplace_back( + std::make_unique<ol_device_impl_t>(-1, nullptr, nullptr, InfoTreeNode{})); Context.HostDevice()->Platform = &HostPlatform; Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); @@ -240,7 +318,7 @@ Error olShutDown_impl() { if (!P.Plugin || !P.Plugin->is_initialized()) continue; - if (auto Res = P.Plugin->deinit()) + if (auto Res = P.destroy()) Result = llvm::joinErrors(std::move(Result), std::move(Res)); } @@ -367,6 +445,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, // Retrieve properties from the plugin interface switch (PropName) { case OL_DEVICE_INFO_NAME: + case OL_DEVICE_INFO_PRODUCT_NAME: case OL_DEVICE_INFO_VENDOR: case OL_DEVICE_INFO_DRIVER_VERSION: { // String values @@ -377,6 +456,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, } case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: + case OL_DEVICE_INFO_MAX_WORK_SIZE: case OL_DEVICE_INFO_VENDOR_ID: case OL_DEVICE_INFO_NUM_COMPUTE_UNITS: case OL_DEVICE_INFO_ADDRESS_BITS: @@ -393,6 +473,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, return Info.write(static_cast<uint32_t>(Value)); } + case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION: case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: { // {x, y, z} triples ol_dimensions_t Out{0, 0, 0}; @@ -431,6 +512,8 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, assert(Device == OffloadContext::get().HostDevice()); InfoWriter Info(PropSize, PropValue, PropSizeRet); + constexpr auto uint32_max = std::numeric_limits<uint32_t>::max(); + switch (PropName) { case OL_DEVICE_INFO_PLATFORM: return Info.write<void *>(Device->Platform); @@ -438,6 +521,8 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST); case OL_DEVICE_INFO_NAME: return Info.writeString("Virtual Host Device"); + case OL_DEVICE_INFO_PRODUCT_NAME: + return Info.writeString("Virtual Host Device"); case OL_DEVICE_INFO_VENDOR: return Info.writeString("Liboffload"); case OL_DEVICE_INFO_DRIVER_VERSION: @@ -446,6 +531,11 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, return Info.write<uint32_t>(1); case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: return Info.write<ol_dimensions_t>(ol_dimensions_t{1, 1, 1}); + case OL_DEVICE_INFO_MAX_WORK_SIZE: + return Info.write<uint32_t>(uint32_max); + case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION: + return Info.write<ol_dimensions_t>( + ol_dimensions_t{uint32_max, uint32_max, uint32_max}); case OL_DEVICE_INFO_VENDOR_ID: return Info.write<uint32_t>(0); case OL_DEVICE_INFO_NUM_COMPUTE_UNITS: @@ -505,7 +595,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) { for (auto &Platform : OffloadContext::get().Platforms) { for (auto &Device : Platform.Devices) { - if (!Callback(&Device, UserData)) { + if (!Callback(Device.get(), UserData)) { break; } } @@ -566,14 +656,46 @@ Error olMemFree_impl(void *Address) { Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) { auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device); - if (auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) + + auto OutstandingQueue = Device->getOutstandingQueue(); + if (OutstandingQueue) { + // The queue is empty, but we still need to sync it to release any temporary + // memory allocations or do other cleanup. + if (auto Err = + Device->Device->synchronize(OutstandingQueue, /*Release=*/false)) + return Err; + CreatedQueue->AsyncInfo = OutstandingQueue; + } else if (auto Err = + Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) { return Err; + } *Queue = CreatedQueue.release(); return Error::success(); } -Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); } +Error olDestroyQueue_impl(ol_queue_handle_t Queue) { + auto *Device = Queue->Device; + // This is safe; as soon as olDestroyQueue is called it is not possible to add + // any more work to the queue, so if it's finished now it will remain finished + // forever. + auto Res = Device->Device->hasPendingWork(Queue->AsyncInfo); + if (!Res) + return Res.takeError(); + + if (!*Res) { + // The queue is complete, so sync it and throw it back into the pool. + if (auto Err = Device->Device->synchronize(Queue->AsyncInfo, + /*Release=*/true)) + return Err; + } else { + // The queue still has outstanding work. Store it so we can check it later. + std::lock_guard<std::mutex> Lock(Device->OutstandingQueuesMutex); + Device->OutstandingQueues.push_back(Queue->AsyncInfo); + } + + return olDestroy(Queue); +} Error olSyncQueue_impl(ol_queue_handle_t Queue) { // Host plugin doesn't have a queue set so it's not safe to call synchronize @@ -601,7 +723,7 @@ Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events, "olWaitEvents asked to wait on a NULL event"); // Do nothing if the event is for this queue or the event is always complete - if (Event->Queue == Queue || !Event->EventInfo) + if (Event->QueueId == Queue->Id || !Event->EventInfo) continue; if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo)) @@ -649,7 +771,7 @@ Error olSyncEvent_impl(ol_event_handle_t Event) { if (!Event->EventInfo) return Plugin::success(); - if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo)) + if (auto Res = Event->Device->Device->syncEvent(Event->EventInfo)) return Res; return Error::success(); @@ -657,7 +779,7 @@ Error olSyncEvent_impl(ol_event_handle_t Event) { Error olDestroyEvent_impl(ol_event_handle_t Event) { if (Event->EventInfo) - if (auto Res = Event->Queue->Device->Device->destroyEvent(Event->EventInfo)) + if (auto Res = Event->Device->Device->destroyEvent(Event->EventInfo)) return Res; return olDestroy(Event); @@ -708,7 +830,7 @@ Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) { if (auto Err = Pending.takeError()) return Err; - *EventOut = new ol_event_impl_t(nullptr, Queue); + *EventOut = new ol_event_impl_t(nullptr, Queue->Device, Queue); if (!*Pending) // Queue is empty, don't record an event and consider the event always // complete |
