llama.cpp/tools/cli/mcp.hpp

570 lines
19 KiB
C++

#pragma once
#include "../../vendor/sheredom/subprocess.h"
#include "log.h"
#include <atomic>
#include <fstream>
#include <functional>
#include <future>
#include <iostream>
#include <map>
#include <mutex>
#include <nlohmann/json.hpp>
#include <sstream>
#include <string>
#include <thread>
#include <vector>
using json = nlohmann::ordered_json;
struct mcp_tool {
std::string name;
std::string description;
json input_schema;
std::string server_name;
};
// Low-level transport layer for subprocess-based MCP communication.
// Handles subprocess lifecycle, raw I/O, and line-based message framing.
class mcp_transport {
std::string name;
std::string command;
std::vector<std::string> args;
std::map<std::string, std::string> env;
struct subprocess_s process = {};
std::atomic<bool> running{ false };
std::thread read_thread;
std::thread err_thread;
std::atomic<bool> stop_read{ false };
std::string read_buffer;
std::function<void(const std::string &)> line_handler;
public:
mcp_transport(const std::string & name,
const std::string & cmd,
const std::vector<std::string> & args,
const std::map<std::string, std::string> & env) :
name(name),
command(cmd),
args(args),
env(env) {}
mcp_transport(const mcp_transport &) = delete;
mcp_transport & operator=(const mcp_transport &) = delete;
mcp_transport(mcp_transport && other) noexcept :
name(std::move(other.name)),
command(std::move(other.command)),
args(std::move(other.args)),
env(std::move(other.env)),
process(other.process),
running(other.running.load()),
read_thread(std::move(other.read_thread)),
err_thread(std::move(other.err_thread)),
stop_read(other.stop_read.load()),
read_buffer(std::move(other.read_buffer)),
line_handler(std::move(other.line_handler)) {
other.process = {};
other.running = false;
}
mcp_transport & operator=(mcp_transport && other) noexcept {
if (this != &other) {
stop();
name = std::move(other.name);
command = std::move(other.command);
args = std::move(other.args);
env = std::move(other.env);
process = other.process;
running = other.running.load();
read_thread = std::move(other.read_thread);
err_thread = std::move(other.err_thread);
stop_read = other.stop_read.load();
read_buffer = std::move(other.read_buffer);
line_handler = std::move(other.line_handler);
other.process = {};
other.running = false;
}
return *this;
}
~mcp_transport() { stop(); }
void set_line_handler(std::function<void(const std::string &)> handler) {
line_handler = std::move(handler);
}
bool is_running() const { return running; }
bool start() {
std::vector<const char *> cmd_args;
cmd_args.push_back(command.c_str());
for (const auto & arg : args) {
cmd_args.push_back(arg.c_str());
}
cmd_args.push_back(nullptr);
std::vector<const char *> env_vars;
std::vector<std::string> env_strings; // keep strings alive
if (!env.empty()) {
for (const auto & kv : env) {
env_strings.push_back(kv.first + "=" + kv.second);
}
for (const auto & s : env_strings) {
env_vars.push_back(s.c_str());
}
env_vars.push_back(nullptr);
}
int options = subprocess_option_search_user_path;
int result;
if (env.empty()) {
options |= subprocess_option_inherit_environment;
result = subprocess_create(cmd_args.data(), options, &process);
} else {
result = subprocess_create_ex(cmd_args.data(), options, env_vars.data(), &process);
}
if (result != 0) {
LOG_ERR("Failed to start MCP server %s: error %d (%s)\n", name.c_str(), errno, strerror(errno));
return false;
}
running = true;
read_thread = std::thread(&mcp_transport::read_loop, this);
err_thread = std::thread(&mcp_transport::err_loop, this);
return true;
}
void stop() {
if (!running) {
return;
}
LOG_INF("Stopping MCP server %s...\n", name.c_str());
stop_read = true;
// 1. Close stdin to signal EOF
if (process.stdin_file) {
fclose(process.stdin_file);
process.stdin_file = nullptr;
}
// 2. Wait for 10 seconds for normal termination
bool terminated = false;
for (int i = 0; i < 100; ++i) { // 100 * 100ms = 10s
if (subprocess_alive(&process) == 0) {
terminated = true;
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
// 3. Terminate if still running
if (!terminated) {
LOG_WRN("MCP server %s did not exit gracefully, terminating...\n", name.c_str());
subprocess_terminate(&process);
}
// 4. Join threads
if (read_thread.joinable()) {
read_thread.join();
}
if (err_thread.joinable()) {
err_thread.join();
}
// 5. Cleanup
if (running) {
subprocess_destroy(&process);
running = false;
}
LOG_INF("MCP server %s stopped.\n", name.c_str());
}
bool send(const std::string & data) {
FILE * f = subprocess_stdin(&process);
if (!f) {
return false;
}
fwrite(data.c_str(), 1, data.size(), f);
fflush(f);
return true;
}
private:
void read_loop() {
LOG_DBG("MCP read_loop started for %s\n", name.c_str());
char buffer[4096];
while (!stop_read && running) {
unsigned bytes_read = subprocess_read_stdout(&process, buffer, sizeof(buffer));
if (bytes_read == 0) {
// If blocking read returns 0, it means EOF (process exited or pipe closed).
// We should NOT call subprocess_alive() here because it calls subprocess_join()
// which modifies the process struct (closes stdin) and causes race conditions/double-free
// when stop() is called concurrently.
// Just break the loop. The process is likely dead or dying.
if (!stop_read) {
LOG_ERR("MCP server %s: read_loop exiting (stdout closed/EOF)\n", name.c_str());
}
running = false;
break;
}
read_buffer.append(buffer, bytes_read);
size_t pos;
while ((pos = read_buffer.find('\n')) != std::string::npos) {
std::string line = read_buffer.substr(0, pos);
read_buffer.erase(0, pos + 1);
if (line.empty()) {
continue;
}
if (line_handler) {
line_handler(line);
}
}
}
LOG_DBG("MCP read_loop exiting for %s (stop_read=%d, running=%d)\n", name.c_str(), stop_read.load(),
running.load());
}
void err_loop() {
char buffer[1024];
while (!stop_read && running) {
unsigned bytes_read = subprocess_read_stderr(&process, buffer, sizeof(buffer));
if (bytes_read > 0) {
if (stop_read) {
break; // Don't log stderr during shutdown
}
std::string err_str(buffer, bytes_read);
LOG_WRN("[%s stderr] %s", name.c_str(), err_str.c_str());
} else {
// EOF
break;
}
}
}
};
// JSON-RPC layer for MCP server communication.
// Handles request/response correlation, message formatting, and MCP protocol methods.
class mcp_server {
std::string name;
mcp_transport transport;
std::mutex mutex;
int next_id = 1;
std::map<int, std::promise<json>> pending_requests;
public:
mcp_server(const std::string & name,
const std::string & cmd,
const std::vector<std::string> & args,
const std::map<std::string, std::string> & env) :
name(name),
transport(name, cmd, args, env) {}
mcp_server(const mcp_server &) = delete;
mcp_server & operator=(const mcp_server &) = delete;
mcp_server(mcp_server && other) noexcept :
name(std::move(other.name)),
transport(std::move(other.transport)),
next_id(other.next_id),
pending_requests(std::move(other.pending_requests)) {}
mcp_server & operator=(mcp_server && other) noexcept {
if (this != &other) {
stop();
name = std::move(other.name);
transport = std::move(other.transport);
next_id = other.next_id;
pending_requests = std::move(other.pending_requests);
}
return *this;
}
~mcp_server() { stop(); }
bool start() {
transport.set_line_handler([this](const std::string & line) {
handle_line(line);
});
return transport.start();
}
void stop() { transport.stop(); }
json send_request(const std::string & method, const json & params = json::object()) {
if (!transport.is_running()) {
LOG_ERR("Cannot send request to %s: server not running\n", name.c_str());
return nullptr;
}
int id;
std::future<json> future;
{
std::lock_guard<std::mutex> lock(mutex);
id = next_id++;
future = pending_requests[id].get_future();
}
json req = {
{ "jsonrpc", "2.0" },
{ "id", id },
{ "method", method },
{ "params", params }
};
std::string req_str = req.dump() + "\n";
LOG_DBG("MCP request to %s [id=%d]: %s\n", name.c_str(), id, method.c_str());
if (!transport.send(req_str)) {
LOG_ERR("Cannot send request to %s: stdin is null\n", name.c_str());
std::lock_guard<std::mutex> lock(mutex);
pending_requests.erase(id);
return nullptr;
}
// Wait for response with timeout
if (future.wait_for(std::chrono::seconds(30)) == std::future_status::timeout) {
LOG_ERR("Timeout waiting for response from %s (method: %s, id: %d)\n", name.c_str(), method.c_str(), id);
std::lock_guard<std::mutex> lock(mutex);
pending_requests.erase(id);
return nullptr;
}
return future.get();
}
void send_notification(const std::string & method, const json & params = json::object()) {
json req = {
{ "jsonrpc", "2.0" },
{ "method", method },
{ "params", params }
};
transport.send(req.dump() + "\n");
}
// Initialize handshake
bool initialize() {
// Send initialize
json init_params = {
{ "protocolVersion", "2024-11-05" },
{ "capabilities", { { "roots", { { "listChanged", false } } }, { "sampling", json::object() } } },
{ "clientInfo",
{
{ "name", "llama.cpp-cli" }, { "version", "0.1.0" } // TODO: use real version
} }
};
json res = send_request("initialize", init_params);
if (res.is_null() || res.contains("error")) {
LOG_ERR("Failed to initialize MCP server %s\n", name.c_str());
return false;
}
// Send initialized notification
send_notification("notifications/initialized");
return true;
}
std::vector<mcp_tool> list_tools() {
std::vector<mcp_tool> tools;
json res = send_request("tools/list");
if (res.is_null() || res.contains("error")) {
LOG_ERR("Failed to list tools from %s\n", name.c_str());
return tools;
}
if (res.contains("result") && res["result"].contains("tools")) {
for (const auto & t : res["result"]["tools"]) {
mcp_tool tool;
tool.name = t["name"].get<std::string>();
if (t.contains("description")) {
tool.description = t["description"].get<std::string>();
}
if (t.contains("inputSchema")) {
tool.input_schema = t["inputSchema"];
}
tool.server_name = name;
tools.push_back(tool);
}
}
return tools;
}
json call_tool(const std::string & tool_name, const json & args) {
json params = {
{ "name", tool_name },
{ "arguments", args }
};
json res = send_request("tools/call", params);
if (res.is_null() || res.contains("error")) {
return {
{ "error", res.contains("error") ? res["error"] : "Unknown error" }
};
}
if (res.contains("result")) {
return res["result"];
}
return {
{ "error", "No result returned" }
};
}
private:
void handle_line(const std::string & line) {
try {
json msg = json::parse(line);
if (msg.contains("id") && !msg["id"].is_null()) {
// Response - handle both int and string IDs (JSON-RPC allows both)
int id;
if (msg["id"].is_string()) {
id = std::stoi(msg["id"].get<std::string>());
} else {
id = msg["id"].get<int>();
}
std::lock_guard<std::mutex> lock(mutex);
if (pending_requests.count(id)) {
LOG_DBG("MCP response received from %s [id=%d]\n", name.c_str(), id);
pending_requests[id].set_value(msg);
pending_requests.erase(id);
} else {
LOG_WRN("MCP response for unknown id %d from %s: %s\n", id, name.c_str(), line.c_str());
}
} else {
// Notification or request from server -> ignore for now or log
// MCP servers might send notifications (e.g. logging)
LOG_INF("MCP Notification from %s: %s\n", name.c_str(), line.c_str());
}
} catch (const std::exception & e) {
// Not a full JSON yet? Or invalid?
// If it was a line, it should be valid JSON-RPC
LOG_WRN("Failed to parse JSON from %s: %s (line: %s)\n", name.c_str(), e.what(), line.c_str());
}
}
};
class mcp_context {
std::map<std::string, std::shared_ptr<mcp_server>> servers;
std::vector<mcp_tool> tools;
bool yolo = false;
public:
void set_yolo(bool y) { yolo = y; }
bool load_config(const std::string & config_path, const std::string & enabled_servers_str) {
std::ifstream f(config_path);
if (!f) {
return false;
}
json config;
try {
f >> config;
} catch (...) {
return false;
}
std::vector<std::string> enabled_list;
std::stringstream ss(enabled_servers_str);
std::string item;
while (std::getline(ss, item, ',')) {
if (!item.empty()) {
enabled_list.push_back(item);
}
}
if (config.contains("mcpServers")) {
std::string server_list;
for (auto & [key, val] : config["mcpServers"].items()) {
if (!server_list.empty()) {
server_list += ", ";
}
server_list += key;
}
LOG_INF("MCP configuration found with servers: %s\n", server_list.c_str());
for (auto & [key, val] : config["mcpServers"].items()) {
bool enabled = true;
if (!enabled_list.empty()) {
bool found = false;
for (const auto & s : enabled_list) {
if (s == key) {
found = true;
}
}
if (!found) {
enabled = false;
}
}
if (enabled) {
std::string cmd = val["command"].get<std::string>();
std::vector<std::string> args = val.value("args", std::vector<std::string>{});
std::map<std::string, std::string> env;
if (val.contains("env")) {
for (auto & [ek, ev] : val["env"].items()) {
env[ek] = ev.get<std::string>();
}
}
auto server = std::make_shared<mcp_server>(key, cmd, args, env);
LOG_INF("Trying to start MCP server: %s...\n", key.c_str());
if (server->start()) {
if (server->initialize()) {
servers[key] = server;
LOG_INF("MCP Server '%s' started and initialized.\n", key.c_str());
auto server_tools = server->list_tools();
tools.insert(tools.end(), server_tools.begin(), server_tools.end());
} else {
LOG_ERR("MCP Server '%s' failed to initialize.\n", key.c_str());
}
}
}
}
}
return true;
}
std::vector<mcp_tool> get_tools() const { return tools; }
bool get_yolo() const { return yolo; }
json call_tool(const std::string & tool_name, const json & args) {
// Find which server has this tool
std::string server_name;
for (const auto & t : tools) {
if (t.name == tool_name) {
server_name = t.server_name;
break;
}
}
if (server_name.empty()) {
return {
{ "error", "Tool not found" }
};
}
if (servers.count(server_name)) {
return servers[server_name]->call_tool(tool_name, args);
}
return {
{ "error", "Server not found" }
};
}
};