diff options
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
| -rw-r--r-- | offload/liboffload/src/OffloadImpl.cpp | 220 |
1 files changed, 163 insertions, 57 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index c2a35a245e2a..f9da63843670 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -13,6 +13,7 @@ #include "OffloadImpl.hpp" #include "Helpers.hpp" +#include "OffloadPrint.hpp" #include "PluginManager.h" #include "llvm/Support/FormatVariadic.h" #include <OffloadAPI.h> @@ -43,18 +44,19 @@ using namespace error; // interface. struct ol_device_impl_t { ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device, - ol_platform_handle_t Platform) - : DeviceNum(DeviceNum), Device(Device), Platform(Platform) {} + ol_platform_handle_t Platform, InfoTreeNode &&DevInfo) + : DeviceNum(DeviceNum), Device(Device), Platform(Platform), + Info(std::forward<InfoTreeNode>(DevInfo)) {} int DeviceNum; GenericDeviceTy *Device; ol_platform_handle_t Platform; + InfoTreeNode Info; }; struct ol_platform_impl_t { ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin, - std::vector<ol_device_impl_t> Devices, ol_platform_backend_t BackendType) - : Plugin(std::move(Plugin)), Devices(Devices), BackendType(BackendType) {} + : Plugin(std::move(Plugin)), BackendType(BackendType) {} std::unique_ptr<GenericPluginTy> Plugin; std::vector<ol_device_impl_t> Devices; ol_platform_backend_t BackendType; @@ -95,7 +97,10 @@ struct AllocInfo { // Global shared state for liboffload struct OffloadContext; -static OffloadContext *OffloadContextVal; +// This pointer is non-null if and only if the context is valid and fully +// initialized +static std::atomic<OffloadContext *> OffloadContextVal; +std::mutex OffloadContextValMutex; struct OffloadContext { OffloadContext(OffloadContext &) = delete; OffloadContext(OffloadContext &&) = delete; @@ -106,6 +111,7 @@ struct OffloadContext { bool ValidationEnabled = true; DenseMap<void *, AllocInfo> AllocInfoMap{}; SmallVector<ol_platform_impl_t, 4> Platforms{}; + size_t RefCount; ol_device_handle_t HostDevice() { // The host platform is always inserted last @@ -144,21 +150,18 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) { #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); #include "Shared/Targets.def" -void initPlugins() { - auto *Context = new OffloadContext{}; - +Error initPlugins(OffloadContext &Context) { // Attempt to create an instance of each supported plugin. #define PLUGIN_TARGET(Name) \ do { \ - Context->Platforms.emplace_back(ol_platform_impl_t{ \ + Context.Platforms.emplace_back(ol_platform_impl_t{ \ std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \ - {}, \ pluginNameToBackend(#Name)}); \ } while (false); #include "Shared/Targets.def" // Preemptively initialize all devices in the plugin - for (auto &Platform : Context->Platforms) { + for (auto &Platform : Context.Platforms) { // Do not use the host plugin - it isn't supported. if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN) continue; @@ -167,55 +170,87 @@ void initPlugins() { for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices(); DevNum++) { if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { - Platform.Devices.emplace_back(ol_device_impl_t{ - DevNum, &Platform.Plugin->getDevice(DevNum), &Platform}); + auto Device = &Platform.Plugin->getDevice(DevNum); + auto Info = Device->obtainInfoImpl(); + if (auto Err = Info.takeError()) + return Err; + Platform.Devices.emplace_back(DevNum, Device, &Platform, + std::move(*Info)); } } } // Add the special host device - auto &HostPlatform = Context->Platforms.emplace_back( - ol_platform_impl_t{nullptr, - {ol_device_impl_t{-1, nullptr, nullptr}}, - OL_PLATFORM_BACKEND_HOST}); - Context->HostDevice()->Platform = &HostPlatform; + auto &HostPlatform = Context.Platforms.emplace_back( + ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST}); + HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{}); + Context.HostDevice()->Platform = &HostPlatform; - Context->TracingEnabled = std::getenv("OFFLOAD_TRACE"); - Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); + Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); + Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); - OffloadContextVal = Context; + return Plugin::success(); } -// TODO: We can properly reference count here and manage the resources in a more -// clever way Error olInit_impl() { - static std::once_flag InitFlag; - std::call_once(InitFlag, initPlugins); + std::lock_guard<std::mutex> Lock{OffloadContextValMutex}; - return Error::success(); + if (isOffloadInitialized()) { + OffloadContext::get().RefCount++; + return Plugin::success(); + } + + // Use a temporary to ensure that entry points querying OffloadContextVal do + // not get a partially initialized context + auto *NewContext = new OffloadContext{}; + Error InitResult = initPlugins(*NewContext); + OffloadContextVal.store(NewContext); + OffloadContext::get().RefCount++; + + return InitResult; +} + +Error olShutDown_impl() { + std::lock_guard<std::mutex> Lock{OffloadContextValMutex}; + + if (--OffloadContext::get().RefCount != 0) + return Error::success(); + + llvm::Error Result = Error::success(); + auto *OldContext = OffloadContextVal.exchange(nullptr); + + for (auto &P : OldContext->Platforms) { + // Host plugin is nullptr and has no deinit + if (!P.Plugin) + continue; + + if (auto Res = P.Plugin->deinit()) + Result = llvm::joinErrors(std::move(Result), std::move(Res)); + } + + delete OldContext; + return Result; } -Error olShutDown_impl() { return Error::success(); } Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { - ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); + InfoWriter Info(PropSize, PropValue, PropSizeRet); bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST; switch (PropName) { case OL_PLATFORM_INFO_NAME: - return ReturnValue(IsHost ? "Host" : Platform->Plugin->getName()); + return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName()); case OL_PLATFORM_INFO_VENDOR_NAME: // TODO: Implement this - return ReturnValue("Unknown platform vendor"); + return Info.writeString("Unknown platform vendor"); case OL_PLATFORM_INFO_VERSION: { - return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, - OL_VERSION_MINOR, OL_VERSION_PATCH) - .str() - .c_str()); + return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, + OL_VERSION_MINOR, OL_VERSION_PATCH) + .str()); } case OL_PLATFORM_INFO_BACKEND: { - return ReturnValue(Platform->BackendType); + return Info.write<ol_platform_backend_t>(Platform->BackendType); } default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, @@ -242,43 +277,108 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { + assert(Device != OffloadContext::get().HostDevice()); + InfoWriter Info(PropSize, PropValue, PropSizeRet); - ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); + auto makeError = [&](ErrorCode Code, StringRef Err) { + std::string ErrBuffer; + llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err; + return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str()); + }; // Find the info if it exists under any of the given names - auto GetInfoString = [&](std::vector<std::string> Names) { - if (Device == OffloadContext::get().HostDevice()) - return "Host"; - - if (!Device->Device) - return ""; + auto getInfoString = + [&](std::vector<std::string> Names) -> llvm::Expected<const char *> { + for (auto &Name : Names) { + if (auto Entry = Device->Info.get(Name)) { + if (!std::holds_alternative<std::string>((*Entry)->Value)) + return makeError(ErrorCode::BACKEND_FAILURE, + "plugin returned incorrect type"); + return std::get<std::string>((*Entry)->Value).c_str(); + } + } - auto Info = Device->Device->obtainInfoImpl(); - if (auto Err = Info.takeError()) - return ""; + return makeError(ErrorCode::UNIMPLEMENTED, + "plugin did not provide a response for this information"); + }; - for (auto Name : Names) { - if (auto Entry = Info->get(Name)) - return std::get<std::string>((*Entry)->Value).c_str(); + auto getInfoXyz = + [&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> { + for (auto &Name : Names) { + if (auto Entry = Device->Info.get(Name)) { + auto Node = *Entry; + ol_dimensions_t Out{0, 0, 0}; + + auto getField = [&](StringRef Name, uint32_t &Dest) { + if (auto F = Node->get(Name)) { + if (!std::holds_alternative<size_t>((*F)->Value)) + return makeError( + ErrorCode::BACKEND_FAILURE, + "plugin returned incorrect type for dimensions element"); + Dest = std::get<size_t>((*F)->Value); + } else + return makeError(ErrorCode::BACKEND_FAILURE, + "plugin didn't provide all values for dimensions"); + return Plugin::success(); + }; + + if (auto Res = getField("x", Out.x)) + return Res; + if (auto Res = getField("y", Out.y)) + return Res; + if (auto Res = getField("z", Out.z)) + return Res; + + return Out; + } } - return ""; + return makeError(ErrorCode::UNIMPLEMENTED, + "plugin did not provide a response for this information"); }; switch (PropName) { case OL_DEVICE_INFO_PLATFORM: - return ReturnValue(Device->Platform); + return Info.write<void *>(Device->Platform); + case OL_DEVICE_INFO_TYPE: + return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU); + case OL_DEVICE_INFO_NAME: + return Info.writeString(getInfoString({"Device Name"})); + case OL_DEVICE_INFO_VENDOR: + return Info.writeString(getInfoString({"Vendor Name"})); + case OL_DEVICE_INFO_DRIVER_VERSION: + return Info.writeString( + getInfoString({"CUDA Driver Version", "HSA Runtime Version"})); + case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: + return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/, + "Maximum Block Dimensions" /*CUDA*/})); + default: + return createOffloadError(ErrorCode::INVALID_ENUMERATION, + "getDeviceInfo enum '%i' is invalid", PropName); + } + + return Error::success(); +} + +Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, + ol_device_info_t PropName, size_t PropSize, + void *PropValue, size_t *PropSizeRet) { + assert(Device == OffloadContext::get().HostDevice()); + InfoWriter Info(PropSize, PropValue, PropSizeRet); + + switch (PropName) { + case OL_DEVICE_INFO_PLATFORM: + return Info.write<void *>(Device->Platform); case OL_DEVICE_INFO_TYPE: - return Device == OffloadContext::get().HostDevice() - ? ReturnValue(OL_DEVICE_TYPE_HOST) - : ReturnValue(OL_DEVICE_TYPE_GPU); + return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST); case OL_DEVICE_INFO_NAME: - return ReturnValue(GetInfoString({"Device Name"})); + return Info.writeString("Virtual Host Device"); case OL_DEVICE_INFO_VENDOR: - return ReturnValue(GetInfoString({"Vendor Name"})); + return Info.writeString("Liboffload"); case OL_DEVICE_INFO_DRIVER_VERSION: - return ReturnValue( - GetInfoString({"CUDA Driver Version", "HSA Runtime Version"})); + return Info.writeString(LLVM_VERSION_STRING); + case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: + return Info.write<ol_dimensions_t>(ol_dimensions_t{1, 1, 1}); default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "getDeviceInfo enum '%i' is invalid", PropName); @@ -289,12 +389,18 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue) { + if (Device == OffloadContext::get().HostDevice()) + return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue, + nullptr); return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, nullptr); } Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet) { + if (Device == OffloadContext::get().HostDevice()) + return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr, + PropSizeRet); return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); } |
