summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCallum Fare <callum@codeplay.com>2025-04-22 19:27:50 +0100
committerGitHub <noreply@github.com>2025-04-22 13:27:50 -0500
commit800d949bb315349a116a980e99d0f36645ffefd3 (patch)
treee13178e6c2d33cbe0235274e43a278d9b85e89d7
parentfcb309715e4bd46d96dda7bdf99291ebf394d130 (diff)
[Offload] Implement the remaining initial Offload API (#122106)
Implement the complete initial version of the Offload API, to the extent that is usable for simple offloading programs. Tested with a basic SYCL program. As far as possible, these are simple wrappers over existing functionality in the plugins. * Allocating and freeing memory (host, device, shared). * Creating a program * Creating a queue (wrapper over asynchronous stream resource) * Enqueuing memcpy operations * Enqueuing kernel executions * Waiting on (optional) output events from the enqueue operations * Waiting on a queue to finish Objects created with the API have reference counting semantics to handle their lifetime. They are created with an initial reference count of 1, which can be incremented and decremented with retain and release functions. They are freed when their reference count reaches 0. Platform and device objects are not reference counted, as they are expected to persist as long as the library is in use, and it's not meaningful for users to create or destroy them. Tests have been added to `offload.unittests`, including device code for testing program and kernel related functionality. The API should still be considered unstable and it's very likely we will need to change the existing entry points.
-rw-r--r--offload/liboffload/API/APIDefs.td2
-rw-r--r--offload/liboffload/API/Common.td28
-rw-r--r--offload/liboffload/API/Device.td37
-rw-r--r--offload/liboffload/API/Event.td31
-rw-r--r--offload/liboffload/API/Kernel.td61
-rw-r--r--offload/liboffload/API/Memory.td68
-rw-r--r--offload/liboffload/API/OffloadAPI.td5
-rw-r--r--offload/liboffload/API/Platform.td43
-rw-r--r--offload/liboffload/API/Program.td34
-rw-r--r--offload/liboffload/API/Queue.td42
-rw-r--r--offload/liboffload/API/README.md6
-rw-r--r--offload/liboffload/include/OffloadImpl.hpp1
-rw-r--r--offload/liboffload/include/generated/OffloadAPI.h703
-rw-r--r--offload/liboffload/include/generated/OffloadEntryPoints.inc763
-rw-r--r--offload/liboffload/include/generated/OffloadFuncs.inc34
-rw-r--r--offload/liboffload/include/generated/OffloadImplFuncDecls.inc52
-rw-r--r--offload/liboffload/include/generated/OffloadPrint.hpp380
-rw-r--r--offload/liboffload/src/OffloadImpl.cpp451
-rw-r--r--offload/liboffload/src/OffloadLib.cpp7
-rw-r--r--offload/test/tools/offload-tblgen/entry_points.td2
-rw-r--r--offload/test/tools/offload-tblgen/functions_ranged_param.td2
-rw-r--r--offload/test/tools/offload-tblgen/print_enum.td2
-rw-r--r--offload/test/tools/offload-tblgen/print_function.td2
-rw-r--r--offload/test/tools/offload-tblgen/type_tagged_enum.td4
-rw-r--r--offload/tools/offload-tblgen/APIGen.cpp24
-rw-r--r--offload/tools/offload-tblgen/EntryPointGen.cpp17
-rw-r--r--offload/tools/offload-tblgen/PrintGen.cpp70
-rw-r--r--offload/tools/offload-tblgen/RecordTypes.hpp20
-rw-r--r--offload/unittests/OffloadAPI/CMakeLists.txt24
-rw-r--r--offload/unittests/OffloadAPI/common/Environment.cpp166
-rw-r--r--offload/unittests/OffloadAPI/common/Environment.hpp7
-rw-r--r--offload/unittests/OffloadAPI/common/Fixtures.hpp86
-rw-r--r--offload/unittests/OffloadAPI/device/olGetDevice.cpp39
-rw-r--r--offload/unittests/OffloadAPI/device/olGetDeviceCount.cpp28
-rw-r--r--offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp6
-rw-r--r--offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp4
-rw-r--r--offload/unittests/OffloadAPI/device/olIterateDevices.cpp45
-rw-r--r--offload/unittests/OffloadAPI/device_code/CMakeLists.txt67
-rw-r--r--offload/unittests/OffloadAPI/device_code/bar.c5
-rw-r--r--offload/unittests/OffloadAPI/device_code/foo.c5
-rw-r--r--offload/unittests/OffloadAPI/kernel/olGetKernel.cpp30
-rw-r--r--offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp83
-rw-r--r--offload/unittests/OffloadAPI/memory/olMemAlloc.cpp45
-rw-r--r--offload/unittests/OffloadAPI/memory/olMemFree.cpp38
-rw-r--r--offload/unittests/OffloadAPI/memory/olMemcpy.cpp106
-rw-r--r--offload/unittests/OffloadAPI/platform/olGetPlatform.cpp28
-rw-r--r--offload/unittests/OffloadAPI/platform/olGetPlatformInfo.cpp2
-rw-r--r--offload/unittests/OffloadAPI/platform/olGetPlatformInfoSize.cpp2
-rw-r--r--offload/unittests/OffloadAPI/platform/olPlatformInfo.hpp1
-rw-r--r--offload/unittests/OffloadAPI/program/olCreateProgram.cpp27
-rw-r--r--offload/unittests/OffloadAPI/program/olDestroyProgram.cpp (renamed from offload/unittests/OffloadAPI/platform/olGetPlatformCount.cpp)14
-rw-r--r--offload/unittests/OffloadAPI/queue/olCreateQueue.cpp28
-rw-r--r--offload/unittests/OffloadAPI/queue/olDestroyQueue.cpp22
-rw-r--r--offload/unittests/OffloadAPI/queue/olWaitQueue.cpp17
54 files changed, 3016 insertions, 800 deletions
diff --git a/offload/liboffload/API/APIDefs.td b/offload/liboffload/API/APIDefs.td
index cee4adea1d9f..640932dcf846 100644
--- a/offload/liboffload/API/APIDefs.td
+++ b/offload/liboffload/API/APIDefs.td
@@ -199,7 +199,7 @@ class Typedef : APIObject { string value; }
class FptrTypedef : APIObject {
list<Param> params;
- list<Return> returns;
+ string return;
}
class Macro : APIObject {
diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td
index 5b19d1d47129..de7502b54061 100644
--- a/offload/liboffload/API/Common.td
+++ b/offload/liboffload/API/Common.td
@@ -62,6 +62,27 @@ def : Handle {
let desc = "Handle of context object";
}
+def : Handle {
+ let name = "ol_queue_handle_t";
+ let desc = "Handle of queue object";
+}
+
+def : Handle {
+ let name = "ol_event_handle_t";
+ let desc = "Handle of event object";
+}
+
+def : Handle {
+ let name = "ol_program_handle_t";
+ let desc = "Handle of program object";
+}
+
+def : Typedef {
+ let name = "ol_kernel_handle_t";
+ let desc = "Handle of kernel object";
+ let value = "void *";
+}
+
def : Enum {
let name = "ol_errc_t";
let desc = "Defines Return/Error codes";
@@ -69,12 +90,11 @@ def : Enum {
Etor<"SUCCESS", "Success">,
Etor<"INVALID_VALUE", "Invalid Value">,
Etor<"INVALID_PLATFORM", "Invalid platform">,
- Etor<"DEVICE_NOT_FOUND", "Device not found">,
Etor<"INVALID_DEVICE", "Invalid device">,
- Etor<"DEVICE_LOST", "Device hung, reset, was removed, or driver update occurred">,
- Etor<"UNINITIALIZED", "plugin is not initialized or specific entry-point is not implemented">,
+ Etor<"INVALID_QUEUE", "Invalid queue">,
+ Etor<"INVALID_EVENT", "Invalid event">,
+ Etor<"INVALID_KERNEL_NAME", "Named kernel not found in the program binary">,
Etor<"OUT_OF_RESOURCES", "Out of resources">,
- Etor<"UNSUPPORTED_VERSION", "generic error code for unsupported versions">,
Etor<"UNSUPPORTED_FEATURE", "generic error code for unsupported features">,
Etor<"INVALID_ARGUMENT", "generic error code for invalid arguments">,
Etor<"INVALID_NULL_HANDLE", "handle argument is not valid">,
diff --git a/offload/liboffload/API/Device.td b/offload/liboffload/API/Device.td
index 30c0b71fe7b3..28c96bb5d291 100644
--- a/offload/liboffload/API/Device.td
+++ b/offload/liboffload/API/Device.td
@@ -12,7 +12,7 @@
def : Enum {
let name = "ol_device_type_t";
- let desc = "Supported device types";
+ let desc = "Supported device types.";
let etors =[
Etor<"DEFAULT", "The default device type as preferred by the runtime">,
Etor<"ALL", "Devices of all types">,
@@ -23,7 +23,7 @@ def : Enum {
def : Enum {
let name = "ol_device_info_t";
- let desc = "Supported device info";
+ let desc = "Supported device info.";
let is_typed = 1;
let etors =[
TaggedEtor<"TYPE", "ol_device_type_t", "type of the device">,
@@ -34,39 +34,34 @@ def : Enum {
];
}
-def : Function {
- let name = "olGetDeviceCount";
- let desc = "Retrieves the number of available devices within a platform";
+def : FptrTypedef {
+ let name = "ol_device_iterate_cb_t";
+ let desc = "User-provided function to be used with `olIterateDevices`";
let params = [
- Param<"ol_platform_handle_t", "Platform", "handle of the platform instance", PARAM_IN>,
- Param<"uint32_t*", "NumDevices", "pointer to the number of devices.", PARAM_OUT>
+ Param<"ol_device_handle_t", "Device", "the device handle of the current iteration", PARAM_IN>,
+ Param<"void*", "UserData", "optional user data", PARAM_IN_OPTIONAL>
];
- let returns = [];
+ let return = "bool";
}
def : Function {
- let name = "olGetDevice";
- let desc = "Retrieves devices within a platform";
+ let name = "olIterateDevices";
+ let desc = "Iterates over all available devices, calling the callback for each device.";
let details = [
- "Multiple calls to this function will return identical device handles, in the same order.",
+ "If the user-provided callback returns `false`, the iteration is stopped."
];
let params = [
- Param<"ol_platform_handle_t", "Platform", "handle of the platform instance", PARAM_IN>,
- Param<"uint32_t", "NumEntries", "the number of devices to be added to phDevices, which must be greater than zero", PARAM_IN>,
- RangedParam<"ol_device_handle_t*", "Devices", "Array of device handles. "
- "If NumEntries is less than the number of devices available, then this function shall only retrieve that number of devices.", PARAM_OUT,
- Range<"0", "NumEntries">>
+ Param<"ol_device_iterate_cb_t", "Callback", "User-provided function called for each available device", PARAM_IN>,
+ Param<"void*", "UserData", "Optional user data to pass to the callback", PARAM_IN_OPTIONAL>
];
let returns = [
- Return<"OL_ERRC_INVALID_SIZE", [
- "`NumEntries == 0`"
- ]>
+ Return<"OL_ERRC_INVALID_DEVICE">
];
}
def : Function {
let name = "olGetDeviceInfo";
- let desc = "Queries the given property of the device";
+ let desc = "Queries the given property of the device.";
let details = [];
let params = [
Param<"ol_device_handle_t", "Device", "handle of the device instance", PARAM_IN>,
@@ -90,7 +85,7 @@ def : Function {
def : Function {
let name = "olGetDeviceInfoSize";
- let desc = "Returns the storage size of the given device query";
+ let desc = "Returns the storage size of the given device query.";
let details = [];
let params = [
Param<"ol_device_handle_t", "Device", "handle of the device instance", PARAM_IN>,
diff --git a/offload/liboffload/API/Event.td b/offload/liboffload/API/Event.td
new file mode 100644
index 000000000000..c9f79159cf26
--- /dev/null
+++ b/offload/liboffload/API/Event.td
@@ -0,0 +1,31 @@
+//===-- Event.td - Event definitions for Offload -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains Offload API definitions related to the event handle
+//
+//===----------------------------------------------------------------------===//
+
+def : Function {
+ let name = "olDestroyEvent";
+ let desc = "Destroy the event and free all underlying resources.";
+ let details = [];
+ let params = [
+ Param<"ol_event_handle_t", "Event", "handle of the event", PARAM_IN>
+ ];
+ let returns = [];
+}
+
+def : Function {
+ let name = "olWaitEvent";
+ let desc = "Wait for the event to be complete.";
+ let details = [];
+ let params = [
+ Param<"ol_event_handle_t", "Event", "handle of the event", PARAM_IN>
+ ];
+ let returns = [];
+}
diff --git a/offload/liboffload/API/Kernel.td b/offload/liboffload/API/Kernel.td
new file mode 100644
index 000000000000..247f9c1bf5b6
--- /dev/null
+++ b/offload/liboffload/API/Kernel.td
@@ -0,0 +1,61 @@
+//===-- Kernel.td - Kernel definitions for Offload ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains Offload API definitions related to the kernel handle
+//
+//===----------------------------------------------------------------------===//
+
+def : Function {
+ let name = "olGetKernel";
+ let desc = "Get a kernel from the function identified by `KernelName` in the given program.";
+ let details = [
+ "The kernel handle returned is owned by the device so does not need to be destroyed."
+ ];
+ let params = [
+ Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>,
+ Param<"const char*", "KernelName", "name of the kernel entry point in the program", PARAM_IN>,
+ Param<"ol_kernel_handle_t*", "Kernel", "output pointer for the fetched kernel", PARAM_OUT>
+ ];
+ let returns = [];
+}
+
+def : Struct {
+ let name = "ol_kernel_launch_size_args_t";
+ let desc = "Size-related arguments for a kernel launch.";
+ let members = [
+ StructMember<"size_t", "Dimensions", "Number of work dimensions">,
+ StructMember<"size_t", "NumGroupsX", "Number of work groups on the X dimension">,
+ StructMember<"size_t", "NumGroupsY", "Number of work groups on the Y dimension">,
+ StructMember<"size_t", "NumGroupsZ", "Number of work groups on the Z dimension">,
+ StructMember<"size_t", "GroupSizeX", "Size of a work group on the X dimension.">,
+ StructMember<"size_t", "GroupSizeY", "Size of a work group on the Y dimension.">,
+ StructMember<"size_t", "GroupSizeZ", "Size of a work group on the Z dimension.">,
+ StructMember<"size_t", "DynSharedMemory", "Size of dynamic shared memory in bytes.">
+ ];
+}
+
+def : Function {
+ let name = "olLaunchKernel";
+ let desc = "Enqueue a kernel launch with the specified size and parameters.";
+ let details = [
+ "If a queue is not specified, kernel execution happens synchronously"
+ ];
+ let params = [
+ Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN_OPTIONAL>,
+ Param<"ol_device_handle_t", "Device", "handle of the device to execute on", PARAM_IN>,
+ Param<"ol_kernel_handle_t", "Kernel", "handle of the kernel", PARAM_IN>,
+ Param<"const void*", "ArgumentsData", "pointer to the kernel argument struct", PARAM_IN>,
+ Param<"size_t", "ArgumentsSize", "size of the kernel argument struct", PARAM_IN>,
+ Param<"const ol_kernel_launch_size_args_t*", "LaunchSizeArgs", "pointer to the struct containing launch size parameters", PARAM_IN>,
+ Param<"ol_event_handle_t*", "EventOut", "optional recorded event for the enqueued operation", PARAM_OUT_OPTIONAL>
+ ];
+ let returns = [
+ Return<"OL_ERRC_INVALID_ARGUMENT", ["`Queue == NULL && EventOut != NULL`"]>,
+ Return<"OL_ERRC_INVALID_DEVICE", ["If Queue is non-null but does not belong to Device"]>,
+ ];
+}
diff --git a/offload/liboffload/API/Memory.td b/offload/liboffload/API/Memory.td
new file mode 100644
index 000000000000..9cd1ef6362e1
--- /dev/null
+++ b/offload/liboffload/API/Memory.td
@@ -0,0 +1,68 @@
+//===-- Memory.td - Memory definitions for Offload ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains Offload API definitions related to memory allocations
+//
+//===----------------------------------------------------------------------===//
+
+def : Enum {
+ let name = "ol_alloc_type_t";
+ let desc = "Represents the type of allocation made with olMemAlloc.";
+ let etors = [
+ Etor<"HOST", "Host allocation">,
+ Etor<"DEVICE", "Device allocation">,
+ Etor<"MANAGED", "Managed allocation">
+ ];
+}
+
+def : Function {
+ let name = "olMemAlloc";
+ let desc = "Creates a memory allocation on the specified device.";
+ let params = [
+ Param<"ol_device_handle_t", "Device", "handle of the device to allocate on", PARAM_IN>,
+ Param<"ol_alloc_type_t", "Type", "type of the allocation", PARAM_IN>,
+ Param<"size_t", "Size", "size of the allocation in bytes", PARAM_IN>,
+ Param<"void**", "AllocationOut", "output for the allocated pointer", PARAM_OUT>
+ ];
+ let returns = [
+ Return<"OL_ERRC_INVALID_SIZE", [
+ "`Size == 0`"
+ ]>
+ ];
+}
+
+def : Function {
+ let name = "olMemFree";
+ let desc = "Frees a memory allocation previously made by olMemAlloc.";
+ let params = [
+ Param<"void*", "Address", "address of the allocation to free", PARAM_IN>,
+ ];
+ let returns = [];
+}
+
+def : Function {
+ let name = "olMemcpy";
+ let desc = "Enqueue a memcpy operation.";
+ let details = [
+ "For host pointers, use the host device belonging to the OL_PLATFORM_BACKEND_HOST platform.",
+ "If a queue is specified, at least one device must be a non-host device",
+ "If a queue is not specified, the memcpy happens synchronously"
+ ];
+ let params = [
+ Param<"ol_queue_handle_t", "Queue", "handle of the queue.", PARAM_IN_OPTIONAL>,
+ Param<"void*", "DstPtr", "pointer to copy to", PARAM_IN>,
+ Param<"ol_device_handle_t", "DstDevice", "device that DstPtr belongs to", PARAM_IN>,
+ Param<"void*", "SrcPtr", "pointer to copy from", PARAM_IN>,
+ Param<"ol_device_handle_t", "SrcDevice", "device that SrcPtr belongs to", PARAM_IN>,
+ Param<"size_t", "Size", "size in bytes of data to copy", PARAM_IN>,
+ Param<"ol_event_handle_t*", "EventOut", "optional recorded event for the enqueued operation", PARAM_OUT_OPTIONAL>
+ ];
+ let returns = [
+ Return<"OL_ERRC_INVALID_ARGUMENT", ["`Queue == NULL && EventOut != NULL`"]>
+ ];
+}
diff --git a/offload/liboffload/API/OffloadAPI.td b/offload/liboffload/API/OffloadAPI.td
index 8a0c3c405812..f9829155b6ce 100644
--- a/offload/liboffload/API/OffloadAPI.td
+++ b/offload/liboffload/API/OffloadAPI.td
@@ -13,3 +13,8 @@ include "APIDefs.td"
include "Common.td"
include "Platform.td"
include "Device.td"
+include "Memory.td"
+include "Queue.td"
+include "Event.td"
+include "Program.td"
+include "Kernel.td"
diff --git a/offload/liboffload/API/Platform.td b/offload/liboffload/API/Platform.td
index 03e70cf96ac9..97c2cc2d0570 100644
--- a/offload/liboffload/API/Platform.td
+++ b/offload/liboffload/API/Platform.td
@@ -9,44 +9,10 @@
// This file contains Offload API definitions related to the Platform handle
//
//===----------------------------------------------------------------------===//
-def : Function {
- let name = "olGetPlatform";
- let desc = "Retrieves all available platforms";
- let details = [
- "Multiple calls to this function will return identical platforms handles, in the same order.",
- ];
- let params = [
- Param<"uint32_t", "NumEntries",
- "The number of platforms to be added to Platforms. NumEntries must be "
- "greater than zero.",
- PARAM_IN>,
- RangedParam<"ol_platform_handle_t*", "Platforms",
- "Array of handle of platforms. If NumEntries is less than the number of "
- "platforms available, then olGetPlatform shall only retrieve that "
- "number of platforms.",
- PARAM_OUT, Range<"0", "NumEntries">>
- ];
- let returns = [
- Return<"OL_ERRC_INVALID_SIZE", [
- "`NumEntries == 0`"
- ]>
- ];
-}
-
-def : Function {
- let name = "olGetPlatformCount";
- let desc = "Retrieves the number of available platforms";
- let params = [
- Param<"uint32_t*",
- "NumPlatforms", "returns the total number of platforms available.",
- PARAM_OUT>
- ];
- let returns = [];
-}
def : Enum {
let name = "ol_platform_info_t";
- let desc = "Supported platform info";
+ let desc = "Supported platform info.";
let is_typed = 1;
let etors = [
TaggedEtor<"NAME", "char[]", "The string denoting name of the platform. The size of the info needs to be dynamically queried.">,
@@ -58,17 +24,18 @@ def : Enum {
def : Enum {
let name = "ol_platform_backend_t";
- let desc = "Identifies the native backend of the platform";
+ let desc = "Identifies the native backend of the platform.";
let etors =[
Etor<"UNKNOWN", "The backend is not recognized">,
Etor<"CUDA", "The backend is CUDA">,
Etor<"AMDGPU", "The backend is AMDGPU">,
+ Etor<"HOST", "The backend is the host">,
];
}
def : Function {
let name = "olGetPlatformInfo";
- let desc = "Queries the given property of the platform";
+ let desc = "Queries the given property of the platform.";
let details = [
"`olGetPlatformInfoSize` can be used to query the storage size "
"required for the given query."
@@ -96,7 +63,7 @@ def : Function {
def : Function {
let name = "olGetPlatformInfoSize";
- let desc = "Returns the storage size of the given platform query";
+ let desc = "Returns the storage size of the given platform query.";
let details = [];
let params = [
Param<"ol_platform_handle_t", "Platform", "handle of the platform", PARAM_IN>,
diff --git a/offload/liboffload/API/Program.td b/offload/liboffload/API/Program.td
new file mode 100644
index 000000000000..8c88fe6e21e6
--- /dev/null
+++ b/offload/liboffload/API/Program.td
@@ -0,0 +1,34 @@
+//===-- Program.td - Program definitions for Offload -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains Offload API definitions related to the program handle
+//
+//===----------------------------------------------------------------------===//
+
+def : Function {
+ let name = "olCreateProgram";
+ let desc = "Create a program for the device from the binary image pointed to by `ProgData`.";
+ let details = [];
+ let params = [
+ Param<"ol_device_handle_t", "Device", "handle of the device", PARAM_IN>,
+ Param<"const void*", "ProgData", "pointer to the program binary data", PARAM_IN>,
+ Param<"size_t", "ProgDataSize", "size of the program binary in bytes", PARAM_IN>,
+ Param<"ol_program_handle_t*", "Program", "output pointer for the created program", PARAM_OUT>
+ ];
+ let returns = [];
+}
+
+def : Function {
+ let name = "olDestroyProgram";
+ let desc = "Destroy the program and free all underlying resources.";
+ let details = [];
+ let params = [
+ Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>
+ ];
+ let returns = [];
+}
diff --git a/offload/liboffload/API/Queue.td b/offload/liboffload/API/Queue.td
new file mode 100644
index 000000000000..b5bb619c5751
--- /dev/null
+++ b/offload/liboffload/API/Queue.td
@@ -0,0 +1,42 @@
+//===-- Queue.td - Queue definitions for Offload -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains Offload API definitions related to the queue handle
+//
+//===----------------------------------------------------------------------===//
+
+def : Function {
+ let name = "olCreateQueue";
+ let desc = "Create a queue for the given device.";
+ let details = [];
+ let params = [
+ Param<"ol_device_handle_t", "Device", "handle of the device", PARAM_IN>,
+ Param<"ol_queue_handle_t*", "Queue", "output pointer for the created queue", PARAM_OUT>
+ ];
+ let returns = [];
+}
+
+def : Function {
+ let name = "olDestroyQueue";
+ let desc = "Destroy the queue and free all underlying resources.";
+ let details = [];
+ let params = [
+ Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN>
+ ];
+ let returns = [];
+}
+
+def : Function {
+ let name = "olWaitQueue";
+ let desc = "Wait for the enqueued work on a queue to complete.";
+ let details = [];
+ let params = [
+ Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN>
+ ];
+ let returns = [];
+}
diff --git a/offload/liboffload/API/README.md b/offload/liboffload/API/README.md
index b59ac2782a2b..fda1ad39fa93 100644
--- a/offload/liboffload/API/README.md
+++ b/offload/liboffload/API/README.md
@@ -138,8 +138,8 @@ allow more backends to be easily added in future.
A new object can be added to the API by adding to one of the existing `.td`
files. It is also possible to add a new tablegen file to the API by adding it
-to the includes in `OffloadAPI.td`. When the offload target is rebuilt, the
-new definition will be included in the generated files.
+to the includes in `OffloadAPI.td`. When the `OffloadGenerate` target is
+rebuilt, the new definition will be included in the generated files.
### Adding a new entry point
@@ -147,4 +147,4 @@ When a new entry point is added (e.g. `offloadDeviceFoo`), the actual entry
point is automatically generated, which contains validation and tracing code.
It expects an implementation function (`offloadDeviceFoo_impl`) to be defined,
which it will call into. The definition of this implementation function should
-be added to `src/offload_impl.cpp`
+be added to `src/OffloadImpl.cpp`
diff --git a/offload/liboffload/include/OffloadImpl.hpp b/offload/liboffload/include/OffloadImpl.hpp
index 6d745095f310..ec470a355309 100644
--- a/offload/liboffload/include/OffloadImpl.hpp
+++ b/offload/liboffload/include/OffloadImpl.hpp
@@ -22,6 +22,7 @@
struct OffloadConfig {
bool TracingEnabled = false;
+ bool ValidationEnabled = true;
};
OffloadConfig &offloadConfig();
diff --git a/offload/liboffload/include/generated/OffloadAPI.h b/offload/liboffload/include/generated/OffloadAPI.h
index 11fcc96625ab..ace31c57cf2f 100644
--- a/offload/liboffload/include/generated/OffloadAPI.h
+++ b/offload/liboffload/include/generated/OffloadAPI.h
@@ -75,15 +75,31 @@ extern "C" {
///////////////////////////////////////////////////////////////////////////////
/// @brief Handle of a platform instance
-typedef struct ol_platform_handle_t_ *ol_platform_handle_t;
+typedef struct ol_platform_impl_t *ol_platform_handle_t;
///////////////////////////////////////////////////////////////////////////////
/// @brief Handle of platform's device object
-typedef struct ol_device_handle_t_ *ol_device_handle_t;
+typedef struct ol_device_impl_t *ol_device_handle_t;
///////////////////////////////////////////////////////////////////////////////
/// @brief Handle of context object
-typedef struct ol_context_handle_t_ *ol_context_handle_t;
+typedef struct ol_context_impl_t *ol_context_handle_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Handle of queue object
+typedef struct ol_queue_impl_t *ol_queue_handle_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Handle of event object
+typedef struct ol_event_impl_t *ol_event_handle_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Handle of program object
+typedef struct ol_program_impl_t *ol_program_handle_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Handle of kernel object
+typedef void *ol_kernel_handle_t;
///////////////////////////////////////////////////////////////////////////////
/// @brief Defines Return/Error codes
@@ -94,34 +110,32 @@ typedef enum ol_errc_t {
OL_ERRC_INVALID_VALUE = 1,
/// Invalid platform
OL_ERRC_INVALID_PLATFORM = 2,
- /// Device not found
- OL_ERRC_DEVICE_NOT_FOUND = 3,
/// Invalid device
- OL_ERRC_INVALID_DEVICE = 4,
- /// Device hung, reset, was removed, or driver update occurred
- OL_ERRC_DEVICE_LOST = 5,
- /// plugin is not initialized or specific entry-point is not implemented
- OL_ERRC_UNINITIALIZED = 6,
+ OL_ERRC_INVALID_DEVICE = 3,
+ /// Invalid queue
+ OL_ERRC_INVALID_QUEUE = 4,
+ /// Invalid event
+ OL_ERRC_INVALID_EVENT = 5,
+ /// Named kernel not found in the program binary
+ OL_ERRC_INVALID_KERNEL_NAME = 6,
/// Out of resources
OL_ERRC_OUT_OF_RESOURCES = 7,
- /// generic error code for unsupported versions
- OL_ERRC_UNSUPPORTED_VERSION = 8,
/// generic error code for unsupported features
- OL_ERRC_UNSUPPORTED_FEATURE = 9,
+ OL_ERRC_UNSUPPORTED_FEATURE = 8,
/// generic error code for invalid arguments
- OL_ERRC_INVALID_ARGUMENT = 10,
+ OL_ERRC_INVALID_ARGUMENT = 9,
/// handle argument is not valid
- OL_ERRC_INVALID_NULL_HANDLE = 11,
+ OL_ERRC_INVALID_NULL_HANDLE = 10,
/// pointer argument may not be nullptr
- OL_ERRC_INVALID_NULL_POINTER = 12,
+ OL_ERRC_INVALID_NULL_POINTER = 11,
/// invalid size or dimensions (e.g., must not be zero, or is out of bounds)
- OL_ERRC_INVALID_SIZE = 13,
+ OL_ERRC_INVALID_SIZE = 12,
/// enumerator argument is not valid
- OL_ERRC_INVALID_ENUMERATION = 14,
+ OL_ERRC_INVALID_ENUMERATION = 13,
/// enumerator argument is not supported by the device
- OL_ERRC_UNSUPPORTED_ENUMERATION = 15,
+ OL_ERRC_UNSUPPORTED_ENUMERATION = 14,
/// Unknown or internal error
- OL_ERRC_UNKNOWN = 16,
+ OL_ERRC_UNKNOWN = 15,
/// @cond
OL_ERRC_FORCE_UINT32 = 0x7fffffff
/// @endcond
@@ -188,48 +202,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olInit();
OL_APIEXPORT ol_result_t OL_APICALL olShutDown();
///////////////////////////////////////////////////////////////////////////////
-/// @brief Retrieves all available platforms
-///
-/// @details
-/// - Multiple calls to this function will return identical platforms
-/// handles, in the same order.
-///
-/// @returns
-/// - ::OL_RESULT_SUCCESS
-/// - ::OL_ERRC_UNINITIALIZED
-/// - ::OL_ERRC_DEVICE_LOST
-/// - ::OL_ERRC_INVALID_SIZE
-/// + `NumEntries == 0`
-/// - ::OL_ERRC_INVALID_NULL_HANDLE
-/// - ::OL_ERRC_INVALID_NULL_POINTER
-/// + `NULL == Platforms`
-OL_APIEXPORT ol_result_t OL_APICALL olGetPlatform(
- // [in] The number of platforms to be added to Platforms. NumEntries must be
- // greater than zero.
- uint32_t NumEntries,
- // [out] Array of handle of platforms. If NumEntries is less than the number
- // of platforms available, then olGetPlatform shall only retrieve that
- // number of platforms.
- ol_platform_handle_t *Platforms);
-
-///////////////////////////////////////////////////////////////////////////////
-/// @brief Retrieves the number of available platforms
-///
-/// @details
-///
-/// @returns
-/// - ::OL_RESULT_SUCCESS
-/// - ::OL_ERRC_UNINITIALIZED
-/// - ::OL_ERRC_DEVICE_LOST
-/// - ::OL_ERRC_INVALID_NULL_HANDLE
-/// - ::OL_ERRC_INVALID_NULL_POINTER
-/// + `NULL == NumPlatforms`
-OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformCount(
- // [out] returns the total number of platforms available.
- uint32_t *NumPlatforms);
-
-///////////////////////////////////////////////////////////////////////////////
-/// @brief Supported platform info
+/// @brief Supported platform info.
typedef enum ol_platform_info_t {
/// [char[]] The string denoting name of the platform. The size of the info
/// needs to be dynamically queried.
@@ -249,7 +222,7 @@ typedef enum ol_platform_info_t {
} ol_platform_info_t;
///////////////////////////////////////////////////////////////////////////////
-/// @brief Identifies the native backend of the platform
+/// @brief Identifies the native backend of the platform.
typedef enum ol_platform_backend_t {
/// The backend is not recognized
OL_PLATFORM_BACKEND_UNKNOWN = 0,
@@ -257,6 +230,8 @@ typedef enum ol_platform_backend_t {
OL_PLATFORM_BACKEND_CUDA = 1,
/// The backend is AMDGPU
OL_PLATFORM_BACKEND_AMDGPU = 2,
+ /// The backend is the host
+ OL_PLATFORM_BACKEND_HOST = 3,
/// @cond
OL_PLATFORM_BACKEND_FORCE_UINT32 = 0x7fffffff
/// @endcond
@@ -264,7 +239,7 @@ typedef enum ol_platform_backend_t {
} ol_platform_backend_t;
///////////////////////////////////////////////////////////////////////////////
-/// @brief Queries the given property of the platform
+/// @brief Queries the given property of the platform.
///
/// @details
/// - `olGetPlatformInfoSize` can be used to query the storage size required
@@ -298,7 +273,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformInfo(
void *PropValue);
///////////////////////////////////////////////////////////////////////////////
-/// @brief Returns the storage size of the given platform query
+/// @brief Returns the storage size of the given platform query.
///
/// @details
///
@@ -322,7 +297,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformInfoSize(
size_t *PropSizeRet);
///////////////////////////////////////////////////////////////////////////////
-/// @brief Supported device types
+/// @brief Supported device types.
typedef enum ol_device_type_t {
/// The default device type as preferred by the runtime
OL_DEVICE_TYPE_DEFAULT = 0,
@@ -339,7 +314,7 @@ typedef enum ol_device_type_t {
} ol_device_type_t;
///////////////////////////////////////////////////////////////////////////////
-/// @brief Supported device info
+/// @brief Supported device info.
typedef enum ol_device_info_t {
/// [ol_device_type_t] type of the device
OL_DEVICE_INFO_TYPE = 0,
@@ -358,54 +333,36 @@ typedef enum ol_device_info_t {
} ol_device_info_t;
///////////////////////////////////////////////////////////////////////////////
-/// @brief Retrieves the number of available devices within a platform
-///
-/// @details
-///
-/// @returns
-/// - ::OL_RESULT_SUCCESS
-/// - ::OL_ERRC_UNINITIALIZED
-/// - ::OL_ERRC_DEVICE_LOST
-/// - ::OL_ERRC_INVALID_NULL_HANDLE
-/// + `NULL == Platform`
-/// - ::OL_ERRC_INVALID_NULL_POINTER
-/// + `NULL == NumDevices`
-OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceCount(
- // [in] handle of the platform instance
- ol_platform_handle_t Platform,
- // [out] pointer to the number of devices.
- uint32_t *NumDevices);
+/// @brief User-provided function to be used with `olIterateDevices`
+typedef bool (*ol_device_iterate_cb_t)(
+ // the device handle of the current iteration
+ ol_device_handle_t Device,
+ // optional user data
+ void *UserData);
///////////////////////////////////////////////////////////////////////////////
-/// @brief Retrieves devices within a platform
+/// @brief Iterates over all available devices, calling the callback for each
+/// device.
///
/// @details
-/// - Multiple calls to this function will return identical device handles,
-/// in the same order.
+/// - If the user-provided callback returns `false`, the iteration is
+/// stopped.
///
/// @returns
/// - ::OL_RESULT_SUCCESS
/// - ::OL_ERRC_UNINITIALIZED
/// - ::OL_ERRC_DEVICE_LOST
-/// - ::OL_ERRC_INVALID_SIZE
-/// + `NumEntries == 0`
+/// - ::OL_ERRC_INVALID_DEVICE
/// - ::OL_ERRC_INVALID_NULL_HANDLE
-/// + `NULL == Platform`
/// - ::OL_ERRC_INVALID_NULL_POINTER
-/// + `NULL == Devices`
-OL_APIEXPORT ol_result_t OL_APICALL olGetDevice(
- // [in] handle of the platform instance
- ol_platform_handle_t Platform,
- // [in] the number of devices to be added to phDevices, which must be
- // greater than zero
- uint32_t NumEntries,
- // [out] Array of device handles. If NumEntries is less than the number of
- // devices available, then this function shall only retrieve that number of
- // devices.
- ol_device_handle_t *Devices);
+OL_APIEXPORT ol_result_t OL_APICALL olIterateDevices(
+ // [in] User-provided function called for each available device
+ ol_device_iterate_cb_t Callback,
+ // [in][optional] Optional user data to pass to the callback
+ void *UserData);
///////////////////////////////////////////////////////////////////////////////
-/// @brief Queries the given property of the device
+/// @brief Queries the given property of the device.
///
/// @details
///
@@ -437,7 +394,7 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfo(
void *PropValue);
///////////////////////////////////////////////////////////////////////////////
-/// @brief Returns the storage size of the given device query
+/// @brief Returns the storage size of the given device query.
///
/// @details
///
@@ -461,19 +418,294 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfoSize(
size_t *PropSizeRet);
///////////////////////////////////////////////////////////////////////////////
-/// @brief Function parameters for olGetPlatform
-/// @details Each entry is a pointer to the parameter passed to the function;
-typedef struct ol_get_platform_params_t {
- uint32_t *pNumEntries;
- ol_platform_handle_t **pPlatforms;
-} ol_get_platform_params_t;
+/// @brief Represents the type of allocation made with olMemAlloc.
+typedef enum ol_alloc_type_t {
+ /// Host allocation
+ OL_ALLOC_TYPE_HOST = 0,
+ /// Device allocation
+ OL_ALLOC_TYPE_DEVICE = 1,
+ /// Managed allocation
+ OL_ALLOC_TYPE_MANAGED = 2,
+ /// @cond
+ OL_ALLOC_TYPE_FORCE_UINT32 = 0x7fffffff
+ /// @endcond
+
+} ol_alloc_type_t;
///////////////////////////////////////////////////////////////////////////////
-/// @brief Function parameters for olGetPlatformCount
-/// @details Each entry is a pointer to the parameter passed to the function;
-typedef struct ol_get_platform_count_params_t {
- uint32_t **pNumPlatforms;
-} ol_get_platform_count_params_t;
+/// @brief Creates a memory allocation on the specified device.
+///
+/// @details
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_SIZE
+/// + `Size == 0`
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == Device`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+/// + `NULL == AllocationOut`
+OL_APIEXPORT ol_result_t OL_APICALL olMemAlloc(
+ // [in] handle of the device to allocate on
+ ol_device_handle_t Device,
+ // [in] type of the allocation
+ ol_alloc_type_t Type,
+ // [in] size of the allocation in bytes
+ size_t Size,
+ // [out] output for the allocated pointer
+ void **AllocationOut);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Frees a memory allocation previously made by olMemAlloc.
+///
+/// @details
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+/// + `NULL == Address`
+OL_APIEXPORT ol_result_t OL_APICALL olMemFree(
+ // [in] address of the allocation to free
+ void *Address);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Enqueue a memcpy operation.
+///
+/// @details
+/// - For host pointers, use the device returned by olGetHostDevice
+/// - If a queue is specified, at least one device must be a non-host device
+/// - If a queue is not specified, the memcpy happens synchronously
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_ARGUMENT
+/// + `Queue == NULL && EventOut != NULL`
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == DstDevice`
+/// + `NULL == SrcDevice`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+/// + `NULL == DstPtr`
+/// + `NULL == SrcPtr`
+OL_APIEXPORT ol_result_t OL_APICALL olMemcpy(
+ // [in][optional] handle of the queue.
+ ol_queue_handle_t Queue,
+ // [in] pointer to copy to
+ void *DstPtr,
+ // [in] device that DstPtr belongs to
+ ol_device_handle_t DstDevice,
+ // [in] pointer to copy from
+ void *SrcPtr,
+ // [in] device that SrcPtr belongs to
+ ol_device_handle_t SrcDevice,
+ // [in] size in bytes of data to copy
+ size_t Size,
+ // [out][optional] optional recorded event for the enqueued operation
+ ol_event_handle_t *EventOut);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Create a queue for the given device.
+///
+/// @details
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == Device`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+/// + `NULL == Queue`
+OL_APIEXPORT ol_result_t OL_APICALL olCreateQueue(
+ // [in] handle of the device
+ ol_device_handle_t Device,
+ // [out] output pointer for the created queue
+ ol_queue_handle_t *Queue);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Destroy the queue and free all underlying resources.
+///
+/// @details
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == Queue`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+OL_APIEXPORT ol_result_t OL_APICALL olDestroyQueue(
+ // [in] handle of the queue
+ ol_queue_handle_t Queue);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Wait for the enqueued work on a queue to complete.
+///
+/// @details
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == Queue`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+OL_APIEXPORT ol_result_t OL_APICALL olWaitQueue(
+ // [in] handle of the queue
+ ol_queue_handle_t Queue);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Destroy the event and free all underlying resources.
+///
+/// @details
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == Event`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+OL_APIEXPORT ol_result_t OL_APICALL olDestroyEvent(
+ // [in] handle of the event
+ ol_event_handle_t Event);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Wait for the event to be complete.
+///
+/// @details
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == Event`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+OL_APIEXPORT ol_result_t OL_APICALL olWaitEvent(
+ // [in] handle of the event
+ ol_event_handle_t Event);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Create a program for the device from the binary image pointed to by
+/// `ProgData`.
+///
+/// @details
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == Device`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+/// + `NULL == ProgData`
+/// + `NULL == Program`
+OL_APIEXPORT ol_result_t OL_APICALL olCreateProgram(
+ // [in] handle of the device
+ ol_device_handle_t Device,
+ // [in] pointer to the program binary data
+ const void *ProgData,
+ // [in] size of the program binary in bytes
+ size_t ProgDataSize,
+ // [out] output pointer for the created program
+ ol_program_handle_t *Program);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Destroy the program and free all underlying resources.
+///
+/// @details
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == Program`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+OL_APIEXPORT ol_result_t OL_APICALL olDestroyProgram(
+ // [in] handle of the program
+ ol_program_handle_t Program);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Get a kernel from the function identified by `KernelName` in the
+/// given program.
+///
+/// @details
+/// - The kernel handle returned is owned by the device so does not need to
+/// be destroyed.
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == Program`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+/// + `NULL == KernelName`
+/// + `NULL == Kernel`
+OL_APIEXPORT ol_result_t OL_APICALL olGetKernel(
+ // [in] handle of the program
+ ol_program_handle_t Program,
+ // [in] name of the kernel entry point in the program
+ const char *KernelName,
+ // [out] output pointer for the fetched kernel
+ ol_kernel_handle_t *Kernel);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Size-related arguments for a kernel launch.
+typedef struct ol_kernel_launch_size_args_t {
+ size_t Dimensions; /// Number of work dimensions
+ size_t NumGroupsX; /// Number of work groups on the X dimension
+ size_t NumGroupsY; /// Number of work groups on the Y dimension
+ size_t NumGroupsZ; /// Number of work groups on the Z dimension
+ size_t GroupSizeX; /// Size of a work group on the X dimension.
+ size_t GroupSizeY; /// Size of a work group on the Y dimension.
+ size_t GroupSizeZ; /// Size of a work group on the Z dimension.
+ size_t DynSharedMemory; /// Size of dynamic shared memory in bytes.
+} ol_kernel_launch_size_args_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Enqueue a kernel launch with the specified size and parameters.
+///
+/// @details
+/// - If a queue is not specified, kernel execution happens synchronously
+///
+/// @returns
+/// - ::OL_RESULT_SUCCESS
+/// - ::OL_ERRC_UNINITIALIZED
+/// - ::OL_ERRC_DEVICE_LOST
+/// - ::OL_ERRC_INVALID_ARGUMENT
+/// + `Queue == NULL && EventOut != NULL`
+/// - ::OL_ERRC_INVALID_DEVICE
+/// + If Queue is non-null but does not belong to Device
+/// - ::OL_ERRC_INVALID_NULL_HANDLE
+/// + `NULL == Device`
+/// + `NULL == Kernel`
+/// - ::OL_ERRC_INVALID_NULL_POINTER
+/// + `NULL == ArgumentsData`
+/// + `NULL == LaunchSizeArgs`
+OL_APIEXPORT ol_result_t OL_APICALL olLaunchKernel(
+ // [in][optional] handle of the queue
+ ol_queue_handle_t Queue,
+ // [in] handle of the device to execute on
+ ol_device_handle_t Device,
+ // [in] handle of the kernel
+ ol_kernel_handle_t Kernel,
+ // [in] pointer to the kernel argument struct
+ const void *ArgumentsData,
+ // [in] size of the kernel argument struct
+ size_t ArgumentsSize,
+ // [in] pointer to the struct containing launch size parameters
+ const ol_kernel_launch_size_args_t *LaunchSizeArgs,
+ // [out][optional] optional recorded event for the enqueued operation
+ ol_event_handle_t *EventOut);
///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for olGetPlatformInfo
@@ -495,21 +727,12 @@ typedef struct ol_get_platform_info_size_params_t {
} ol_get_platform_info_size_params_t;
///////////////////////////////////////////////////////////////////////////////
-/// @brief Function parameters for olGetDeviceCount
+/// @brief Function parameters for olIterateDevices
/// @details Each entry is a pointer to the parameter passed to the function;
-typedef struct ol_get_device_count_params_t {
- ol_platform_handle_t *pPlatform;
- uint32_t **pNumDevices;
-} ol_get_device_count_params_t;
-
-///////////////////////////////////////////////////////////////////////////////
-/// @brief Function parameters for olGetDevice
-/// @details Each entry is a pointer to the parameter passed to the function;
-typedef struct ol_get_device_params_t {
- ol_platform_handle_t *pPlatform;
- uint32_t *pNumEntries;
- ol_device_handle_t **pDevices;
-} ol_get_device_params_t;
+typedef struct ol_iterate_devices_params_t {
+ ol_device_iterate_cb_t *pCallback;
+ void **pUserData;
+} ol_iterate_devices_params_t;
///////////////////////////////////////////////////////////////////////////////
/// @brief Function parameters for olGetDeviceInfo
@@ -531,6 +754,111 @@ typedef struct ol_get_device_info_size_params_t {
} ol_get_device_info_size_params_t;
///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olMemAlloc
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_mem_alloc_params_t {
+ ol_device_handle_t *pDevice;
+ ol_alloc_type_t *pType;
+ size_t *pSize;
+ void ***pAllocationOut;
+} ol_mem_alloc_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olMemFree
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_mem_free_params_t {
+ void **pAddress;
+} ol_mem_free_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olMemcpy
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_memcpy_params_t {
+ ol_queue_handle_t *pQueue;
+ void **pDstPtr;
+ ol_device_handle_t *pDstDevice;
+ void **pSrcPtr;
+ ol_device_handle_t *pSrcDevice;
+ size_t *pSize;
+ ol_event_handle_t **pEventOut;
+} ol_memcpy_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olCreateQueue
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_create_queue_params_t {
+ ol_device_handle_t *pDevice;
+ ol_queue_handle_t **pQueue;
+} ol_create_queue_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olDestroyQueue
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_destroy_queue_params_t {
+ ol_queue_handle_t *pQueue;
+} ol_destroy_queue_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olWaitQueue
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_wait_queue_params_t {
+ ol_queue_handle_t *pQueue;
+} ol_wait_queue_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olDestroyEvent
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_destroy_event_params_t {
+ ol_event_handle_t *pEvent;
+} ol_destroy_event_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olWaitEvent
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_wait_event_params_t {
+ ol_event_handle_t *pEvent;
+} ol_wait_event_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olCreateProgram
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_create_program_params_t {
+ ol_device_handle_t *pDevice;
+ const void **pProgData;
+ size_t *pProgDataSize;
+ ol_program_handle_t **pProgram;
+} ol_create_program_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olDestroyProgram
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_destroy_program_params_t {
+ ol_program_handle_t *pProgram;
+} ol_destroy_program_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olGetKernel
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_get_kernel_params_t {
+ ol_program_handle_t *pProgram;
+ const char **pKernelName;
+ ol_kernel_handle_t **pKernel;
+} ol_get_kernel_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Function parameters for olLaunchKernel
+/// @details Each entry is a pointer to the parameter passed to the function;
+typedef struct ol_launch_kernel_params_t {
+ ol_queue_handle_t *pQueue;
+ ol_device_handle_t *pDevice;
+ ol_kernel_handle_t *pKernel;
+ const void **pArgumentsData;
+ size_t *pArgumentsSize;
+ const ol_kernel_launch_size_args_t **pLaunchSizeArgs;
+ ol_event_handle_t **pEventOut;
+} ol_launch_kernel_params_t;
+
+///////////////////////////////////////////////////////////////////////////////
/// @brief Variant of olInit that also sets source code location information
/// @details See also ::olInit
OL_APIEXPORT ol_result_t OL_APICALL
@@ -543,21 +871,6 @@ OL_APIEXPORT ol_result_t OL_APICALL
olShutDownWithCodeLoc(ol_code_location_t *CodeLocation);
///////////////////////////////////////////////////////////////////////////////
-/// @brief Variant of olGetPlatform that also sets source code location
-/// information
-/// @details See also ::olGetPlatform
-OL_APIEXPORT ol_result_t OL_APICALL
-olGetPlatformWithCodeLoc(uint32_t NumEntries, ol_platform_handle_t *Platforms,
- ol_code_location_t *CodeLocation);
-
-///////////////////////////////////////////////////////////////////////////////
-/// @brief Variant of olGetPlatformCount that also sets source code location
-/// information
-/// @details See also ::olGetPlatformCount
-OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformCountWithCodeLoc(
- uint32_t *NumPlatforms, ol_code_location_t *CodeLocation);
-
-///////////////////////////////////////////////////////////////////////////////
/// @brief Variant of olGetPlatformInfo that also sets source code location
/// information
/// @details See also ::olGetPlatformInfo
@@ -574,22 +887,14 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformInfoSizeWithCodeLoc(
size_t *PropSizeRet, ol_code_location_t *CodeLocation);
///////////////////////////////////////////////////////////////////////////////
-/// @brief Variant of olGetDeviceCount that also sets source code location
+/// @brief Variant of olIterateDevices that also sets source code location
/// information
-/// @details See also ::olGetDeviceCount
+/// @details See also ::olIterateDevices
OL_APIEXPORT ol_result_t OL_APICALL
-olGetDeviceCountWithCodeLoc(ol_platform_handle_t Platform, uint32_t *NumDevices,
+olIterateDevicesWithCodeLoc(ol_device_iterate_cb_t Callback, void *UserData,
ol_code_location_t *CodeLocation);
///////////////////////////////////////////////////////////////////////////////
-/// @brief Variant of olGetDevice that also sets source code location
-/// information
-/// @details See also ::olGetDevice
-OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceWithCodeLoc(
- ol_platform_handle_t Platform, uint32_t NumEntries,
- ol_device_handle_t *Devices, ol_code_location_t *CodeLocation);
-
-///////////////////////////////////////////////////////////////////////////////
/// @brief Variant of olGetDeviceInfo that also sets source code location
/// information
/// @details See also ::olGetDeviceInfo
@@ -605,6 +910,96 @@ OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfoSizeWithCodeLoc(
ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet,
ol_code_location_t *CodeLocation);
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olMemAlloc that also sets source code location information
+/// @details See also ::olMemAlloc
+OL_APIEXPORT ol_result_t OL_APICALL olMemAllocWithCodeLoc(
+ ol_device_handle_t Device, ol_alloc_type_t Type, size_t Size,
+ void **AllocationOut, ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olMemFree that also sets source code location information
+/// @details See also ::olMemFree
+OL_APIEXPORT ol_result_t OL_APICALL
+olMemFreeWithCodeLoc(void *Address, ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olMemcpy that also sets source code location information
+/// @details See also ::olMemcpy
+OL_APIEXPORT ol_result_t OL_APICALL olMemcpyWithCodeLoc(
+ ol_queue_handle_t Queue, void *DstPtr, ol_device_handle_t DstDevice,
+ void *SrcPtr, ol_device_handle_t SrcDevice, size_t Size,
+ ol_event_handle_t *EventOut, ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olCreateQueue that also sets source code location
+/// information
+/// @details See also ::olCreateQueue
+OL_APIEXPORT ol_result_t OL_APICALL
+olCreateQueueWithCodeLoc(ol_device_handle_t Device, ol_queue_handle_t *Queue,
+ ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olDestroyQueue that also sets source code location
+/// information
+/// @details See also ::olDestroyQueue
+OL_APIEXPORT ol_result_t OL_APICALL olDestroyQueueWithCodeLoc(
+ ol_queue_handle_t Queue, ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olWaitQueue that also sets source code location
+/// information
+/// @details See also ::olWaitQueue
+OL_APIEXPORT ol_result_t OL_APICALL olWaitQueueWithCodeLoc(
+ ol_queue_handle_t Queue, ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olDestroyEvent that also sets source code location
+/// information
+/// @details See also ::olDestroyEvent
+OL_APIEXPORT ol_result_t OL_APICALL olDestroyEventWithCodeLoc(
+ ol_event_handle_t Event, ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olWaitEvent that also sets source code location
+/// information
+/// @details See also ::olWaitEvent
+OL_APIEXPORT ol_result_t OL_APICALL olWaitEventWithCodeLoc(
+ ol_event_handle_t Event, ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olCreateProgram that also sets source code location
+/// information
+/// @details See also ::olCreateProgram
+OL_APIEXPORT ol_result_t OL_APICALL olCreateProgramWithCodeLoc(
+ ol_device_handle_t Device, const void *ProgData, size_t ProgDataSize,
+ ol_program_handle_t *Program, ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olDestroyProgram that also sets source code location
+/// information
+/// @details See also ::olDestroyProgram
+OL_APIEXPORT ol_result_t OL_APICALL olDestroyProgramWithCodeLoc(
+ ol_program_handle_t Program, ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olGetKernel that also sets source code location
+/// information
+/// @details See also ::olGetKernel
+OL_APIEXPORT ol_result_t OL_APICALL olGetKernelWithCodeLoc(
+ ol_program_handle_t Program, const char *KernelName,
+ ol_kernel_handle_t *Kernel, ol_code_location_t *CodeLocation);
+
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Variant of olLaunchKernel that also sets source code location
+/// information
+/// @details See also ::olLaunchKernel
+OL_APIEXPORT ol_result_t OL_APICALL olLaunchKernelWithCodeLoc(
+ ol_queue_handle_t Queue, ol_device_handle_t Device,
+ ol_kernel_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize,
+ const ol_kernel_launch_size_args_t *LaunchSizeArgs,
+ ol_event_handle_t *EventOut, ol_code_location_t *CodeLocation);
+
#if defined(__cplusplus)
} // extern "C"
#endif
diff --git a/offload/liboffload/include/generated/OffloadEntryPoints.inc b/offload/liboffload/include/generated/OffloadEntryPoints.inc
index 49c1c8169615..d70ebed934dc 100644
--- a/offload/liboffload/include/generated/OffloadEntryPoints.inc
+++ b/offload/liboffload/include/generated/OffloadEntryPoints.inc
@@ -8,30 +8,30 @@
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olInit_val() {
- if (true /*enableParameterValidation*/) {
+ if (offloadConfig().ValidationEnabled) {
}
- return olInit_impl();
+ return llvm::offload::olInit_impl();
}
OL_APIEXPORT ol_result_t OL_APICALL olInit() {
if (offloadConfig().TracingEnabled) {
- std::cout << "---> olInit";
+ llvm::errs() << "---> olInit";
}
ol_result_t Result = olInit_val();
if (offloadConfig().TracingEnabled) {
- std::cout << "()";
- std::cout << "-> " << Result << "\n";
+ llvm::errs() << "()";
+ llvm::errs() << "-> " << Result << "\n";
if (Result && Result->Details) {
- std::cout << " *Error Details* " << Result->Details << " \n";
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
}
}
return Result;
}
ol_result_t olInitWithCodeLoc(ol_code_location_t *CodeLocation) {
currentCodeLocation() = CodeLocation;
- ol_result_t Result = olInit();
+ ol_result_t Result = ::olInit();
currentCodeLocation() = nullptr;
return Result;
@@ -39,124 +39,184 @@ ol_result_t olInitWithCodeLoc(ol_code_location_t *CodeLocation) {
///////////////////////////////////////////////////////////////////////////////
ol_impl_result_t olShutDown_val() {
- if (true /*enableParameterValidation*/) {
+ if (offloadConfig().ValidationEnabled) {
}
- return olShutDown_impl();
+ return llvm::offload::olShutDown_impl();
}
OL_APIEXPORT ol_result_t OL_APICALL olShutDown() {
if (offloadConfig().TracingEnabled) {
- std::cout << "---> olShutDown";
+ llvm::errs() << "---> olShutDown";
}
ol_result_t Result = olShutDown_val();
if (offloadConfig().TracingEnabled) {
- std::cout << "()";
- std::cout << "-> " << Result << "\n";
+ llvm::errs() << "()";
+ llvm::errs() << "-> " << Result << "\n";
if (Result && Result->Details) {
- std::cout << " *Error Details* " << Result->Details << " \n";
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
}
}
return Result;
}
ol_result_t olShutDownWithCodeLoc(ol_code_location_t *CodeLocation) {
currentCodeLocation() = CodeLocation;
- ol_result_t Result = olShutDown();
+ ol_result_t Result = ::olShutDown();
currentCodeLocation() = nullptr;
return Result;
}
///////////////////////////////////////////////////////////////////////////////
-ol_impl_result_t olGetPlatform_val(uint32_t NumEntries,
- ol_platform_handle_t *Platforms) {
- if (true /*enableParameterValidation*/) {
- if (NumEntries == 0) {
+ol_impl_result_t olGetPlatformInfo_val(ol_platform_handle_t Platform,
+ ol_platform_info_t PropName,
+ size_t PropSize, void *PropValue) {
+ if (offloadConfig().ValidationEnabled) {
+ if (PropSize == 0) {
return OL_ERRC_INVALID_SIZE;
}
- if (NULL == Platforms) {
+ if (NULL == Platform) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
+ }
+
+ if (NULL == PropValue) {
return OL_ERRC_INVALID_NULL_POINTER;
}
}
- return olGetPlatform_impl(NumEntries, Platforms);
+ return llvm::offload::olGetPlatformInfo_impl(Platform, PropName, PropSize,
+ PropValue);
}
OL_APIEXPORT ol_result_t OL_APICALL
-olGetPlatform(uint32_t NumEntries, ol_platform_handle_t *Platforms) {
+olGetPlatformInfo(ol_platform_handle_t Platform, ol_platform_info_t PropName,
+ size_t PropSize, void *PropValue) {
if (offloadConfig().TracingEnabled) {
- std::cout << "---> olGetPlatform";
+ llvm::errs() << "---> olGetPlatformInfo";
}
- ol_result_t Result = olGetPlatform_val(NumEntries, Platforms);
+ ol_result_t Result =
+ olGetPlatformInfo_val(Platform, PropName, PropSize, PropValue);
if (offloadConfig().TracingEnabled) {
- ol_get_platform_params_t Params = {&NumEntries, &Platforms};
- std::cout << "(" << &Params << ")";
- std::cout << "-> " << Result << "\n";
+ ol_get_platform_info_params_t Params = {&Platform, &PropName, &PropSize,
+ &PropValue};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
if (Result && Result->Details) {
- std::cout << " *Error Details* " << Result->Details << " \n";
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
}
}
return Result;
}
-ol_result_t olGetPlatformWithCodeLoc(uint32_t NumEntries,
- ol_platform_handle_t *Platforms,
- ol_code_location_t *CodeLocation) {
+ol_result_t olGetPlatformInfoWithCodeLoc(ol_platform_handle_t Platform,
+ ol_platform_info_t PropName,
+ size_t PropSize, void *PropValue,
+ ol_code_location_t *CodeLocation) {
currentCodeLocation() = CodeLocation;
- ol_result_t Result = olGetPlatform(NumEntries, Platforms);
+ ol_result_t Result =
+ ::olGetPlatformInfo(Platform, PropName, PropSize, PropValue);
currentCodeLocation() = nullptr;
return Result;
}
///////////////////////////////////////////////////////////////////////////////
-ol_impl_result_t olGetPlatformCount_val(uint32_t *NumPlatforms) {
- if (true /*enableParameterValidation*/) {
- if (NULL == NumPlatforms) {
+ol_impl_result_t olGetPlatformInfoSize_val(ol_platform_handle_t Platform,
+ ol_platform_info_t PropName,
+ size_t *PropSizeRet) {
+ if (offloadConfig().ValidationEnabled) {
+ if (NULL == Platform) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
+ }
+
+ if (NULL == PropSizeRet) {
return OL_ERRC_INVALID_NULL_POINTER;
}
}
- return olGetPlatformCount_impl(NumPlatforms);
+ return llvm::offload::olGetPlatformInfoSize_impl(Platform, PropName,
+ PropSizeRet);
}
-OL_APIEXPORT ol_result_t OL_APICALL olGetPlatformCount(uint32_t *NumPlatforms) {
+OL_APIEXPORT ol_result_t OL_APICALL
+olGetPlatformInfoSize(ol_platform_handle_t Platform,
+ ol_platform_info_t PropName, size_t *PropSizeRet) {
if (offloadConfig().TracingEnabled) {
- std::cout << "---> olGetPlatformCount";
+ llvm::errs() << "---> olGetPlatformInfoSize";
}
- ol_result_t Result = olGetPlatformCount_val(NumPlatforms);
+ ol_result_t Result =
+ olGetPlatformInfoSize_val(Platform, PropName, PropSizeRet);
if (offloadConfig().TracingEnabled) {
- ol_get_platform_count_params_t Params = {&NumPlatforms};
- std::cout << "(" << &Params << ")";
- std::cout << "-> " << Result << "\n";
+ ol_get_platform_info_size_params_t Params = {&Platform, &PropName,
+ &PropSizeRet};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
if (Result && Result->Details) {
- std::cout << " *Error Details* " << Result->Details << " \n";
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
}
}
return Result;
}
-ol_result_t olGetPlatformCountWithCodeLoc(uint32_t *NumPlatforms,
- ol_code_location_t *CodeLocation) {
+ol_result_t olGetPlatformInfoSizeWithCodeLoc(ol_platform_handle_t Platform,
+ ol_platform_info_t PropName,
+ size_t *PropSizeRet,
+ ol_code_location_t *CodeLocation) {
currentCodeLocation() = CodeLocation;
- ol_result_t Result = olGetPlatformCount(NumPlatforms);
+ ol_result_t Result = ::olGetPlatformInfoSize(Platform, PropName, PropSizeRet);
currentCodeLocation() = nullptr;
return Result;
}
///////////////////////////////////////////////////////////////////////////////
-ol_impl_result_t olGetPlatformInfo_val(ol_platform_handle_t Platform,
- ol_platform_info_t PropName,
- size_t PropSize, void *PropValue) {
- if (true /*enableParameterValidation*/) {
+ol_impl_result_t olIterateDevices_val(ol_device_iterate_cb_t Callback,
+ void *UserData) {
+ if (offloadConfig().ValidationEnabled) {
+ }
+
+ return llvm::offload::olIterateDevices_impl(Callback, UserData);
+}
+OL_APIEXPORT ol_result_t OL_APICALL
+olIterateDevices(ol_device_iterate_cb_t Callback, void *UserData) {
+ if (offloadConfig().TracingEnabled) {
+ llvm::errs() << "---> olIterateDevices";
+ }
+
+ ol_result_t Result = olIterateDevices_val(Callback, UserData);
+
+ if (offloadConfig().TracingEnabled) {
+ ol_iterate_devices_params_t Params = {&Callback, &UserData};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
+ if (Result && Result->Details) {
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
+ }
+ }
+ return Result;
+}
+ol_result_t olIterateDevicesWithCodeLoc(ol_device_iterate_cb_t Callback,
+ void *UserData,
+ ol_code_location_t *CodeLocation) {
+ currentCodeLocation() = CodeLocation;
+ ol_result_t Result = ::olIterateDevices(Callback, UserData);
+
+ currentCodeLocation() = nullptr;
+ return Result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+ol_impl_result_t olGetDeviceInfo_val(ol_device_handle_t Device,
+ ol_device_info_t PropName, size_t PropSize,
+ void *PropValue) {
+ if (offloadConfig().ValidationEnabled) {
if (PropSize == 0) {
return OL_ERRC_INVALID_SIZE;
}
- if (NULL == Platform) {
+ if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
}
@@ -165,47 +225,48 @@ ol_impl_result_t olGetPlatformInfo_val(ol_platform_handle_t Platform,
}
}
- return olGetPlatformInfo_impl(Platform, PropName, PropSize, PropValue);
+ return llvm::offload::olGetDeviceInfo_impl(Device, PropName, PropSize,
+ PropValue);
}
-OL_APIEXPORT ol_result_t OL_APICALL
-olGetPlatformInfo(ol_platform_handle_t Platform, ol_platform_info_t PropName,
- size_t PropSize, void *PropValue) {
+OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfo(ol_device_handle_t Device,
+ ol_device_info_t PropName,
+ size_t PropSize,
+ void *PropValue) {
if (offloadConfig().TracingEnabled) {
- std::cout << "---> olGetPlatformInfo";
+ llvm::errs() << "---> olGetDeviceInfo";
}
ol_result_t Result =
- olGetPlatformInfo_val(Platform, PropName, PropSize, PropValue);
+ olGetDeviceInfo_val(Device, PropName, PropSize, PropValue);
if (offloadConfig().TracingEnabled) {
- ol_get_platform_info_params_t Params = {&Platform, &PropName, &PropSize,
- &PropValue};
- std::cout << "(" << &Params << ")";
- std::cout << "-> " << Result << "\n";
+ ol_get_device_info_params_t Params = {&Device, &PropName, &PropSize,
+ &PropValue};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
if (Result && Result->Details) {
- std::cout << " *Error Details* " << Result->Details << " \n";
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
}
}
return Result;
}
-ol_result_t olGetPlatformInfoWithCodeLoc(ol_platform_handle_t Platform,
- ol_platform_info_t PropName,
- size_t PropSize, void *PropValue,
- ol_code_location_t *CodeLocation) {
+ol_result_t olGetDeviceInfoWithCodeLoc(ol_device_handle_t Device,
+ ol_device_info_t PropName,
+ size_t PropSize, void *PropValue,
+ ol_code_location_t *CodeLocation) {
currentCodeLocation() = CodeLocation;
- ol_result_t Result =
- olGetPlatformInfo(Platform, PropName, PropSize, PropValue);
+ ol_result_t Result = ::olGetDeviceInfo(Device, PropName, PropSize, PropValue);
currentCodeLocation() = nullptr;
return Result;
}
///////////////////////////////////////////////////////////////////////////////
-ol_impl_result_t olGetPlatformInfoSize_val(ol_platform_handle_t Platform,
- ol_platform_info_t PropName,
- size_t *PropSizeRet) {
- if (true /*enableParameterValidation*/) {
- if (NULL == Platform) {
+ol_impl_result_t olGetDeviceInfoSize_val(ol_device_handle_t Device,
+ ol_device_info_t PropName,
+ size_t *PropSizeRet) {
+ if (offloadConfig().ValidationEnabled) {
+ if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
}
@@ -214,227 +275,585 @@ ol_impl_result_t olGetPlatformInfoSize_val(ol_platform_handle_t Platform,
}
}
- return olGetPlatformInfoSize_impl(Platform, PropName, PropSizeRet);
+ return llvm::offload::olGetDeviceInfoSize_impl(Device, PropName, PropSizeRet);
}
-OL_APIEXPORT ol_result_t OL_APICALL
-olGetPlatformInfoSize(ol_platform_handle_t Platform,
- ol_platform_info_t PropName, size_t *PropSizeRet) {
+OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfoSize(
+ ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet) {
if (offloadConfig().TracingEnabled) {
- std::cout << "---> olGetPlatformInfoSize";
+ llvm::errs() << "---> olGetDeviceInfoSize";
}
- ol_result_t Result =
- olGetPlatformInfoSize_val(Platform, PropName, PropSizeRet);
+ ol_result_t Result = olGetDeviceInfoSize_val(Device, PropName, PropSizeRet);
if (offloadConfig().TracingEnabled) {
- ol_get_platform_info_size_params_t Params = {&Platform, &PropName,
- &PropSizeRet};
- std::cout << "(" << &Params << ")";
- std::cout << "-> " << Result << "\n";
+ ol_get_device_info_size_params_t Params = {&Device, &PropName,
+ &PropSizeRet};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
if (Result && Result->Details) {
- std::cout << " *Error Details* " << Result->Details << " \n";
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
}
}
return Result;
}
-ol_result_t olGetPlatformInfoSizeWithCodeLoc(ol_platform_handle_t Platform,
- ol_platform_info_t PropName,
- size_t *PropSizeRet,
- ol_code_location_t *CodeLocation) {
+ol_result_t olGetDeviceInfoSizeWithCodeLoc(ol_device_handle_t Device,
+ ol_device_info_t PropName,
+ size_t *PropSizeRet,
+ ol_code_location_t *CodeLocation) {
currentCodeLocation() = CodeLocation;
- ol_result_t Result = olGetPlatformInfoSize(Platform, PropName, PropSizeRet);
+ ol_result_t Result = ::olGetDeviceInfoSize(Device, PropName, PropSizeRet);
currentCodeLocation() = nullptr;
return Result;
}
///////////////////////////////////////////////////////////////////////////////
-ol_impl_result_t olGetDeviceCount_val(ol_platform_handle_t Platform,
- uint32_t *NumDevices) {
- if (true /*enableParameterValidation*/) {
- if (NULL == Platform) {
+ol_impl_result_t olMemAlloc_val(ol_device_handle_t Device, ol_alloc_type_t Type,
+ size_t Size, void **AllocationOut) {
+ if (offloadConfig().ValidationEnabled) {
+ if (Size == 0) {
+ return OL_ERRC_INVALID_SIZE;
+ }
+
+ if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
}
- if (NULL == NumDevices) {
+ if (NULL == AllocationOut) {
return OL_ERRC_INVALID_NULL_POINTER;
}
}
- return olGetDeviceCount_impl(Platform, NumDevices);
+ return llvm::offload::olMemAlloc_impl(Device, Type, Size, AllocationOut);
}
-OL_APIEXPORT ol_result_t OL_APICALL
-olGetDeviceCount(ol_platform_handle_t Platform, uint32_t *NumDevices) {
+OL_APIEXPORT ol_result_t OL_APICALL olMemAlloc(ol_device_handle_t Device,
+ ol_alloc_type_t Type,
+ size_t Size,
+ void **AllocationOut) {
if (offloadConfig().TracingEnabled) {
- std::cout << "---> olGetDeviceCount";
+ llvm::errs() << "---> olMemAlloc";
}
- ol_result_t Result = olGetDeviceCount_val(Platform, NumDevices);
+ ol_result_t Result = olMemAlloc_val(Device, Type, Size, AllocationOut);
if (offloadConfig().TracingEnabled) {
- ol_get_device_count_params_t Params = {&Platform, &NumDevices};
- std::cout << "(" << &Params << ")";
- std::cout << "-> " << Result << "\n";
+ ol_mem_alloc_params_t Params = {&Device, &Type, &Size, &AllocationOut};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
if (Result && Result->Details) {
- std::cout << " *Error Details* " << Result->Details << " \n";
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
}
}
return Result;
}
-ol_result_t olGetDeviceCountWithCodeLoc(ol_platform_handle_t Platform,
- uint32_t *NumDevices,
- ol_code_location_t *CodeLocation) {
+ol_result_t olMemAllocWithCodeLoc(ol_device_handle_t Device,
+ ol_alloc_type_t Type, size_t Size,
+ void **AllocationOut,
+ ol_code_location_t *CodeLocation) {
currentCodeLocation() = CodeLocation;
- ol_result_t Result = olGetDeviceCount(Platform, NumDevices);
+ ol_result_t Result = ::olMemAlloc(Device, Type, Size, AllocationOut);
currentCodeLocation() = nullptr;
return Result;
}
///////////////////////////////////////////////////////////////////////////////
-ol_impl_result_t olGetDevice_val(ol_platform_handle_t Platform,
- uint32_t NumEntries,
- ol_device_handle_t *Devices) {
- if (true /*enableParameterValidation*/) {
- if (NumEntries == 0) {
- return OL_ERRC_INVALID_SIZE;
+ol_impl_result_t olMemFree_val(void *Address) {
+ if (offloadConfig().ValidationEnabled) {
+ if (NULL == Address) {
+ return OL_ERRC_INVALID_NULL_POINTER;
}
+ }
- if (NULL == Platform) {
+ return llvm::offload::olMemFree_impl(Address);
+}
+OL_APIEXPORT ol_result_t OL_APICALL olMemFree(void *Address) {
+ if (offloadConfig().TracingEnabled) {
+ llvm::errs() << "---> olMemFree";
+ }
+
+ ol_result_t Result = olMemFree_val(Address);
+
+ if (offloadConfig().TracingEnabled) {
+ ol_mem_free_params_t Params = {&Address};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
+ if (Result && Result->Details) {
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
+ }
+ }
+ return Result;
+}
+ol_result_t olMemFreeWithCodeLoc(void *Address,
+ ol_code_location_t *CodeLocation) {
+ currentCodeLocation() = CodeLocation;
+ ol_result_t Result = ::olMemFree(Address);
+
+ currentCodeLocation() = nullptr;
+ return Result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+ol_impl_result_t olMemcpy_val(ol_queue_handle_t Queue, void *DstPtr,
+ ol_device_handle_t DstDevice, void *SrcPtr,
+ ol_device_handle_t SrcDevice, size_t Size,
+ ol_event_handle_t *EventOut) {
+ if (offloadConfig().ValidationEnabled) {
+ if (Queue == NULL && EventOut != NULL) {
+ return OL_ERRC_INVALID_ARGUMENT;
+ }
+
+ if (NULL == DstDevice) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
+ }
+
+ if (NULL == SrcDevice) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
+ }
+
+ if (NULL == DstPtr) {
+ return OL_ERRC_INVALID_NULL_POINTER;
+ }
+
+ if (NULL == SrcPtr) {
+ return OL_ERRC_INVALID_NULL_POINTER;
+ }
+ }
+
+ return llvm::offload::olMemcpy_impl(Queue, DstPtr, DstDevice, SrcPtr,
+ SrcDevice, Size, EventOut);
+}
+OL_APIEXPORT ol_result_t OL_APICALL
+olMemcpy(ol_queue_handle_t Queue, void *DstPtr, ol_device_handle_t DstDevice,
+ void *SrcPtr, ol_device_handle_t SrcDevice, size_t Size,
+ ol_event_handle_t *EventOut) {
+ if (offloadConfig().TracingEnabled) {
+ llvm::errs() << "---> olMemcpy";
+ }
+
+ ol_result_t Result =
+ olMemcpy_val(Queue, DstPtr, DstDevice, SrcPtr, SrcDevice, Size, EventOut);
+
+ if (offloadConfig().TracingEnabled) {
+ ol_memcpy_params_t Params = {&Queue, &DstPtr, &DstDevice, &SrcPtr,
+ &SrcDevice, &Size, &EventOut};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
+ if (Result && Result->Details) {
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
+ }
+ }
+ return Result;
+}
+ol_result_t olMemcpyWithCodeLoc(ol_queue_handle_t Queue, void *DstPtr,
+ ol_device_handle_t DstDevice, void *SrcPtr,
+ ol_device_handle_t SrcDevice, size_t Size,
+ ol_event_handle_t *EventOut,
+ ol_code_location_t *CodeLocation) {
+ currentCodeLocation() = CodeLocation;
+ ol_result_t Result =
+ ::olMemcpy(Queue, DstPtr, DstDevice, SrcPtr, SrcDevice, Size, EventOut);
+
+ currentCodeLocation() = nullptr;
+ return Result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+ol_impl_result_t olCreateQueue_val(ol_device_handle_t Device,
+ ol_queue_handle_t *Queue) {
+ if (offloadConfig().ValidationEnabled) {
+ if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
}
- if (NULL == Devices) {
+ if (NULL == Queue) {
return OL_ERRC_INVALID_NULL_POINTER;
}
}
- return olGetDevice_impl(Platform, NumEntries, Devices);
+ return llvm::offload::olCreateQueue_impl(Device, Queue);
}
-OL_APIEXPORT ol_result_t OL_APICALL olGetDevice(ol_platform_handle_t Platform,
- uint32_t NumEntries,
- ol_device_handle_t *Devices) {
+OL_APIEXPORT ol_result_t OL_APICALL olCreateQueue(ol_device_handle_t Device,
+ ol_queue_handle_t *Queue) {
if (offloadConfig().TracingEnabled) {
- std::cout << "---> olGetDevice";
+ llvm::errs() << "---> olCreateQueue";
}
- ol_result_t Result = olGetDevice_val(Platform, NumEntries, Devices);
+ ol_result_t Result = olCreateQueue_val(Device, Queue);
if (offloadConfig().TracingEnabled) {
- ol_get_device_params_t Params = {&Platform, &NumEntries, &Devices};
- std::cout << "(" << &Params << ")";
- std::cout << "-> " << Result << "\n";
+ ol_create_queue_params_t Params = {&Device, &Queue};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
if (Result && Result->Details) {
- std::cout << " *Error Details* " << Result->Details << " \n";
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
}
}
return Result;
}
-ol_result_t olGetDeviceWithCodeLoc(ol_platform_handle_t Platform,
- uint32_t NumEntries,
- ol_device_handle_t *Devices,
+ol_result_t olCreateQueueWithCodeLoc(ol_device_handle_t Device,
+ ol_queue_handle_t *Queue,
+ ol_code_location_t *CodeLocation) {
+ currentCodeLocation() = CodeLocation;
+ ol_result_t Result = ::olCreateQueue(Device, Queue);
+
+ currentCodeLocation() = nullptr;
+ return Result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+ol_impl_result_t olDestroyQueue_val(ol_queue_handle_t Queue) {
+ if (offloadConfig().ValidationEnabled) {
+ if (NULL == Queue) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
+ }
+ }
+
+ return llvm::offload::olDestroyQueue_impl(Queue);
+}
+OL_APIEXPORT ol_result_t OL_APICALL olDestroyQueue(ol_queue_handle_t Queue) {
+ if (offloadConfig().TracingEnabled) {
+ llvm::errs() << "---> olDestroyQueue";
+ }
+
+ ol_result_t Result = olDestroyQueue_val(Queue);
+
+ if (offloadConfig().TracingEnabled) {
+ ol_destroy_queue_params_t Params = {&Queue};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
+ if (Result && Result->Details) {
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
+ }
+ }
+ return Result;
+}
+ol_result_t olDestroyQueueWithCodeLoc(ol_queue_handle_t Queue,
+ ol_code_location_t *CodeLocation) {
+ currentCodeLocation() = CodeLocation;
+ ol_result_t Result = ::olDestroyQueue(Queue);
+
+ currentCodeLocation() = nullptr;
+ return Result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+ol_impl_result_t olWaitQueue_val(ol_queue_handle_t Queue) {
+ if (offloadConfig().ValidationEnabled) {
+ if (NULL == Queue) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
+ }
+ }
+
+ return llvm::offload::olWaitQueue_impl(Queue);
+}
+OL_APIEXPORT ol_result_t OL_APICALL olWaitQueue(ol_queue_handle_t Queue) {
+ if (offloadConfig().TracingEnabled) {
+ llvm::errs() << "---> olWaitQueue";
+ }
+
+ ol_result_t Result = olWaitQueue_val(Queue);
+
+ if (offloadConfig().TracingEnabled) {
+ ol_wait_queue_params_t Params = {&Queue};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
+ if (Result && Result->Details) {
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
+ }
+ }
+ return Result;
+}
+ol_result_t olWaitQueueWithCodeLoc(ol_queue_handle_t Queue,
ol_code_location_t *CodeLocation) {
currentCodeLocation() = CodeLocation;
- ol_result_t Result = olGetDevice(Platform, NumEntries, Devices);
+ ol_result_t Result = ::olWaitQueue(Queue);
currentCodeLocation() = nullptr;
return Result;
}
///////////////////////////////////////////////////////////////////////////////
-ol_impl_result_t olGetDeviceInfo_val(ol_device_handle_t Device,
- ol_device_info_t PropName, size_t PropSize,
- void *PropValue) {
- if (true /*enableParameterValidation*/) {
- if (PropSize == 0) {
- return OL_ERRC_INVALID_SIZE;
+ol_impl_result_t olDestroyEvent_val(ol_event_handle_t Event) {
+ if (offloadConfig().ValidationEnabled) {
+ if (NULL == Event) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
+ }
+ }
+
+ return llvm::offload::olDestroyEvent_impl(Event);
+}
+OL_APIEXPORT ol_result_t OL_APICALL olDestroyEvent(ol_event_handle_t Event) {
+ if (offloadConfig().TracingEnabled) {
+ llvm::errs() << "---> olDestroyEvent";
+ }
+
+ ol_result_t Result = olDestroyEvent_val(Event);
+
+ if (offloadConfig().TracingEnabled) {
+ ol_destroy_event_params_t Params = {&Event};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
+ if (Result && Result->Details) {
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
+ }
+ }
+ return Result;
+}
+ol_result_t olDestroyEventWithCodeLoc(ol_event_handle_t Event,
+ ol_code_location_t *CodeLocation) {
+ currentCodeLocation() = CodeLocation;
+ ol_result_t Result = ::olDestroyEvent(Event);
+
+ currentCodeLocation() = nullptr;
+ return Result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+ol_impl_result_t olWaitEvent_val(ol_event_handle_t Event) {
+ if (offloadConfig().ValidationEnabled) {
+ if (NULL == Event) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
}
+ }
+
+ return llvm::offload::olWaitEvent_impl(Event);
+}
+OL_APIEXPORT ol_result_t OL_APICALL olWaitEvent(ol_event_handle_t Event) {
+ if (offloadConfig().TracingEnabled) {
+ llvm::errs() << "---> olWaitEvent";
+ }
+
+ ol_result_t Result = olWaitEvent_val(Event);
+ if (offloadConfig().TracingEnabled) {
+ ol_wait_event_params_t Params = {&Event};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
+ if (Result && Result->Details) {
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
+ }
+ }
+ return Result;
+}
+ol_result_t olWaitEventWithCodeLoc(ol_event_handle_t Event,
+ ol_code_location_t *CodeLocation) {
+ currentCodeLocation() = CodeLocation;
+ ol_result_t Result = ::olWaitEvent(Event);
+
+ currentCodeLocation() = nullptr;
+ return Result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+ol_impl_result_t olCreateProgram_val(ol_device_handle_t Device,
+ const void *ProgData, size_t ProgDataSize,
+ ol_program_handle_t *Program) {
+ if (offloadConfig().ValidationEnabled) {
if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
}
- if (NULL == PropValue) {
+ if (NULL == ProgData) {
+ return OL_ERRC_INVALID_NULL_POINTER;
+ }
+
+ if (NULL == Program) {
return OL_ERRC_INVALID_NULL_POINTER;
}
}
- return olGetDeviceInfo_impl(Device, PropName, PropSize, PropValue);
+ return llvm::offload::olCreateProgram_impl(Device, ProgData, ProgDataSize,
+ Program);
}
-OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfo(ol_device_handle_t Device,
- ol_device_info_t PropName,
- size_t PropSize,
- void *PropValue) {
+OL_APIEXPORT ol_result_t OL_APICALL
+olCreateProgram(ol_device_handle_t Device, const void *ProgData,
+ size_t ProgDataSize, ol_program_handle_t *Program) {
if (offloadConfig().TracingEnabled) {
- std::cout << "---> olGetDeviceInfo";
+ llvm::errs() << "---> olCreateProgram";
}
ol_result_t Result =
- olGetDeviceInfo_val(Device, PropName, PropSize, PropValue);
+ olCreateProgram_val(Device, ProgData, ProgDataSize, Program);
if (offloadConfig().TracingEnabled) {
- ol_get_device_info_params_t Params = {&Device, &PropName, &PropSize,
- &PropValue};
- std::cout << "(" << &Params << ")";
- std::cout << "-> " << Result << "\n";
+ ol_create_program_params_t Params = {&Device, &ProgData, &ProgDataSize,
+ &Program};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
if (Result && Result->Details) {
- std::cout << " *Error Details* " << Result->Details << " \n";
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
}
}
return Result;
}
-ol_result_t olGetDeviceInfoWithCodeLoc(ol_device_handle_t Device,
- ol_device_info_t PropName,
- size_t PropSize, void *PropValue,
+ol_result_t olCreateProgramWithCodeLoc(ol_device_handle_t Device,
+ const void *ProgData,
+ size_t ProgDataSize,
+ ol_program_handle_t *Program,
ol_code_location_t *CodeLocation) {
currentCodeLocation() = CodeLocation;
- ol_result_t Result = olGetDeviceInfo(Device, PropName, PropSize, PropValue);
+ ol_result_t Result =
+ ::olCreateProgram(Device, ProgData, ProgDataSize, Program);
currentCodeLocation() = nullptr;
return Result;
}
///////////////////////////////////////////////////////////////////////////////
-ol_impl_result_t olGetDeviceInfoSize_val(ol_device_handle_t Device,
- ol_device_info_t PropName,
- size_t *PropSizeRet) {
- if (true /*enableParameterValidation*/) {
+ol_impl_result_t olDestroyProgram_val(ol_program_handle_t Program) {
+ if (offloadConfig().ValidationEnabled) {
+ if (NULL == Program) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
+ }
+ }
+
+ return llvm::offload::olDestroyProgram_impl(Program);
+}
+OL_APIEXPORT ol_result_t OL_APICALL
+olDestroyProgram(ol_program_handle_t Program) {
+ if (offloadConfig().TracingEnabled) {
+ llvm::errs() << "---> olDestroyProgram";
+ }
+
+ ol_result_t Result = olDestroyProgram_val(Program);
+
+ if (offloadConfig().TracingEnabled) {
+ ol_destroy_program_params_t Params = {&Program};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
+ if (Result && Result->Details) {
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
+ }
+ }
+ return Result;
+}
+ol_result_t olDestroyProgramWithCodeLoc(ol_program_handle_t Program,
+ ol_code_location_t *CodeLocation) {
+ currentCodeLocation() = CodeLocation;
+ ol_result_t Result = ::olDestroyProgram(Program);
+
+ currentCodeLocation() = nullptr;
+ return Result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+ol_impl_result_t olGetKernel_val(ol_program_handle_t Program,
+ const char *KernelName,
+ ol_kernel_handle_t *Kernel) {
+ if (offloadConfig().ValidationEnabled) {
+ if (NULL == Program) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
+ }
+
+ if (NULL == KernelName) {
+ return OL_ERRC_INVALID_NULL_POINTER;
+ }
+
+ if (NULL == Kernel) {
+ return OL_ERRC_INVALID_NULL_POINTER;
+ }
+ }
+
+ return llvm::offload::olGetKernel_impl(Program, KernelName, Kernel);
+}
+OL_APIEXPORT ol_result_t OL_APICALL olGetKernel(ol_program_handle_t Program,
+ const char *KernelName,
+ ol_kernel_handle_t *Kernel) {
+ if (offloadConfig().TracingEnabled) {
+ llvm::errs() << "---> olGetKernel";
+ }
+
+ ol_result_t Result = olGetKernel_val(Program, KernelName, Kernel);
+
+ if (offloadConfig().TracingEnabled) {
+ ol_get_kernel_params_t Params = {&Program, &KernelName, &Kernel};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
+ if (Result && Result->Details) {
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
+ }
+ }
+ return Result;
+}
+ol_result_t olGetKernelWithCodeLoc(ol_program_handle_t Program,
+ const char *KernelName,
+ ol_kernel_handle_t *Kernel,
+ ol_code_location_t *CodeLocation) {
+ currentCodeLocation() = CodeLocation;
+ ol_result_t Result = ::olGetKernel(Program, KernelName, Kernel);
+
+ currentCodeLocation() = nullptr;
+ return Result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+ol_impl_result_t
+olLaunchKernel_val(ol_queue_handle_t Queue, ol_device_handle_t Device,
+ ol_kernel_handle_t Kernel, const void *ArgumentsData,
+ size_t ArgumentsSize,
+ const ol_kernel_launch_size_args_t *LaunchSizeArgs,
+ ol_event_handle_t *EventOut) {
+ if (offloadConfig().ValidationEnabled) {
+ if (Queue == NULL && EventOut != NULL) {
+ return OL_ERRC_INVALID_ARGUMENT;
+ }
+
if (NULL == Device) {
return OL_ERRC_INVALID_NULL_HANDLE;
}
- if (NULL == PropSizeRet) {
+ if (NULL == Kernel) {
+ return OL_ERRC_INVALID_NULL_HANDLE;
+ }
+
+ if (NULL == ArgumentsData) {
+ return OL_ERRC_INVALID_NULL_POINTER;
+ }
+
+ if (NULL == LaunchSizeArgs) {
return OL_ERRC_INVALID_NULL_POINTER;
}
}
- return olGetDeviceInfoSize_impl(Device, PropName, PropSizeRet);
+ return llvm::offload::olLaunchKernel_impl(Queue, Device, Kernel,
+ ArgumentsData, ArgumentsSize,
+ LaunchSizeArgs, EventOut);
}
-OL_APIEXPORT ol_result_t OL_APICALL olGetDeviceInfoSize(
- ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet) {
+OL_APIEXPORT ol_result_t OL_APICALL olLaunchKernel(
+ ol_queue_handle_t Queue, ol_device_handle_t Device,
+ ol_kernel_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize,
+ const ol_kernel_launch_size_args_t *LaunchSizeArgs,
+ ol_event_handle_t *EventOut) {
if (offloadConfig().TracingEnabled) {
- std::cout << "---> olGetDeviceInfoSize";
+ llvm::errs() << "---> olLaunchKernel";
}
- ol_result_t Result = olGetDeviceInfoSize_val(Device, PropName, PropSizeRet);
+ ol_result_t Result =
+ olLaunchKernel_val(Queue, Device, Kernel, ArgumentsData, ArgumentsSize,
+ LaunchSizeArgs, EventOut);
if (offloadConfig().TracingEnabled) {
- ol_get_device_info_size_params_t Params = {&Device, &PropName,
- &PropSizeRet};
- std::cout << "(" << &Params << ")";
- std::cout << "-> " << Result << "\n";
+ ol_launch_kernel_params_t Params = {
+ &Queue, &Device, &Kernel, &ArgumentsData,
+ &ArgumentsSize, &LaunchSizeArgs, &EventOut};
+ llvm::errs() << "(" << &Params << ")";
+ llvm::errs() << "-> " << Result << "\n";
if (Result && Result->Details) {
- std::cout << " *Error Details* " << Result->Details << " \n";
+ llvm::errs() << " *Error Details* " << Result->Details << " \n";
}
}
return Result;
}
-ol_result_t olGetDeviceInfoSizeWithCodeLoc(ol_device_handle_t Device,
- ol_device_info_t PropName,
- size_t *PropSizeRet,
- ol_code_location_t *CodeLocation) {
+ol_result_t olLaunchKernelWithCodeLoc(
+ ol_queue_handle_t Queue, ol_device_handle_t Device,
+ ol_kernel_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize,
+ const ol_kernel_launch_size_args_t *LaunchSizeArgs,
+ ol_event_handle_t *EventOut, ol_code_location_t *CodeLocation) {
currentCodeLocation() = CodeLocation;
- ol_result_t Result = olGetDeviceInfoSize(Device, PropName, PropSizeRet);
+ ol_result_t Result =
+ ::olLaunchKernel(Queue, Device, Kernel, ArgumentsData, ArgumentsSize,
+ LaunchSizeArgs, EventOut);
currentCodeLocation() = nullptr;
return Result;
diff --git a/offload/liboffload/include/generated/OffloadFuncs.inc b/offload/liboffload/include/generated/OffloadFuncs.inc
index 48115493c790..78ff9ddb8279 100644
--- a/offload/liboffload/include/generated/OffloadFuncs.inc
+++ b/offload/liboffload/include/generated/OffloadFuncs.inc
@@ -12,23 +12,41 @@
OFFLOAD_FUNC(olInit)
OFFLOAD_FUNC(olShutDown)
-OFFLOAD_FUNC(olGetPlatform)
-OFFLOAD_FUNC(olGetPlatformCount)
OFFLOAD_FUNC(olGetPlatformInfo)
OFFLOAD_FUNC(olGetPlatformInfoSize)
-OFFLOAD_FUNC(olGetDeviceCount)
-OFFLOAD_FUNC(olGetDevice)
+OFFLOAD_FUNC(olIterateDevices)
OFFLOAD_FUNC(olGetDeviceInfo)
OFFLOAD_FUNC(olGetDeviceInfoSize)
+OFFLOAD_FUNC(olMemAlloc)
+OFFLOAD_FUNC(olMemFree)
+OFFLOAD_FUNC(olMemcpy)
+OFFLOAD_FUNC(olCreateQueue)
+OFFLOAD_FUNC(olDestroyQueue)
+OFFLOAD_FUNC(olWaitQueue)
+OFFLOAD_FUNC(olDestroyEvent)
+OFFLOAD_FUNC(olWaitEvent)
+OFFLOAD_FUNC(olCreateProgram)
+OFFLOAD_FUNC(olDestroyProgram)
+OFFLOAD_FUNC(olGetKernel)
+OFFLOAD_FUNC(olLaunchKernel)
OFFLOAD_FUNC(olInitWithCodeLoc)
OFFLOAD_FUNC(olShutDownWithCodeLoc)
-OFFLOAD_FUNC(olGetPlatformWithCodeLoc)
-OFFLOAD_FUNC(olGetPlatformCountWithCodeLoc)
OFFLOAD_FUNC(olGetPlatformInfoWithCodeLoc)
OFFLOAD_FUNC(olGetPlatformInfoSizeWithCodeLoc)
-OFFLOAD_FUNC(olGetDeviceCountWithCodeLoc)
-OFFLOAD_FUNC(olGetDeviceWithCodeLoc)
+OFFLOAD_FUNC(olIterateDevicesWithCodeLoc)
OFFLOAD_FUNC(olGetDeviceInfoWithCodeLoc)
OFFLOAD_FUNC(olGetDeviceInfoSizeWithCodeLoc)
+OFFLOAD_FUNC(olMemAllocWithCodeLoc)
+OFFLOAD_FUNC(olMemFreeWithCodeLoc)
+OFFLOAD_FUNC(olMemcpyWithCodeLoc)
+OFFLOAD_FUNC(olCreateQueueWithCodeLoc)
+OFFLOAD_FUNC(olDestroyQueueWithCodeLoc)
+OFFLOAD_FUNC(olWaitQueueWithCodeLoc)
+OFFLOAD_FUNC(olDestroyEventWithCodeLoc)
+OFFLOAD_FUNC(olWaitEventWithCodeLoc)
+OFFLOAD_FUNC(olCreateProgramWithCodeLoc)
+OFFLOAD_FUNC(olDestroyProgramWithCodeLoc)
+OFFLOAD_FUNC(olGetKernelWithCodeLoc)
+OFFLOAD_FUNC(olLaunchKernelWithCodeLoc)
#undef OFFLOAD_FUNC
diff --git a/offload/liboffload/include/generated/OffloadImplFuncDecls.inc b/offload/liboffload/include/generated/OffloadImplFuncDecls.inc
index 5b26b2653a05..ced659c2a4bd 100644
--- a/offload/liboffload/include/generated/OffloadImplFuncDecls.inc
+++ b/offload/liboffload/include/generated/OffloadImplFuncDecls.inc
@@ -9,11 +9,6 @@ ol_impl_result_t olInit_impl();
ol_impl_result_t olShutDown_impl();
-ol_impl_result_t olGetPlatform_impl(uint32_t NumEntries,
- ol_platform_handle_t *Platforms);
-
-ol_impl_result_t olGetPlatformCount_impl(uint32_t *NumPlatforms);
-
ol_impl_result_t olGetPlatformInfo_impl(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t PropSize, void *PropValue);
@@ -22,12 +17,8 @@ ol_impl_result_t olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t *PropSizeRet);
-ol_impl_result_t olGetDeviceCount_impl(ol_platform_handle_t Platform,
- uint32_t *NumDevices);
-
-ol_impl_result_t olGetDevice_impl(ol_platform_handle_t Platform,
- uint32_t NumEntries,
- ol_device_handle_t *Devices);
+ol_impl_result_t olIterateDevices_impl(ol_device_iterate_cb_t Callback,
+ void *UserData);
ol_impl_result_t olGetDeviceInfo_impl(ol_device_handle_t Device,
ol_device_info_t PropName,
@@ -36,3 +27,42 @@ ol_impl_result_t olGetDeviceInfo_impl(ol_device_handle_t Device,
ol_impl_result_t olGetDeviceInfoSize_impl(ol_device_handle_t Device,
ol_device_info_t PropName,
size_t *PropSizeRet);
+
+ol_impl_result_t olMemAlloc_impl(ol_device_handle_t Device,
+ ol_alloc_type_t Type, size_t Size,
+ void **AllocationOut);
+
+ol_impl_result_t olMemFree_impl(void *Address);
+
+ol_impl_result_t olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
+ ol_device_handle_t DstDevice, void *SrcPtr,
+ ol_device_handle_t SrcDevice, size_t Size,
+ ol_event_handle_t *EventOut);
+
+ol_impl_result_t olCreateQueue_impl(ol_device_handle_t Device,
+ ol_queue_handle_t *Queue);
+
+ol_impl_result_t olDestroyQueue_impl(ol_queue_handle_t Queue);
+
+ol_impl_result_t olWaitQueue_impl(ol_queue_handle_t Queue);
+
+ol_impl_result_t olDestroyEvent_impl(ol_event_handle_t Event);
+
+ol_impl_result_t olWaitEvent_impl(ol_event_handle_t Event);
+
+ol_impl_result_t olCreateProgram_impl(ol_device_handle_t Device,
+ const void *ProgData, size_t ProgDataSize,
+ ol_program_handle_t *Program);
+
+ol_impl_result_t olDestroyProgram_impl(ol_program_handle_t Program);
+
+ol_impl_result_t olGetKernel_impl(ol_program_handle_t Program,
+ const char *KernelName,
+ ol_kernel_handle_t *Kernel);
+
+ol_impl_result_t
+olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
+ ol_kernel_handle_t Kernel, const void *ArgumentsData,
+ size_t ArgumentsSize,
+ const ol_kernel_launch_size_args_t *LaunchSizeArgs,
+ ol_event_handle_t *EventOut);
diff --git a/offload/liboffload/include/generated/OffloadPrint.hpp b/offload/liboffload/include/generated/OffloadPrint.hpp
index 8981bb054a4c..7f5e33aea6f7 100644
--- a/offload/liboffload/include/generated/OffloadPrint.hpp
+++ b/offload/liboffload/include/generated/OffloadPrint.hpp
@@ -11,31 +11,40 @@
#pragma once
#include <OffloadAPI.h>
-#include <ostream>
+#include <llvm/Support/raw_ostream.h>
template <typename T>
-inline ol_result_t printPtr(std::ostream &os, const T *ptr);
+inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr);
template <typename T>
-inline void printTagged(std::ostream &os, const void *ptr, T value,
+inline void printTagged(llvm::raw_ostream &os, const void *ptr, T value,
size_t size);
template <typename T> struct is_handle : std::false_type {};
template <> struct is_handle<ol_platform_handle_t> : std::true_type {};
template <> struct is_handle<ol_device_handle_t> : std::true_type {};
template <> struct is_handle<ol_context_handle_t> : std::true_type {};
+template <> struct is_handle<ol_queue_handle_t> : std::true_type {};
+template <> struct is_handle<ol_event_handle_t> : std::true_type {};
+template <> struct is_handle<ol_program_handle_t> : std::true_type {};
template <typename T> inline constexpr bool is_handle_v = is_handle<T>::value;
-inline std::ostream &operator<<(std::ostream &os, enum ol_errc_t value);
-inline std::ostream &operator<<(std::ostream &os,
- enum ol_platform_info_t value);
-inline std::ostream &operator<<(std::ostream &os,
- enum ol_platform_backend_t value);
-inline std::ostream &operator<<(std::ostream &os, enum ol_device_type_t value);
-inline std::ostream &operator<<(std::ostream &os, enum ol_device_info_t value);
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_errc_t value);
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_platform_info_t value);
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_platform_backend_t value);
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_device_type_t value);
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_device_info_t value);
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_alloc_type_t value);
///////////////////////////////////////////////////////////////////////////////
/// @brief Print operator for the ol_errc_t type
-/// @returns std::ostream &
-inline std::ostream &operator<<(std::ostream &os, enum ol_errc_t value) {
+/// @returns llvm::raw_ostream &
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_errc_t value) {
switch (value) {
case OL_ERRC_SUCCESS:
os << "OL_ERRC_SUCCESS";
@@ -46,24 +55,21 @@ inline std::ostream &operator<<(std::ostream &os, enum ol_errc_t value) {
case OL_ERRC_INVALID_PLATFORM:
os << "OL_ERRC_INVALID_PLATFORM";
break;
- case OL_ERRC_DEVICE_NOT_FOUND:
- os << "OL_ERRC_DEVICE_NOT_FOUND";
- break;
case OL_ERRC_INVALID_DEVICE:
os << "OL_ERRC_INVALID_DEVICE";
break;
- case OL_ERRC_DEVICE_LOST:
- os << "OL_ERRC_DEVICE_LOST";
+ case OL_ERRC_INVALID_QUEUE:
+ os << "OL_ERRC_INVALID_QUEUE";
+ break;
+ case OL_ERRC_INVALID_EVENT:
+ os << "OL_ERRC_INVALID_EVENT";
break;
- case OL_ERRC_UNINITIALIZED:
- os << "OL_ERRC_UNINITIALIZED";
+ case OL_ERRC_INVALID_KERNEL_NAME:
+ os << "OL_ERRC_INVALID_KERNEL_NAME";
break;
case OL_ERRC_OUT_OF_RESOURCES:
os << "OL_ERRC_OUT_OF_RESOURCES";
break;
- case OL_ERRC_UNSUPPORTED_VERSION:
- os << "OL_ERRC_UNSUPPORTED_VERSION";
- break;
case OL_ERRC_UNSUPPORTED_FEATURE:
os << "OL_ERRC_UNSUPPORTED_FEATURE";
break;
@@ -97,9 +103,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ol_errc_t value) {
///////////////////////////////////////////////////////////////////////////////
/// @brief Print operator for the ol_platform_info_t type
-/// @returns std::ostream &
-inline std::ostream &operator<<(std::ostream &os,
- enum ol_platform_info_t value) {
+/// @returns llvm::raw_ostream &
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_platform_info_t value) {
switch (value) {
case OL_PLATFORM_INFO_NAME:
os << "OL_PLATFORM_INFO_NAME";
@@ -122,9 +128,9 @@ inline std::ostream &operator<<(std::ostream &os,
///////////////////////////////////////////////////////////////////////////////
/// @brief Print type-tagged ol_platform_info_t enum value
-/// @returns std::ostream &
+/// @returns llvm::raw_ostream &
template <>
-inline void printTagged(std::ostream &os, const void *ptr,
+inline void printTagged(llvm::raw_ostream &os, const void *ptr,
ol_platform_info_t value, size_t size) {
if (ptr == NULL) {
printPtr(os, ptr);
@@ -159,9 +165,9 @@ inline void printTagged(std::ostream &os, const void *ptr,
}
///////////////////////////////////////////////////////////////////////////////
/// @brief Print operator for the ol_platform_backend_t type
-/// @returns std::ostream &
-inline std::ostream &operator<<(std::ostream &os,
- enum ol_platform_backend_t value) {
+/// @returns llvm::raw_ostream &
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_platform_backend_t value) {
switch (value) {
case OL_PLATFORM_BACKEND_UNKNOWN:
os << "OL_PLATFORM_BACKEND_UNKNOWN";
@@ -172,6 +178,9 @@ inline std::ostream &operator<<(std::ostream &os,
case OL_PLATFORM_BACKEND_AMDGPU:
os << "OL_PLATFORM_BACKEND_AMDGPU";
break;
+ case OL_PLATFORM_BACKEND_HOST:
+ os << "OL_PLATFORM_BACKEND_HOST";
+ break;
default:
os << "unknown enumerator";
break;
@@ -181,8 +190,9 @@ inline std::ostream &operator<<(std::ostream &os,
///////////////////////////////////////////////////////////////////////////////
/// @brief Print operator for the ol_device_type_t type
-/// @returns std::ostream &
-inline std::ostream &operator<<(std::ostream &os, enum ol_device_type_t value) {
+/// @returns llvm::raw_ostream &
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_device_type_t value) {
switch (value) {
case OL_DEVICE_TYPE_DEFAULT:
os << "OL_DEVICE_TYPE_DEFAULT";
@@ -205,8 +215,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ol_device_type_t value) {
///////////////////////////////////////////////////////////////////////////////
/// @brief Print operator for the ol_device_info_t type
-/// @returns std::ostream &
-inline std::ostream &operator<<(std::ostream &os, enum ol_device_info_t value) {
+/// @returns llvm::raw_ostream &
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_device_info_t value) {
switch (value) {
case OL_DEVICE_INFO_TYPE:
os << "OL_DEVICE_INFO_TYPE";
@@ -232,9 +243,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ol_device_info_t value) {
///////////////////////////////////////////////////////////////////////////////
/// @brief Print type-tagged ol_device_info_t enum value
-/// @returns std::ostream &
+/// @returns llvm::raw_ostream &
template <>
-inline void printTagged(std::ostream &os, const void *ptr,
+inline void printTagged(llvm::raw_ostream &os, const void *ptr,
ol_device_info_t value, size_t size) {
if (ptr == NULL) {
printPtr(os, ptr);
@@ -274,9 +285,30 @@ inline void printTagged(std::ostream &os, const void *ptr,
break;
}
}
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Print operator for the ol_alloc_type_t type
+/// @returns llvm::raw_ostream &
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ enum ol_alloc_type_t value) {
+ switch (value) {
+ case OL_ALLOC_TYPE_HOST:
+ os << "OL_ALLOC_TYPE_HOST";
+ break;
+ case OL_ALLOC_TYPE_DEVICE:
+ os << "OL_ALLOC_TYPE_DEVICE";
+ break;
+ case OL_ALLOC_TYPE_MANAGED:
+ os << "OL_ALLOC_TYPE_MANAGED";
+ break;
+ default:
+ os << "unknown enumerator";
+ break;
+ }
+ return os;
+}
-inline std::ostream &operator<<(std::ostream &os,
- const ol_error_struct_t *Err) {
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const ol_error_struct_t *Err) {
if (Err == nullptr) {
os << "OL_SUCCESS";
} else {
@@ -284,34 +316,64 @@ inline std::ostream &operator<<(std::ostream &os,
}
return os;
}
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Print operator for the ol_code_location_t type
+/// @returns llvm::raw_ostream &
-inline std::ostream &operator<<(std::ostream &os,
- const struct ol_get_platform_params_t *params) {
- os << ".NumEntries = ";
- os << *params->pNumEntries;
- os << ", ";
- os << ".Platforms = ";
- os << "{";
- for (size_t i = 0; i < *params->pNumEntries; i++) {
- if (i > 0) {
- os << ", ";
- }
- printPtr(os, (*params->pPlatforms)[i]);
- }
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const struct ol_code_location_t params) {
+ os << "(struct ol_code_location_t){";
+ os << ".FunctionName = ";
+ printPtr(os, params.FunctionName);
+ os << ", ";
+ os << ".SourceFile = ";
+ printPtr(os, params.SourceFile);
+ os << ", ";
+ os << ".LineNumber = ";
+ os << params.LineNumber;
+ os << ", ";
+ os << ".ColumnNumber = ";
+ os << params.ColumnNumber;
os << "}";
return os;
}
+///////////////////////////////////////////////////////////////////////////////
+/// @brief Print operator for the ol_kernel_launch_size_args_t type
+/// @returns llvm::raw_ostream &
-inline std::ostream &
-operator<<(std::ostream &os,
- const struct ol_get_platform_count_params_t *params) {
- os << ".NumPlatforms = ";
- printPtr(os, *params->pNumPlatforms);
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+ const struct ol_kernel_launch_size_args_t params) {
+ os << "(struct ol_kernel_launch_size_args_t){";
+ os << ".Dimensions = ";
+ os << params.Dimensions;
+ os << ", ";
+ os << ".NumGroupsX = ";
+ os << params.NumGroupsX;
+ os << ", ";
+ os << ".NumGroupsY = ";
+ os << params.NumGroupsY;
+ os << ", ";
+ os << ".NumGroupsZ = ";
+ os << params.NumGroupsZ;
+ os << ", ";
+ os << ".GroupSizeX = ";
+ os << params.GroupSizeX;
+ os << ", ";
+ os << ".GroupSizeY = ";
+ os << params.GroupSizeY;
+ os << ", ";
+ os << ".GroupSizeZ = ";
+ os << params.GroupSizeZ;
+ os << ", ";
+ os << ".DynSharedMemory = ";
+ os << params.DynSharedMemory;
+ os << "}";
return os;
}
-inline std::ostream &
-operator<<(std::ostream &os,
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
const struct ol_get_platform_info_params_t *params) {
os << ".Platform = ";
printPtr(os, *params->pPlatform);
@@ -327,8 +389,8 @@ operator<<(std::ostream &os,
return os;
}
-inline std::ostream &
-operator<<(std::ostream &os,
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
const struct ol_get_platform_info_size_params_t *params) {
os << ".Platform = ";
printPtr(os, *params->pPlatform);
@@ -341,39 +403,20 @@ operator<<(std::ostream &os,
return os;
}
-inline std::ostream &
-operator<<(std::ostream &os,
- const struct ol_get_device_count_params_t *params) {
- os << ".Platform = ";
- printPtr(os, *params->pPlatform);
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+ const struct ol_iterate_devices_params_t *params) {
+ os << ".Callback = ";
+ os << reinterpret_cast<void *>(*params->pCallback);
os << ", ";
- os << ".NumDevices = ";
- printPtr(os, *params->pNumDevices);
- return os;
-}
-
-inline std::ostream &operator<<(std::ostream &os,
- const struct ol_get_device_params_t *params) {
- os << ".Platform = ";
- printPtr(os, *params->pPlatform);
- os << ", ";
- os << ".NumEntries = ";
- os << *params->pNumEntries;
- os << ", ";
- os << ".Devices = ";
- os << "{";
- for (size_t i = 0; i < *params->pNumEntries; i++) {
- if (i > 0) {
- os << ", ";
- }
- printPtr(os, (*params->pDevices)[i]);
- }
- os << "}";
+ os << ".UserData = ";
+ printPtr(os, *params->pUserData);
return os;
}
-inline std::ostream &
-operator<<(std::ostream &os, const struct ol_get_device_info_params_t *params) {
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+ const struct ol_get_device_info_params_t *params) {
os << ".Device = ";
printPtr(os, *params->pDevice);
os << ", ";
@@ -388,8 +431,8 @@ operator<<(std::ostream &os, const struct ol_get_device_info_params_t *params) {
return os;
}
-inline std::ostream &
-operator<<(std::ostream &os,
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
const struct ol_get_device_info_size_params_t *params) {
os << ".Device = ";
printPtr(os, *params->pDevice);
@@ -402,10 +445,163 @@ operator<<(std::ostream &os,
return os;
}
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const struct ol_mem_alloc_params_t *params) {
+ os << ".Device = ";
+ printPtr(os, *params->pDevice);
+ os << ", ";
+ os << ".Type = ";
+ os << *params->pType;
+ os << ", ";
+ os << ".Size = ";
+ os << *params->pSize;
+ os << ", ";
+ os << ".AllocationOut = ";
+ printPtr(os, *params->pAllocationOut);
+ return os;
+}
+
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const struct ol_mem_free_params_t *params) {
+ os << ".Address = ";
+ printPtr(os, *params->pAddress);
+ return os;
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const struct ol_memcpy_params_t *params) {
+ os << ".Queue = ";
+ printPtr(os, *params->pQueue);
+ os << ", ";
+ os << ".DstPtr = ";
+ printPtr(os, *params->pDstPtr);
+ os << ", ";
+ os << ".DstDevice = ";
+ printPtr(os, *params->pDstDevice);
+ os << ", ";
+ os << ".SrcPtr = ";
+ printPtr(os, *params->pSrcPtr);
+ os << ", ";
+ os << ".SrcDevice = ";
+ printPtr(os, *params->pSrcDevice);
+ os << ", ";
+ os << ".Size = ";
+ os << *params->pSize;
+ os << ", ";
+ os << ".EventOut = ";
+ printPtr(os, *params->pEventOut);
+ return os;
+}
+
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+ const struct ol_create_queue_params_t *params) {
+ os << ".Device = ";
+ printPtr(os, *params->pDevice);
+ os << ", ";
+ os << ".Queue = ";
+ printPtr(os, *params->pQueue);
+ return os;
+}
+
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+ const struct ol_destroy_queue_params_t *params) {
+ os << ".Queue = ";
+ printPtr(os, *params->pQueue);
+ return os;
+}
+
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const struct ol_wait_queue_params_t *params) {
+ os << ".Queue = ";
+ printPtr(os, *params->pQueue);
+ return os;
+}
+
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+ const struct ol_destroy_event_params_t *params) {
+ os << ".Event = ";
+ printPtr(os, *params->pEvent);
+ return os;
+}
+
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const struct ol_wait_event_params_t *params) {
+ os << ".Event = ";
+ printPtr(os, *params->pEvent);
+ return os;
+}
+
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+ const struct ol_create_program_params_t *params) {
+ os << ".Device = ";
+ printPtr(os, *params->pDevice);
+ os << ", ";
+ os << ".ProgData = ";
+ printPtr(os, *params->pProgData);
+ os << ", ";
+ os << ".ProgDataSize = ";
+ os << *params->pProgDataSize;
+ os << ", ";
+ os << ".Program = ";
+ printPtr(os, *params->pProgram);
+ return os;
+}
+
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+ const struct ol_destroy_program_params_t *params) {
+ os << ".Program = ";
+ printPtr(os, *params->pProgram);
+ return os;
+}
+
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const struct ol_get_kernel_params_t *params) {
+ os << ".Program = ";
+ printPtr(os, *params->pProgram);
+ os << ", ";
+ os << ".KernelName = ";
+ printPtr(os, *params->pKernelName);
+ os << ", ";
+ os << ".Kernel = ";
+ printPtr(os, *params->pKernel);
+ return os;
+}
+
+inline llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os,
+ const struct ol_launch_kernel_params_t *params) {
+ os << ".Queue = ";
+ printPtr(os, *params->pQueue);
+ os << ", ";
+ os << ".Device = ";
+ printPtr(os, *params->pDevice);
+ os << ", ";
+ os << ".Kernel = ";
+ printPtr(os, *params->pKernel);
+ os << ", ";
+ os << ".ArgumentsData = ";
+ printPtr(os, *params->pArgumentsData);
+ os << ", ";
+ os << ".ArgumentsSize = ";
+ os << *params->pArgumentsSize;
+ os << ", ";
+ os << ".LaunchSizeArgs = ";
+ printPtr(os, *params->pLaunchSizeArgs);
+ os << ", ";
+ os << ".EventOut = ";
+ printPtr(os, *params->pEventOut);
+ return os;
+}
+
///////////////////////////////////////////////////////////////////////////////
// @brief Print pointer value
template <typename T>
-inline ol_result_t printPtr(std::ostream &os, const T *ptr) {
+inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr) {
if (ptr == nullptr) {
os << "nullptr";
} else if constexpr (std::is_pointer_v<T>) {
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 457f1053f163..d956d274b5eb 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -19,27 +19,6 @@
#include <mutex>
-using namespace llvm;
-using namespace llvm::omp::target::plugin;
-
-// Handle type definitions. Ideally these would be 1:1 with the plugins
-struct ol_device_handle_t_ {
- int DeviceNum;
- GenericDeviceTy &Device;
- ol_platform_handle_t Platform;
-};
-
-struct ol_platform_handle_t_ {
- std::unique_ptr<GenericPluginTy> Plugin;
- std::vector<ol_device_handle_t_> Devices;
-};
-
-using PlatformVecT = SmallVector<ol_platform_handle_t_, 4>;
-PlatformVecT &Platforms() {
- static PlatformVecT Platforms;
- return Platforms;
-}
-
// TODO: Some plugins expect to be linked into libomptarget which defines these
// symbols to implement ompt callbacks. The least invasive workaround here is to
// define them in libLLVMOffload as false/null so they are never used. In future
@@ -55,6 +34,97 @@ ompt_function_lookup_t lookupCallbackByName = nullptr;
} // namespace llvm::omp::target
#endif
+using namespace llvm::omp::target;
+using namespace llvm::omp::target::plugin;
+
+// Handle type definitions. Ideally these would be 1:1 with the plugins, but
+// we add some additional data here for now to avoid churn in the plugin
+// interface.
+struct ol_device_impl_t {
+ ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
+ ol_platform_handle_t Platform)
+ : DeviceNum(DeviceNum), Device(Device), Platform(Platform) {}
+ int DeviceNum;
+ GenericDeviceTy *Device;
+ ol_platform_handle_t Platform;
+};
+
+struct ol_platform_impl_t {
+ ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
+ std::vector<ol_device_impl_t> Devices,
+ ol_platform_backend_t BackendType)
+ : Plugin(std::move(Plugin)), Devices(Devices), BackendType(BackendType) {}
+ std::unique_ptr<GenericPluginTy> Plugin;
+ std::vector<ol_device_impl_t> Devices;
+ ol_platform_backend_t BackendType;
+};
+
+struct ol_queue_impl_t {
+ ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
+ : AsyncInfo(AsyncInfo), Device(Device) {}
+ __tgt_async_info *AsyncInfo;
+ ol_device_handle_t Device;
+};
+
+struct ol_event_impl_t {
+ ol_event_impl_t(void *EventInfo, ol_queue_handle_t Queue)
+ : EventInfo(EventInfo), Queue(Queue) {}
+ ~ol_event_impl_t() { (void)Queue->Device->Device->destroyEvent(EventInfo); }
+ void *EventInfo;
+ ol_queue_handle_t Queue;
+};
+
+struct ol_program_impl_t {
+ ol_program_impl_t(plugin::DeviceImageTy *Image,
+ std::unique_ptr<llvm::MemoryBuffer> ImageData,
+ const __tgt_device_image &DeviceImage)
+ : Image(Image), ImageData(std::move(ImageData)),
+ DeviceImage(DeviceImage) {}
+ plugin::DeviceImageTy *Image;
+ std::unique_ptr<llvm::MemoryBuffer> ImageData;
+ __tgt_device_image DeviceImage;
+};
+
+namespace llvm {
+namespace offload {
+
+struct AllocInfo {
+ ol_device_handle_t Device;
+ ol_alloc_type_t Type;
+};
+
+using AllocInfoMapT = DenseMap<void *, AllocInfo>;
+AllocInfoMapT &allocInfoMap() {
+ static AllocInfoMapT AllocInfoMap{};
+ return AllocInfoMap;
+}
+
+using PlatformVecT = SmallVector<ol_platform_impl_t, 4>;
+PlatformVecT &Platforms() {
+ static PlatformVecT Platforms;
+ return Platforms;
+}
+
+ol_device_handle_t HostDevice() {
+ // The host platform is always inserted last
+ return &Platforms().back().Devices[0];
+}
+
+template <typename HandleT> ol_impl_result_t olDestroy(HandleT Handle) {
+ delete Handle;
+ return OL_SUCCESS;
+}
+
+constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
+ if (Name == "amdgpu") {
+ return OL_PLATFORM_BACKEND_AMDGPU;
+ } else if (Name == "cuda") {
+ return OL_PLATFORM_BACKEND_CUDA;
+ } else {
+ return OL_PLATFORM_BACKEND_UNKNOWN;
+ }
+}
+
// Every plugin exports this method to create an instance of the plugin type.
#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
#include "Shared/Targets.def"
@@ -63,26 +133,36 @@ void initPlugins() {
// Attempt to create an instance of each supported plugin.
#define PLUGIN_TARGET(Name) \
do { \
- Platforms().emplace_back(ol_platform_handle_t_{ \
- std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), {}}); \
+ Platforms().emplace_back(ol_platform_impl_t{ \
+ std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
+ {}, \
+ pluginNameToBackend(#Name)}); \
} while (false);
#include "Shared/Targets.def"
- // Preemptively initialize all devices in the plugin so we can just return
- // them from deviceGet
+ // Preemptively initialize all devices in the plugin
for (auto &Platform : Platforms()) {
auto Err = Platform.Plugin->init();
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
DevNum++) {
if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
- Platform.Devices.emplace_back(ol_device_handle_t_{
- DevNum, Platform.Plugin->getDevice(DevNum), &Platform});
+ Platform.Devices.emplace_back(ol_device_impl_t{
+ DevNum, &Platform.Plugin->getDevice(DevNum), &Platform});
}
}
}
+ // Add the special host device
+ auto &HostPlatform = Platforms().emplace_back(
+ ol_platform_impl_t{nullptr,
+ {ol_device_impl_t{-1, nullptr, nullptr}},
+ OL_PLATFORM_BACKEND_HOST});
+ HostDevice()->Platform = &HostPlatform;
+
offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE");
+ offloadConfig().ValidationEnabled =
+ !std::getenv("OFFLOAD_DISABLE_VALIDATION");
}
// TODO: We can properly reference count here and manage the resources in a more
@@ -95,36 +175,16 @@ ol_impl_result_t olInit_impl() {
}
ol_impl_result_t olShutDown_impl() { return OL_SUCCESS; }
-ol_impl_result_t olGetPlatformCount_impl(uint32_t *NumPlatforms) {
- *NumPlatforms = Platforms().size();
- return OL_SUCCESS;
-}
-
-ol_impl_result_t olGetPlatform_impl(uint32_t NumEntries,
- ol_platform_handle_t *PlatformsOut) {
- if (NumEntries > Platforms().size()) {
- return {OL_ERRC_INVALID_SIZE,
- std::string{formatv("{0} platform(s) available but {1} requested.",
- Platforms().size(), NumEntries)}};
- }
-
- for (uint32_t PlatformIndex = 0; PlatformIndex < NumEntries;
- PlatformIndex++) {
- PlatformsOut[PlatformIndex] = &(Platforms())[PlatformIndex];
- }
-
- return OL_SUCCESS;
-}
-
ol_impl_result_t olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
ol_platform_info_t PropName,
size_t PropSize, void *PropValue,
size_t *PropSizeRet) {
ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
+ bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
switch (PropName) {
case OL_PLATFORM_INFO_NAME:
- return ReturnValue(Platform->Plugin->getName());
+ return ReturnValue(IsHost ? "Host" : Platform->Plugin->getName());
case OL_PLATFORM_INFO_VENDOR_NAME:
// TODO: Implement this
return ReturnValue("Unknown platform vendor");
@@ -135,14 +195,7 @@ ol_impl_result_t olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
.c_str());
}
case OL_PLATFORM_INFO_BACKEND: {
- auto PluginName = Platform->Plugin->getName();
- if (PluginName == StringRef("CUDA")) {
- return ReturnValue(OL_PLATFORM_BACKEND_CUDA);
- } else if (PluginName == StringRef("AMDGPU")) {
- return ReturnValue(OL_PLATFORM_BACKEND_AMDGPU);
- } else {
- return ReturnValue(OL_PLATFORM_BACKEND_UNKNOWN);
- }
+ return ReturnValue(Platform->BackendType);
}
default:
return OL_ERRC_INVALID_ENUMERATION;
@@ -165,27 +218,6 @@ ol_impl_result_t olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
PropSizeRet);
}
-ol_impl_result_t olGetDeviceCount_impl(ol_platform_handle_t Platform,
- uint32_t *pNumDevices) {
- *pNumDevices = static_cast<uint32_t>(Platform->Devices.size());
-
- return OL_SUCCESS;
-}
-
-ol_impl_result_t olGetDevice_impl(ol_platform_handle_t Platform,
- uint32_t NumEntries,
- ol_device_handle_t *Devices) {
- if (NumEntries > Platform->Devices.size()) {
- return OL_ERRC_INVALID_SIZE;
- }
-
- for (uint32_t DeviceIndex = 0; DeviceIndex < NumEntries; DeviceIndex++) {
- Devices[DeviceIndex] = &(Platform->Devices[DeviceIndex]);
- }
-
- return OL_SUCCESS;
-}
-
ol_impl_result_t olGetDeviceInfoImplDetail(ol_device_handle_t Device,
ol_device_info_t PropName,
size_t PropSize, void *PropValue,
@@ -193,12 +225,12 @@ ol_impl_result_t olGetDeviceInfoImplDetail(ol_device_handle_t Device,
ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
- InfoQueueTy DevInfo;
- if (auto Err = Device->Device.obtainInfoImpl(DevInfo))
- return OL_ERRC_OUT_OF_RESOURCES;
-
// Find the info if it exists under any of the given names
- auto GetInfo = [&DevInfo](std::vector<std::string> Names) {
+ auto GetInfo = [&](std::vector<std::string> Names) {
+ InfoQueueTy DevInfo;
+ if (auto Err = Device->Device->obtainInfoImpl(DevInfo))
+ return std::string("");
+
for (auto Name : Names) {
auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) {
return Info.Key == Name;
@@ -245,3 +277,256 @@ ol_impl_result_t olGetDeviceInfoSize_impl(ol_device_handle_t Device,
size_t *PropSizeRet) {
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
}
+
+ol_impl_result_t olIterateDevices_impl(ol_device_iterate_cb_t Callback,
+ void *UserData) {
+ for (auto &Platform : Platforms()) {
+ for (auto &Device : Platform.Devices) {
+ if (!Callback(&Device, UserData)) {
+ break;
+ }
+ }
+ }
+
+ return OL_SUCCESS;
+}
+
+TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {
+ switch (Type) {
+ case OL_ALLOC_TYPE_DEVICE:
+ return TARGET_ALLOC_DEVICE;
+ case OL_ALLOC_TYPE_HOST:
+ return TARGET_ALLOC_HOST;
+ case OL_ALLOC_TYPE_MANAGED:
+ default:
+ return TARGET_ALLOC_SHARED;
+ }
+}
+
+ol_impl_result_t olMemAlloc_impl(ol_device_handle_t Device,
+ ol_alloc_type_t Type, size_t Size,
+ void **AllocationOut) {
+ auto Alloc =
+ Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type));
+ if (!Alloc)
+ return {OL_ERRC_OUT_OF_RESOURCES,
+ formatv("Could not create allocation on device {0}", Device).str()};
+
+ *AllocationOut = *Alloc;
+ allocInfoMap().insert_or_assign(*Alloc, AllocInfo{Device, Type});
+ return OL_SUCCESS;
+}
+
+ol_impl_result_t olMemFree_impl(void *Address) {
+ if (!allocInfoMap().contains(Address))
+ return {OL_ERRC_INVALID_ARGUMENT, "Address is not a known allocation"};
+
+ auto AllocInfo = allocInfoMap().at(Address);
+ auto Device = AllocInfo.Device;
+ auto Type = AllocInfo.Type;
+
+ auto Res =
+ Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type));
+ if (Res)
+ return {OL_ERRC_OUT_OF_RESOURCES, "Could not free allocation"};
+
+ allocInfoMap().erase(Address);
+
+ return OL_SUCCESS;
+}
+
+ol_impl_result_t olCreateQueue_impl(ol_device_handle_t Device,
+ ol_queue_handle_t *Queue) {
+ auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device);
+ auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo));
+ if (Err)
+ return {OL_ERRC_UNKNOWN, "Could not initialize stream resource"};
+
+ *Queue = CreatedQueue.release();
+ return OL_SUCCESS;
+}
+
+ol_impl_result_t olDestroyQueue_impl(ol_queue_handle_t Queue) {
+ return olDestroy(Queue);
+}
+
+ol_impl_result_t olWaitQueue_impl(ol_queue_handle_t Queue) {
+ // Host plugin doesn't have a queue set so it's not safe to call synchronize
+ // on it, but we have nothing to synchronize in that situation anyway.
+ if (Queue->AsyncInfo->Queue) {
+ auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo);
+ if (Err)
+ return {OL_ERRC_INVALID_QUEUE, "The queue failed to synchronize"};
+ }
+
+ // Recreate the stream resource so the queue can be reused
+ // TODO: Would be easier for the synchronization to (optionally) not release
+ // it to begin with.
+ auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo);
+ if (Res)
+ return {OL_ERRC_UNKNOWN, "Could not reinitialize the stream resource"};
+
+ return OL_SUCCESS;
+}
+
+ol_impl_result_t olWaitEvent_impl(ol_event_handle_t Event) {
+ auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo);
+ if (Res)
+ return {OL_ERRC_INVALID_EVENT, "The event failed to synchronize"};
+
+ return OL_SUCCESS;
+}
+
+ol_impl_result_t olDestroyEvent_impl(ol_event_handle_t Event) {
+ return olDestroy(Event);
+}
+
+ol_event_handle_t makeEvent(ol_queue_handle_t Queue) {
+ auto EventImpl = std::make_unique<ol_event_impl_t>(nullptr, Queue);
+ auto Res = Queue->Device->Device->createEvent(&EventImpl->EventInfo);
+ if (Res)
+ return nullptr;
+
+ Res = Queue->Device->Device->recordEvent(EventImpl->EventInfo,
+ Queue->AsyncInfo);
+ if (Res)
+ return nullptr;
+
+ return EventImpl.release();
+}
+
+ol_impl_result_t olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
+ ol_device_handle_t DstDevice, void *SrcPtr,
+ ol_device_handle_t SrcDevice, size_t Size,
+ ol_event_handle_t *EventOut) {
+ if (DstDevice == HostDevice() && SrcDevice == HostDevice()) {
+ if (!Queue) {
+ std::memcpy(DstPtr, SrcPtr, Size);
+ return OL_SUCCESS;
+ } else {
+ return {OL_ERRC_INVALID_ARGUMENT,
+ "One of DstDevice and SrcDevice must be a non-host device if "
+ "Queue is specified"};
+ }
+ }
+
+ // If no queue is given the memcpy will be synchronous
+ auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
+
+ if (DstDevice == HostDevice()) {
+ auto Res = SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl);
+ if (Res)
+ return {OL_ERRC_UNKNOWN, "The data retrieve operation failed"};
+ } else if (SrcDevice == HostDevice()) {
+ auto Res = DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl);
+ if (Res)
+ return {OL_ERRC_UNKNOWN, "The data submit operation failed"};
+ } else {
+ auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device,
+ DstPtr, Size, QueueImpl);
+ if (Res)
+ return {OL_ERRC_UNKNOWN, "The data exchange operation failed"};
+ }
+
+ if (EventOut)
+ *EventOut = makeEvent(Queue);
+
+ return OL_SUCCESS;
+}
+
+ol_impl_result_t olCreateProgram_impl(ol_device_handle_t Device,
+ const void *ProgData, size_t ProgDataSize,
+ ol_program_handle_t *Program) {
+ // Make a copy of the program binary in case it is released by the caller.
+ auto ImageData = MemoryBuffer::getMemBufferCopy(
+ StringRef(reinterpret_cast<const char *>(ProgData), ProgDataSize));
+
+ auto DeviceImage = __tgt_device_image{
+ const_cast<char *>(ImageData->getBuffer().data()),
+ const_cast<char *>(ImageData->getBuffer().data()) + ProgDataSize, nullptr,
+ nullptr};
+
+ ol_program_handle_t Prog =
+ new ol_program_impl_t(nullptr, std::move(ImageData), DeviceImage);
+
+ auto Res =
+ Device->Device->loadBinary(Device->Device->Plugin, &Prog->DeviceImage);
+ if (!Res) {
+ delete Prog;
+ return OL_ERRC_INVALID_VALUE;
+ }
+
+ Prog->Image = *Res;
+ *Program = Prog;
+
+ return OL_SUCCESS;
+}
+
+ol_impl_result_t olDestroyProgram_impl(ol_program_handle_t Program) {
+ return olDestroy(Program);
+}
+
+ol_impl_result_t olGetKernel_impl(ol_program_handle_t Program,
+ const char *KernelName,
+ ol_kernel_handle_t *Kernel) {
+
+ auto &Device = Program->Image->getDevice();
+ auto KernelImpl = Device.constructKernel(KernelName);
+ if (!KernelImpl)
+ return OL_ERRC_INVALID_KERNEL_NAME;
+
+ auto Err = KernelImpl->init(Device, *Program->Image);
+ if (Err)
+ return {OL_ERRC_UNKNOWN, "Could not initialize the kernel"};
+
+ *Kernel = &*KernelImpl;
+
+ return OL_SUCCESS;
+}
+
+ol_impl_result_t
+olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
+ ol_kernel_handle_t Kernel, const void *ArgumentsData,
+ size_t ArgumentsSize,
+ const ol_kernel_launch_size_args_t *LaunchSizeArgs,
+ ol_event_handle_t *EventOut) {
+ auto *DeviceImpl = Device->Device;
+ if (Queue && Device != Queue->Device) {
+ return {OL_ERRC_INVALID_DEVICE,
+ "Device specified does not match the device of the given queue"};
+ }
+
+ auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
+ AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl);
+ KernelArgsTy LaunchArgs{};
+ LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroupsX;
+ LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroupsY;
+ LaunchArgs.NumTeams[2] = LaunchSizeArgs->NumGroupsZ;
+ LaunchArgs.ThreadLimit[0] = LaunchSizeArgs->GroupSizeX;
+ LaunchArgs.ThreadLimit[1] = LaunchSizeArgs->GroupSizeY;
+ LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSizeZ;
+ LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory;
+
+ KernelLaunchParamsTy Params;
+ Params.Data = const_cast<void *>(ArgumentsData);
+ Params.Size = ArgumentsSize;
+ LaunchArgs.ArgPtrs = reinterpret_cast<void **>(&Params);
+ // Don't do anything with pointer indirection; use arg data as-is
+ LaunchArgs.Flags.IsCUDA = true;
+
+ auto *KernelImpl = reinterpret_cast<GenericKernelTy *>(Kernel);
+ auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr,
+ LaunchArgs, AsyncInfoWrapper);
+
+ AsyncInfoWrapper.finalize(Err);
+ if (Err)
+ return {OL_ERRC_UNKNOWN, "Could not finalize the AsyncInfoWrapper"};
+
+ if (EventOut)
+ *EventOut = makeEvent(Queue);
+
+ return OL_SUCCESS;
+}
+
+} // namespace offload
+} // namespace llvm
diff --git a/offload/liboffload/src/OffloadLib.cpp b/offload/liboffload/src/OffloadLib.cpp
index 70e1ce1f84d8..8662d3a44124 100644
--- a/offload/liboffload/src/OffloadLib.cpp
+++ b/offload/liboffload/src/OffloadLib.cpp
@@ -11,11 +11,10 @@
//===----------------------------------------------------------------------===//
#include "OffloadImpl.hpp"
+#include "llvm/Support/raw_ostream.h"
#include <OffloadAPI.h>
#include <OffloadPrint.hpp>
-#include <iostream>
-
llvm::StringSet<> &errorStrs() {
static llvm::StringSet<> ErrorStrs;
return ErrorStrs;
@@ -36,9 +35,13 @@ OffloadConfig &offloadConfig() {
return Config;
}
+namespace llvm {
+namespace offload {
// Pull in the declarations for the implementation functions. The actual entry
// points in this file wrap these.
#include "OffloadImplFuncDecls.inc"
+} // namespace offload
+} // namespace llvm
// Pull in the tablegen'd entry point definitions.
#include "OffloadEntryPoints.inc"
diff --git a/offload/test/tools/offload-tblgen/entry_points.td b/offload/test/tools/offload-tblgen/entry_points.td
index a66ddb927992..cfddb84aa5b0 100644
--- a/offload/test/tools/offload-tblgen/entry_points.td
+++ b/offload/test/tools/offload-tblgen/entry_points.td
@@ -20,7 +20,7 @@ def : Function {
// The validation function should call the implementation function
// CHECK: FunctionA_val
-// CHECK: return FunctionA_impl(ParamA, ParamB);
+// CHECK: return llvm::offload::FunctionA_impl(ParamA, ParamB);
// CHECK: ol_result_t{{.*}} FunctionA(
diff --git a/offload/test/tools/offload-tblgen/functions_ranged_param.td b/offload/test/tools/offload-tblgen/functions_ranged_param.td
index 21a84d8a7033..d0996b231973 100644
--- a/offload/test/tools/offload-tblgen/functions_ranged_param.td
+++ b/offload/test/tools/offload-tblgen/functions_ranged_param.td
@@ -25,7 +25,7 @@ def : Function {
let returns = [];
}
-// CHECK: inline std::ostream &operator<<(std::ostream &os, const struct function_a_params_t *params) {
+// CHECK: inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct function_a_params_t *params) {
// CHECK: os << ".OutPtr = ";
// CHECK: for (size_t i = 0; i < *params->pOutCount; i++) {
// CHECK: if (i > 0) {
diff --git a/offload/test/tools/offload-tblgen/print_enum.td b/offload/test/tools/offload-tblgen/print_enum.td
index 0b5506009bec..97f869689293 100644
--- a/offload/test/tools/offload-tblgen/print_enum.td
+++ b/offload/test/tools/offload-tblgen/print_enum.td
@@ -15,7 +15,7 @@ def : Enum {
];
}
-// CHECK: inline std::ostream &operator<<(std::ostream &os, enum my_enum_t value)
+// CHECK: inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, enum my_enum_t value)
// CHECK: switch (value) {
// CHECK: case MY_ENUM_VALUE_ONE:
// CHECK: os << "MY_ENUM_VALUE_ONE";
diff --git a/offload/test/tools/offload-tblgen/print_function.td b/offload/test/tools/offload-tblgen/print_function.td
index 3f4944df6594..ce1fe4c52760 100644
--- a/offload/test/tools/offload-tblgen/print_function.td
+++ b/offload/test/tools/offload-tblgen/print_function.td
@@ -27,7 +27,7 @@ def : Function {
// CHECK-API-NEXT: ol_foo_handle_t* pParamHandle;
// CHECK-API-NEXT: uint32_t** pParamPointer;
-// CHECK-PRINT: inline std::ostream &operator<<(std::ostream &os, const struct function_a_params_t *params)
+// CHECK-PRINT: inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct function_a_params_t *params)
// CHECK-PRINT: os << ".ParamValue = ";
// CHECK-PRINT: os << *params->pParamValue;
// CHECK-PRINT: os << ", ";
diff --git a/offload/test/tools/offload-tblgen/type_tagged_enum.td b/offload/test/tools/offload-tblgen/type_tagged_enum.td
index 49e91e43bb6e..95964e32f0c9 100644
--- a/offload/test/tools/offload-tblgen/type_tagged_enum.td
+++ b/offload/test/tools/offload-tblgen/type_tagged_enum.td
@@ -50,7 +50,7 @@ def : Function {
}
// Check that a tagged enum print function definition is generated
-// CHECK-PRINT: void printTagged(std::ostream &os, const void *ptr, my_type_tagged_enum_t value, size_t size) {
+// CHECK-PRINT: void printTagged(llvm::raw_ostream &os, const void *ptr, my_type_tagged_enum_t value, size_t size) {
// CHECK-PRINT: case MY_TYPE_TAGGED_ENUM_VALUE_ONE: {
// CHECK-PRINT: const uint32_t * const tptr = (const uint32_t * const)ptr;
// CHECK-PRINT: os << (const void *)tptr << " (";
@@ -71,6 +71,6 @@ def : Function {
// CHECK-PRINT: }
// Check that the tagged type information is used when printing function parameters
-// CHECK-PRINT: std::ostream &operator<<(std::ostream &os, const struct function_a_params_t *params) {
+// CHECK-PRINT: llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct function_a_params_t *params) {
// CHECK-PRINT: os << ".PropValue = "
// CHECK-PRINT-NEXT: printTagged(os, *params->pPropValue, *params->pPropName, *params->pPropSize);
diff --git a/offload/tools/offload-tblgen/APIGen.cpp b/offload/tools/offload-tblgen/APIGen.cpp
index 97a2464f7a75..800c9cadfe38 100644
--- a/offload/tools/offload-tblgen/APIGen.cpp
+++ b/offload/tools/offload-tblgen/APIGen.cpp
@@ -41,9 +41,16 @@ static std::string MakeComment(StringRef in) {
}
static void ProcessHandle(const HandleRec &H, raw_ostream &OS) {
+ if (!H.getName().ends_with("_handle_t")) {
+ errs() << "Handle type name (" << H.getName()
+ << ") must end with '_handle_t'!\n";
+ exit(1);
+ }
+
+ auto ImplName = H.getName().substr(0, H.getName().size() - 9) + "_impl_t";
OS << CommentsHeader;
OS << formatv("/// @brief {0}\n", H.getDesc());
- OS << formatv("typedef struct {0}_ *{0};\n", H.getName());
+ OS << formatv("typedef struct {0} *{1};\n", ImplName, H.getName());
}
static void ProcessTypedef(const TypedefRec &T, raw_ostream &OS) {
@@ -158,6 +165,19 @@ static void ProcessStruct(const StructRec &Struct, raw_ostream &OS) {
OS << formatv("} {0};\n\n", Struct.getName());
}
+static void ProcessFptrTypedef(const FptrTypedefRec &F, raw_ostream &OS) {
+ OS << CommentsHeader;
+ OS << formatv("/// @brief {0}\n", F.getDesc());
+ OS << formatv("typedef {0} (*{1})(", F.getReturn(), F.getName());
+ for (const auto &Param : F.getParams()) {
+ OS << formatv("\n // {0}\n {1} {2}", Param.getDesc(), Param.getType(),
+ Param.getName());
+ if (Param != F.getParams().back())
+ OS << ",";
+ }
+ OS << ");\n";
+}
+
static void ProcessFuncParamStruct(const FunctionRec &Func, raw_ostream &OS) {
if (Func.getParams().size() == 0) {
return;
@@ -213,6 +233,8 @@ void EmitOffloadAPI(const RecordKeeper &Records, raw_ostream &OS) {
ProcessEnum(EnumRec{R}, OS);
} else if (R->isSubClassOf("Struct")) {
ProcessStruct(StructRec{R}, OS);
+ } else if (R->isSubClassOf("FptrTypedef")) {
+ ProcessFptrTypedef(FptrTypedefRec{R}, OS);
}
}
diff --git a/offload/tools/offload-tblgen/EntryPointGen.cpp b/offload/tools/offload-tblgen/EntryPointGen.cpp
index 990ff96a3121..66b9665292e1 100644
--- a/offload/tools/offload-tblgen/EntryPointGen.cpp
+++ b/offload/tools/offload-tblgen/EntryPointGen.cpp
@@ -35,7 +35,7 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
}
OS << ") {\n";
- OS << TAB_1 "if (true /*enableParameterValidation*/) {\n";
+ OS << TAB_1 "if (offloadConfig().ValidationEnabled) {\n";
// Emit validation checks
for (const auto &Return : F.getReturns()) {
for (auto &Condition : Return.getConditions()) {
@@ -51,7 +51,8 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
// Perform actual function call to the implementation
ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
- OS << formatv(TAB_1 "return {0}_impl({1});\n\n", F.getName(), ParamNameList);
+ OS << formatv(TAB_1 "return llvm::offload::{0}_impl({1});\n\n", F.getName(),
+ ParamNameList);
OS << "}\n";
}
@@ -72,7 +73,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
// Emit pre-call prints
OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n";
- OS << formatv(TAB_2 "std::cout << \"---> {0}\";\n", F.getName());
+ OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName());
OS << TAB_1 "}\n\n";
// Perform actual function call to the validation wrapper
@@ -91,13 +92,13 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
}
}
OS << formatv("};\n");
- OS << TAB_2 "std::cout << \"(\" << &Params << \")\";\n";
+ OS << TAB_2 "llvm::errs() << \"(\" << &Params << \")\";\n";
} else {
- OS << TAB_2 "std::cout << \"()\";\n";
+ OS << TAB_2 "llvm::errs() << \"()\";\n";
}
- OS << TAB_2 "std::cout << \"-> \" << Result << \"\\n\";\n";
+ OS << TAB_2 "llvm::errs() << \"-> \" << Result << \"\\n\";\n";
OS << TAB_2 "if (Result && Result->Details) {\n";
- OS << TAB_3 "std::cout << \" *Error Details* \" << Result->Details "
+ OS << TAB_3 "llvm::errs() << \" *Error Details* \" << Result->Details "
"<< \" \\n\";\n";
OS << TAB_2 "}\n";
OS << TAB_1 "}\n";
@@ -121,7 +122,7 @@ static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) {
OS << "ol_code_location_t *CodeLocation";
OS << ") {\n";
OS << TAB_1 "currentCodeLocation() = CodeLocation;\n";
- OS << formatv(TAB_1 "{0}_result_t Result = {1}({2});\n\n", PrefixLower,
+ OS << formatv(TAB_1 "{0}_result_t Result = ::{1}({2});\n\n", PrefixLower,
F.getName(), ParamNameList);
OS << TAB_1 "currentCodeLocation() = nullptr;\n";
OS << TAB_1 "return Result;\n";
diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp
index 2a7c63c3dfd1..a964ff09d0f6 100644
--- a/offload/tools/offload-tblgen/PrintGen.cpp
+++ b/offload/tools/offload-tblgen/PrintGen.cpp
@@ -20,24 +20,24 @@
using namespace llvm;
using namespace offload::tblgen;
-constexpr auto PrintEnumHeader =
+constexpr auto PrintTypeHeader =
R"(///////////////////////////////////////////////////////////////////////////////
/// @brief Print operator for the {0} type
-/// @returns std::ostream &
+/// @returns llvm::raw_ostream &
)";
constexpr auto PrintTaggedEnumHeader =
R"(///////////////////////////////////////////////////////////////////////////////
/// @brief Print type-tagged {0} enum value
-/// @returns std::ostream &
+/// @returns llvm::raw_ostream &
)";
static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) {
- OS << formatv(PrintEnumHeader, Enum.getName());
- OS << formatv(
- "inline std::ostream &operator<<(std::ostream &os, enum {0} value) "
- "{{\n" TAB_1 "switch (value) {{\n",
- Enum.getName());
+ OS << formatv(PrintTypeHeader, Enum.getName());
+ OS << formatv("inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, "
+ "enum {0} value) "
+ "{{\n" TAB_1 "switch (value) {{\n",
+ Enum.getName());
for (const auto &Val : Enum.getValues()) {
auto Name = Enum.getEnumValNamePrefix() + "_" + Val.getName();
@@ -56,7 +56,7 @@ static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) {
OS << formatv(PrintTaggedEnumHeader, Enum.getName());
OS << formatv(R"""(template <>
-inline void printTagged(std::ostream &os, const void *ptr, {0} value, size_t size) {{
+inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_t size) {{
if (ptr == NULL) {{
printPtr(os, ptr);
return;
@@ -96,7 +96,7 @@ inline void printTagged(std::ostream &os, const void *ptr, {0} value, size_t siz
static void EmitResultPrint(raw_ostream &OS) {
OS << R""(
-inline std::ostream &operator<<(std::ostream &os,
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const ol_error_struct_t *Err) {
if (Err == nullptr) {
os << "OL_SUCCESS";
@@ -115,7 +115,7 @@ static void EmitFunctionParamStructPrint(const FunctionRec &Func,
}
OS << formatv(R"(
-inline std::ostream &operator<<(std::ostream &os, const struct {0} *params) {{
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct {0} *params) {{
)",
Func.getParamStructName());
@@ -139,6 +139,9 @@ inline std::ostream &operator<<(std::ostream &os, const struct {0} *params) {{
Param.getName(), TypeInfo->first, TypeInfo->second);
} else if (Param.isPointerType() || Param.isHandleType()) {
OS << formatv(TAB_1 "printPtr(os, *params->p{0});\n", Param.getName());
+ } else if (Param.isFptrType()) {
+ OS << formatv(TAB_1 "os << reinterpret_cast<void*>(*params->p{0});\n",
+ Param.getName());
} else {
OS << formatv(TAB_1 "os << *params->p{0};\n", Param.getName());
}
@@ -150,6 +153,32 @@ inline std::ostream &operator<<(std::ostream &os, const struct {0} *params) {{
OS << TAB_1 "return os;\n}\n";
}
+void ProcessStruct(const StructRec &Struct, raw_ostream &OS) {
+ if (Struct.getName() == "ol_error_struct_t") {
+ return;
+ }
+ OS << formatv(PrintTypeHeader, Struct.getName());
+ OS << formatv(R"(
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const struct {0} params) {{
+)",
+ Struct.getName());
+ OS << formatv(TAB_1 "os << \"(struct {0}){{\";\n", Struct.getName());
+ for (const auto &Member : Struct.getMembers()) {
+ OS << formatv(TAB_1 "os << \".{0} = \";\n", Member.getName());
+ if (Member.isPointerType() || Member.isHandleType()) {
+ OS << formatv(TAB_1 "printPtr(os, params.{0});\n", Member.getName());
+ } else {
+ OS << formatv(TAB_1 "os << params.{0};\n", Member.getName());
+ }
+ if (Member.getName() != Struct.getMembers().back().getName()) {
+ OS << TAB_1 "os << \", \";\n";
+ }
+ }
+ OS << TAB_1 "os << \"}\";\n";
+ OS << TAB_1 "return os;\n";
+ OS << "}\n";
+}
+
void EmitOffloadPrintHeader(const RecordKeeper &Records, raw_ostream &OS) {
OS << GenericHeader;
OS << R"""(
@@ -158,11 +187,11 @@ void EmitOffloadPrintHeader(const RecordKeeper &Records, raw_ostream &OS) {
#pragma once
#include <OffloadAPI.h>
-#include <ostream>
+#include <llvm/Support/raw_ostream.h>
-template <typename T> inline ol_result_t printPtr(std::ostream &os, const T *ptr);
-template <typename T> inline void printTagged(std::ostream &os, const void *ptr, T value, size_t size);
+template <typename T> inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr);
+template <typename T> inline void printTagged(llvm::raw_ostream &os, const void *ptr, T value, size_t size);
)""";
// ==========
@@ -180,9 +209,9 @@ template <typename T> inline void printTagged(std::ostream &os, const void *ptr,
// use each other.
OS << "\n";
for (auto *R : Records.getAllDerivedDefinitions("Enum")) {
- OS << formatv(
- "inline std::ostream &operator<<(std::ostream &os, enum {0} value);\n",
- EnumRec{R}.getName());
+ OS << formatv("inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, "
+ "enum {0} value);\n",
+ EnumRec{R}.getName());
}
OS << "\n";
@@ -193,6 +222,11 @@ template <typename T> inline void printTagged(std::ostream &os, const void *ptr,
}
EmitResultPrint(OS);
+ for (auto *R : Records.getAllDerivedDefinitions("Struct")) {
+ StructRec S{R};
+ ProcessStruct(S, OS);
+ }
+
// Emit print functions for the function param structs
for (auto *R : Records.getAllDerivedDefinitions("Function")) {
EmitFunctionParamStructPrint(FunctionRec{R}, OS);
@@ -201,7 +235,7 @@ template <typename T> inline void printTagged(std::ostream &os, const void *ptr,
OS << R"""(
///////////////////////////////////////////////////////////////////////////////
// @brief Print pointer value
-template <typename T> inline ol_result_t printPtr(std::ostream &os, const T *ptr) {
+template <typename T> inline ol_result_t printPtr(llvm::raw_ostream &os, const T *ptr) {
if (ptr == nullptr) {
os << "nullptr";
} else if constexpr (std::is_pointer_v<T>) {
diff --git a/offload/tools/offload-tblgen/RecordTypes.hpp b/offload/tools/offload-tblgen/RecordTypes.hpp
index 0bf3256c525d..686634ed778a 100644
--- a/offload/tools/offload-tblgen/RecordTypes.hpp
+++ b/offload/tools/offload-tblgen/RecordTypes.hpp
@@ -103,6 +103,8 @@ public:
StringRef getType() const { return rec->getValueAsString("type"); }
StringRef getName() const { return rec->getValueAsString("name"); }
StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ bool isPointerType() const { return getType().ends_with('*'); }
+ bool isHandleType() const { return getType().ends_with("_handle_t"); }
private:
const Record *rec;
@@ -153,6 +155,7 @@ public:
StringRef getType() const { return rec->getValueAsString("type"); }
bool isPointerType() const { return getType().ends_with('*'); }
bool isHandleType() const { return getType().ends_with("_handle_t"); }
+ bool isFptrType() const { return getType().ends_with("_cb_t"); }
StringRef getDesc() const { return rec->getValueAsString("desc"); }
bool isIn() const { return dyn_cast<BitInit>(flags->getBit(0))->getValue(); }
bool isOut() const { return dyn_cast<BitInit>(flags->getBit(1))->getValue(); }
@@ -222,6 +225,23 @@ private:
const Record *rec;
};
+class FptrTypedefRec {
+public:
+ explicit FptrTypedefRec(const Record *rec) : rec(rec) {
+ for (auto &Param : rec->getValueAsListOfDefs("params"))
+ params.emplace_back(Param);
+ }
+ StringRef getName() const { return rec->getValueAsString("name"); }
+ StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ StringRef getReturn() const { return rec->getValueAsString("return"); }
+ const std::vector<ParamRec> &getParams() const { return params; }
+
+private:
+ std::vector<ParamRec> params;
+
+ const Record *rec;
+};
+
} // namespace tblgen
} // namespace offload
} // namespace llvm
diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt
index 033ee2b6ec74..c4d628a5a87f 100644
--- a/offload/unittests/OffloadAPI/CMakeLists.txt
+++ b/offload/unittests/OffloadAPI/CMakeLists.txt
@@ -1,16 +1,28 @@
set(PLUGINS_TEST_COMMON LLVMOffload)
set(PLUGINS_TEST_INCLUDE ${LIBOMPTARGET_INCLUDE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/common)
+add_subdirectory(device_code)
+message(${OFFLOAD_TEST_DEVICE_CODE_PATH})
+
add_libompt_unittest("offload.unittests"
${CMAKE_CURRENT_SOURCE_DIR}/common/Environment.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/platform/olGetPlatform.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/platform/olGetPlatformCount.cpp
${CMAKE_CURRENT_SOURCE_DIR}/platform/olGetPlatformInfo.cpp
${CMAKE_CURRENT_SOURCE_DIR}/platform/olGetPlatformInfoSize.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/device/olGetDevice.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/device/olGetDeviceCount.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/device/olIterateDevices.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device/olGetDeviceInfo.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/device/olGetDeviceInfoSize.cpp)
-add_dependencies("offload.unittests" ${PLUGINS_TEST_COMMON})
+ ${CMAKE_CURRENT_SOURCE_DIR}/device/olGetDeviceInfoSize.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/queue/olCreateQueue.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/queue/olWaitQueue.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/queue/olDestroyQueue.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/memory/olMemAlloc.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/memory/olMemFree.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/memory/olMemcpy.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/program/olCreateProgram.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/program/olDestroyProgram.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernel/olGetKernel.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernel/olLaunchKernel.cpp
+ )
+add_dependencies("offload.unittests" ${PLUGINS_TEST_COMMON} LibomptUnitTestsDeviceBins)
+target_compile_definitions("offload.unittests" PRIVATE DEVICE_CODE_PATH="${OFFLOAD_TEST_DEVICE_CODE_PATH}")
target_link_libraries("offload.unittests" PRIVATE ${PLUGINS_TEST_COMMON})
target_include_directories("offload.unittests" PRIVATE ${PLUGINS_TEST_INCLUDE})
diff --git a/offload/unittests/OffloadAPI/common/Environment.cpp b/offload/unittests/OffloadAPI/common/Environment.cpp
index f07a66cda218..88cf33e45f3d 100644
--- a/offload/unittests/OffloadAPI/common/Environment.cpp
+++ b/offload/unittests/OffloadAPI/common/Environment.cpp
@@ -9,7 +9,9 @@
#include "Environment.hpp"
#include "Fixtures.hpp"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/MemoryBuffer.h"
#include <OffloadAPI.h>
+#include <fstream>
using namespace llvm;
@@ -25,8 +27,8 @@ static cl::opt<std::string>
SelectedPlatform("platform", cl::desc("Only test the specified platform"),
cl::value_desc("platform"));
-std::ostream &operator<<(std::ostream &Out,
- const ol_platform_handle_t &Platform) {
+raw_ostream &operator<<(raw_ostream &Out,
+ const ol_platform_handle_t &Platform) {
size_t Size;
olGetPlatformInfoSize(Platform, OL_PLATFORM_INFO_NAME, &Size);
std::vector<char> Name(Size);
@@ -35,62 +37,132 @@ std::ostream &operator<<(std::ostream &Out,
return Out;
}
-std::ostream &operator<<(std::ostream &Out,
- const std::vector<ol_platform_handle_t> &Platforms) {
- for (auto Platform : Platforms) {
- Out << "\n * \"" << Platform << "\"";
- }
- return Out;
-}
+void printPlatforms() {
+ SmallDenseSet<ol_platform_handle_t> Platforms;
+ using DeviceVecT = SmallVector<ol_device_handle_t, 8>;
+ DeviceVecT Devices{};
-const std::vector<ol_platform_handle_t> &TestEnvironment::getPlatforms() {
- static std::vector<ol_platform_handle_t> Platforms{};
+ olIterateDevices(
+ [](ol_device_handle_t D, void *Data) {
+ static_cast<DeviceVecT *>(Data)->push_back(D);
+ return true;
+ },
+ &Devices);
- if (Platforms.empty()) {
- uint32_t PlatformCount = 0;
- olGetPlatformCount(&PlatformCount);
- if (PlatformCount > 0) {
- Platforms.resize(PlatformCount);
- olGetPlatform(PlatformCount, Platforms.data());
- }
+ for (auto &Device : Devices) {
+ ol_platform_handle_t Platform;
+ olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM, sizeof(Platform),
+ &Platform);
+ Platforms.insert(Platform);
}
- return Platforms;
+ for (const auto &Platform : Platforms) {
+ errs() << " * " << Platform << "\n";
+ }
}
-// Get a single platform, which may be selected by the user.
-ol_platform_handle_t TestEnvironment::getPlatform() {
- static ol_platform_handle_t Platform = nullptr;
- const auto &Platforms = getPlatforms();
+ol_device_handle_t TestEnvironment::getDevice() {
+ static ol_device_handle_t Device = nullptr;
- if (!Platform) {
+ if (!Device) {
if (SelectedPlatform != "") {
- for (const auto CandidatePlatform : Platforms) {
- std::stringstream PlatformName;
- PlatformName << CandidatePlatform;
- if (SelectedPlatform == PlatformName.str()) {
- Platform = CandidatePlatform;
- return Platform;
- }
+ olIterateDevices(
+ [](ol_device_handle_t D, void *Data) {
+ ol_platform_handle_t Platform;
+ olGetDeviceInfo(D, OL_DEVICE_INFO_PLATFORM, sizeof(Platform),
+ &Platform);
+
+ std::string PlatformName;
+ raw_string_ostream S(PlatformName);
+ S << Platform;
+
+ if (PlatformName == SelectedPlatform) {
+ *(static_cast<ol_device_handle_t *>(Data)) = D;
+ return false;
+ }
+
+ return true;
+ },
+ &Device);
+
+ if (Device == nullptr) {
+ errs() << "No device found with the platform \"" << SelectedPlatform
+ << "\". Choose from:"
+ << "\n";
+ printPlatforms();
+ std::exit(1);
}
- std::cout << "No platform found with the name \"" << SelectedPlatform
- << "\". Choose from:" << Platforms << "\n";
- std::exit(1);
} else {
- // Pick a single platform. We prefer one that has available devices, but
- // just pick the first initially in case none have any devices.
- Platform = Platforms[0];
- for (auto CandidatePlatform : Platforms) {
- uint32_t NumDevices = 0;
- if (olGetDeviceCount(CandidatePlatform, &NumDevices) == OL_SUCCESS) {
- if (NumDevices > 0) {
- Platform = CandidatePlatform;
- break;
- }
- }
- }
+ olIterateDevices(
+ [](ol_device_handle_t D, void *Data) {
+ *(static_cast<ol_device_handle_t *>(Data)) = D;
+ return false;
+ },
+ &Device);
}
}
- return Platform;
+ return Device;
+}
+
+ol_device_handle_t TestEnvironment::getHostDevice() {
+ static ol_device_handle_t HostDevice = nullptr;
+
+ if (!HostDevice) {
+ olIterateDevices(
+ [](ol_device_handle_t D, void *Data) {
+ ol_platform_handle_t Platform;
+ olGetDeviceInfo(D, OL_DEVICE_INFO_PLATFORM, sizeof(Platform),
+ &Platform);
+ ol_platform_backend_t Backend;
+ olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND, sizeof(Backend),
+ &Backend);
+
+ if (Backend == OL_PLATFORM_BACKEND_HOST) {
+ *(static_cast<ol_device_handle_t *>(Data)) = D;
+ return false;
+ }
+
+ return true;
+ },
+ &HostDevice);
+ }
+
+ return HostDevice;
+}
+
+// TODO: Allow overriding via cmd line arg
+const std::string DeviceBinsDirectory = DEVICE_CODE_PATH;
+
+bool TestEnvironment::loadDeviceBinary(
+ const std::string &BinaryName, ol_device_handle_t Device,
+ std::unique_ptr<MemoryBuffer> &BinaryOut) {
+
+ // Get the platform type
+ ol_platform_handle_t Platform;
+ olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM, sizeof(Platform), &Platform);
+ ol_platform_backend_t Backend = OL_PLATFORM_BACKEND_UNKNOWN;
+ olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND, sizeof(Backend),
+ &Backend);
+ std::string FileExtension;
+ if (Backend == OL_PLATFORM_BACKEND_AMDGPU) {
+ FileExtension = ".amdgpu.bin";
+ } else if (Backend == OL_PLATFORM_BACKEND_CUDA) {
+ FileExtension = ".nvptx64.bin";
+ } else {
+ errs() << "Unsupported platform type for a device binary test.\n";
+ return false;
+ }
+
+ std::string SourcePath =
+ DeviceBinsDirectory + "/" + BinaryName + FileExtension;
+
+ auto SourceFile = MemoryBuffer::getFile(SourcePath, false, false);
+ if (!SourceFile) {
+ errs() << "failed to read device binary file: " + SourcePath;
+ return false;
+ }
+
+ BinaryOut = std::move(SourceFile.get());
+ return true;
}
diff --git a/offload/unittests/OffloadAPI/common/Environment.hpp b/offload/unittests/OffloadAPI/common/Environment.hpp
index 6dba2381eb0b..a0bf688b4551 100644
--- a/offload/unittests/OffloadAPI/common/Environment.hpp
+++ b/offload/unittests/OffloadAPI/common/Environment.hpp
@@ -8,10 +8,13 @@
#pragma once
+#include "llvm/Support/MemoryBuffer.h"
#include <OffloadAPI.h>
#include <gtest/gtest.h>
namespace TestEnvironment {
-const std::vector<ol_platform_handle_t> &getPlatforms();
-ol_platform_handle_t getPlatform();
+ol_device_handle_t getDevice();
+ol_device_handle_t getHostDevice();
+bool loadDeviceBinary(const std::string &BinaryName, ol_device_handle_t Device,
+ std::unique_ptr<llvm::MemoryBuffer> &BinaryOut);
} // namespace TestEnvironment
diff --git a/offload/unittests/OffloadAPI/common/Fixtures.hpp b/offload/unittests/OffloadAPI/common/Fixtures.hpp
index 410a435dee1b..028ebf43d5cd 100644
--- a/offload/unittests/OffloadAPI/common/Fixtures.hpp
+++ b/offload/unittests/OffloadAPI/common/Fixtures.hpp
@@ -27,6 +27,14 @@
} while (0)
#endif
+#ifndef ASSERT_ANY_ERROR
+#define ASSERT_ANY_ERROR(ACTUAL) \
+ do { \
+ ol_result_t Res = ACTUAL; \
+ ASSERT_TRUE(Res); \
+ } while (0)
+#endif
+
#define RETURN_ON_FATAL_FAILURE(...) \
__VA_ARGS__; \
if (this->HasFatalFailure() || this->IsSkipped()) { \
@@ -34,31 +42,81 @@
} \
(void)0
-struct offloadTest : ::testing::Test {
- // No special behavior now, but just in case we need to override it in future
+struct OffloadTest : ::testing::Test {
+ ol_device_handle_t Host = TestEnvironment::getHostDevice();
};
-struct offloadPlatformTest : offloadTest {
+struct OffloadDeviceTest : OffloadTest {
void SetUp() override {
- RETURN_ON_FATAL_FAILURE(offloadTest::SetUp());
+ RETURN_ON_FATAL_FAILURE(OffloadTest::SetUp());
+
+ Device = TestEnvironment::getDevice();
+ if (Device == nullptr)
+ GTEST_SKIP() << "No available devices.";
+ }
- Platform = TestEnvironment::getPlatform();
+ ol_device_handle_t Device = nullptr;
+};
+
+struct OffloadPlatformTest : OffloadDeviceTest {
+ void SetUp() override {
+ RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp());
+
+ ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM,
+ sizeof(Platform), &Platform));
ASSERT_NE(Platform, nullptr);
}
- ol_platform_handle_t Platform;
+ ol_platform_handle_t Platform = nullptr;
+};
+
+// Fixture for a generic program test. If you want a different program, use
+// offloadQueueTest and create your own program handle with the binary you want.
+struct OffloadProgramTest : OffloadDeviceTest {
+ void SetUp() override {
+ RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp());
+ ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin));
+ ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
+ ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(),
+ DeviceBin->getBufferSize(), &Program));
+ }
+
+ void TearDown() override {
+ if (Program) {
+ olDestroyProgram(Program);
+ }
+ RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::TearDown());
+ }
+
+ ol_program_handle_t Program = nullptr;
+ std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
};
-struct offloadDeviceTest : offloadPlatformTest {
+struct OffloadKernelTest : OffloadProgramTest {
void SetUp() override {
- RETURN_ON_FATAL_FAILURE(offloadPlatformTest::SetUp());
+ RETURN_ON_FATAL_FAILURE(OffloadProgramTest::SetUp());
+ ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel));
+ }
+
+ void TearDown() override {
+ RETURN_ON_FATAL_FAILURE(OffloadProgramTest::TearDown());
+ }
+
+ ol_kernel_handle_t Kernel = nullptr;
+};
+
+struct OffloadQueueTest : OffloadDeviceTest {
+ void SetUp() override {
+ RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp());
+ ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
+ }
- uint32_t NumDevices;
- ASSERT_SUCCESS(olGetDeviceCount(Platform, &NumDevices));
- if (NumDevices == 0)
- GTEST_SKIP() << "No available devices on this platform.";
- ASSERT_SUCCESS(olGetDevice(Platform, 1, &Device));
+ void TearDown() override {
+ if (Queue) {
+ olDestroyQueue(Queue);
+ }
+ RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::TearDown());
}
- ol_device_handle_t Device;
+ ol_queue_handle_t Queue = nullptr;
};
diff --git a/offload/unittests/OffloadAPI/device/olGetDevice.cpp b/offload/unittests/OffloadAPI/device/olGetDevice.cpp
deleted file mode 100644
index 68d4682dd335..000000000000
--- a/offload/unittests/OffloadAPI/device/olGetDevice.cpp
+++ /dev/null
@@ -1,39 +0,0 @@
-//===------- Offload API tests - olGetDevice -------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "../common/Fixtures.hpp"
-#include <OffloadAPI.h>
-#include <gtest/gtest.h>
-
-using olGetDeviceTest = offloadPlatformTest;
-
-TEST_F(olGetDeviceTest, Success) {
- uint32_t Count = 0;
- ASSERT_SUCCESS(olGetDeviceCount(Platform, &Count));
- if (Count == 0)
- GTEST_SKIP() << "No available devices on this platform.";
-
- std::vector<ol_device_handle_t> Devices(Count);
- ASSERT_SUCCESS(olGetDevice(Platform, Count, Devices.data()));
- for (auto Device : Devices) {
- ASSERT_NE(nullptr, Device);
- }
-}
-
-TEST_F(olGetDeviceTest, SuccessSubsetOfDevices) {
- uint32_t Count;
- ASSERT_SUCCESS(olGetDeviceCount(Platform, &Count));
- if (Count < 2)
- GTEST_SKIP() << "Only one device is available on this platform.";
-
- std::vector<ol_device_handle_t> Devices(Count - 1);
- ASSERT_SUCCESS(olGetDevice(Platform, Count - 1, Devices.data()));
- for (auto Device : Devices) {
- ASSERT_NE(nullptr, Device);
- }
-}
diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceCount.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceCount.cpp
deleted file mode 100644
index ef377d671bf6..000000000000
--- a/offload/unittests/OffloadAPI/device/olGetDeviceCount.cpp
+++ /dev/null
@@ -1,28 +0,0 @@
-//===------- Offload API tests - olGetDeviceCount --------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "../common/Fixtures.hpp"
-#include <OffloadAPI.h>
-#include <gtest/gtest.h>
-
-using olGetDeviceCountTest = offloadPlatformTest;
-
-TEST_F(olGetDeviceCountTest, Success) {
- uint32_t Count = 0;
- ASSERT_SUCCESS(olGetDeviceCount(Platform, &Count));
-}
-
-TEST_F(olGetDeviceCountTest, InvalidNullPlatform) {
- uint32_t Count = 0;
- ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olGetDeviceCount(nullptr, &Count));
-}
-
-TEST_F(olGetDeviceCountTest, InvalidNullPointer) {
- ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
- olGetDeviceCount(Platform, nullptr));
-}
diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
index c936802fb1e4..f71f60a2c057 100644
--- a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
+++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
@@ -11,10 +11,10 @@
#include <OffloadAPI.h>
#include <gtest/gtest.h>
-struct olGetDeviceInfoTest : offloadDeviceTest,
+struct olGetDeviceInfoTest : OffloadDeviceTest,
::testing::WithParamInterface<ol_device_info_t> {
- void SetUp() override { RETURN_ON_FATAL_FAILURE(offloadDeviceTest::SetUp()); }
+ void SetUp() override { RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp()); }
};
INSTANTIATE_TEST_SUITE_P(
@@ -37,7 +37,7 @@ TEST_P(olGetDeviceInfoTest, Success) {
if (InfoType == OL_DEVICE_INFO_PLATFORM) {
auto *ReturnedPlatform =
reinterpret_cast<ol_platform_handle_t *>(InfoData.data());
- ASSERT_EQ(Platform, *ReturnedPlatform);
+ ASSERT_NE(nullptr, *ReturnedPlatform);
}
}
diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
index 9e792d1c3e25..b4b5042dbfd8 100644
--- a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
+++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
@@ -12,10 +12,10 @@
#include "olDeviceInfo.hpp"
struct olGetDeviceInfoSizeTest
- : offloadDeviceTest,
+ : OffloadDeviceTest,
::testing::WithParamInterface<ol_device_info_t> {
- void SetUp() override { RETURN_ON_FATAL_FAILURE(offloadDeviceTest::SetUp()); }
+ void SetUp() override { RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp()); }
};
// TODO: We could autogenerate the list of enum values
diff --git a/offload/unittests/OffloadAPI/device/olIterateDevices.cpp b/offload/unittests/OffloadAPI/device/olIterateDevices.cpp
new file mode 100644
index 000000000000..5bdbd17e9e97
--- /dev/null
+++ b/offload/unittests/OffloadAPI/device/olIterateDevices.cpp
@@ -0,0 +1,45 @@
+//===------- Offload API tests - olIterateDevices -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+using olIterateDevicesTest = OffloadTest;
+
+TEST_F(olIterateDevicesTest, SuccessEmptyCallback) {
+ ASSERT_SUCCESS(olIterateDevices(
+ [](ol_device_handle_t, void *) { return false; }, nullptr));
+}
+
+TEST_F(olIterateDevicesTest, SuccessGetDevice) {
+ uint32_t DeviceCount = 0;
+ ol_device_handle_t Device = nullptr;
+
+ ASSERT_SUCCESS(olIterateDevices(
+ [](ol_device_handle_t, void *Data) {
+ auto Count = static_cast<uint32_t *>(Data);
+ *Count += 1;
+ return false;
+ },
+ &DeviceCount));
+
+ if (DeviceCount == 0) {
+ GTEST_SKIP() << "No available devices.";
+ }
+
+ ASSERT_SUCCESS(olIterateDevices(
+ [](ol_device_handle_t D, void *Data) {
+ auto DevicePtr = static_cast<ol_device_handle_t *>(Data);
+ *DevicePtr = D;
+ return true;
+ },
+ &Device));
+
+ ASSERT_NE(Device, nullptr);
+}
diff --git a/offload/unittests/OffloadAPI/device_code/CMakeLists.txt b/offload/unittests/OffloadAPI/device_code/CMakeLists.txt
new file mode 100644
index 000000000000..ded555b3a3cf
--- /dev/null
+++ b/offload/unittests/OffloadAPI/device_code/CMakeLists.txt
@@ -0,0 +1,67 @@
+macro(add_offload_test_device_code test_filename test_name)
+ set(SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/${test_filename})
+
+ # Build for NVPTX
+ if(OFFLOAD_TEST_TARGET_NVIDIA)
+ set(BIN_PATH ${CMAKE_CURRENT_BINARY_DIR}/${test_name}.nvptx64.bin)
+ add_custom_command(OUTPUT ${BIN_PATH}
+ COMMAND
+ ${CMAKE_C_COMPILER} --target=nvptx64-nvidia-cuda
+ -march=${LIBOMPTARGET_DEP_CUDA_ARCH}
+ --cuda-path=${CUDA_ROOT}
+ ${SRC_PATH} -o ${BIN_PATH}
+ DEPENDS ${SRC_PATH}
+ )
+ list(APPEND BIN_PATHS ${BIN_PATH})
+ endif()
+
+ # Build for AMDGPU
+ if(OFFLOAD_TEST_TARGET_AMDGPU)
+ set(BIN_PATH ${CMAKE_CURRENT_BINARY_DIR}/${test_name}.amdgpu.bin)
+ add_custom_command(OUTPUT ${BIN_PATH}
+ COMMAND
+ ${CMAKE_C_COMPILER} --target=amdgcn-amd-amdhsa -nogpulib
+ -mcpu=${LIBOMPTARGET_DEP_AMDGPU_ARCH}
+ ${SRC_PATH} -o ${BIN_PATH}
+ DEPENDS ${SRC_PATH}
+ )
+ list(APPEND BIN_PATHS ${BIN_PATH})
+ endif()
+
+ # TODO: Build for host CPU
+endmacro()
+
+
+# Decide what device targets to build for. LibomptargetGetDependencies is
+# included at the top-level so the GPUs present on the system are already
+# detected.
+set(OFFLOAD_TESTS_FORCE_NVIDIA_ARCH "" CACHE STRING
+ "Force building of NVPTX device code for Offload unit tests with the given arch, e.g. sm_61")
+set(OFFLOAD_TESTS_FORCE_AMDGPU_ARCH "" CACHE STRING
+ "Force building of AMDGPU device code for Offload unit tests with the given arch, e.g. gfx1030")
+
+find_package(CUDAToolkit QUIET)
+if(CUDAToolkit_FOUND)
+ get_filename_component(CUDA_ROOT "${CUDAToolkit_BIN_DIR}" DIRECTORY ABSOLUTE)
+endif()
+if (OFFLOAD_TESTS_FORCE_NVIDIA_ARCH)
+ set(LIBOMPTARGET_DEP_CUDA_ARCH ${OFFLOAD_TESTS_FORCE_NVIDIA_ARCH})
+ set(OFFLOAD_TEST_TARGET_NVIDIA ON)
+elseif (LIBOMPTARGET_FOUND_NVIDIA_GPU AND CUDA_ROOT AND "cuda" IN_LIST LIBOMPTARGET_PLUGINS_TO_BUILD)
+ set(OFFLOAD_TEST_TARGET_NVIDIA ON)
+endif()
+
+if (OFFLOAD_TESTS_FORCE_AMDGPU_ARCH)
+ set(LIBOMPTARGET_DEP_AMDGPU_ARCH ${OFFLOAD_TESTS_FORCE_AMDGPU_ARCH})
+ set(OFFLOAD_TEST_TARGET_AMDGPU ON)
+elseif (LIBOMPTARGET_FOUND_AMDGPU_GPU AND "amdgpu" IN_LIST LIBOMPTARGET_PLUGINS_TO_BUILD)
+ list(GET LIBOMPTARGET_AMDGPU_DETECTED_ARCH_LIST 0 LIBOMPTARGET_DEP_AMDGPU_ARCH)
+ set(OFFLOAD_TEST_TARGET_AMDGPU ON)
+endif()
+
+add_offload_test_device_code(foo.c foo)
+add_offload_test_device_code(bar.c bar)
+
+add_custom_target(LibomptUnitTestsDeviceBins DEPENDS ${BIN_PATHS})
+
+set(OFFLOAD_TEST_DEVICE_CODE_PATH ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE)
diff --git a/offload/unittests/OffloadAPI/device_code/bar.c b/offload/unittests/OffloadAPI/device_code/bar.c
new file mode 100644
index 000000000000..786aa2f5d61e
--- /dev/null
+++ b/offload/unittests/OffloadAPI/device_code/bar.c
@@ -0,0 +1,5 @@
+#include <gpuintrin.h>
+
+__gpu_kernel void foo(int *out) {
+ out[__gpu_thread_id(0)] = __gpu_thread_id(0) + 1;
+}
diff --git a/offload/unittests/OffloadAPI/device_code/foo.c b/offload/unittests/OffloadAPI/device_code/foo.c
new file mode 100644
index 000000000000..5bc893961d49
--- /dev/null
+++ b/offload/unittests/OffloadAPI/device_code/foo.c
@@ -0,0 +1,5 @@
+#include <gpuintrin.h>
+
+__gpu_kernel void foo(int *out) {
+ out[__gpu_thread_id(0)] = __gpu_thread_id(0);
+}
diff --git a/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp b/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp
new file mode 100644
index 000000000000..f320d191ad58
--- /dev/null
+++ b/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp
@@ -0,0 +1,30 @@
+//===------- Offload API tests - olGetKernel ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+using olGetKernelTest = OffloadProgramTest;
+
+TEST_F(olGetKernelTest, Success) {
+ ol_kernel_handle_t Kernel = nullptr;
+ ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel));
+ ASSERT_NE(Kernel, nullptr);
+}
+
+TEST_F(olGetKernelTest, InvalidNullProgram) {
+ ol_kernel_handle_t Kernel = nullptr;
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
+ olGetKernel(nullptr, "foo", &Kernel));
+}
+
+TEST_F(olGetKernelTest, InvalidNullKernelPointer) {
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
+ olGetKernel(Program, "foo", nullptr));
+}
diff --git a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
new file mode 100644
index 000000000000..2e51a48b9a7a
--- /dev/null
+++ b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
@@ -0,0 +1,83 @@
+//===------- Offload API tests - olLaunchKernel --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+struct olLaunchKernelTest : OffloadQueueTest {
+ void SetUp() override {
+ RETURN_ON_FATAL_FAILURE(OffloadQueueTest::SetUp());
+ ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin));
+ ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
+ ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(),
+ DeviceBin->getBufferSize(), &Program));
+ ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel));
+ LaunchArgs.Dimensions = 1;
+ LaunchArgs.GroupSizeX = 64;
+ LaunchArgs.GroupSizeY = 1;
+ LaunchArgs.GroupSizeZ = 1;
+
+ LaunchArgs.NumGroupsX = 1;
+ LaunchArgs.NumGroupsY = 1;
+ LaunchArgs.NumGroupsZ = 1;
+
+ LaunchArgs.DynSharedMemory = 0;
+ }
+
+ void TearDown() override {
+ if (Program) {
+ olDestroyProgram(Program);
+ }
+ RETURN_ON_FATAL_FAILURE(OffloadQueueTest::TearDown());
+ }
+
+ std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
+ ol_program_handle_t Program = nullptr;
+ ol_kernel_handle_t Kernel = nullptr;
+ ol_kernel_launch_size_args_t LaunchArgs{};
+};
+
+TEST_F(olLaunchKernelTest, Success) {
+ void *Mem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 64, &Mem));
+ struct {
+ void *Mem;
+ } Args{Mem};
+
+ ASSERT_SUCCESS(olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args),
+ &LaunchArgs, nullptr));
+
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+
+ int *Data = (int *)Mem;
+ for (int i = 0; i < 64; i++) {
+ ASSERT_EQ(Data[i], i);
+ }
+
+ ASSERT_SUCCESS(olMemFree(Mem));
+}
+
+TEST_F(olLaunchKernelTest, SuccessSynchronous) {
+ void *Mem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 64, &Mem));
+
+ struct {
+ void *Mem;
+ } Args{Mem};
+
+ ASSERT_SUCCESS(olLaunchKernel(nullptr, Device, Kernel, &Args, sizeof(Args),
+ &LaunchArgs, nullptr));
+
+ int *Data = (int *)Mem;
+ for (int i = 0; i < 64; i++) {
+ ASSERT_EQ(Data[i], i);
+ }
+
+ ASSERT_SUCCESS(olMemFree(Mem));
+}
diff --git a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
new file mode 100644
index 000000000000..580ba022954e
--- /dev/null
+++ b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
@@ -0,0 +1,45 @@
+//===------- Offload API tests - olMemAlloc -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+using olMemAllocTest = OffloadDeviceTest;
+
+TEST_F(olMemAllocTest, SuccessAllocManaged) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+ ASSERT_NE(Alloc, nullptr);
+ olMemFree(Alloc);
+}
+
+TEST_F(olMemAllocTest, SuccessAllocHost) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+ ASSERT_NE(Alloc, nullptr);
+ olMemFree(Alloc);
+}
+
+TEST_F(olMemAllocTest, SuccessAllocDevice) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+ ASSERT_NE(Alloc, nullptr);
+ olMemFree(Alloc);
+}
+
+TEST_F(olMemAllocTest, InvalidNullDevice) {
+ void *Alloc = nullptr;
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
+ olMemAlloc(nullptr, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+}
+
+TEST_F(olMemAllocTest, InvalidNullOutPtr) {
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
+ olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, nullptr));
+}
diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
new file mode 100644
index 000000000000..99ad389f27fb
--- /dev/null
+++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
@@ -0,0 +1,38 @@
+//===------- Offload API tests - olMemFree --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+using olMemFreeTest = OffloadDeviceTest;
+
+TEST_F(olMemFreeTest, SuccessFreeManaged) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+ ASSERT_SUCCESS(olMemFree(Alloc));
+}
+
+TEST_F(olMemFreeTest, SuccessFreeHost) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+ ASSERT_SUCCESS(olMemFree(Alloc));
+}
+
+TEST_F(olMemFreeTest, SuccessFreeDevice) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+ ASSERT_SUCCESS(olMemFree(Alloc));
+}
+
+TEST_F(olMemFreeTest, InvalidNullPtr) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(nullptr));
+ ASSERT_SUCCESS(olMemFree(Alloc));
+}
diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
new file mode 100644
index 000000000000..b00ded9b53ed
--- /dev/null
+++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
@@ -0,0 +1,106 @@
+//===------- Offload API tests - olMemcpy --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+using olMemcpyTest = OffloadQueueTest;
+
+TEST_F(olMemcpyTest, SuccessHtoD) {
+ constexpr size_t Size = 1024;
+ void *Alloc;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, Size, &Alloc));
+ std::vector<uint8_t> Input(Size, 42);
+ ASSERT_SUCCESS(
+ olMemcpy(Queue, Alloc, Device, Input.data(), Host, Size, nullptr));
+ olWaitQueue(Queue);
+ olMemFree(Alloc);
+}
+
+TEST_F(olMemcpyTest, SuccessDtoH) {
+ constexpr size_t Size = 1024;
+ void *Alloc;
+ std::vector<uint8_t> Input(Size, 42);
+ std::vector<uint8_t> Output(Size, 0);
+
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, Size, &Alloc));
+ ASSERT_SUCCESS(
+ olMemcpy(Queue, Alloc, Device, Input.data(), Host, Size, nullptr));
+ ASSERT_SUCCESS(
+ olMemcpy(Queue, Output.data(), Host, Alloc, Device, Size, nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ for (uint8_t Val : Output) {
+ ASSERT_EQ(Val, 42);
+ }
+ ASSERT_SUCCESS(olMemFree(Alloc));
+}
+
+TEST_F(olMemcpyTest, SuccessDtoD) {
+ constexpr size_t Size = 1024;
+ void *AllocA;
+ void *AllocB;
+ std::vector<uint8_t> Input(Size, 42);
+ std::vector<uint8_t> Output(Size, 0);
+
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, Size, &AllocA));
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, Size, &AllocB));
+ ASSERT_SUCCESS(
+ olMemcpy(Queue, AllocA, Device, Input.data(), Host, Size, nullptr));
+ ASSERT_SUCCESS(
+ olMemcpy(Queue, AllocB, Device, AllocA, Device, Size, nullptr));
+ ASSERT_SUCCESS(
+ olMemcpy(Queue, Output.data(), Host, AllocB, Device, Size, nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ for (uint8_t Val : Output) {
+ ASSERT_EQ(Val, 42);
+ }
+ ASSERT_SUCCESS(olMemFree(AllocA));
+ ASSERT_SUCCESS(olMemFree(AllocB));
+}
+
+TEST_F(olMemcpyTest, SuccessHtoHSync) {
+ constexpr size_t Size = 1024;
+ std::vector<uint8_t> Input(Size, 42);
+ std::vector<uint8_t> Output(Size, 0);
+
+ ASSERT_SUCCESS(olMemcpy(nullptr, Output.data(), Host, Input.data(), Host,
+ Size, nullptr));
+
+ for (uint8_t Val : Output) {
+ ASSERT_EQ(Val, 42);
+ }
+}
+
+TEST_F(olMemcpyTest, SuccessDtoHSync) {
+ constexpr size_t Size = 1024;
+ void *Alloc;
+ std::vector<uint8_t> Input(Size, 42);
+ std::vector<uint8_t> Output(Size, 0);
+
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, Size, &Alloc));
+ ASSERT_SUCCESS(
+ olMemcpy(nullptr, Alloc, Device, Input.data(), Host, Size, nullptr));
+ ASSERT_SUCCESS(
+ olMemcpy(nullptr, Output.data(), Host, Alloc, Device, Size, nullptr));
+ for (uint8_t Val : Output) {
+ ASSERT_EQ(Val, 42);
+ }
+ ASSERT_SUCCESS(olMemFree(Alloc));
+}
+
+TEST_F(olMemcpyTest, SuccessSizeZero) {
+ constexpr size_t Size = 1024;
+ std::vector<uint8_t> Input(Size, 42);
+ std::vector<uint8_t> Output(Size, 0);
+
+ // As with std::memcpy, size 0 is allowed. Keep all other arguments valid even
+ // if they aren't used.
+ ASSERT_SUCCESS(
+ olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr));
+}
diff --git a/offload/unittests/OffloadAPI/platform/olGetPlatform.cpp b/offload/unittests/OffloadAPI/platform/olGetPlatform.cpp
deleted file mode 100644
index 4a2f9e8ac774..000000000000
--- a/offload/unittests/OffloadAPI/platform/olGetPlatform.cpp
+++ /dev/null
@@ -1,28 +0,0 @@
-//===------- Offload API tests - olGetPlatform -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "../common/Fixtures.hpp"
-#include <OffloadAPI.h>
-#include <gtest/gtest.h>
-
-using olGetPlatformTest = offloadTest;
-
-TEST_F(olGetPlatformTest, Success) {
- uint32_t PlatformCount;
- ASSERT_SUCCESS(olGetPlatformCount(&PlatformCount));
- std::vector<ol_platform_handle_t> Platforms(PlatformCount);
- ASSERT_SUCCESS(olGetPlatform(PlatformCount, Platforms.data()));
-}
-
-TEST_F(olGetPlatformTest, InvalidNumEntries) {
- uint32_t PlatformCount;
- ASSERT_SUCCESS(olGetPlatformCount(&PlatformCount));
- std::vector<ol_platform_handle_t> Platforms(PlatformCount);
- ASSERT_ERROR(OL_ERRC_INVALID_SIZE,
- olGetPlatform(PlatformCount + 1, Platforms.data()));
-}
diff --git a/offload/unittests/OffloadAPI/platform/olGetPlatformInfo.cpp b/offload/unittests/OffloadAPI/platform/olGetPlatformInfo.cpp
index c646bdc50b7d..bd6ad3f84e77 100644
--- a/offload/unittests/OffloadAPI/platform/olGetPlatformInfo.cpp
+++ b/offload/unittests/OffloadAPI/platform/olGetPlatformInfo.cpp
@@ -12,7 +12,7 @@
#include "olPlatformInfo.hpp"
struct olGetPlatformInfoTest
- : offloadPlatformTest,
+ : OffloadPlatformTest,
::testing::WithParamInterface<ol_platform_info_t> {};
INSTANTIATE_TEST_SUITE_P(
diff --git a/offload/unittests/OffloadAPI/platform/olGetPlatformInfoSize.cpp b/offload/unittests/OffloadAPI/platform/olGetPlatformInfoSize.cpp
index 7c9274082e8e..5f6067e2e259 100644
--- a/offload/unittests/OffloadAPI/platform/olGetPlatformInfoSize.cpp
+++ b/offload/unittests/OffloadAPI/platform/olGetPlatformInfoSize.cpp
@@ -12,7 +12,7 @@
#include "olPlatformInfo.hpp"
struct olGetPlatformInfoSizeTest
- : offloadPlatformTest,
+ : OffloadPlatformTest,
::testing::WithParamInterface<ol_platform_info_t> {};
INSTANTIATE_TEST_SUITE_P(
diff --git a/offload/unittests/OffloadAPI/platform/olPlatformInfo.hpp b/offload/unittests/OffloadAPI/platform/olPlatformInfo.hpp
index d49cdb90d321..f61bca0cf52f 100644
--- a/offload/unittests/OffloadAPI/platform/olPlatformInfo.hpp
+++ b/offload/unittests/OffloadAPI/platform/olPlatformInfo.hpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#pragma once
+#include <unordered_map>
#include <vector>
// TODO: We could autogenerate these
diff --git a/offload/unittests/OffloadAPI/program/olCreateProgram.cpp b/offload/unittests/OffloadAPI/program/olCreateProgram.cpp
new file mode 100644
index 000000000000..c586c0459620
--- /dev/null
+++ b/offload/unittests/OffloadAPI/program/olCreateProgram.cpp
@@ -0,0 +1,27 @@
+//===------- Offload API tests - olCreateProgram --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+using olCreateProgramTest = OffloadDeviceTest;
+
+TEST_F(olCreateProgramTest, Success) {
+
+ std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
+ ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin));
+ ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
+
+ ol_program_handle_t Program;
+ ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(),
+ DeviceBin->getBufferSize(), &Program));
+ ASSERT_NE(Program, nullptr);
+
+ ASSERT_SUCCESS(olDestroyProgram(Program));
+}
diff --git a/offload/unittests/OffloadAPI/platform/olGetPlatformCount.cpp b/offload/unittests/OffloadAPI/program/olDestroyProgram.cpp
index 15b4b6abcd70..ea21dadb59c4 100644
--- a/offload/unittests/OffloadAPI/platform/olGetPlatformCount.cpp
+++ b/offload/unittests/OffloadAPI/program/olDestroyProgram.cpp
@@ -1,4 +1,4 @@
-//===------- Offload API tests - olGetPlatformCount ------------------===//
+//===------- Offload API tests - olDestroyProgram -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,13 +10,13 @@
#include <OffloadAPI.h>
#include <gtest/gtest.h>
-using olGetPlatformCountTest = offloadTest;
+using olDestroyProgramTest = OffloadProgramTest;
-TEST_F(olGetPlatformCountTest, Success) {
- uint32_t PlatformCount;
- ASSERT_SUCCESS(olGetPlatformCount(&PlatformCount));
+TEST_F(olDestroyProgramTest, Success) {
+ ASSERT_SUCCESS(olDestroyProgram(Program));
+ Program = nullptr;
}
-TEST_F(olGetPlatformCountTest, InvalidNullPointer) {
- ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olGetPlatformCount(nullptr));
+TEST_F(olDestroyProgramTest, InvalidNullHandle) {
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olDestroyProgram(nullptr));
}
diff --git a/offload/unittests/OffloadAPI/queue/olCreateQueue.cpp b/offload/unittests/OffloadAPI/queue/olCreateQueue.cpp
new file mode 100644
index 000000000000..0534debed055
--- /dev/null
+++ b/offload/unittests/OffloadAPI/queue/olCreateQueue.cpp
@@ -0,0 +1,28 @@
+//===------- Offload API tests - olCreateQueue ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+using olCreateQueueTest = OffloadDeviceTest;
+
+TEST_F(olCreateQueueTest, Success) {
+ ol_queue_handle_t Queue = nullptr;
+ ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
+ ASSERT_NE(Queue, nullptr);
+}
+
+TEST_F(olCreateQueueTest, InvalidNullHandleDevice) {
+ ol_queue_handle_t Queue = nullptr;
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olCreateQueue(nullptr, &Queue));
+}
+
+TEST_F(olCreateQueueTest, InvalidNullPointerQueue) {
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olCreateQueue(Device, nullptr));
+}
diff --git a/offload/unittests/OffloadAPI/queue/olDestroyQueue.cpp b/offload/unittests/OffloadAPI/queue/olDestroyQueue.cpp
new file mode 100644
index 000000000000..b54694e0c798
--- /dev/null
+++ b/offload/unittests/OffloadAPI/queue/olDestroyQueue.cpp
@@ -0,0 +1,22 @@
+//===------- Offload API tests - olDestroyQueue ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+using olDestroyQueueTest = OffloadQueueTest;
+
+TEST_F(olDestroyQueueTest, Success) {
+ ASSERT_SUCCESS(olDestroyQueue(Queue));
+ Queue = nullptr;
+}
+
+TEST_F(olDestroyQueueTest, InvalidNullHandle) {
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olDestroyQueue(nullptr));
+}
diff --git a/offload/unittests/OffloadAPI/queue/olWaitQueue.cpp b/offload/unittests/OffloadAPI/queue/olWaitQueue.cpp
new file mode 100644
index 000000000000..07ef774583ae
--- /dev/null
+++ b/offload/unittests/OffloadAPI/queue/olWaitQueue.cpp
@@ -0,0 +1,17 @@
+//===------- Offload API tests - olWaitQueue ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+using olWaitQueueTest = OffloadQueueTest;
+
+TEST_F(olWaitQueueTest, SuccessEmptyQueue) {
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+}