diff options
Diffstat (limited to 'offload/libomptarget/device.cpp')
| -rw-r--r-- | offload/libomptarget/device.cpp | 84 |
1 files changed, 83 insertions, 1 deletions
diff --git a/offload/libomptarget/device.cpp b/offload/libomptarget/device.cpp index 6585286bf428..71423ae0c94d 100644 --- a/offload/libomptarget/device.cpp +++ b/offload/libomptarget/device.cpp @@ -37,6 +37,8 @@ using namespace llvm::omp::target::ompt; #endif +using namespace llvm::omp::target::plugin; + int HostDataToTargetTy::addEventIfNecessary(DeviceTy &Device, AsyncInfoTy &AsyncInfo) const { // First, check if the user disabled atomic map transfer/malloc/dealloc. @@ -97,7 +99,55 @@ llvm::Error DeviceTy::init() { return llvm::Error::success(); } -// Load binary to device. +// Extract the mapping of host function pointers to device function pointers +// from the entry table. Functions marked as 'indirect' in OpenMP will have +// offloading entries generated for them which map the host's function pointer +// to a global containing the corresponding function pointer on the device. +static llvm::Expected<std::pair<void *, uint64_t>> +setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image, + __tgt_device_binary Binary) { + AsyncInfoTy AsyncInfo(Device); + llvm::ArrayRef<llvm::offloading::EntryTy> Entries(Image->EntriesBegin, + Image->EntriesEnd); + llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable; + for (const auto &Entry : Entries) { + if (Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP || + Entry.Size == 0 || !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT)) + continue; + + assert(Entry.Size == sizeof(void *) && "Global not a function pointer?"); + auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back(); + + void *Ptr; + if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr)) + return error::createOffloadError(error::ErrorCode::INVALID_BINARY, + "failed to load %s", Entry.SymbolName); + + HstPtr = Entry.Address; + if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo)) + return error::createOffloadError(error::ErrorCode::INVALID_BINARY, + "failed to load %s", Entry.SymbolName); + } + + // If we do not have any indirect globals we exit early. + if (IndirectCallTable.empty()) + return std::pair{nullptr, 0}; + + // Sort the array to allow for more efficient lookup of device pointers. + llvm::sort(IndirectCallTable, + [](const auto &x, const auto &y) { return x.first < y.first; }); + + uint64_t TableSize = + IndirectCallTable.size() * sizeof(std::pair<void *, void *>); + void *DevicePtr = Device.allocData(TableSize, nullptr, TARGET_ALLOC_DEVICE); + if (Device.submitData(DevicePtr, IndirectCallTable.data(), TableSize, + AsyncInfo)) + return error::createOffloadError(error::ErrorCode::INVALID_BINARY, + "failed to copy data"); + return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size()); +} + +// Load binary to device and perform global initialization if needed. llvm::Expected<__tgt_device_binary> DeviceTy::loadBinary(__tgt_device_image *Img) { __tgt_device_binary Binary; @@ -105,6 +155,38 @@ DeviceTy::loadBinary(__tgt_device_image *Img) { if (RTL->load_binary(RTLDeviceID, Img, &Binary) != OFFLOAD_SUCCESS) return error::createOffloadError(error::ErrorCode::INVALID_BINARY, "failed to load binary %p", Img); + + // This symbol is optional. + void *DeviceEnvironmentPtr; + if (RTL->get_global(Binary, sizeof(DeviceEnvironmentTy), + "__omp_rtl_device_environment", &DeviceEnvironmentPtr)) + return Binary; + + // Obtain a table mapping host function pointers to device function pointers. + auto CallTablePairOrErr = setupIndirectCallTable(*this, Img, Binary); + if (!CallTablePairOrErr) + return CallTablePairOrErr.takeError(); + + GenericDeviceTy &GenericDevice = RTL->getDevice(RTLDeviceID); + DeviceEnvironmentTy DeviceEnvironment; + DeviceEnvironment.DeviceDebugKind = GenericDevice.getDebugKind(); + DeviceEnvironment.NumDevices = RTL->getNumDevices(); + // TODO: The device ID used here is not the real device ID used by OpenMP. + DeviceEnvironment.DeviceNum = RTLDeviceID; + DeviceEnvironment.DynamicMemSize = GenericDevice.getDynamicMemorySize(); + DeviceEnvironment.ClockFrequency = GenericDevice.getClockFrequency(); + DeviceEnvironment.IndirectCallTable = + reinterpret_cast<uintptr_t>(CallTablePairOrErr->first); + DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second; + DeviceEnvironment.HardwareParallelism = + GenericDevice.getHardwareParallelism(); + + AsyncInfoTy AsyncInfo(*this); + if (submitData(DeviceEnvironmentPtr, &DeviceEnvironment, + sizeof(DeviceEnvironment), AsyncInfo)) + return error::createOffloadError(error::ErrorCode::INVALID_BINARY, + "failed to copy data"); + return Binary; } |
