diff --git a/common/arg.cpp b/common/arg.cpp index 5fbc9022c0..7b243abf76 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3784,6 +3784,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + add_opt(common_arg( + {"--mcp-config"}, "FILE", + "path to MCP configuration file", + [](common_params & params, const std::string & value) { + params.mcp_config = value; + } + ).set_examples({LLAMA_EXAMPLE_CLI})); + + add_opt(common_arg( + {"--mcp-yolo"}, + "auto-approve all MCP tool calls (no user confirmation)", + [](common_params & params) { + params.mcp_yolo = true; + } + ).set_examples({LLAMA_EXAMPLE_CLI})); + return ctx_arg; } diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 29819e48d3..b4460fd01f 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -1140,6 +1140,9 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { return; } + // DEBUG: trace input + // fprintf(stderr, "hermes_parse: input='%s'\n", builder.input().substr(builder.pos()).c_str()); + static const common_regex open_regex( "(?:" "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) @@ -1160,6 +1163,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { ); while (auto res = builder.try_find_regex(open_regex)) { + // fprintf(stderr, "hermes_parse: found regex match\n"); const auto & block_start = res->groups[1]; std::string block_end = block_start.empty() ? "" : "```"; @@ -1172,16 +1176,23 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) { if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) { + fprintf(stderr, "hermes: json tool add failed or partial (partial=%d)\n", tool_call->is_partial); throw common_chat_msg_partial_exception("incomplete tool call"); } builder.consume_spaces(); - builder.consume_literal(close_tag); + // builder.consume_literal(close_tag); // Handle mismatched close tag gracefully? + if (!builder.try_consume_literal(close_tag)) { + fprintf(stderr, "hermes: failed to consume close tag '%s'. Remaining: '%s'\n", close_tag.c_str(), builder.input().substr(builder.pos()).c_str()); + // If closing tag is missing, is it partial? + throw common_chat_msg_partial_exception("missing close tag"); + } builder.consume_spaces(); if (!block_end.empty()) { builder.consume_literal(block_end); builder.consume_spaces(); } } else { + fprintf(stderr, "hermes: failed to consume json\n"); throw common_chat_msg_partial_exception("failed to parse tool call"); } } else { diff --git a/common/common.h b/common/common.h index 398ebb0960..aa74464a92 100644 --- a/common/common.h +++ b/common/common.h @@ -543,6 +543,11 @@ struct common_params { std::map default_template_kwargs; + // MCP params + std::string mcp_config = ""; + std::string mcp_servers = ""; + bool mcp_yolo = false; + // webui configs bool webui = true; std::string webui_config_json; diff --git a/mcp_config.json b/mcp_config.json new file mode 100644 index 0000000000..d531881607 --- /dev/null +++ b/mcp_config.json @@ -0,0 +1,13 @@ +{ + "mcpServers": { + "serena": { + "command": "uvx", + "args": [ + "--from", + "git+https://github.com/oraios/serena", + "serena", + "start-mcp-server" + ] + } + } +} \ No newline at end of file diff --git a/mcp_dummy.py b/mcp_dummy.py new file mode 100644 index 0000000000..1fff089407 --- /dev/null +++ b/mcp_dummy.py @@ -0,0 +1,82 @@ +import sys +import os +import json +import logging + +logging.basicConfig(filename='/devel/tools/llama.cpp/mcp_dummy.log', level=logging.DEBUG) + +def main(): + logging.info("Starting MCP Dummy Server") + while True: + try: + line = sys.stdin.readline() + if not line: + break + logging.info(f"Received: {line.strip()}") + try: + req = json.loads(line) + except json.JSONDecodeError: + continue + + if "method" in req: + method = req["method"] + req_id = req.get("id") + + resp = {"jsonrpc": "2.0", "id": req_id} + + if method == "initialize": + resp["result"] = { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "serverInfo": {"name": "dummy", "version": "1.0"} + } + elif method == "tools/list": + resp["result"] = { + "tools": [ + { + "name": "get_weather", + "description": "Get weather for a location", + "inputSchema": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} + }, + "required": ["location"] + } + } + ] + } + elif method == "tools/call": + params = req.get("params", {}) + name = params.get("name") + args = params.get("arguments", {}) + + logging.info(f"Tool call: {name} with {args}") + + content = [{"type": "text", "text": f"Weather in {args.get('location')} is 25C"}] + # For simplicity, return raw content or follow MCP spec? + # MCP spec: result: { content: [ {type: "text", text: "..."} ] } + # My mcp.hpp returns res["result"]. + # My cli.cpp dumps res.dump(). + # So passing full result object is fine. + resp["result"] = { + "content": content + } + else: + # Ignore notifications or other methods + if req_id is not None: + resp["error"] = {"code": -32601, "message": "Method not found"} + else: + continue + + logging.info(f"Sending: {json.dumps(resp)}") + if req_id is not None: + sys.stdout.write(json.dumps(resp) + "\n\n") + sys.stdout.flush() + except Exception as e: + logging.error(f"Error: {e}") + break + +if __name__ == "__main__": + main() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c9436c5995..5f07075213 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,5 +1,9 @@ llama_add_compile_flags() +add_executable(test-mcp-integration test-mcp-integration.cpp) + target_link_libraries(test-mcp-integration PRIVATE common) +target_include_directories(test-mcp-integration PRIVATE ../tools/cli) + function(llama_build source) set(TEST_SOURCES ${source} ${ARGN}) diff --git a/tests/test-mcp-integration.cpp b/tests/test-mcp-integration.cpp new file mode 100644 index 0000000000..fe7d189872 --- /dev/null +++ b/tests/test-mcp-integration.cpp @@ -0,0 +1,71 @@ +#include "../tools/cli/mcp.hpp" +#include "common.h" +#include "log.h" + +#include +#include +#include +#include +#include + +int main(int argc, char ** argv) { + if (argc < 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + std::string config_file = argv[1]; + + printf("Testing MCP integration with config: %s\n", config_file.c_str()); + + mcp_context mcp; + std::string tool_to_run; + json tool_args = json::object(); + std::string servers_arg = ""; + + if (argc >= 3) { + tool_to_run = argv[2]; + } + if (argc >= 4) { + try { + tool_args = json::parse(argv[3]); + } catch (const std::exception & e) { + fprintf(stderr, "Error parsing tool arguments JSON: %s\n", e.what()); + return 1; + } + } + + if (mcp.load_config(config_file, servers_arg)) { + printf("MCP config loaded successfully.\n"); + + auto tools = mcp.get_tools(); + printf("Found %zu tools.\n", tools.size()); + + for (const auto & tool : tools) { + printf("Tool: %s\n", tool.name.c_str()); + printf(" Description: %s\n", tool.description.c_str()); + } + + if (!tool_to_run.empty()) { + printf("Calling tool %s...\n", tool_to_run.c_str()); + mcp.set_yolo(true); + json res = mcp.call_tool(tool_to_run, tool_args); + printf("Result: %s\n", res.dump().c_str()); + } else if (!tools.empty()) { + printf("No tool specified. Calling first tool '%s' with empty args as smoke test...\n", tools[0].name.c_str()); + json args = json::object(); + mcp.set_yolo(true); + json res = mcp.call_tool(tools[0].name, args); + printf("Result: %s\n", res.dump().c_str()); + } + + } else { + printf("Failed to load MCP config.\n"); + return 1; + } + + // Allow some time for threads to shutdown if any + std::this_thread::sleep_for(std::chrono::seconds(1)); + + return 0; +} diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 02ccb72598..f9ad5094b6 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -1,22 +1,25 @@ -#include "common.h" #include "arg.h" +#include "chat.h" +#include "common.h" #include "console.h" // #include "log.h" +#include "mcp.hpp" #include "server-context.h" #include "server-task.h" +#include + #include #include #include -#include #if defined(_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -# define NOMINMAX -#endif -#include +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include #endif const char * LLAMA_ASCII_LOGO = R"( @@ -30,11 +33,12 @@ const char * LLAMA_ASCII_LOGO = R"( )"; static std::atomic g_is_interrupted = false; + static bool should_stop() { return g_is_interrupted.load(); } -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) static void signal_handler(int) { if (g_is_interrupted.load()) { // second Ctrl+C - exit immediately @@ -48,10 +52,12 @@ static void signal_handler(int) { #endif struct cli_context { - server_context ctx_server; - json messages = json::array(); + server_context ctx_server; + json messages = json::array(); + json tools = json::array(); std::vector input_files; - task_params defaults; + task_params defaults; + std::set approved_requests; // thread for showing "loading" animation std::atomic loading_show; @@ -63,12 +69,12 @@ struct cli_context { defaults.n_predict = params.n_predict; defaults.antiprompt = params.antiprompt; - defaults.stream = true; // make sure we always use streaming mode - defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way + defaults.stream = true; // make sure we always use streaming mode + defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way // defaults.return_progress = true; // TODO: show progress } - std::string generate_completion(result_timings & out_timings) { + std::string generate_completion(result_timings & out_timings, server_task_result_ptr & out_result) { server_response_reader rd = ctx_server.get_response_reader(); auto chat_params = format_chat(); { @@ -88,7 +94,7 @@ struct cli_context { task.params.chat_parser_params.parser.load(chat_params.parser); } - rd.post_task({std::move(task)}); + rd.post_task({ std::move(task) }); } // wait for first result @@ -97,7 +103,7 @@ struct cli_context { console::spinner::stop(); std::string curr_content; - bool is_thinking = false; + bool is_thinking = false; while (result) { if (should_stop()) { @@ -110,6 +116,7 @@ struct cli_context { } else { console::error("Error: %s\n", err_data.dump().c_str()); } + out_result = std::move(result); return curr_content; } auto res_partial = dynamic_cast(result.get()); @@ -140,6 +147,7 @@ struct cli_context { auto res_final = dynamic_cast(result.get()); if (res_final) { out_timings = std::move(res_final->timings); + out_result = std::move(result); break; } result = rd.next(should_stop); @@ -166,14 +174,14 @@ struct cli_context { } } - common_chat_params format_chat() { + common_chat_params format_chat() const { auto meta = ctx_server.get_meta(); auto & chat_params = meta.chat_params; common_chat_templates_inputs inputs; inputs.messages = common_chat_msgs_parse_oaicompat(messages); - inputs.tools = {}; // TODO - inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = tools.empty() ? COMMON_CHAT_TOOL_CHOICE_NONE : COMMON_CHAT_TOOL_CHOICE_AUTO; inputs.json_schema = ""; // TODO inputs.grammar = ""; // TODO inputs.use_jinja = chat_params.use_jinja; @@ -189,7 +197,7 @@ struct cli_context { int main(int argc, char ** argv) { common_params params; - params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs + params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CLI)) { return 1; @@ -215,21 +223,43 @@ int main(int argc, char ** argv) { console::set_display(DISPLAY_TYPE_RESET); -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler; - sigemptyset (&sigint_action.sa_mask); + sigemptyset(&sigint_action.sa_mask); sigint_action.sa_flags = 0; sigaction(SIGINT, &sigint_action, NULL); sigaction(SIGTERM, &sigint_action, NULL); -#elif defined (_WIN32) +#elif defined(_WIN32) auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; }; SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - console::log("\nLoading model... "); // followed by loading animation + // Initialize MCP + mcp_context mcp; + if (!params.mcp_config.empty()) { + if (mcp.load_config(params.mcp_config, params.mcp_servers)) { + mcp.set_yolo(params.mcp_yolo); + auto mcp_tools = mcp.get_tools(); + for (const auto & t : mcp_tools) { + json tool = { + { "type", "function" }, + { "function", + { { "name", t.name }, { "description", t.description }, { "parameters", t.input_schema } } } + }; + ctx_cli.tools.push_back(tool); + } + if (!ctx_cli.tools.empty()) { + console::log("Enabled %d MCP tools\n", (int) ctx_cli.tools.size()); + } + } else { + console::error("Failed to load MCP config: %s\n", params.mcp_config.c_str()); + } + } + + console::log("\nLoading model... "); // followed by loading animation console::spinner::start(); if (!ctx_cli.ctx_server.load_model(params)) { console::spinner::stop(); @@ -240,11 +270,9 @@ int main(int argc, char ** argv) { console::spinner::stop(); console::log("\n"); - std::thread inference_thread([&ctx_cli]() { - ctx_cli.ctx_server.start_loop(); - }); + std::thread inference_thread([&ctx_cli]() { ctx_cli.ctx_server.start_loop(); }); - auto inf = ctx_cli.ctx_server.get_meta(); + auto inf = ctx_cli.ctx_server.get_meta(); std::string modalities = "text"; if (inf.has_inp_image) { modalities += ", vision"; @@ -255,8 +283,8 @@ int main(int argc, char ** argv) { if (!params.system_prompt.empty()) { ctx_cli.messages.push_back({ - {"role", "system"}, - {"content", params.system_prompt} + { "role", "system" }, + { "content", params.system_prompt } }); } @@ -290,7 +318,7 @@ int main(int argc, char ** argv) { if (params.prompt.empty()) { console::log("\n> "); std::string line; - bool another_line = true; + bool another_line = true; do { another_line = console::readline(line, params.multiline_input); buffer += line; @@ -312,7 +340,7 @@ int main(int argc, char ** argv) { } else { console::log("\n> %s\n", buffer.c_str()); } - params.prompt.clear(); // only use it once + params.prompt.clear(); // only use it once } console::set_display(DISPLAY_TYPE_RESET); console::log("\n"); @@ -323,7 +351,7 @@ int main(int argc, char ** argv) { } // remove trailing newline - if (!buffer.empty() &&buffer.back() == '\n') { + if (!buffer.empty() && buffer.back() == '\n') { buffer.pop_back(); } @@ -351,11 +379,10 @@ int main(int argc, char ** argv) { ctx_cli.input_files.clear(); console::log("Chat history cleared.\n"); continue; - } else if ( - (string_starts_with(buffer, "/image ") && inf.has_inp_image) || - (string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) { + } else if ((string_starts_with(buffer, "/image ") && inf.has_inp_image) || + (string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) { // just in case (bad copy-paste for example), we strip all trailing/leading spaces - std::string fname = string_strip(buffer.substr(7)); + std::string fname = string_strip(buffer.substr(7)); std::string marker = ctx_cli.load_input_file(fname, true); if (marker.empty()) { console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str()); @@ -365,7 +392,7 @@ int main(int argc, char ** argv) { console::log("Loaded media from '%s'\n", fname.c_str()); continue; } else if (string_starts_with(buffer, "/read ")) { - std::string fname = string_strip(buffer.substr(6)); + std::string fname = string_strip(buffer.substr(6)); std::string marker = ctx_cli.load_input_file(fname, false); if (marker.empty()) { console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str()); @@ -382,23 +409,100 @@ int main(int argc, char ** argv) { // generate response if (add_user_msg) { ctx_cli.messages.push_back({ - {"role", "user"}, - {"content", cur_msg} + { "role", "user" }, + { "content", cur_msg } }); cur_msg.clear(); } result_timings timings; - std::string assistant_content = ctx_cli.generate_completion(timings); - ctx_cli.messages.push_back({ - {"role", "assistant"}, - {"content", assistant_content} - }); + while (true) { + server_task_result_ptr result; + std::string assistant_content = ctx_cli.generate_completion(timings, result); + auto res_final = dynamic_cast(result.get()); + + if (res_final && !res_final->oaicompat_msg.tool_calls.empty()) { + ctx_cli.messages.push_back(res_final->oaicompat_msg.to_json_oaicompat()); + + for (const auto & tc : res_final->oaicompat_msg.tool_calls) { + json args; + try { + if (tc.arguments.empty()) { + args = json::object(); + } else { + args = json::parse(tc.arguments); + } + } catch (...) { + json err_msg = { + { "role", "tool" }, + { "content", "Error parsing arguments" }, + { "tool_call_id", tc.id }, + { "name", tc.name } + }; + ctx_cli.messages.push_back(err_msg); + continue; + } + + std::string request_id = tc.name + tc.arguments; + if (!mcp.get_yolo() && ctx_cli.approved_requests.find(request_id) == ctx_cli.approved_requests.end()) { + // Prompt user + fprintf(stdout, "\n\033[1;33mTool call: %s\033[0m\n", tc.name.c_str()); + fprintf(stdout, "Arguments: %s\n", args.dump(2).c_str()); + fprintf(stdout, "Approve? [y]es, [n]o, [A]lways allow feature: "); + fflush(stdout); + + char c = ' '; + std::string line; + // We are in main loop which might have console settings. + // Use console::readline or similar? + // But for single char input, standard cin/getline might interfere with console lib state if it's separate. + // Given `console::readline` is used above, let's use standard input as fallback or just try cin. + + // Simple blocking read + std::cin >> c; + std::getline(std::cin, line); // consume rest + + if (c == 'y' || c == 'Y') { + // approved once + } else if (c == 'A') { + ctx_cli.approved_requests.insert(request_id); + } else { + json err_msg = { + { "role", "tool" }, + { "content", "User denied tool execution" }, + { "tool_call_id", tc.id }, + { "name", tc.name } + }; + ctx_cli.messages.push_back(err_msg); + continue; + } + } + + json res = mcp.call_tool(tc.name, args); + json tool_msg = { + { "role", "tool" }, + { "content", res.dump() }, + { "tool_call_id", tc.id }, + { "name", tc.name } + }; + ctx_cli.messages.push_back(tool_msg); + } + // continue loop to generate with tool results + } else { + json assistant_msg = { + { "role", "assistant" }, + { "content", assistant_content } + }; + ctx_cli.messages.push_back(assistant_msg); + break; + } + } console::log("\n"); if (params.show_timings) { console::set_display(DISPLAY_TYPE_INFO); console::log("\n"); - console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second); + console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, + timings.predicted_per_second); console::set_display(DISPLAY_TYPE_RESET); } diff --git a/tools/cli/mcp.hpp b/tools/cli/mcp.hpp new file mode 100644 index 0000000000..62662f1ef0 --- /dev/null +++ b/tools/cli/mcp.hpp @@ -0,0 +1,499 @@ +#pragma once + +#include "../../vendor/sheredom/subprocess.h" +#include "common.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +struct mcp_tool { + std::string name; + std::string description; + json input_schema; + std::string server_name; +}; + +class mcp_server { + std::string name; + std::string command; + std::vector args; + std::map env; + + struct subprocess_s process = {}; + std::atomic running{ false }; + std::thread read_thread; + std::thread err_thread; + std::atomic stop_read{ false }; + + std::mutex mutex; + int next_id = 1; + std::map> pending_requests; + + // Buffer for reading + std::string read_buffer; + + public: + mcp_server(const std::string & name, + const std::string & cmd, + const std::vector & args, + const std::map & env) : + name(name), + command(cmd), + args(args), + env(env) {} + + mcp_server(const mcp_server &) = delete; + mcp_server & operator=(const mcp_server &) = delete; + + mcp_server(mcp_server && other) noexcept : + name(std::move(other.name)), + command(std::move(other.command)), + args(std::move(other.args)), + env(std::move(other.env)), + process(other.process), + running(other.running.load()), + read_thread(std::move(other.read_thread)), + err_thread(std::move(other.err_thread)), + stop_read(other.stop_read.load()), + next_id(other.next_id), + pending_requests(std::move(other.pending_requests)), + read_buffer(std::move(other.read_buffer)) { + // Zero out the source process to prevent double-free + other.process = {}; + other.running = false; + } + + mcp_server & operator=(mcp_server && other) noexcept { + if (this != &other) { + stop(); // Clean up current resources + + name = std::move(other.name); + command = std::move(other.command); + args = std::move(other.args); + env = std::move(other.env); + process = other.process; + running = other.running.load(); + read_thread = std::move(other.read_thread); + err_thread = std::move(other.err_thread); + stop_read = other.stop_read.load(); + next_id = other.next_id; + pending_requests = std::move(other.pending_requests); + read_buffer = std::move(other.read_buffer); + + // Zero out source + other.process = {}; + other.running = false; + } + return *this; + } + + ~mcp_server() { stop(); } + + bool start() { + std::vector cmd_args; + cmd_args.push_back(command.c_str()); + for (const auto & arg : args) { + cmd_args.push_back(arg.c_str()); + } + cmd_args.push_back(nullptr); + + std::vector env_vars; + std::vector env_strings; // keep strings alive + if (!env.empty()) { + for (const auto & kv : env) { + env_strings.push_back(kv.first + "=" + kv.second); + } + for (const auto & s : env_strings) { + env_vars.push_back(s.c_str()); + } + env_vars.push_back(nullptr); + } + + // Blocking I/O is simpler with threads + int options = subprocess_option_search_user_path; + int result; + if (env.empty()) { + options |= subprocess_option_inherit_environment; + result = subprocess_create(cmd_args.data(), options, &process); + } else { + result = subprocess_create_ex(cmd_args.data(), options, env_vars.data(), &process); + } + + if (result != 0) { + LOG_ERR("Failed to start MCP server %s: error %d (%s)\n", name.c_str(), errno, strerror(errno)); + return false; + } + + running = true; + read_thread = std::thread(&mcp_server::read_loop, this); + err_thread = std::thread(&mcp_server::err_loop, this); + + return true; + } + + void stop() { + if (!running) return; + LOG_INF("Stopping MCP server %s...\n", name.c_str()); + stop_read = true; + + // 1. Close stdin to signal EOF + if (process.stdin_file) { + fclose(process.stdin_file); + process.stdin_file = nullptr; + } + + // 2. Wait for 10 seconds for normal termination + bool terminated = false; + for (int i = 0; i < 100; ++i) { // 100 * 100ms = 10s + if (subprocess_alive(&process) == 0) { + terminated = true; + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + // 3. Terminate if still running + if (!terminated) { + LOG_WRN("MCP server %s did not exit gracefully, terminating...\n", name.c_str()); + subprocess_terminate(&process); + } + + // 4. Join threads + if (read_thread.joinable()) { + read_thread.join(); + } + if (err_thread.joinable()) { + err_thread.join(); + } + + // 5. Cleanup + if (running) { + subprocess_destroy(&process); + running = false; + } + LOG_INF("MCP server %s stopped.\n", name.c_str()); + } + + json send_request(const std::string & method, const json & params = json::object()) { + int id; + std::future future; + { + std::lock_guard lock(mutex); + id = next_id++; + future = pending_requests[id].get_future(); + } + + json req = { + { "jsonrpc", "2.0" }, + { "id", id }, + { "method", method }, + { "params", params } + }; + + std::string req_str = req.dump() + "\n"; + FILE * stdin_file = subprocess_stdin(&process); + if (stdin_file) { + fwrite(req_str.c_str(), 1, req_str.size(), stdin_file); + fflush(stdin_file); + } + + // Wait for response with timeout + if (future.wait_for(std::chrono::seconds(10)) == std::future_status::timeout) { + LOG_ERR("Timeout waiting for response from %s (method: %s)\n", name.c_str(), method.c_str()); + std::lock_guard lock(mutex); + pending_requests.erase(id); + return nullptr; + } + + return future.get(); + } + + void send_notification(const std::string & method, const json & params = json::object()) { + json req = { + { "jsonrpc", "2.0" }, + { "method", method }, + { "params", params } + }; + + std::string req_str = req.dump() + "\n"; + FILE * stdin_file = subprocess_stdin(&process); + if (stdin_file) { + fwrite(req_str.c_str(), 1, req_str.size(), stdin_file); + fflush(stdin_file); + } + } + + // Initialize handshake + bool initialize() { + // Send initialize + json init_params = { + { "protocolVersion", "2024-11-05" }, + { "capabilities", { + { "roots", { + { "listChanged", false } + } }, + { "sampling", json::object() } + } }, + { "clientInfo", { + { "name", "llama.cpp-cli" }, + { "version", "0.1.0" } // TODO: use real version + } } + }; + + json res = send_request("initialize", init_params); + if (res.is_null() || res.contains("error")) { + LOG_ERR("Failed to initialize MCP server %s\n", name.c_str()); + return false; + } + + // Send initialized notification + send_notification("notifications/initialized"); + return true; + } + + std::vector list_tools() { + std::vector tools; + json res = send_request("tools/list"); + if (res.is_null() || res.contains("error")) { + LOG_ERR("Failed to list tools from %s\n", name.c_str()); + return tools; + } + + if (res.contains("result") && res["result"].contains("tools")) { + for (const auto & t : res["result"]["tools"]) { + mcp_tool tool; + tool.name = t["name"].get(); + if (t.contains("description")) { + tool.description = t["description"].get(); + } + if (t.contains("inputSchema")) { + tool.input_schema = t["inputSchema"]; + } + tool.server_name = name; + tools.push_back(tool); + } + } + return tools; + } + + json call_tool(const std::string & tool_name, const json & args) { + json params = { + { "name", tool_name }, + { "arguments", args } + }; + json res = send_request("tools/call", params); + if (res.is_null() || res.contains("error")) { + return { + { "error", res.contains("error") ? res["error"] : "Unknown error" } + }; + } + if (res.contains("result")) { + return res["result"]; + } + return { + { "error", "No result returned" } + }; + } + + private: + void read_loop() { + char buffer[4096]; + while (!stop_read && running) { + unsigned bytes_read = subprocess_read_stdout(&process, buffer, sizeof(buffer)); + + if (bytes_read == 0) { + // If blocking read returns 0, it means EOF (process exited or pipe closed). + // We should NOT call subprocess_alive() here because it calls subprocess_join() + // which modifies the process struct (closes stdin) and causes race conditions/double-free + // when stop() is called concurrently. + // Just break the loop. The process is likely dead or dying. + if (!stop_read) { + LOG_ERR("MCP process died (stdout closed)\n"); + } + running = false; + break; + } + + read_buffer.append(buffer, bytes_read); + size_t pos; + while ((pos = read_buffer.find('\n')) != std::string::npos) { + std::string line = read_buffer.substr(0, pos); + read_buffer.erase(0, pos + 1); + + if (line.empty()) { + continue; + } + + try { + json msg = json::parse(line); + if (msg.contains("id")) { + // Response + int id = msg["id"].get(); + std::lock_guard lock(mutex); + if (pending_requests.count(id)) { + pending_requests[id].set_value(msg); + pending_requests.erase(id); + } else { + // ID not found + } + } else { + // Notification or request from server -> ignore for now or log + // MCP servers might send notifications (e.g. logging) + LOG_ERR("MCP Notification from %s: %s\n", name.c_str(), line.c_str()); + } + } catch (const std::exception & e) { + // Not a full JSON yet? Or invalid? + // If it was a line, it should be valid JSON-RPC + LOG_WRN("Failed to parse JSON from %s: %s\n", name.c_str(), e.what()); + } + } + } + } + + void err_loop() { + char buffer[1024]; + while (!stop_read && running) { + unsigned bytes_read = subprocess_read_stderr(&process, buffer, sizeof(buffer)); + if (bytes_read > 0) { + if (stop_read) break; // Don't log stderr during shutdown + std::string err_str(buffer, bytes_read); + // Filter out empty/whitespace-only stderr if desired, or just keep it. + // User said "extra logging that passes the stderr here is unnecessary" referring to shutdown. + LOG_WRN("[%s stderr] %s", name.c_str(), err_str.c_str()); + } else { + // EOF + break; + } + } + } +}; + +class mcp_context { + std::map> servers; + std::vector tools; + bool yolo = false; + + public: + void set_yolo(bool y) { yolo = y; } + + bool load_config(const std::string & config_path, const std::string & enabled_servers_str) { + std::ifstream f(config_path); + if (!f) { + return false; + } + + json config; + try { + f >> config; + } catch (...) { + return false; + } + + std::vector enabled_list; + std::stringstream ss(enabled_servers_str); + std::string item; + while (std::getline(ss, item, ',')) { + if (!item.empty()) { + enabled_list.push_back(item); + } + } + + if (config.contains("mcpServers")) { + std::string server_list; + for (auto & [key, val] : config["mcpServers"].items()) { + if (!server_list.empty()) server_list += ", "; + server_list += key; + } + LOG_INF("MCP configuration found with servers: %s\n", server_list.c_str()); + + for (auto & [key, val] : config["mcpServers"].items()) { + // If enabled_servers_str is empty, enable all? User said "possibility to pick which MCP servers to enable". + // If the user specifies explicit list, we filter. If not, maybe we shouldn't enable any or enable all? + // The prompt says "possibility to pick". + // Let's assume if list provided, use it. If not, enable all? Or none? + // User provided example implies explicit enabling might be desired. + // Let's assume if `enabled_servers_str` is not empty, we filter. + + bool enabled = true; + if (!enabled_list.empty()) { + bool found = false; + for (const auto & s : enabled_list) { + if (s == key) { + found = true; + } + } + if (!found) { + enabled = false; + } + } + + if (enabled) { + std::string cmd = val["command"].get(); + std::vector args = val.value("args", std::vector{}); + std::map env; + if (val.contains("env")) { + for (auto & [ek, ev] : val["env"].items()) { + env[ek] = ev.get(); + } + } + + auto server = std::make_shared(key, cmd, args, env); + LOG_INF("Trying to start MCP server: %s...\n", key.c_str()); + if (server->start()) { + if (server->initialize()) { + servers[key] = server; + LOG_INF("MCP Server '%s' started and initialized.\n", key.c_str()); + + auto server_tools = server->list_tools(); + tools.insert(tools.end(), server_tools.begin(), server_tools.end()); + } else { + LOG_ERR("MCP Server '%s' failed to initialize.\n", key.c_str()); + } + } + } + } + } + return true; + } + + std::vector get_tools() const { return tools; } + + bool get_yolo() const { return yolo; } + + json call_tool(const std::string & tool_name, const json & args) { + // Find which server has this tool + std::string server_name; + for (const auto & t : tools) { + if (t.name == tool_name) { + server_name = t.server_name; + break; + } + } + + if (server_name.empty()) { + return { + { "error", "Tool not found" } + }; + } + + if (servers.count(server_name)) { + return servers[server_name]->call_tool(tool_name, args); + } + return { + { "error", "Server not found" } + }; + } +};