diff --git a/common/arg.cpp b/common/arg.cpp index 5fbc9022c0..7b243abf76 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3784,6 +3784,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + add_opt(common_arg( + {"--mcp-config"}, "FILE", + "path to MCP configuration file", + [](common_params & params, const std::string & value) { + params.mcp_config = value; + } + ).set_examples({LLAMA_EXAMPLE_CLI})); + + add_opt(common_arg( + {"--mcp-yolo"}, + "auto-approve all MCP tool calls (no user confirmation)", + [](common_params & params) { + params.mcp_yolo = true; + } + ).set_examples({LLAMA_EXAMPLE_CLI})); + return ctx_arg; } diff --git a/common/common.h b/common/common.h index 398ebb0960..aa74464a92 100644 --- a/common/common.h +++ b/common/common.h @@ -543,6 +543,11 @@ struct common_params { std::map default_template_kwargs; + // MCP params + std::string mcp_config = ""; + std::string mcp_servers = ""; + bool mcp_yolo = false; + // webui configs bool webui = true; std::string webui_config_json; diff --git a/mcp_config.json b/mcp_config.json new file mode 100644 index 0000000000..700b12ee9f --- /dev/null +++ b/mcp_config.json @@ -0,0 +1,13 @@ +{ + "mcpServers": { + "serena": { + "command": "uvx", + "args": [ + "--from", + "git+https://github.com/oraios/serena", + "serena", + "start-mcp-server" + ] + } + } +} diff --git a/mcp_dummy.py b/mcp_dummy.py new file mode 100644 index 0000000000..32b8ff7716 --- /dev/null +++ b/mcp_dummy.py @@ -0,0 +1,83 @@ +import sys +import json +import logging + +logging.basicConfig(filename='mcp_dummy.log', level=logging.DEBUG) + + +def main(): + logging.info("Starting MCP Dummy Server") + while True: + try: + line = sys.stdin.readline() + if not line: + break + logging.info(f"Received: {line.strip()}") + try: + req = json.loads(line) + except json.JSONDecodeError: + continue + + if "method" in req: + method = req["method"] + req_id = req.get("id") + + resp = {"jsonrpc": "2.0", "id": req_id} + + if method == "initialize": + resp["result"] = { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "serverInfo": {"name": "dummy", "version": "1.0"} + } + elif method == "tools/list": + resp["result"] = { + "tools": [ + { + "name": "get_weather", + "description": "Get weather for a location", + "inputSchema": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} + }, + "required": ["location"] + } + } + ] + } + elif method == "tools/call": + params = req.get("params", {}) + name = params.get("name") + args = params.get("arguments", {}) + + logging.info(f"Tool call: {name} with {args}") + + content = [{"type": "text", "text": f"Weather in {args.get('location')} is 25C"}] + # For simplicity, return raw content or follow MCP spec? + # MCP spec: result: { content: [ {type: "text", text: "..."} ] } + # My mcp.hpp returns res["result"]. + # My cli.cpp dumps res.dump(). + # So passing full result object is fine. + resp["result"] = { + "content": content + } + else: + # Ignore notifications or other methods + if req_id is not None: + resp["error"] = {"code": -32601, "message": "Method not found"} + else: + continue + + logging.info(f"Sending: {json.dumps(resp)}") + if req_id is not None: + sys.stdout.write(json.dumps(resp) + "\n\n") + sys.stdout.flush() + except Exception as e: + logging.error(f"Error: {e}") + break + + +if __name__ == "__main__": + main() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c9436c5995..5f07075213 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,5 +1,9 @@ llama_add_compile_flags() +add_executable(test-mcp-integration test-mcp-integration.cpp) + target_link_libraries(test-mcp-integration PRIVATE common) +target_include_directories(test-mcp-integration PRIVATE ../tools/cli) + function(llama_build source) set(TEST_SOURCES ${source} ${ARGN}) diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 4378a8db71..e12805e2f3 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -69,8 +69,8 @@ static common_chat_msg normalize(const common_chat_msg & msg) { for (auto & tool_call : normalized.tool_calls) { try { tool_call.arguments = json::parse(tool_call.arguments).dump(); - } catch (const std::exception &) { - // Do nothing + } catch (const std::exception &e) { + LOG_DBG("Normalize failed on tool call: %s\n", e.what()); } } return normalized; @@ -183,7 +183,7 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha } } -common_chat_tool special_function_tool { +static common_chat_tool special_function_tool { /* .name = */ "special_function", /* .description = */ "I'm special", /* .parameters = */ R"({ @@ -197,7 +197,7 @@ common_chat_tool special_function_tool { "required": ["arg1"] })", }; -common_chat_tool special_function_tool_with_optional_param { +static common_chat_tool special_function_tool_with_optional_param { /* .name = */ "special_function_with_opt", /* .description = */ "I'm special but have optional stuff", /* .parameters = */ R"({ diff --git a/tests/test-mcp-integration.cpp b/tests/test-mcp-integration.cpp new file mode 100644 index 0000000000..fe7d189872 --- /dev/null +++ b/tests/test-mcp-integration.cpp @@ -0,0 +1,71 @@ +#include "../tools/cli/mcp.hpp" +#include "common.h" +#include "log.h" + +#include +#include +#include +#include +#include + +int main(int argc, char ** argv) { + if (argc < 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + std::string config_file = argv[1]; + + printf("Testing MCP integration with config: %s\n", config_file.c_str()); + + mcp_context mcp; + std::string tool_to_run; + json tool_args = json::object(); + std::string servers_arg = ""; + + if (argc >= 3) { + tool_to_run = argv[2]; + } + if (argc >= 4) { + try { + tool_args = json::parse(argv[3]); + } catch (const std::exception & e) { + fprintf(stderr, "Error parsing tool arguments JSON: %s\n", e.what()); + return 1; + } + } + + if (mcp.load_config(config_file, servers_arg)) { + printf("MCP config loaded successfully.\n"); + + auto tools = mcp.get_tools(); + printf("Found %zu tools.\n", tools.size()); + + for (const auto & tool : tools) { + printf("Tool: %s\n", tool.name.c_str()); + printf(" Description: %s\n", tool.description.c_str()); + } + + if (!tool_to_run.empty()) { + printf("Calling tool %s...\n", tool_to_run.c_str()); + mcp.set_yolo(true); + json res = mcp.call_tool(tool_to_run, tool_args); + printf("Result: %s\n", res.dump().c_str()); + } else if (!tools.empty()) { + printf("No tool specified. Calling first tool '%s' with empty args as smoke test...\n", tools[0].name.c_str()); + json args = json::object(); + mcp.set_yolo(true); + json res = mcp.call_tool(tools[0].name, args); + printf("Result: %s\n", res.dump().c_str()); + } + + } else { + printf("Failed to load MCP config.\n"); + return 1; + } + + // Allow some time for threads to shutdown if any + std::this_thread::sleep_for(std::chrono::seconds(1)); + + return 0; +} diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 02ccb72598..838128fe04 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -3,6 +3,7 @@ #include "console.h" // #include "log.h" +#include "mcp.hpp" #include "server-context.h" #include "server-task.h" @@ -48,10 +49,12 @@ static void signal_handler(int) { #endif struct cli_context { - server_context ctx_server; - json messages = json::array(); + server_context ctx_server; + json messages = json::array(); + json tools = json::array(); std::vector input_files; - task_params defaults; + task_params defaults; + std::set approved_requests; // thread for showing "loading" animation std::atomic loading_show; @@ -68,7 +71,7 @@ struct cli_context { // defaults.return_progress = true; // TODO: show progress } - std::string generate_completion(result_timings & out_timings) { + std::string generate_completion(result_timings & out_timings, server_task_result_ptr & out_result) { server_response_reader rd = ctx_server.get_response_reader(); auto chat_params = format_chat(); { @@ -110,11 +113,12 @@ struct cli_context { } else { console::error("Error: %s\n", err_data.dump().c_str()); } + out_result = std::move(result); return curr_content; } - auto res_partial = dynamic_cast(result.get()); + auto *res_partial = dynamic_cast(result.get()); if (res_partial) { - out_timings = std::move(res_partial->timings); + out_timings = res_partial->timings; for (const auto & diff : res_partial->oaicompat_msg_diffs) { if (!diff.content_delta.empty()) { if (is_thinking) { @@ -137,9 +141,10 @@ struct cli_context { } } } - auto res_final = dynamic_cast(result.get()); + auto *res_final = dynamic_cast(result.get()); if (res_final) { - out_timings = std::move(res_final->timings); + out_timings = res_final->timings; + out_result = std::move(result); break; } result = rd.next(should_stop); @@ -160,20 +165,18 @@ struct cli_context { buf.assign((std::istreambuf_iterator(file)), std::istreambuf_iterator()); input_files.push_back(std::move(buf)); return mtmd_default_marker(); - } else { - std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); - return content; } + return std::string((std::istreambuf_iterator(file)), std::istreambuf_iterator()); } - common_chat_params format_chat() { + common_chat_params format_chat() const { auto meta = ctx_server.get_meta(); auto & chat_params = meta.chat_params; common_chat_templates_inputs inputs; inputs.messages = common_chat_msgs_parse_oaicompat(messages); - inputs.tools = {}; // TODO - inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = tools.empty() ? COMMON_CHAT_TOOL_CHOICE_NONE : COMMON_CHAT_TOOL_CHOICE_AUTO; inputs.json_schema = ""; // TODO inputs.grammar = ""; // TODO inputs.use_jinja = chat_params.use_jinja; @@ -229,7 +232,29 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - console::log("\nLoading model... "); // followed by loading animation + // Initialize MCP + mcp_context mcp; + if (!params.mcp_config.empty()) { + if (mcp.load_config(params.mcp_config, params.mcp_servers)) { + mcp.set_yolo(params.mcp_yolo); + auto mcp_tools = mcp.get_tools(); + for (const auto & t : mcp_tools) { + json tool = { + { "type", "function" }, + { "function", + { { "name", t.name }, { "description", t.description }, { "parameters", t.input_schema } } } + }; + ctx_cli.tools.push_back(tool); + } + if (!ctx_cli.tools.empty()) { + console::log("Enabled %d MCP tools\n", (int) ctx_cli.tools.size()); + } + } else { + console::error("Failed to load MCP config: %s\n", params.mcp_config.c_str()); + } + } + + console::log("\nLoading model... "); // followed by loading animation console::spinner::start(); if (!ctx_cli.ctx_server.load_model(params)) { console::spinner::stop(); @@ -337,7 +362,8 @@ int main(int argc, char ** argv) { // process commands if (string_starts_with(buffer, "/exit")) { break; - } else if (string_starts_with(buffer, "/regen")) { + } + if (string_starts_with(buffer, "/regen")) { if (ctx_cli.messages.size() >= 2) { size_t last_idx = ctx_cli.messages.size() - 1; ctx_cli.messages.erase(last_idx); @@ -388,11 +414,85 @@ int main(int argc, char ** argv) { cur_msg.clear(); } result_timings timings; - std::string assistant_content = ctx_cli.generate_completion(timings); - ctx_cli.messages.push_back({ - {"role", "assistant"}, - {"content", assistant_content} - }); + while (true) { + server_task_result_ptr result; + std::string assistant_content = ctx_cli.generate_completion(timings, result); + auto * res_final = dynamic_cast(result.get()); + + if (res_final && !res_final->oaicompat_msg.tool_calls.empty()) { + ctx_cli.messages.push_back(res_final->oaicompat_msg.to_json_oaicompat()); + + for (const auto & tc : res_final->oaicompat_msg.tool_calls) { + json args; + try { + if (tc.arguments.empty()) { + args = json::object(); + } else { + args = json::parse(tc.arguments); + } + } catch (...) { + json err_msg = { + { "role", "tool" }, + { "content", "Error parsing arguments" }, + { "tool_call_id", tc.id }, + { "name", tc.name } + }; + ctx_cli.messages.push_back(err_msg); + continue; + } + + std::string request_id = tc.name + tc.arguments; + if (!mcp.get_yolo() && + ctx_cli.approved_requests.find(request_id) == ctx_cli.approved_requests.end()) { + // Prompt user + fprintf(stdout, "\n\n\033[1;33mTool call: %s\033[0m\n", tc.name.c_str()); + fprintf(stdout, "Arguments: %s\n", args.dump(2).c_str()); + fprintf(stdout, "Approve? [y]es, [n]o, [A]lways allow feature: "); + fflush(stdout); + + char c = ' '; + std::string line; + + console::readline(line, false); + if (!line.empty()) { + c = line[0]; + } + + if (c == 'y' || c == 'Y') { + // approved once + } else if (c == 'A') { + ctx_cli.approved_requests.insert(request_id); + } else { + json err_msg = { + { "role", "tool" }, + { "content", "User denied tool execution" }, + { "tool_call_id", tc.id }, + { "name", tc.name } + }; + ctx_cli.messages.push_back(err_msg); + continue; + } + } + + json res = mcp.call_tool(tc.name, args); + json tool_msg = { + { "role", "tool" }, + { "content", res.dump() }, + { "tool_call_id", tc.id }, + { "name", tc.name } + }; + ctx_cli.messages.push_back(tool_msg); + } + // continue loop to generate with tool results + } else { + json assistant_msg = { + {"role", "assistant"}, + {"content", assistant_content} + }; + ctx_cli.messages.push_back(assistant_msg); + break; + } + } console::log("\n"); if (params.show_timings) { diff --git a/tools/cli/mcp.hpp b/tools/cli/mcp.hpp new file mode 100644 index 0000000000..e5d0d9d6c8 --- /dev/null +++ b/tools/cli/mcp.hpp @@ -0,0 +1,569 @@ +#pragma once + +#include "../../vendor/sheredom/subprocess.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +struct mcp_tool { + std::string name; + std::string description; + json input_schema; + std::string server_name; +}; + +// Low-level transport layer for subprocess-based MCP communication. +// Handles subprocess lifecycle, raw I/O, and line-based message framing. +class mcp_transport { + std::string name; + std::string command; + std::vector args; + std::map env; + + struct subprocess_s process = {}; + std::atomic running{ false }; + std::thread read_thread; + std::thread err_thread; + std::atomic stop_read{ false }; + + std::string read_buffer; + + std::function line_handler; + + public: + mcp_transport(const std::string & name, + const std::string & cmd, + const std::vector & args, + const std::map & env) : + name(name), + command(cmd), + args(args), + env(env) {} + + mcp_transport(const mcp_transport &) = delete; + mcp_transport & operator=(const mcp_transport &) = delete; + + mcp_transport(mcp_transport && other) noexcept : + name(std::move(other.name)), + command(std::move(other.command)), + args(std::move(other.args)), + env(std::move(other.env)), + process(other.process), + running(other.running.load()), + read_thread(std::move(other.read_thread)), + err_thread(std::move(other.err_thread)), + stop_read(other.stop_read.load()), + read_buffer(std::move(other.read_buffer)), + line_handler(std::move(other.line_handler)) { + other.process = {}; + other.running = false; + } + + mcp_transport & operator=(mcp_transport && other) noexcept { + if (this != &other) { + stop(); + + name = std::move(other.name); + command = std::move(other.command); + args = std::move(other.args); + env = std::move(other.env); + process = other.process; + running = other.running.load(); + read_thread = std::move(other.read_thread); + err_thread = std::move(other.err_thread); + stop_read = other.stop_read.load(); + read_buffer = std::move(other.read_buffer); + line_handler = std::move(other.line_handler); + + other.process = {}; + other.running = false; + } + return *this; + } + + ~mcp_transport() { stop(); } + + void set_line_handler(std::function handler) { + line_handler = std::move(handler); + } + + bool is_running() const { return running; } + + bool start() { + std::vector cmd_args; + cmd_args.push_back(command.c_str()); + for (const auto & arg : args) { + cmd_args.push_back(arg.c_str()); + } + cmd_args.push_back(nullptr); + + std::vector env_vars; + std::vector env_strings; // keep strings alive + if (!env.empty()) { + for (const auto & kv : env) { + env_strings.push_back(kv.first + "=" + kv.second); + } + for (const auto & s : env_strings) { + env_vars.push_back(s.c_str()); + } + env_vars.push_back(nullptr); + } + + int options = subprocess_option_search_user_path; + int result; + if (env.empty()) { + options |= subprocess_option_inherit_environment; + result = subprocess_create(cmd_args.data(), options, &process); + } else { + result = subprocess_create_ex(cmd_args.data(), options, env_vars.data(), &process); + } + + if (result != 0) { + LOG_ERR("Failed to start MCP server %s: error %d (%s)\n", name.c_str(), errno, strerror(errno)); + return false; + } + + running = true; + read_thread = std::thread(&mcp_transport::read_loop, this); + err_thread = std::thread(&mcp_transport::err_loop, this); + + return true; + } + + void stop() { + if (!running) { + return; + } + LOG_INF("Stopping MCP server %s...\n", name.c_str()); + stop_read = true; + + // 1. Close stdin to signal EOF + if (process.stdin_file) { + fclose(process.stdin_file); + process.stdin_file = nullptr; + } + + // 2. Wait for 10 seconds for normal termination + bool terminated = false; + for (int i = 0; i < 100; ++i) { // 100 * 100ms = 10s + if (subprocess_alive(&process) == 0) { + terminated = true; + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + // 3. Terminate if still running + if (!terminated) { + LOG_WRN("MCP server %s did not exit gracefully, terminating...\n", name.c_str()); + subprocess_terminate(&process); + } + + // 4. Join threads + if (read_thread.joinable()) { + read_thread.join(); + } + if (err_thread.joinable()) { + err_thread.join(); + } + + // 5. Cleanup + if (running) { + subprocess_destroy(&process); + running = false; + } + LOG_INF("MCP server %s stopped.\n", name.c_str()); + } + + bool send(const std::string & data) { + FILE * f = subprocess_stdin(&process); + if (!f) { + return false; + } + fwrite(data.c_str(), 1, data.size(), f); + fflush(f); + return true; + } + + private: + void read_loop() { + LOG_DBG("MCP read_loop started for %s\n", name.c_str()); + char buffer[4096]; + while (!stop_read && running) { + unsigned bytes_read = subprocess_read_stdout(&process, buffer, sizeof(buffer)); + + if (bytes_read == 0) { + // If blocking read returns 0, it means EOF (process exited or pipe closed). + // We should NOT call subprocess_alive() here because it calls subprocess_join() + // which modifies the process struct (closes stdin) and causes race conditions/double-free + // when stop() is called concurrently. + // Just break the loop. The process is likely dead or dying. + if (!stop_read) { + LOG_ERR("MCP server %s: read_loop exiting (stdout closed/EOF)\n", name.c_str()); + } + running = false; + break; + } + + read_buffer.append(buffer, bytes_read); + size_t pos; + while ((pos = read_buffer.find('\n')) != std::string::npos) { + std::string line = read_buffer.substr(0, pos); + read_buffer.erase(0, pos + 1); + + if (line.empty()) { + continue; + } + + if (line_handler) { + line_handler(line); + } + } + } + LOG_DBG("MCP read_loop exiting for %s (stop_read=%d, running=%d)\n", name.c_str(), stop_read.load(), + running.load()); + } + + void err_loop() { + char buffer[1024]; + while (!stop_read && running) { + unsigned bytes_read = subprocess_read_stderr(&process, buffer, sizeof(buffer)); + if (bytes_read > 0) { + if (stop_read) { + break; // Don't log stderr during shutdown + } + std::string err_str(buffer, bytes_read); + LOG_WRN("[%s stderr] %s", name.c_str(), err_str.c_str()); + } else { + // EOF + break; + } + } + } +}; + +// JSON-RPC layer for MCP server communication. +// Handles request/response correlation, message formatting, and MCP protocol methods. +class mcp_server { + std::string name; + mcp_transport transport; + + std::mutex mutex; + int next_id = 1; + std::map> pending_requests; + + public: + mcp_server(const std::string & name, + const std::string & cmd, + const std::vector & args, + const std::map & env) : + name(name), + transport(name, cmd, args, env) {} + + mcp_server(const mcp_server &) = delete; + mcp_server & operator=(const mcp_server &) = delete; + + mcp_server(mcp_server && other) noexcept : + name(std::move(other.name)), + transport(std::move(other.transport)), + next_id(other.next_id), + pending_requests(std::move(other.pending_requests)) {} + + mcp_server & operator=(mcp_server && other) noexcept { + if (this != &other) { + stop(); + + name = std::move(other.name); + transport = std::move(other.transport); + next_id = other.next_id; + pending_requests = std::move(other.pending_requests); + } + return *this; + } + + ~mcp_server() { stop(); } + + bool start() { + transport.set_line_handler([this](const std::string & line) { + handle_line(line); + }); + return transport.start(); + } + + void stop() { transport.stop(); } + + json send_request(const std::string & method, const json & params = json::object()) { + if (!transport.is_running()) { + LOG_ERR("Cannot send request to %s: server not running\n", name.c_str()); + return nullptr; + } + + int id; + std::future future; + { + std::lock_guard lock(mutex); + id = next_id++; + future = pending_requests[id].get_future(); + } + + json req = { + { "jsonrpc", "2.0" }, + { "id", id }, + { "method", method }, + { "params", params } + }; + + std::string req_str = req.dump() + "\n"; + LOG_DBG("MCP request to %s [id=%d]: %s\n", name.c_str(), id, method.c_str()); + if (!transport.send(req_str)) { + LOG_ERR("Cannot send request to %s: stdin is null\n", name.c_str()); + std::lock_guard lock(mutex); + pending_requests.erase(id); + return nullptr; + } + + // Wait for response with timeout + if (future.wait_for(std::chrono::seconds(30)) == std::future_status::timeout) { + LOG_ERR("Timeout waiting for response from %s (method: %s, id: %d)\n", name.c_str(), method.c_str(), id); + std::lock_guard lock(mutex); + pending_requests.erase(id); + return nullptr; + } + + return future.get(); + } + + void send_notification(const std::string & method, const json & params = json::object()) { + json req = { + { "jsonrpc", "2.0" }, + { "method", method }, + { "params", params } + }; + + transport.send(req.dump() + "\n"); + } + + // Initialize handshake + bool initialize() { + // Send initialize + json init_params = { + { "protocolVersion", "2024-11-05" }, + { "capabilities", { { "roots", { { "listChanged", false } } }, { "sampling", json::object() } } }, + { "clientInfo", + { + { "name", "llama.cpp-cli" }, { "version", "0.1.0" } // TODO: use real version + } } + }; + + json res = send_request("initialize", init_params); + if (res.is_null() || res.contains("error")) { + LOG_ERR("Failed to initialize MCP server %s\n", name.c_str()); + return false; + } + + // Send initialized notification + send_notification("notifications/initialized"); + return true; + } + + std::vector list_tools() { + std::vector tools; + json res = send_request("tools/list"); + if (res.is_null() || res.contains("error")) { + LOG_ERR("Failed to list tools from %s\n", name.c_str()); + return tools; + } + + if (res.contains("result") && res["result"].contains("tools")) { + for (const auto & t : res["result"]["tools"]) { + mcp_tool tool; + tool.name = t["name"].get(); + if (t.contains("description")) { + tool.description = t["description"].get(); + } + if (t.contains("inputSchema")) { + tool.input_schema = t["inputSchema"]; + } + tool.server_name = name; + tools.push_back(tool); + } + } + return tools; + } + + json call_tool(const std::string & tool_name, const json & args) { + json params = { + { "name", tool_name }, + { "arguments", args } + }; + json res = send_request("tools/call", params); + if (res.is_null() || res.contains("error")) { + return { + { "error", res.contains("error") ? res["error"] : "Unknown error" } + }; + } + if (res.contains("result")) { + return res["result"]; + } + return { + { "error", "No result returned" } + }; + } + + private: + void handle_line(const std::string & line) { + try { + json msg = json::parse(line); + if (msg.contains("id") && !msg["id"].is_null()) { + // Response - handle both int and string IDs (JSON-RPC allows both) + int id; + if (msg["id"].is_string()) { + id = std::stoi(msg["id"].get()); + } else { + id = msg["id"].get(); + } + std::lock_guard lock(mutex); + if (pending_requests.count(id)) { + LOG_DBG("MCP response received from %s [id=%d]\n", name.c_str(), id); + pending_requests[id].set_value(msg); + pending_requests.erase(id); + } else { + LOG_WRN("MCP response for unknown id %d from %s: %s\n", id, name.c_str(), line.c_str()); + } + } else { + // Notification or request from server -> ignore for now or log + // MCP servers might send notifications (e.g. logging) + LOG_INF("MCP Notification from %s: %s\n", name.c_str(), line.c_str()); + } + } catch (const std::exception & e) { + // Not a full JSON yet? Or invalid? + // If it was a line, it should be valid JSON-RPC + LOG_WRN("Failed to parse JSON from %s: %s (line: %s)\n", name.c_str(), e.what(), line.c_str()); + } + } +}; + +class mcp_context { + std::map> servers; + std::vector tools; + bool yolo = false; + + public: + void set_yolo(bool y) { yolo = y; } + + bool load_config(const std::string & config_path, const std::string & enabled_servers_str) { + std::ifstream f(config_path); + if (!f) { + return false; + } + + json config; + try { + f >> config; + } catch (...) { + return false; + } + + std::vector enabled_list; + std::stringstream ss(enabled_servers_str); + std::string item; + while (std::getline(ss, item, ',')) { + if (!item.empty()) { + enabled_list.push_back(item); + } + } + + if (config.contains("mcpServers")) { + std::string server_list; + for (auto & [key, val] : config["mcpServers"].items()) { + if (!server_list.empty()) { + server_list += ", "; + } + server_list += key; + } + LOG_INF("MCP configuration found with servers: %s\n", server_list.c_str()); + + for (auto & [key, val] : config["mcpServers"].items()) { + bool enabled = true; + if (!enabled_list.empty()) { + bool found = false; + for (const auto & s : enabled_list) { + if (s == key) { + found = true; + } + } + if (!found) { + enabled = false; + } + } + + if (enabled) { + std::string cmd = val["command"].get(); + std::vector args = val.value("args", std::vector{}); + std::map env; + if (val.contains("env")) { + for (auto & [ek, ev] : val["env"].items()) { + env[ek] = ev.get(); + } + } + + auto server = std::make_shared(key, cmd, args, env); + LOG_INF("Trying to start MCP server: %s...\n", key.c_str()); + if (server->start()) { + if (server->initialize()) { + servers[key] = server; + LOG_INF("MCP Server '%s' started and initialized.\n", key.c_str()); + + auto server_tools = server->list_tools(); + tools.insert(tools.end(), server_tools.begin(), server_tools.end()); + } else { + LOG_ERR("MCP Server '%s' failed to initialize.\n", key.c_str()); + } + } + } + } + } + return true; + } + + std::vector get_tools() const { return tools; } + + bool get_yolo() const { return yolo; } + + json call_tool(const std::string & tool_name, const json & args) { + // Find which server has this tool + std::string server_name; + for (const auto & t : tools) { + if (t.name == tool_name) { + server_name = t.server_name; + break; + } + } + + if (server_name.empty()) { + return { + { "error", "Tool not found" } + }; + } + + if (servers.count(server_name)) { + return servers[server_name]->call_tool(tool_name, args); + } + return { + { "error", "Server not found" } + }; + } +};