diff options
| author | Mingming Liu <mingmingl@google.com> | 2025-09-10 15:25:31 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-10 15:25:31 -0700 |
| commit | 1417dafa1db9cb1b2b09438aa9f53ea5ab6e36e2 (patch) | |
| tree | 57f4b1f313c8cf74eed8819870f39c36ea263c68 /lldb/source/Plugins/Protocol/MCP | |
| parent | 898b813bc8a6d0276bf0f4769f5f2f64b34e632d (diff) | |
| parent | b8cefcb601ddaa18482555c4ff363c01a270c2fe (diff) | |
Merge branch 'main' into users/mingmingl-llvm/samplefdo-profile-formatusers/mingmingl-llvm/samplefdo-profile-format
Diffstat (limited to 'lldb/source/Plugins/Protocol/MCP')
| -rw-r--r-- | lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp | 129 | ||||
| -rw-r--r-- | lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h | 24 | ||||
| -rw-r--r-- | lldb/source/Plugins/Protocol/MCP/Resource.cpp | 15 | ||||
| -rw-r--r-- | lldb/source/Plugins/Protocol/MCP/Resource.h | 15 | ||||
| -rw-r--r-- | lldb/source/Plugins/Protocol/MCP/Tool.cpp | 12 | ||||
| -rw-r--r-- | lldb/source/Plugins/Protocol/MCP/Tool.h | 8 |
6 files changed, 86 insertions, 117 deletions
diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index c359663239dc..dc18c8e06803 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -10,14 +10,15 @@ #include "Resource.h" #include "Tool.h" #include "lldb/Core/PluginManager.h" -#include "lldb/Protocol/MCP/MCPError.h" -#include "lldb/Protocol/MCP/Tool.h" +#include "lldb/Host/FileSystem.h" +#include "lldb/Host/HostInfo.h" +#include "lldb/Protocol/MCP/Server.h" #include "lldb/Utility/LLDBLog.h" #include "lldb/Utility/Log.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Error.h" #include "llvm/Support/Threading.h" #include <thread> -#include <variant> using namespace lldb_private; using namespace lldb_private::mcp; @@ -26,24 +27,10 @@ using namespace llvm; LLDB_PLUGIN_DEFINE(ProtocolServerMCP) -static constexpr size_t kChunkSize = 1024; static constexpr llvm::StringLiteral kName = "lldb-mcp"; static constexpr llvm::StringLiteral kVersion = "0.1.0"; -ProtocolServerMCP::ProtocolServerMCP() - : ProtocolServer(), - lldb_protocol::mcp::Server(std::string(kName), std::string(kVersion)) { - AddNotificationHandler("notifications/initialized", - [](const lldb_protocol::mcp::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), - "MCP initialization complete"); - }); - - AddTool( - std::make_unique<CommandTool>("lldb_command", "Run an lldb command.")); - - AddResourceProvider(std::make_unique<DebuggerResourceProvider>()); -} +ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {} ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } @@ -53,6 +40,8 @@ void ProtocolServerMCP::Initialize() { } void ProtocolServerMCP::Terminate() { + if (llvm::Error error = ProtocolServer::Terminate()) + LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); PluginManager::UnregisterPlugin(CreateInstance); } @@ -64,57 +53,37 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { return "MCP Server."; } +void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { + server.AddNotificationHandler("notifications/initialized", + [](const lldb_protocol::mcp::Notification &) { + LLDB_LOG(GetLog(LLDBLog::Host), + "MCP initialization complete"); + }); + server.AddTool( + std::make_unique<CommandTool>("lldb_command", "Run an lldb command.")); + server.AddResourceProvider(std::make_unique<DebuggerResourceProvider>()); +} + void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) { - LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", - m_clients.size() + 1); + Log *log = GetLog(LLDBLog::Host); + std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1); + LLDB_LOG(log, "New MCP client connected: {0}", client_name); lldb::IOObjectSP io_sp = std::move(socket); - auto client_up = std::make_unique<Client>(); - client_up->io_sp = io_sp; - Client *client = client_up.get(); - - Status status; - auto read_handle_up = m_loop.RegisterReadObject( - io_sp, - [this, client](MainLoopBase &loop) { - if (llvm::Error error = ReadCallback(*client)) { - LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); - client->read_handle_up.reset(); - } - }, - status); - if (status.Fail()) + auto transport_up = std::make_unique<lldb_protocol::mcp::Transport>( + io_sp, io_sp, [client_name](llvm::StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message); + }); + auto instance_up = std::make_unique<lldb_protocol::mcp::Server>( + std::string(kName), std::string(kVersion), std::move(transport_up), + m_loop); + Extend(*instance_up); + llvm::Error error = instance_up->Run(); + if (error) { + LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}"); return; - - client_up->read_handle_up = std::move(read_handle_up); - m_clients.emplace_back(std::move(client_up)); -} - -llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { - char chunk[kChunkSize]; - size_t bytes_read = sizeof(chunk); - if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail()) - return status.takeError(); - client.buffer.append(chunk, bytes_read); - - for (std::string::size_type pos; - (pos = client.buffer.find('\n')) != std::string::npos;) { - llvm::Expected<std::optional<lldb_protocol::mcp::Message>> message = - HandleData(StringRef(client.buffer.data(), pos)); - client.buffer = client.buffer.erase(0, pos + 1); - if (!message) - return message.takeError(); - - if (*message) { - std::string Output; - llvm::raw_string_ostream OS(Output); - OS << llvm::formatv("{0}", toJSON(**message)) << '\n'; - size_t num_bytes = Output.size(); - return client.io_sp->Write(Output.data(), num_bytes).takeError(); - } } - - return llvm::Error::success(); + m_instances.push_back(std::move(instance_up)); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { @@ -138,7 +107,19 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { if (llvm::Error error = handles.takeError()) return error; + auto listening_uris = m_listener->GetListeningConnectionURI(); + if (listening_uris.empty()) + return createStringError("failed to get listening connections"); + std::string address = + llvm::join(m_listener->GetListeningConnectionURI(), ", "); + + ServerInfo info{listening_uris[0]}; + llvm::Expected<ServerInfoHandle> handle = ServerInfo::Write(info); + if (!handle) + return handle.takeError(); + m_running = true; + m_server_info_handle = std::move(*handle); m_listen_handlers = std::move(*handles); m_loop_thread = std::thread([=] { llvm::set_thread_name("protocol-server.mcp"); @@ -158,27 +139,15 @@ llvm::Error ProtocolServerMCP::Stop() { // Stop the main loop. m_loop.AddPendingCallback( - [](MainLoopBase &loop) { loop.RequestTermination(); }); + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); // Wait for the main loop to exit. if (m_loop_thread.joinable()) m_loop_thread.join(); - { - std::lock_guard<std::mutex> guard(m_mutex); - m_listener.reset(); - m_listen_handlers.clear(); - m_clients.clear(); - } + m_listen_handlers.clear(); + m_server_info_handle = ServerInfoHandle(); + m_instances.clear(); return llvm::Error::success(); } - -lldb_protocol::mcp::Capabilities ProtocolServerMCP::GetCapabilities() { - lldb_protocol::mcp::Capabilities capabilities; - capabilities.tools.listChanged = true; - // FIXME: Support sending notifications when a debugger/target are - // added/removed. - capabilities.resources.listChanged = false; - return capabilities; -} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 7fe909a728b8..0251664a2acc 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -18,8 +18,7 @@ namespace lldb_private::mcp { -class ProtocolServerMCP : public ProtocolServer, - public lldb_protocol::mcp::Server { +class ProtocolServerMCP : public ProtocolServer { public: ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; @@ -39,26 +38,25 @@ public: Socket *GetSocket() const override { return m_listener.get(); } +protected: + // This adds tools and resource providers that + // are specific to this server. Overridable by the unit tests. + virtual void Extend(lldb_protocol::mcp::Server &server) const; + private: void AcceptCallback(std::unique_ptr<Socket> socket); - lldb_protocol::mcp::Capabilities GetCapabilities() override; - bool m_running = false; - MainLoop m_loop; + lldb_protocol::mcp::ServerInfoHandle m_server_info_handle; + lldb_private::MainLoop m_loop; std::thread m_loop_thread; + std::mutex m_mutex; std::unique_ptr<Socket> m_listener; - std::vector<MainLoopBase::ReadHandleUP> m_listen_handlers; - struct Client { - lldb::IOObjectSP io_sp; - MainLoopBase::ReadHandleUP read_handle_up; - std::string buffer; - }; - llvm::Error ReadCallback(Client &client); - std::vector<std::unique_ptr<Client>> m_clients; + std::vector<MainLoopBase::ReadHandleUP> m_listen_handlers; + std::vector<std::unique_ptr<lldb_protocol::mcp::Server>> m_instances; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.cpp b/lldb/source/Plugins/Protocol/MCP/Resource.cpp index e94d2cdd65e0..581424510d4c 100644 --- a/lldb/source/Plugins/Protocol/MCP/Resource.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Resource.cpp @@ -8,7 +8,6 @@ #include "lldb/Core/Debugger.h" #include "lldb/Core/Module.h" #include "lldb/Protocol/MCP/MCPError.h" -#include "lldb/Target/Platform.h" using namespace lldb_private; using namespace lldb_private::mcp; @@ -124,7 +123,7 @@ DebuggerResourceProvider::GetResources() const { return resources; } -llvm::Expected<lldb_protocol::mcp::ResourceResult> +llvm::Expected<lldb_protocol::mcp::ReadResourceResult> DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { auto [protocol, path] = uri.split("://"); @@ -161,7 +160,7 @@ DebuggerResourceProvider::ReadResource(llvm::StringRef uri) const { return ReadDebuggerResource(uri, debugger_idx); } -llvm::Expected<lldb_protocol::mcp::ResourceResult> +llvm::Expected<lldb_protocol::mcp::ReadResourceResult> DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id) { lldb::DebuggerSP debugger_sp = Debugger::FindDebuggerWithID(debugger_id); @@ -173,17 +172,17 @@ DebuggerResourceProvider::ReadDebuggerResource(llvm::StringRef uri, debugger_resource.name = debugger_sp->GetInstanceName(); debugger_resource.num_targets = debugger_sp->GetTargetList().GetNumTargets(); - lldb_protocol::mcp::ResourceContents contents; + lldb_protocol::mcp::TextResourceContents contents; contents.uri = uri; contents.mimeType = kMimeTypeJSON; contents.text = llvm::formatv("{0}", toJSON(debugger_resource)); - lldb_protocol::mcp::ResourceResult result; + lldb_protocol::mcp::ReadResourceResult result; result.contents.push_back(contents); return result; } -llvm::Expected<lldb_protocol::mcp::ResourceResult> +llvm::Expected<lldb_protocol::mcp::ReadResourceResult> DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, size_t target_idx) { @@ -209,12 +208,12 @@ DebuggerResourceProvider::ReadTargetResource(llvm::StringRef uri, if (lldb::PlatformSP platform_sp = target_sp->GetPlatform()) target_resource.platform = platform_sp->GetName(); - lldb_protocol::mcp::ResourceContents contents; + lldb_protocol::mcp::TextResourceContents contents; contents.uri = uri; contents.mimeType = kMimeTypeJSON; contents.text = llvm::formatv("{0}", toJSON(target_resource)); - lldb_protocol::mcp::ResourceResult result; + lldb_protocol::mcp::ReadResourceResult result; result.contents.push_back(contents); return result; } diff --git a/lldb/source/Plugins/Protocol/MCP/Resource.h b/lldb/source/Plugins/Protocol/MCP/Resource.h index e2382a74f796..0c6576602905 100644 --- a/lldb/source/Plugins/Protocol/MCP/Resource.h +++ b/lldb/source/Plugins/Protocol/MCP/Resource.h @@ -11,7 +11,11 @@ #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" -#include "lldb/lldb-private.h" +#include "lldb/lldb-forward.h" +#include "lldb/lldb-types.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include <cstddef> #include <vector> namespace lldb_private::mcp { @@ -21,9 +25,8 @@ public: using ResourceProvider::ResourceProvider; virtual ~DebuggerResourceProvider() = default; - virtual std::vector<lldb_protocol::mcp::Resource> - GetResources() const override; - virtual llvm::Expected<lldb_protocol::mcp::ResourceResult> + std::vector<lldb_protocol::mcp::Resource> GetResources() const override; + llvm::Expected<lldb_protocol::mcp::ReadResourceResult> ReadResource(llvm::StringRef uri) const override; private: @@ -31,9 +34,9 @@ private: static lldb_protocol::mcp::Resource GetTargetResource(size_t target_idx, Target &target); - static llvm::Expected<lldb_protocol::mcp::ResourceResult> + static llvm::Expected<lldb_protocol::mcp::ReadResourceResult> ReadDebuggerResource(llvm::StringRef uri, lldb::user_id_t debugger_id); - static llvm::Expected<lldb_protocol::mcp::ResourceResult> + static llvm::Expected<lldb_protocol::mcp::ReadResourceResult> ReadTargetResource(llvm::StringRef uri, lldb::user_id_t debugger_id, size_t target_idx); }; diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index 143470702a6f..2f451bf76e81 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -7,9 +7,9 @@ //===----------------------------------------------------------------------===// #include "Tool.h" -#include "lldb/Core/Module.h" #include "lldb/Interpreter/CommandInterpreter.h" #include "lldb/Interpreter/CommandReturnObject.h" +#include "lldb/Protocol/MCP/Protocol.h" using namespace lldb_private; using namespace lldb_protocol; @@ -29,10 +29,10 @@ bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A, O.mapOptional("arguments", A.arguments); } -/// Helper function to create a TextResult from a string output. -static lldb_protocol::mcp::TextResult createTextResult(std::string output, - bool is_error = false) { - lldb_protocol::mcp::TextResult text_result; +/// Helper function to create a CallToolResult from a string output. +static lldb_protocol::mcp::CallToolResult +createTextResult(std::string output, bool is_error = false) { + lldb_protocol::mcp::CallToolResult text_result; text_result.content.emplace_back( lldb_protocol::mcp::TextContent{{std::move(output)}}); text_result.isError = is_error; @@ -41,7 +41,7 @@ static lldb_protocol::mcp::TextResult createTextResult(std::string output, } // namespace -llvm::Expected<lldb_protocol::mcp::TextResult> +llvm::Expected<lldb_protocol::mcp::CallToolResult> CommandTool::Call(const lldb_protocol::mcp::ToolArguments &args) { if (!std::holds_alternative<json::Value>(args)) return createStringError("CommandTool requires arguments"); diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h index b7b1756eb38d..1886525b9168 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.h +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -9,11 +9,11 @@ #ifndef LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H #define LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H -#include "lldb/Core/Debugger.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Tool.h" +#include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" -#include <string> +#include <optional> namespace lldb_private::mcp { @@ -22,10 +22,10 @@ public: using lldb_protocol::mcp::Tool::Tool; ~CommandTool() = default; - virtual llvm::Expected<lldb_protocol::mcp::TextResult> + llvm::Expected<lldb_protocol::mcp::CallToolResult> Call(const lldb_protocol::mcp::ToolArguments &args) override; - virtual std::optional<llvm::json::Value> GetSchema() const override; + std::optional<llvm::json::Value> GetSchema() const override; }; } // namespace lldb_private::mcp |
