summaryrefslogtreecommitdiff
path: root/offload/plugins-nextgen/common/src/PluginInterface.cpp
diff options
context:
space:
mode:
authorFangrui Song <i@maskray.me>2024-06-06 13:23:38 -0700
committerFangrui Song <i@maskray.me>2024-06-06 13:23:38 -0700
commit683ca4ab2cce926ca945b5eed9fa0bb3cf575de9 (patch)
treec32c6df233afdf9469e20f99733cde3f552e49de /offload/plugins-nextgen/common/src/PluginInterface.cpp
parentcf44857e7bce6b2defe3f174e0134e2bb7a0ac9d (diff)
parentfbcb92ca017ee7fbf84be808701133fbdf3b1c59 (diff)
Created using spr 1.3.5-bogner [skip ci]
Diffstat (limited to 'offload/plugins-nextgen/common/src/PluginInterface.cpp')
-rw-r--r--offload/plugins-nextgen/common/src/PluginInterface.cpp72
1 files changed, 57 insertions, 15 deletions
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 913721a15d71..5a53c479e33d 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -748,8 +748,7 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
if (ompt::Initialized) {
bool ExpectedStatus = false;
if (OmptInitialized.compare_exchange_strong(ExpectedStatus, true))
- performOmptCallback(device_initialize, /*device_num=*/DeviceId +
- Plugin.getDeviceIdStartIndex(),
+ performOmptCallback(device_initialize, Plugin.getUserId(DeviceId),
/*type=*/getComputeUnitKind().c_str(),
/*device=*/reinterpret_cast<ompt_device_t *>(this),
/*lookup=*/ompt::lookupCallbackByName,
@@ -847,9 +846,7 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
if (ompt::Initialized) {
bool ExpectedStatus = true;
if (OmptInitialized.compare_exchange_strong(ExpectedStatus, false))
- performOmptCallback(device_finalize,
- /*device_num=*/DeviceId +
- Plugin.getDeviceIdStartIndex());
+ performOmptCallback(device_finalize, Plugin.getUserId(DeviceId));
}
#endif
@@ -908,7 +905,7 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
size_t Bytes =
getPtrDiff(InputTgtImage->ImageEnd, InputTgtImage->ImageStart);
performOmptCallback(
- device_load, /*device_num=*/DeviceId + Plugin.getDeviceIdStartIndex(),
+ device_load, Plugin.getUserId(DeviceId),
/*FileName=*/nullptr, /*FileOffset=*/0, /*VmaInFile=*/nullptr,
/*ImgSize=*/Bytes, /*HostAddr=*/InputTgtImage->ImageStart,
/*DeviceAddr=*/nullptr, /* FIXME: ModuleId */ 0);
@@ -1492,11 +1489,14 @@ Error GenericDeviceTy::syncEvent(void *EventPtr) {
bool GenericDeviceTy::useAutoZeroCopy() { return useAutoZeroCopyImpl(); }
Error GenericPluginTy::init() {
+ if (Initialized)
+ return Plugin::success();
+
auto NumDevicesOrErr = initImpl();
if (!NumDevicesOrErr)
return NumDevicesOrErr.takeError();
-
Initialized = true;
+
NumDevices = *NumDevicesOrErr;
if (NumDevices == 0)
return Plugin::success();
@@ -1517,6 +1517,8 @@ Error GenericPluginTy::init() {
}
Error GenericPluginTy::deinit() {
+ assert(Initialized && "Plugin was not initialized!");
+
// Deinitialize all active devices.
for (int32_t DeviceId = 0; DeviceId < NumDevices; ++DeviceId) {
if (Devices[DeviceId]) {
@@ -1537,7 +1539,11 @@ Error GenericPluginTy::deinit() {
delete RecordReplay;
// Perform last deinitializations on the plugin.
- return deinitImpl();
+ if (Error Err = deinitImpl())
+ return Err;
+ Initialized = false;
+
+ return Plugin::success();
}
Error GenericPluginTy::initDevice(int32_t DeviceId) {
@@ -1599,8 +1605,7 @@ Expected<bool> GenericPluginTy::checkBitcodeImage(StringRef Image) const {
int32_t GenericPluginTy::is_initialized() const { return Initialized; }
-int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
- bool Initialized) {
+int32_t GenericPluginTy::is_plugin_compatible(__tgt_device_image *Image) {
StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
@@ -1618,11 +1623,43 @@ int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
auto MatchOrErr = checkELFImage(Buffer);
if (Error Err = MatchOrErr.takeError())
return HandleError(std::move(Err));
- if (!Initialized || !*MatchOrErr)
- return *MatchOrErr;
+ return *MatchOrErr;
+ }
+ case file_magic::bitcode: {
+ auto MatchOrErr = checkBitcodeImage(Buffer);
+ if (Error Err = MatchOrErr.takeError())
+ return HandleError(std::move(Err));
+ return *MatchOrErr;
+ }
+ default:
+ return false;
+ }
+}
+
+int32_t GenericPluginTy::is_device_compatible(int32_t DeviceId,
+ __tgt_device_image *Image) {
+ StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
+ target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
+
+ auto HandleError = [&](Error Err) -> bool {
+ [[maybe_unused]] std::string ErrStr = toString(std::move(Err));
+ DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str());
+ return false;
+ };
+ switch (identify_magic(Buffer)) {
+ case file_magic::elf:
+ case file_magic::elf_relocatable:
+ case file_magic::elf_executable:
+ case file_magic::elf_shared_object:
+ case file_magic::elf_core: {
+ auto MatchOrErr = checkELFImage(Buffer);
+ if (Error Err = MatchOrErr.takeError())
+ return HandleError(std::move(Err));
+ if (!*MatchOrErr)
+ return false;
// Perform plugin-dependent checks for the specific architecture if needed.
- auto CompatibleOrErr = isELFCompatible(Buffer);
+ auto CompatibleOrErr = isELFCompatible(DeviceId, Buffer);
if (Error Err = CompatibleOrErr.takeError())
return HandleError(std::move(Err));
return *CompatibleOrErr;
@@ -1638,6 +1675,10 @@ int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
}
}
+int32_t GenericPluginTy::is_device_initialized(int32_t DeviceId) const {
+ return isValidDeviceId(DeviceId) && Devices[DeviceId] != nullptr;
+}
+
int32_t GenericPluginTy::init_device(int32_t DeviceId) {
auto Err = initDevice(DeviceId);
if (Err) {
@@ -1985,8 +2026,9 @@ int32_t GenericPluginTy::init_device_info(int32_t DeviceId,
return OFFLOAD_SUCCESS;
}
-int32_t GenericPluginTy::set_device_offset(int32_t DeviceIdOffset) {
- setDeviceIdStartIndex(DeviceIdOffset);
+int32_t GenericPluginTy::set_device_identifier(int32_t UserId,
+ int32_t DeviceId) {
+ UserDeviceIds[DeviceId] = UserId;
return OFFLOAD_SUCCESS;
}