summaryrefslogtreecommitdiff
path: root/lldb/source/Protocol/MCP/Server.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lldb/source/Protocol/MCP/Server.cpp')
-rw-r--r--lldb/source/Protocol/MCP/Server.cpp291
1 files changed, 203 insertions, 88 deletions
diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp
index a9c1482e3e37..f3489c620832 100644
--- a/lldb/source/Protocol/MCP/Server.cpp
+++ b/lldb/source/Protocol/MCP/Server.cpp
@@ -7,13 +7,115 @@
//===----------------------------------------------------------------------===//
#include "lldb/Protocol/MCP/Server.h"
+#include "lldb/Host/File.h"
+#include "lldb/Host/FileSystem.h"
+#include "lldb/Host/HostInfo.h"
+#include "lldb/Host/JSONTransport.h"
#include "lldb/Protocol/MCP/MCPError.h"
+#include "lldb/Protocol/MCP/Protocol.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/JSON.h"
+#include "llvm/Support/Signals.h"
-using namespace lldb_protocol::mcp;
using namespace llvm;
+using namespace lldb_private;
+using namespace lldb_protocol::mcp;
+
+ServerInfoHandle::ServerInfoHandle() : ServerInfoHandle("") {}
+
+ServerInfoHandle::ServerInfoHandle(StringRef filename) : m_filename(filename) {
+ if (!m_filename.empty())
+ sys::RemoveFileOnSignal(m_filename);
+}
+
+ServerInfoHandle::~ServerInfoHandle() {
+ if (m_filename.empty())
+ return;
+
+ sys::fs::remove(m_filename);
+ sys::DontRemoveFileOnSignal(m_filename);
+ m_filename.clear();
+}
+
+ServerInfoHandle::ServerInfoHandle(ServerInfoHandle &&other)
+ : m_filename(other.m_filename) {
+ *this = std::move(other);
+}
+
+ServerInfoHandle &
+ServerInfoHandle::operator=(ServerInfoHandle &&other) noexcept {
+ m_filename = other.m_filename;
+ other.m_filename.clear();
+ return *this;
+}
+
+json::Value lldb_protocol::mcp::toJSON(const ServerInfo &SM) {
+ return json::Object{{"connection_uri", SM.connection_uri}};
+}
+
+bool lldb_protocol::mcp::fromJSON(const json::Value &V, ServerInfo &SM,
+ json::Path P) {
+ json::ObjectMapper O(V, P);
+ return O && O.map("connection_uri", SM.connection_uri);
+}
-Server::Server(std::string name, std::string version)
- : m_name(std::move(name)), m_version(std::move(version)) {
+Expected<ServerInfoHandle> ServerInfo::Write(const ServerInfo &info) {
+ std::string buf = formatv("{0}", toJSON(info)).str();
+ size_t num_bytes = buf.size();
+
+ FileSpec user_lldb_dir = HostInfo::GetUserLLDBDir();
+
+ Status error(sys::fs::create_directory(user_lldb_dir.GetPath()));
+ if (error.Fail())
+ return error.takeError();
+
+ FileSpec mcp_registry_entry_path = user_lldb_dir.CopyByAppendingPathComponent(
+ formatv("lldb-mcp-{0}.json", getpid()).str());
+
+ const File::OpenOptions flags = File::eOpenOptionWriteOnly |
+ File::eOpenOptionCanCreate |
+ File::eOpenOptionTruncate;
+ Expected<lldb::FileUP> file =
+ FileSystem::Instance().Open(mcp_registry_entry_path, flags);
+ if (!file)
+ return file.takeError();
+ if (llvm::Error error = (*file)->Write(buf.data(), num_bytes).takeError())
+ return error;
+ return ServerInfoHandle{mcp_registry_entry_path.GetPath()};
+}
+
+Expected<std::vector<ServerInfo>> ServerInfo::Load() {
+ namespace path = llvm::sys::path;
+ FileSpec user_lldb_dir = HostInfo::GetUserLLDBDir();
+ FileSystem &fs = FileSystem::Instance();
+ std::error_code EC;
+ vfs::directory_iterator it = fs.DirBegin(user_lldb_dir, EC);
+ vfs::directory_iterator end;
+ std::vector<ServerInfo> infos;
+ for (; it != end && !EC; it.increment(EC)) {
+ auto &entry = *it;
+ auto path = entry.path();
+ auto name = path::filename(path);
+ if (!name.starts_with("lldb-mcp-") || !name.ends_with(".json"))
+ continue;
+
+ auto buffer = fs.CreateDataBuffer(path);
+ auto info = json::parse<ServerInfo>(toStringRef(buffer->GetData()));
+ if (!info)
+ return info.takeError();
+
+ infos.emplace_back(std::move(*info));
+ }
+
+ return infos;
+}
+
+Server::Server(std::string name, std::string version,
+ std::unique_ptr<MCPTransport> transport_up,
+ lldb_private::MainLoop &loop)
+ : m_name(std::move(name)), m_version(std::move(version)),
+ m_transport_up(std::move(transport_up)), m_loop(loop) {
AddRequestHandlers();
}
@@ -30,7 +132,7 @@ void Server::AddRequestHandlers() {
this, std::placeholders::_1));
}
-llvm::Expected<Response> Server::Handle(Request request) {
+llvm::Expected<Response> Server::Handle(const Request &request) {
auto it = m_request_handlers.find(request.method);
if (it != m_request_handlers.end()) {
llvm::Expected<Response> response = it->second(request);
@@ -44,7 +146,7 @@ llvm::Expected<Response> Server::Handle(Request request) {
llvm::formatv("no handler for request: {0}", request.method).str());
}
-void Server::Handle(Notification notification) {
+void Server::Handle(const Notification &notification) {
auto it = m_notification_handlers.find(notification.method);
if (it != m_notification_handlers.end()) {
it->second(notification);
@@ -52,49 +154,7 @@ void Server::Handle(Notification notification) {
}
}
-llvm::Expected<std::optional<Message>>
-Server::HandleData(llvm::StringRef data) {
- auto message = llvm::json::parse<Message>(/*JSON=*/data);
- if (!message)
- return message.takeError();
-
- if (const Request *request = std::get_if<Request>(&(*message))) {
- llvm::Expected<Response> response = Handle(*request);
-
- // Handle failures by converting them into an Error message.
- if (!response) {
- Error protocol_error;
- llvm::handleAllErrors(
- response.takeError(),
- [&](const MCPError &err) { protocol_error = err.toProtocolError(); },
- [&](const llvm::ErrorInfoBase &err) {
- protocol_error.code = MCPError::kInternalError;
- protocol_error.message = err.message();
- });
- Response error_response;
- error_response.id = request->id;
- error_response.result = std::move(protocol_error);
- return error_response;
- }
-
- return *response;
- }
-
- if (const Notification *notification =
- std::get_if<Notification>(&(*message))) {
- Handle(*notification);
- return std::nullopt;
- }
-
- if (std::get_if<Response>(&(*message)))
- return llvm::createStringError("unexpected MCP message: response");
-
- llvm_unreachable("all message types handled");
-}
-
void Server::AddTool(std::unique_ptr<Tool> tool) {
- std::lock_guard<std::mutex> guard(m_mutex);
-
if (!tool)
return;
m_tools[tool->GetName()] = std::move(tool);
@@ -102,42 +162,39 @@ void Server::AddTool(std::unique_ptr<Tool> tool) {
void Server::AddResourceProvider(
std::unique_ptr<ResourceProvider> resource_provider) {
- std::lock_guard<std::mutex> guard(m_mutex);
-
if (!resource_provider)
return;
m_resource_providers.push_back(std::move(resource_provider));
}
void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) {
- std::lock_guard<std::mutex> guard(m_mutex);
m_request_handlers[method] = std::move(handler);
}
void Server::AddNotificationHandler(llvm::StringRef method,
NotificationHandler handler) {
- std::lock_guard<std::mutex> guard(m_mutex);
m_notification_handlers[method] = std::move(handler);
}
llvm::Expected<Response> Server::InitializeHandler(const Request &request) {
Response response;
- response.result = llvm::json::Object{
- {"protocolVersion", mcp::kProtocolVersion},
- {"capabilities", GetCapabilities()},
- {"serverInfo",
- llvm::json::Object{{"name", m_name}, {"version", m_version}}}};
+ InitializeResult result;
+ result.protocolVersion = mcp::kProtocolVersion;
+ result.capabilities = GetCapabilities();
+ result.serverInfo.name = m_name;
+ result.serverInfo.version = m_version;
+ response.result = std::move(result);
return response;
}
llvm::Expected<Response> Server::ToolsListHandler(const Request &request) {
Response response;
- llvm::json::Array tools;
+ ListToolsResult result;
for (const auto &tool : m_tools)
- tools.emplace_back(toJSON(tool.second->GetDefinition()));
+ result.tools.emplace_back(tool.second->GetDefinition());
- response.result = llvm::json::Object{{"tools", std::move(tools)}};
+ response.result = std::move(result);
return response;
}
@@ -147,16 +204,12 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
if (!request.params)
return llvm::createStringError("no tool parameters");
+ CallToolParams params;
+ json::Path::Root root("params");
+ if (!fromJSON(request.params, params, root))
+ return root.getError();
- const json::Object *param_obj = request.params->getAsObject();
- if (!param_obj)
- return llvm::createStringError("no tool parameters");
-
- const json::Value *name = param_obj->get("name");
- if (!name)
- return llvm::createStringError("no tool name");
-
- llvm::StringRef tool_name = name->getAsString().value_or("");
+ llvm::StringRef tool_name = params.name;
if (tool_name.empty())
return llvm::createStringError("no tool name");
@@ -165,10 +218,10 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name));
ToolArguments tool_args;
- if (const json::Value *args = param_obj->get("arguments"))
- tool_args = *args;
+ if (params.arguments)
+ tool_args = *params.arguments;
- llvm::Expected<TextResult> text_result = it->second->Call(tool_args);
+ llvm::Expected<CallToolResult> text_result = it->second->Call(tool_args);
if (!text_result)
return text_result.takeError();
@@ -180,15 +233,13 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) {
Response response;
- llvm::json::Array resources;
-
- std::lock_guard<std::mutex> guard(m_mutex);
+ ListResourcesResult result;
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
- m_resource_providers) {
+ m_resource_providers)
for (const Resource &resource : resource_provider_up->GetResources())
- resources.push_back(resource);
- }
- response.result = llvm::json::Object{{"resources", std::move(resources)}};
+ result.resources.push_back(resource);
+
+ response.result = std::move(result);
return response;
}
@@ -199,22 +250,18 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
if (!request.params)
return llvm::createStringError("no resource parameters");
- const json::Object *param_obj = request.params->getAsObject();
- if (!param_obj)
- return llvm::createStringError("no resource parameters");
-
- const json::Value *uri = param_obj->get("uri");
- if (!uri)
- return llvm::createStringError("no resource uri");
+ ReadResourceParams params;
+ json::Path::Root root("params");
+ if (!fromJSON(request.params, params, root))
+ return root.getError();
- llvm::StringRef uri_str = uri->getAsString().value_or("");
+ llvm::StringRef uri_str = params.uri;
if (uri_str.empty())
return llvm::createStringError("no resource uri");
- std::lock_guard<std::mutex> guard(m_mutex);
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
m_resource_providers) {
- llvm::Expected<ResourceResult> result =
+ llvm::Expected<ReadResourceResult> result =
resource_provider_up->ReadResource(uri_str);
if (result.errorIsA<UnsupportedURI>()) {
llvm::consumeError(result.takeError());
@@ -232,3 +279,71 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
llvm::formatv("no resource handler for uri: {0}", uri_str).str(),
MCPError::kResourceNotFound);
}
+
+ServerCapabilities Server::GetCapabilities() {
+ lldb_protocol::mcp::ServerCapabilities capabilities;
+ capabilities.supportsToolsList = true;
+ // FIXME: Support sending notifications when a debugger/target are
+ // added/removed.
+ capabilities.supportsResourcesList = false;
+ return capabilities;
+}
+
+llvm::Error Server::Run() {
+ auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this);
+ if (!handle)
+ return handle.takeError();
+
+ lldb_private::Status status = m_loop.Run();
+ if (status.Fail())
+ return status.takeError();
+
+ return llvm::Error::success();
+}
+
+void Server::Received(const Request &request) {
+ auto SendResponse = [this](const Response &response) {
+ if (llvm::Error error = m_transport_up->Send(response))
+ m_transport_up->Log(llvm::toString(std::move(error)));
+ };
+
+ llvm::Expected<Response> response = Handle(request);
+ if (response)
+ return SendResponse(*response);
+
+ lldb_protocol::mcp::Error protocol_error;
+ llvm::handleAllErrors(
+ response.takeError(),
+ [&](const MCPError &err) { protocol_error = err.toProtocolError(); },
+ [&](const llvm::ErrorInfoBase &err) {
+ protocol_error.code = MCPError::kInternalError;
+ protocol_error.message = err.message();
+ });
+ Response error_response;
+ error_response.id = request.id;
+ error_response.result = std::move(protocol_error);
+ SendResponse(error_response);
+}
+
+void Server::Received(const Response &response) {
+ m_transport_up->Log("unexpected MCP message: response");
+}
+
+void Server::Received(const Notification &notification) {
+ Handle(notification);
+}
+
+void Server::OnError(llvm::Error error) {
+ m_transport_up->Log(llvm::toString(std::move(error)));
+ TerminateLoop();
+}
+
+void Server::OnClosed() {
+ m_transport_up->Log("EOF");
+ TerminateLoop();
+}
+
+void Server::TerminateLoop() {
+ m_loop.AddPendingCallback(
+ [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
+}