diff options
| author | Ross Brunton <ross@codeplay.com> | 2025-06-30 15:00:43 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-06-30 15:00:43 +0100 |
| commit | 67e73ba605ea78d757c293f85e32a42257f9c6ed (patch) | |
| tree | 7841f701db1e8de66d0314536bd3a9bf771d1185 /offload | |
| parent | 619f7afd716c520e9ab98e1cca30f75dafe40655 (diff) | |
[Offload] Refactor device/platform info queries (#146345)
This makes several small changes to how the platform and device info
queries are handled:
* ReturnHelper has been replaced with InfoWriter which is more explicit
in how it is invoked.
* InfoWriter consumes `llvm::Expected` rather than values directly, and
will early exit if it returns an error.
* As a result of the above, `GetInfoString` now correctly returns errors
rather than empty strings.
* The host device now has its own dedicated "getInfo" function rather
than being checked in multiple places.
Diffstat (limited to 'offload')
| -rw-r--r-- | offload/liboffload/src/Helpers.hpp | 54 | ||||
| -rw-r--r-- | offload/liboffload/src/OffloadImpl.cpp | 86 |
2 files changed, 90 insertions, 50 deletions
diff --git a/offload/liboffload/src/Helpers.hpp b/offload/liboffload/src/Helpers.hpp index 425934b6760d..8b85945508b9 100644 --- a/offload/liboffload/src/Helpers.hpp +++ b/offload/liboffload/src/Helpers.hpp @@ -61,39 +61,41 @@ llvm::Error getInfoArray(size_t array_length, size_t ParamValueSize, array_length * sizeof(T), memcpy); } -template <> -inline llvm::Error -getInfo<const char *>(size_t ParamValueSize, void *ParamValue, - size_t *ParamValueSizeRet, const char *Value) { - return getInfoArray(strlen(Value) + 1, ParamValueSize, ParamValue, - ParamValueSizeRet, Value); +llvm::Error getInfoString(size_t ParamValueSize, void *ParamValue, + size_t *ParamValueSizeRet, llvm::StringRef Value) { + return getInfoArray(Value.size() + 1, ParamValueSize, ParamValue, + ParamValueSizeRet, Value.data()); } -class ReturnHelper { +class InfoWriter { public: - ReturnHelper(size_t ParamValueSize, void *ParamValue, - size_t *ParamValueSizeRet) - : ParamValueSize(ParamValueSize), ParamValue(ParamValue), - ParamValueSizeRet(ParamValueSizeRet) {} + InfoWriter(size_t Size, void *Target, size_t *SizeRet) + : Size(Size), Target(Target), SizeRet(SizeRet) {}; + InfoWriter() = delete; + InfoWriter(InfoWriter &) = delete; + ~InfoWriter() = default; - // A version where in/out info size is represented by a single pointer - // to a value which is updated on return - ReturnHelper(size_t *ParamValueSize, void *ParamValue) - : ParamValueSize(*ParamValueSize), ParamValue(ParamValue), - ParamValueSizeRet(ParamValueSize) {} + template <typename T> llvm::Error write(llvm::Expected<T> &&Val) { + if (Val) + return getInfo(Size, Target, SizeRet, *Val); + return Val.takeError(); + } - // Scalar return Value - template <class T> llvm::Error operator()(const T &t) { - return getInfo(ParamValueSize, ParamValue, ParamValueSizeRet, t); + template <typename T> + llvm::Error writeArray(llvm::Expected<T> &&Val, size_t Elems) { + if (Val) + return getInfoArray(Elems, Size, Target, SizeRet, *Val); + return Val.takeError(); } - // Array return Value - template <class T> llvm::Error operator()(const T *t, size_t s) { - return getInfoArray(s, ParamValueSize, ParamValue, ParamValueSizeRet, t); + llvm::Error writeString(llvm::Expected<llvm::StringRef> &&Val) { + if (Val) + return getInfoString(Size, Target, SizeRet, *Val); + return Val.takeError(); } -protected: - size_t ParamValueSize; - void *ParamValue; - size_t *ParamValueSizeRet; +private: + size_t Size; + void *Target; + size_t *SizeRet; }; diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 9d4f4f54a821..e7da4eddce54 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -13,6 +13,7 @@ #include "OffloadImpl.hpp" #include "Helpers.hpp" +#include "OffloadPrint.hpp" #include "PluginManager.h" #include "llvm/Support/FormatVariadic.h" #include <OffloadAPI.h> @@ -234,23 +235,22 @@ Error olShutDown_impl() { Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { - ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); + InfoWriter Info(PropSize, PropValue, PropSizeRet); bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST; switch (PropName) { case OL_PLATFORM_INFO_NAME: - return ReturnValue(IsHost ? "Host" : Platform->Plugin->getName()); + return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName()); case OL_PLATFORM_INFO_VENDOR_NAME: // TODO: Implement this - return ReturnValue("Unknown platform vendor"); + return Info.writeString("Unknown platform vendor"); case OL_PLATFORM_INFO_VERSION: { - return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, - OL_VERSION_MINOR, OL_VERSION_PATCH) - .str() - .c_str()); + return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, + OL_VERSION_MINOR, OL_VERSION_PATCH) + .str()); } case OL_PLATFORM_INFO_BACKEND: { - return ReturnValue(Platform->BackendType); + return Info.write<ol_platform_backend_t>(Platform->BackendType); } default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, @@ -277,36 +277,68 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { + assert(Device != OffloadContext::get().HostDevice()); + InfoWriter Info(PropSize, PropValue, PropSizeRet); - ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); + auto makeError = [&](ErrorCode Code, StringRef Err) { + std::string ErrBuffer; + llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err; + return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str()); + }; // Find the info if it exists under any of the given names - auto GetInfoString = [&](std::vector<std::string> Names) { - if (Device == OffloadContext::get().HostDevice()) - return "Host"; - - for (auto Name : Names) { - if (auto Entry = Device->Info.get(Name)) + auto getInfoString = + [&](std::vector<std::string> Names) -> llvm::Expected<const char *> { + for (auto &Name : Names) { + if (auto Entry = Device->Info.get(Name)) { + if (!std::holds_alternative<std::string>((*Entry)->Value)) + return makeError(ErrorCode::BACKEND_FAILURE, + "plugin returned incorrect type"); return std::get<std::string>((*Entry)->Value).c_str(); + } } - return ""; + return makeError(ErrorCode::UNIMPLEMENTED, + "plugin did not provide a response for this information"); }; switch (PropName) { case OL_DEVICE_INFO_PLATFORM: - return ReturnValue(Device->Platform); + return Info.write<void *>(Device->Platform); + case OL_DEVICE_INFO_TYPE: + return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU); + case OL_DEVICE_INFO_NAME: + return Info.writeString(getInfoString({"Device Name"})); + case OL_DEVICE_INFO_VENDOR: + return Info.writeString(getInfoString({"Vendor Name"})); + case OL_DEVICE_INFO_DRIVER_VERSION: + return Info.writeString( + getInfoString({"CUDA Driver Version", "HSA Runtime Version"})); + default: + return createOffloadError(ErrorCode::INVALID_ENUMERATION, + "getDeviceInfo enum '%i' is invalid", PropName); + } + + return Error::success(); +} + +Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, + ol_device_info_t PropName, size_t PropSize, + void *PropValue, size_t *PropSizeRet) { + assert(Device == OffloadContext::get().HostDevice()); + InfoWriter Info(PropSize, PropValue, PropSizeRet); + + switch (PropName) { + case OL_DEVICE_INFO_PLATFORM: + return Info.write<void *>(Device->Platform); case OL_DEVICE_INFO_TYPE: - return Device == OffloadContext::get().HostDevice() - ? ReturnValue(OL_DEVICE_TYPE_HOST) - : ReturnValue(OL_DEVICE_TYPE_GPU); + return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST); case OL_DEVICE_INFO_NAME: - return ReturnValue(GetInfoString({"Device Name"})); + return Info.writeString("Virtual Host Device"); case OL_DEVICE_INFO_VENDOR: - return ReturnValue(GetInfoString({"Vendor Name"})); + return Info.writeString("Liboffload"); case OL_DEVICE_INFO_DRIVER_VERSION: - return ReturnValue( - GetInfoString({"CUDA Driver Version", "HSA Runtime Version"})); + return Info.writeString(LLVM_VERSION_STRING); default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "getDeviceInfo enum '%i' is invalid", PropName); @@ -317,12 +349,18 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue) { + if (Device == OffloadContext::get().HostDevice()) + return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue, + nullptr); return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, nullptr); } Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet) { + if (Device == OffloadContext::get().HostDevice()) + return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr, + PropSizeRet); return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); } |
