summaryrefslogtreecommitdiff
path: root/offload/liboffload/src/OffloadImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
-rw-r--r--offload/liboffload/src/OffloadImpl.cpp220
1 files changed, 163 insertions, 57 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index c2a35a245e2a..f9da63843670 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -13,6 +13,7 @@
#include "OffloadImpl.hpp"
#include "Helpers.hpp"
+#include "OffloadPrint.hpp"
#include "PluginManager.h"
#include "llvm/Support/FormatVariadic.h"
#include <OffloadAPI.h>
@@ -43,18 +44,19 @@ using namespace error;
// interface.
struct ol_device_impl_t {
ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
- ol_platform_handle_t Platform)
- : DeviceNum(DeviceNum), Device(Device), Platform(Platform) {}
+ ol_platform_handle_t Platform, InfoTreeNode &&DevInfo)
+ : DeviceNum(DeviceNum), Device(Device), Platform(Platform),
+ Info(std::forward<InfoTreeNode>(DevInfo)) {}
int DeviceNum;
GenericDeviceTy *Device;
ol_platform_handle_t Platform;
+ InfoTreeNode Info;
};
struct ol_platform_impl_t {
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
- std::vector<ol_device_impl_t> Devices,
ol_platform_backend_t BackendType)
- : Plugin(std::move(Plugin)), Devices(Devices), BackendType(BackendType) {}
+ : Plugin(std::move(Plugin)), BackendType(BackendType) {}
std::unique_ptr<GenericPluginTy> Plugin;
std::vector<ol_device_impl_t> Devices;
ol_platform_backend_t BackendType;
@@ -95,7 +97,10 @@ struct AllocInfo {
// Global shared state for liboffload
struct OffloadContext;
-static OffloadContext *OffloadContextVal;
+// This pointer is non-null if and only if the context is valid and fully
+// initialized
+static std::atomic<OffloadContext *> OffloadContextVal;
+std::mutex OffloadContextValMutex;
struct OffloadContext {
OffloadContext(OffloadContext &) = delete;
OffloadContext(OffloadContext &&) = delete;
@@ -106,6 +111,7 @@ struct OffloadContext {
bool ValidationEnabled = true;
DenseMap<void *, AllocInfo> AllocInfoMap{};
SmallVector<ol_platform_impl_t, 4> Platforms{};
+ size_t RefCount;
ol_device_handle_t HostDevice() {
// The host platform is always inserted last
@@ -144,21 +150,18 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
#include "Shared/Targets.def"
-void initPlugins() {
- auto *Context = new OffloadContext{};
-
+Error initPlugins(OffloadContext &Context) {
// Attempt to create an instance of each supported plugin.
#define PLUGIN_TARGET(Name) \
do { \
- Context->Platforms.emplace_back(ol_platform_impl_t{ \
+ Context.Platforms.emplace_back(ol_platform_impl_t{ \
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
- {}, \
pluginNameToBackend(#Name)}); \
} while (false);
#include "Shared/Targets.def"
// Preemptively initialize all devices in the plugin
- for (auto &Platform : Context->Platforms) {
+ for (auto &Platform : Context.Platforms) {
// Do not use the host plugin - it isn't supported.
if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
continue;
@@ -167,55 +170,87 @@ void initPlugins() {
for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
DevNum++) {
if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
- Platform.Devices.emplace_back(ol_device_impl_t{
- DevNum, &Platform.Plugin->getDevice(DevNum), &Platform});
+ auto Device = &Platform.Plugin->getDevice(DevNum);
+ auto Info = Device->obtainInfoImpl();
+ if (auto Err = Info.takeError())
+ return Err;
+ Platform.Devices.emplace_back(DevNum, Device, &Platform,
+ std::move(*Info));
}
}
}
// Add the special host device
- auto &HostPlatform = Context->Platforms.emplace_back(
- ol_platform_impl_t{nullptr,
- {ol_device_impl_t{-1, nullptr, nullptr}},
- OL_PLATFORM_BACKEND_HOST});
- Context->HostDevice()->Platform = &HostPlatform;
+ auto &HostPlatform = Context.Platforms.emplace_back(
+ ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
+ HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{});
+ Context.HostDevice()->Platform = &HostPlatform;
- Context->TracingEnabled = std::getenv("OFFLOAD_TRACE");
- Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
+ Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
+ Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
- OffloadContextVal = Context;
+ return Plugin::success();
}
-// TODO: We can properly reference count here and manage the resources in a more
-// clever way
Error olInit_impl() {
- static std::once_flag InitFlag;
- std::call_once(InitFlag, initPlugins);
+ std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
- return Error::success();
+ if (isOffloadInitialized()) {
+ OffloadContext::get().RefCount++;
+ return Plugin::success();
+ }
+
+ // Use a temporary to ensure that entry points querying OffloadContextVal do
+ // not get a partially initialized context
+ auto *NewContext = new OffloadContext{};
+ Error InitResult = initPlugins(*NewContext);
+ OffloadContextVal.store(NewContext);
+ OffloadContext::get().RefCount++;
+
+ return InitResult;
+}
+
+Error olShutDown_impl() {
+ std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
+
+ if (--OffloadContext::get().RefCount != 0)
+ return Error::success();
+
+ llvm::Error Result = Error::success();
+ auto *OldContext = OffloadContextVal.exchange(nullptr);
+
+ for (auto &P : OldContext->Platforms) {
+ // Host plugin is nullptr and has no deinit
+ if (!P.Plugin)
+ continue;
+
+ if (auto Res = P.Plugin->deinit())
+ Result = llvm::joinErrors(std::move(Result), std::move(Res));
+ }
+
+ delete OldContext;
+ return Result;
}
-Error olShutDown_impl() { return Error::success(); }
Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
ol_platform_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
- ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
+ InfoWriter Info(PropSize, PropValue, PropSizeRet);
bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
switch (PropName) {
case OL_PLATFORM_INFO_NAME:
- return ReturnValue(IsHost ? "Host" : Platform->Plugin->getName());
+ return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName());
case OL_PLATFORM_INFO_VENDOR_NAME:
// TODO: Implement this
- return ReturnValue("Unknown platform vendor");
+ return Info.writeString("Unknown platform vendor");
case OL_PLATFORM_INFO_VERSION: {
- return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
- OL_VERSION_MINOR, OL_VERSION_PATCH)
- .str()
- .c_str());
+ return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
+ OL_VERSION_MINOR, OL_VERSION_PATCH)
+ .str());
}
case OL_PLATFORM_INFO_BACKEND: {
- return ReturnValue(Platform->BackendType);
+ return Info.write<ol_platform_backend_t>(Platform->BackendType);
}
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
@@ -242,43 +277,108 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
ol_device_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
+ assert(Device != OffloadContext::get().HostDevice());
+ InfoWriter Info(PropSize, PropValue, PropSizeRet);
- ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
+ auto makeError = [&](ErrorCode Code, StringRef Err) {
+ std::string ErrBuffer;
+ llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
+ return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
+ };
// Find the info if it exists under any of the given names
- auto GetInfoString = [&](std::vector<std::string> Names) {
- if (Device == OffloadContext::get().HostDevice())
- return "Host";
-
- if (!Device->Device)
- return "";
+ auto getInfoString =
+ [&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
+ for (auto &Name : Names) {
+ if (auto Entry = Device->Info.get(Name)) {
+ if (!std::holds_alternative<std::string>((*Entry)->Value))
+ return makeError(ErrorCode::BACKEND_FAILURE,
+ "plugin returned incorrect type");
+ return std::get<std::string>((*Entry)->Value).c_str();
+ }
+ }
- auto Info = Device->Device->obtainInfoImpl();
- if (auto Err = Info.takeError())
- return "";
+ return makeError(ErrorCode::UNIMPLEMENTED,
+ "plugin did not provide a response for this information");
+ };
- for (auto Name : Names) {
- if (auto Entry = Info->get(Name))
- return std::get<std::string>((*Entry)->Value).c_str();
+ auto getInfoXyz =
+ [&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> {
+ for (auto &Name : Names) {
+ if (auto Entry = Device->Info.get(Name)) {
+ auto Node = *Entry;
+ ol_dimensions_t Out{0, 0, 0};
+
+ auto getField = [&](StringRef Name, uint32_t &Dest) {
+ if (auto F = Node->get(Name)) {
+ if (!std::holds_alternative<size_t>((*F)->Value))
+ return makeError(
+ ErrorCode::BACKEND_FAILURE,
+ "plugin returned incorrect type for dimensions element");
+ Dest = std::get<size_t>((*F)->Value);
+ } else
+ return makeError(ErrorCode::BACKEND_FAILURE,
+ "plugin didn't provide all values for dimensions");
+ return Plugin::success();
+ };
+
+ if (auto Res = getField("x", Out.x))
+ return Res;
+ if (auto Res = getField("y", Out.y))
+ return Res;
+ if (auto Res = getField("z", Out.z))
+ return Res;
+
+ return Out;
+ }
}
- return "";
+ return makeError(ErrorCode::UNIMPLEMENTED,
+ "plugin did not provide a response for this information");
};
switch (PropName) {
case OL_DEVICE_INFO_PLATFORM:
- return ReturnValue(Device->Platform);
+ return Info.write<void *>(Device->Platform);
+ case OL_DEVICE_INFO_TYPE:
+ return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
+ case OL_DEVICE_INFO_NAME:
+ return Info.writeString(getInfoString({"Device Name"}));
+ case OL_DEVICE_INFO_VENDOR:
+ return Info.writeString(getInfoString({"Vendor Name"}));
+ case OL_DEVICE_INFO_DRIVER_VERSION:
+ return Info.writeString(
+ getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
+ case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
+ return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
+ "Maximum Block Dimensions" /*CUDA*/}));
+ default:
+ return createOffloadError(ErrorCode::INVALID_ENUMERATION,
+ "getDeviceInfo enum '%i' is invalid", PropName);
+ }
+
+ return Error::success();
+}
+
+Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
+ ol_device_info_t PropName, size_t PropSize,
+ void *PropValue, size_t *PropSizeRet) {
+ assert(Device == OffloadContext::get().HostDevice());
+ InfoWriter Info(PropSize, PropValue, PropSizeRet);
+
+ switch (PropName) {
+ case OL_DEVICE_INFO_PLATFORM:
+ return Info.write<void *>(Device->Platform);
case OL_DEVICE_INFO_TYPE:
- return Device == OffloadContext::get().HostDevice()
- ? ReturnValue(OL_DEVICE_TYPE_HOST)
- : ReturnValue(OL_DEVICE_TYPE_GPU);
+ return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
case OL_DEVICE_INFO_NAME:
- return ReturnValue(GetInfoString({"Device Name"}));
+ return Info.writeString("Virtual Host Device");
case OL_DEVICE_INFO_VENDOR:
- return ReturnValue(GetInfoString({"Vendor Name"}));
+ return Info.writeString("Liboffload");
case OL_DEVICE_INFO_DRIVER_VERSION:
- return ReturnValue(
- GetInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
+ return Info.writeString(LLVM_VERSION_STRING);
+ case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
+ return Info.write<ol_dimensions_t>(ol_dimensions_t{1, 1, 1});
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getDeviceInfo enum '%i' is invalid", PropName);
@@ -289,12 +389,18 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
size_t PropSize, void *PropValue) {
+ if (Device == OffloadContext::get().HostDevice())
+ return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue,
+ nullptr);
return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
nullptr);
}
Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
ol_device_info_t PropName, size_t *PropSizeRet) {
+ if (Device == OffloadContext::get().HostDevice())
+ return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr,
+ PropSizeRet);
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
}