summaryrefslogtreecommitdiff
path: root/offload/plugins-nextgen/amdgpu/src/rtl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'offload/plugins-nextgen/amdgpu/src/rtl.cpp')
-rw-r--r--offload/plugins-nextgen/amdgpu/src/rtl.cpp49
1 files changed, 27 insertions, 22 deletions
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index 663cfdc5fdf0..e6643d3260eb 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -20,6 +20,7 @@
#include <unistd.h>
#include <unordered_map>
+#include "Shared/APITypes.h"
#include "Shared/Debug.h"
#include "Shared/Environment.h"
#include "Shared/Utils.h"
@@ -57,12 +58,12 @@
#endif
#if defined(__has_include)
-#if __has_include("hsa/hsa.h")
-#include "hsa/hsa.h"
-#include "hsa/hsa_ext_amd.h"
-#elif __has_include("hsa.h")
+#if __has_include("hsa.h")
#include "hsa.h"
#include "hsa_ext_amd.h"
+#elif __has_include("hsa/hsa.h")
+#include "hsa/hsa.h"
+#include "hsa/hsa_ext_amd.h"
#endif
#else
#include "hsa/hsa.h"
@@ -558,7 +559,8 @@ struct AMDGPUKernelTy : public GenericKernelTy {
/// Launch the AMDGPU kernel function.
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
- uint64_t NumBlocks, KernelArgsTy &KernelArgs, void *Args,
+ uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+ KernelLaunchParamsTy LaunchParams,
AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
/// Print more elaborate kernel launch info for AMDGPU
@@ -2802,9 +2804,10 @@ private:
AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);
KernelArgsTy KernelArgs = {};
- if (auto Err = AMDGPUKernel.launchImpl(*this, /*NumThread=*/1u,
- /*NumBlocks=*/1ul, KernelArgs,
- /*Args=*/nullptr, AsyncInfoWrapper))
+ if (auto Err =
+ AMDGPUKernel.launchImpl(*this, /*NumThread=*/1u,
+ /*NumBlocks=*/1ul, KernelArgs,
+ KernelLaunchParamsTy{}, AsyncInfoWrapper))
return Err;
Error Err = Plugin::success();
@@ -3265,11 +3268,11 @@ private:
Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
uint32_t NumThreads, uint64_t NumBlocks,
- KernelArgsTy &KernelArgs, void *Args,
+ KernelArgsTy &KernelArgs,
+ KernelLaunchParamsTy LaunchParams,
AsyncInfoWrapperTy &AsyncInfoWrapper) const {
- const uint32_t KernelArgsSize = KernelArgs.NumArgs * sizeof(void *);
-
- if (ArgsSize < KernelArgsSize)
+ if (ArgsSize != LaunchParams.Size &&
+ ArgsSize != LaunchParams.Size + getImplicitArgsSize())
return Plugin::error("Mismatch of kernel arguments size");
AMDGPUPluginTy &AMDGPUPlugin =
@@ -3292,20 +3295,21 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
if (auto Err = GenericDevice.getDeviceStackSize(StackSize))
return Err;
- // Initialize implicit arguments.
- utils::AMDGPUImplicitArgsTy *ImplArgs =
- reinterpret_cast<utils::AMDGPUImplicitArgsTy *>(
- advanceVoidPtr(AllArgs, KernelArgsSize));
+ utils::AMDGPUImplicitArgsTy *ImplArgs = nullptr;
+ if (ArgsSize == LaunchParams.Size + getImplicitArgsSize()) {
+ // Initialize implicit arguments.
+ ImplArgs = reinterpret_cast<utils::AMDGPUImplicitArgsTy *>(
+ advanceVoidPtr(AllArgs, LaunchParams.Size));
- // Initialize the implicit arguments to zero.
- std::memset(ImplArgs, 0, ImplicitArgsSize);
+ // Initialize the implicit arguments to zero.
+ std::memset(ImplArgs, 0, getImplicitArgsSize());
+ }
// Copy the explicit arguments.
// TODO: We should expose the args memory manager alloc to the common part as
// alternative to copying them twice.
- if (KernelArgs.NumArgs)
- std::memcpy(AllArgs, *static_cast<void **>(Args),
- sizeof(void *) * KernelArgs.NumArgs);
+ if (LaunchParams.Size)
+ std::memcpy(AllArgs, LaunchParams.Data, LaunchParams.Size);
AMDGPUDeviceTy &AMDGPUDevice = static_cast<AMDGPUDeviceTy &>(GenericDevice);
@@ -3318,7 +3322,8 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
Stream->setRPCServer(GenericDevice.getRPCServer());
// Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
- if (getImplicitArgsSize() == sizeof(utils::AMDGPUImplicitArgsTy)) {
+ if (ImplArgs &&
+ getImplicitArgsSize() == sizeof(utils::AMDGPUImplicitArgsTy)) {
ImplArgs->BlockCountX = NumBlocks;
ImplArgs->BlockCountY = 1;
ImplArgs->BlockCountZ = 1;