diff options
Diffstat (limited to 'offload/plugins-nextgen')
| -rw-r--r-- | offload/plugins-nextgen/amdgpu/src/rtl.cpp | 74 | ||||
| -rw-r--r-- | offload/plugins-nextgen/common/include/PluginInterface.h | 11 | ||||
| -rw-r--r-- | offload/plugins-nextgen/common/src/PluginInterface.cpp | 20 |
3 files changed, 87 insertions, 18 deletions
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp index a7723b859881..0b03ef534d27 100644 --- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp @@ -923,6 +923,10 @@ private: /// devices. This class relies on signals to implement streams and define the /// dependencies between asynchronous operations. struct AMDGPUStreamTy { +public: + /// Function pointer type for `pushHostCallback` + using HostFnType = void (*)(void *); + private: /// Utility struct holding arguments for async H2H memory copies. struct MemcpyArgsTy { @@ -1084,18 +1088,19 @@ private: /// Indicate to spread data transfers across all available SDMAs bool UseMultipleSdmaEngines; + struct CallbackDataType { + HostFnType UserFn; + void *UserData; + AMDGPUSignalTy *OutputSignal; + }; /// Wrapper function for implementing host callbacks - static void CallbackWrapper(AMDGPUSignalTy *InputSignal, - AMDGPUSignalTy *OutputSignal, - void (*Callback)(void *), void *UserData) { - // The wait call will not error in this context. - if (InputSignal) - if (auto Err = InputSignal->wait()) - reportFatalInternalError(std::move(Err)); - - Callback(UserData); - - OutputSignal->signal(); + static bool callbackWrapper([[maybe_unused]] hsa_signal_value_t Signal, + void *UserData) { + auto CallbackData = reinterpret_cast<CallbackDataType *>(UserData); + CallbackData->UserFn(CallbackData->UserData); + CallbackData->OutputSignal->signal(); + delete CallbackData; + return false; } /// Return the current number of asynchronous operations on the stream. @@ -1540,7 +1545,7 @@ public: OutputSignal->get()); } - Error pushHostCallback(void (*Callback)(void *), void *UserData) { + Error pushHostCallback(HostFnType Callback, void *UserData) { // Retrieve an available signal for the operation's output. AMDGPUSignalTy *OutputSignal = nullptr; if (auto Err = SignalManager.getResource(OutputSignal)) @@ -1556,12 +1561,21 @@ public: InputSignal = consume(OutputSignal).second; } - // "Leaking" the thread here is consistent with other work added to the - // queue. The input and output signals will remain valid until the output is - // signaled. - std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData) - .detach(); + auto *CallbackData = new CallbackDataType{Callback, UserData, OutputSignal}; + if (InputSignal && InputSignal->load()) { + hsa_status_t Status = hsa_amd_signal_async_handler( + InputSignal->get(), HSA_SIGNAL_CONDITION_EQ, 0, callbackWrapper, + CallbackData); + return Plugin::check(Status, "error in hsa_amd_signal_async_handler: %s"); + } + + // No dependencies - schedule it now. + // Using a seperate thread because this function should run asynchronously + // and not block the main thread. + std::thread([](void *CallbackData) { callbackWrapper(0, CallbackData); }, + CallbackData) + .detach(); return Plugin::success(); } @@ -2733,7 +2747,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { return Plugin::success(); } - Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData, + Error enqueueHostCallImpl(AMDGPUStreamTy::HostFnType Callback, void *UserData, AsyncInfoWrapperTy &AsyncInfo) override { AMDGPUStreamTy *Stream = nullptr; if (auto Err = getStream(AsyncInfo, Stream)) @@ -3048,6 +3062,30 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { return ((IsAPU || OMPX_ApuMaps) && IsXnackEnabled); } + Expected<bool> isAccessiblePtrImpl(const void *Ptr, size_t Size) override { + hsa_amd_pointer_info_t Info; + Info.size = sizeof(hsa_amd_pointer_info_t); + + hsa_agent_t *Agents = nullptr; + uint32_t Count = 0; + hsa_status_t Status = + hsa_amd_pointer_info(Ptr, &Info, malloc, &Count, &Agents); + + if (auto Err = Plugin::check(Status, "error in hsa_amd_pointer_info: %s")) + return std::move(Err); + + // Checks if the pointer is known by HSA and accessible by the device + for (uint32_t i = 0; i < Count; i++) { + if (Agents[i].handle == getAgent().handle) + return Info.sizeInBytes >= Size; + } + + // If the pointer is unknown to HSA it's assumed a host pointer + // in that case the device can access it on unified memory support is + // enabled + return IsXnackEnabled; + } + /// Getters and setters for stack and heap sizes. Error getDeviceStackSize(uint64_t &Value) override { Value = StackSize; diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h index 8c530bba3882..f9bff9abd903 100644 --- a/offload/plugins-nextgen/common/include/PluginInterface.h +++ b/offload/plugins-nextgen/common/include/PluginInterface.h @@ -1066,6 +1066,10 @@ struct GenericDeviceTy : public DeviceAllocatorTy { bool useAutoZeroCopy(); virtual bool useAutoZeroCopyImpl() { return false; } + /// Returns true if the plugin can guarantee that the associated + /// storage is accessible + Expected<bool> isAccessiblePtr(const void *Ptr, size_t Size); + virtual Expected<omp_interop_val_t *> createInterop(int32_t InteropType, interop_spec_t &InteropSpec) { return nullptr; @@ -1166,6 +1170,10 @@ private: /// Per device setting of MemoryManager's Threshold virtual size_t getMemoryManagerSizeThreshold() { return 0; } + virtual Expected<bool> isAccessiblePtrImpl(const void *Ptr, size_t Size) { + return false; + } + /// Environment variables defined by the OpenMP standard. Int32Envar OMP_TeamLimit; Int32Envar OMP_NumTeams; @@ -1492,6 +1500,9 @@ public: /// Returns if the plugin can support automatic copy. int32_t use_auto_zero_copy(int32_t DeviceId); + /// Returns if the associated storage is accessible for a given device. + int32_t is_accessible_ptr(int32_t DeviceId, const void *Ptr, size_t Size); + /// Look up a global symbol in the given binary. int32_t get_global(__tgt_device_binary Binary, uint64_t Size, const char *Name, void **DevicePtr); diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index db43cbe49cc2..36d643b65922 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -1599,6 +1599,10 @@ Error GenericDeviceTy::syncEvent(void *EventPtr) { bool GenericDeviceTy::useAutoZeroCopy() { return useAutoZeroCopyImpl(); } +Expected<bool> GenericDeviceTy::isAccessiblePtr(const void *Ptr, size_t Size) { + return isAccessiblePtrImpl(Ptr, Size); +} + Error GenericPluginTy::init() { if (Initialized) return Plugin::success(); @@ -2133,6 +2137,22 @@ int32_t GenericPluginTy::use_auto_zero_copy(int32_t DeviceId) { return getDevice(DeviceId).useAutoZeroCopy(); } +int32_t GenericPluginTy::is_accessible_ptr(int32_t DeviceId, const void *Ptr, + size_t Size) { + auto HandleError = [&](Error Err) -> bool { + [[maybe_unused]] std::string ErrStr = toString(std::move(Err)); + DP("Failure while checking accessibility of pointer %p for device %d: %s", + Ptr, DeviceId, ErrStr.c_str()); + return false; + }; + + auto AccessibleOrErr = getDevice(DeviceId).isAccessiblePtr(Ptr, Size); + if (Error Err = AccessibleOrErr.takeError()) + return HandleError(std::move(Err)); + + return *AccessibleOrErr; +} + int32_t GenericPluginTy::get_global(__tgt_device_binary Binary, uint64_t Size, const char *Name, void **DevicePtr) { assert(Binary.handle && "Invalid device binary handle"); |
