summaryrefslogtreecommitdiff
path: root/offload/plugins-nextgen/common/src/RPC.cpp
blob: faa2cbd4f02fe154cf5a8c376940ba86150e10d8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//===- RPC.h - Interface for remote procedure calls from the GPU ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "RPC.h"

#include "Shared/Debug.h"

#include "PluginInterface.h"

#if defined(LIBOMPTARGET_RPC_SUPPORT)
#include "llvm-libc-types/rpc_opcodes_t.h"
#include "llvmlibc_rpc_server.h"
#endif

using namespace llvm;
using namespace omp;
using namespace target;

RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
    : Handles(Plugin.getNumDevices()) {}

llvm::Expected<bool>
RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
                              plugin::GenericGlobalHandlerTy &Handler,
                              plugin::DeviceImageTy &Image) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
  return Handler.isSymbolInImage(Device, Image, rpc_client_symbol_name);
#else
  return false;
#endif
}

Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
                              plugin::GenericGlobalHandlerTy &Handler,
                              plugin::DeviceImageTy &Image) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
  auto Alloc = [](uint64_t Size, void *Data) {
    plugin::GenericDeviceTy &Device =
        *reinterpret_cast<plugin::GenericDeviceTy *>(Data);
    return Device.allocate(Size, nullptr, TARGET_ALLOC_HOST);
  };
  uint64_t NumPorts =
      std::min(Device.requestedRPCPortCount(), RPC_MAXIMUM_PORT_COUNT);
  rpc_device_t RPCDevice;
  if (rpc_status_t Err = rpc_server_init(&RPCDevice, NumPorts,
                                         Device.getWarpSize(), Alloc, &Device))
    return plugin::Plugin::error(
        "Failed to initialize RPC server for device %d: %d",
        Device.getDeviceId(), Err);

  // Register a custom opcode handler to perform plugin specific allocation.
  auto MallocHandler = [](rpc_port_t Port, void *Data) {
    rpc_recv_and_send(
        Port,
        [](rpc_buffer_t *Buffer, void *Data) {
          plugin::GenericDeviceTy &Device =
              *reinterpret_cast<plugin::GenericDeviceTy *>(Data);
          Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate(
              Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING));
        },
        Data);
  };
  if (rpc_status_t Err =
          rpc_register_callback(RPCDevice, RPC_MALLOC, MallocHandler, &Device))
    return plugin::Plugin::error(
        "Failed to register RPC malloc handler for device %d: %d\n",
        Device.getDeviceId(), Err);

  // Register a custom opcode handler to perform plugin specific deallocation.
  auto FreeHandler = [](rpc_port_t Port, void *Data) {
    rpc_recv(
        Port,
        [](rpc_buffer_t *Buffer, void *Data) {
          plugin::GenericDeviceTy &Device =
              *reinterpret_cast<plugin::GenericDeviceTy *>(Data);
          Device.free(reinterpret_cast<void *>(Buffer->data[0]),
                      TARGET_ALLOC_DEVICE_NON_BLOCKING);
        },
        Data);
  };
  if (rpc_status_t Err =
          rpc_register_callback(RPCDevice, RPC_FREE, FreeHandler, &Device))
    return plugin::Plugin::error(
        "Failed to register RPC free handler for device %d: %d\n",
        Device.getDeviceId(), Err);

  // Get the address of the RPC client from the device.
  void *ClientPtr;
  plugin::GlobalTy ClientGlobal(rpc_client_symbol_name, sizeof(void *));
  if (auto Err =
          Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal))
    return Err;

  if (auto Err = Device.dataRetrieve(&ClientPtr, ClientGlobal.getPtr(),
                                     sizeof(void *), nullptr))
    return Err;

  const void *ClientBuffer = rpc_get_client_buffer(RPCDevice);
  if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer,
                                   rpc_get_client_size(), nullptr))
    return Err;
  Handles[Device.getDeviceId()] = RPCDevice.handle;
#endif
  return Error::success();
}

Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
  rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
  if (rpc_status_t Err = rpc_handle_server(RPCDevice))
    return plugin::Plugin::error(
        "Error while running RPC server on device %d: %d", Device.getDeviceId(),
        Err);
#endif
  return Error::success();
}

Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
  rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
  auto Dealloc = [](void *Ptr, void *Data) {
    plugin::GenericDeviceTy &Device =
        *reinterpret_cast<plugin::GenericDeviceTy *>(Data);
    Device.free(Ptr, TARGET_ALLOC_HOST);
  };
  if (rpc_status_t Err = rpc_server_shutdown(RPCDevice, Dealloc, &Device))
    return plugin::Plugin::error(
        "Failed to shut down RPC server for device %d: %d",
        Device.getDeviceId(), Err);
#endif
  return Error::success();
}