Merge 9e60c0c83f into 06bf3796f4
This commit is contained in:
commit
5247db9fac
|
|
@ -3784,6 +3784,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--mcp-config"}, "FILE",
|
||||
"path to MCP configuration file",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.mcp_config = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_CLI}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--mcp-yolo"},
|
||||
"auto-approve all MCP tool calls (no user confirmation)",
|
||||
[](common_params & params) {
|
||||
params.mcp_yolo = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_CLI}));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -543,6 +543,11 @@ struct common_params {
|
|||
|
||||
std::map<std::string, std::string> default_template_kwargs;
|
||||
|
||||
// MCP params
|
||||
std::string mcp_config = "";
|
||||
std::string mcp_servers = "";
|
||||
bool mcp_yolo = false;
|
||||
|
||||
// webui configs
|
||||
bool webui = true;
|
||||
std::string webui_config_json;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"mcpServers": {
|
||||
"serena": {
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"--from",
|
||||
"git+https://github.com/oraios/serena",
|
||||
"serena",
|
||||
"start-mcp-server"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
import sys
|
||||
import json
|
||||
import logging
|
||||
|
||||
logging.basicConfig(filename='mcp_dummy.log', level=logging.DEBUG)
|
||||
|
||||
|
||||
def main():
|
||||
logging.info("Starting MCP Dummy Server")
|
||||
while True:
|
||||
try:
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
break
|
||||
logging.info(f"Received: {line.strip()}")
|
||||
try:
|
||||
req = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if "method" in req:
|
||||
method = req["method"]
|
||||
req_id = req.get("id")
|
||||
|
||||
resp = {"jsonrpc": "2.0", "id": req_id}
|
||||
|
||||
if method == "initialize":
|
||||
resp["result"] = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"serverInfo": {"name": "dummy", "version": "1.0"}
|
||||
}
|
||||
elif method == "tools/list":
|
||||
resp["result"] = {
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a location",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
elif method == "tools/call":
|
||||
params = req.get("params", {})
|
||||
name = params.get("name")
|
||||
args = params.get("arguments", {})
|
||||
|
||||
logging.info(f"Tool call: {name} with {args}")
|
||||
|
||||
content = [{"type": "text", "text": f"Weather in {args.get('location')} is 25C"}]
|
||||
# For simplicity, return raw content or follow MCP spec?
|
||||
# MCP spec: result: { content: [ {type: "text", text: "..."} ] }
|
||||
# My mcp.hpp returns res["result"].
|
||||
# My cli.cpp dumps res.dump().
|
||||
# So passing full result object is fine.
|
||||
resp["result"] = {
|
||||
"content": content
|
||||
}
|
||||
else:
|
||||
# Ignore notifications or other methods
|
||||
if req_id is not None:
|
||||
resp["error"] = {"code": -32601, "message": "Method not found"}
|
||||
else:
|
||||
continue
|
||||
|
||||
logging.info(f"Sending: {json.dumps(resp)}")
|
||||
if req_id is not None:
|
||||
sys.stdout.write(json.dumps(resp) + "\n\n")
|
||||
sys.stdout.flush()
|
||||
except Exception as e:
|
||||
logging.error(f"Error: {e}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
llama_add_compile_flags()
|
||||
|
||||
add_executable(test-mcp-integration test-mcp-integration.cpp)
|
||||
target_link_libraries(test-mcp-integration PRIVATE common)
|
||||
target_include_directories(test-mcp-integration PRIVATE ../tools/cli)
|
||||
|
||||
function(llama_build source)
|
||||
set(TEST_SOURCES ${source} ${ARGN})
|
||||
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ static common_chat_msg normalize(const common_chat_msg & msg) {
|
|||
for (auto & tool_call : normalized.tool_calls) {
|
||||
try {
|
||||
tool_call.arguments = json::parse(tool_call.arguments).dump();
|
||||
} catch (const std::exception &) {
|
||||
// Do nothing
|
||||
} catch (const std::exception &e) {
|
||||
LOG_DBG("Normalize failed on tool call: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
return normalized;
|
||||
|
|
@ -183,7 +183,7 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha
|
|||
}
|
||||
}
|
||||
|
||||
common_chat_tool special_function_tool {
|
||||
static common_chat_tool special_function_tool {
|
||||
/* .name = */ "special_function",
|
||||
/* .description = */ "I'm special",
|
||||
/* .parameters = */ R"({
|
||||
|
|
@ -197,7 +197,7 @@ common_chat_tool special_function_tool {
|
|||
"required": ["arg1"]
|
||||
})",
|
||||
};
|
||||
common_chat_tool special_function_tool_with_optional_param {
|
||||
static common_chat_tool special_function_tool_with_optional_param {
|
||||
/* .name = */ "special_function_with_opt",
|
||||
/* .description = */ "I'm special but have optional stuff",
|
||||
/* .parameters = */ R"({
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
#include "../tools/cli/mcp.hpp"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
if (argc < 2) {
|
||||
fprintf(stderr, "Usage: %s <mcp_config_file>\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string config_file = argv[1];
|
||||
|
||||
printf("Testing MCP integration with config: %s\n", config_file.c_str());
|
||||
|
||||
mcp_context mcp;
|
||||
std::string tool_to_run;
|
||||
json tool_args = json::object();
|
||||
std::string servers_arg = "";
|
||||
|
||||
if (argc >= 3) {
|
||||
tool_to_run = argv[2];
|
||||
}
|
||||
if (argc >= 4) {
|
||||
try {
|
||||
tool_args = json::parse(argv[3]);
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "Error parsing tool arguments JSON: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (mcp.load_config(config_file, servers_arg)) {
|
||||
printf("MCP config loaded successfully.\n");
|
||||
|
||||
auto tools = mcp.get_tools();
|
||||
printf("Found %zu tools.\n", tools.size());
|
||||
|
||||
for (const auto & tool : tools) {
|
||||
printf("Tool: %s\n", tool.name.c_str());
|
||||
printf(" Description: %s\n", tool.description.c_str());
|
||||
}
|
||||
|
||||
if (!tool_to_run.empty()) {
|
||||
printf("Calling tool %s...\n", tool_to_run.c_str());
|
||||
mcp.set_yolo(true);
|
||||
json res = mcp.call_tool(tool_to_run, tool_args);
|
||||
printf("Result: %s\n", res.dump().c_str());
|
||||
} else if (!tools.empty()) {
|
||||
printf("No tool specified. Calling first tool '%s' with empty args as smoke test...\n", tools[0].name.c_str());
|
||||
json args = json::object();
|
||||
mcp.set_yolo(true);
|
||||
json res = mcp.call_tool(tools[0].name, args);
|
||||
printf("Result: %s\n", res.dump().c_str());
|
||||
}
|
||||
|
||||
} else {
|
||||
printf("Failed to load MCP config.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Allow some time for threads to shutdown if any
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
#include "console.h"
|
||||
// #include "log.h"
|
||||
|
||||
#include "mcp.hpp"
|
||||
#include "server-context.h"
|
||||
#include "server-task.h"
|
||||
|
||||
|
|
@ -48,10 +49,12 @@ static void signal_handler(int) {
|
|||
#endif
|
||||
|
||||
struct cli_context {
|
||||
server_context ctx_server;
|
||||
json messages = json::array();
|
||||
server_context ctx_server;
|
||||
json messages = json::array();
|
||||
json tools = json::array();
|
||||
std::vector<raw_buffer> input_files;
|
||||
task_params defaults;
|
||||
task_params defaults;
|
||||
std::set<std::string> approved_requests;
|
||||
|
||||
// thread for showing "loading" animation
|
||||
std::atomic<bool> loading_show;
|
||||
|
|
@ -68,7 +71,7 @@ struct cli_context {
|
|||
// defaults.return_progress = true; // TODO: show progress
|
||||
}
|
||||
|
||||
std::string generate_completion(result_timings & out_timings) {
|
||||
std::string generate_completion(result_timings & out_timings, server_task_result_ptr & out_result) {
|
||||
server_response_reader rd = ctx_server.get_response_reader();
|
||||
auto chat_params = format_chat();
|
||||
{
|
||||
|
|
@ -110,11 +113,12 @@ struct cli_context {
|
|||
} else {
|
||||
console::error("Error: %s\n", err_data.dump().c_str());
|
||||
}
|
||||
out_result = std::move(result);
|
||||
return curr_content;
|
||||
}
|
||||
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
|
||||
auto *res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
|
||||
if (res_partial) {
|
||||
out_timings = std::move(res_partial->timings);
|
||||
out_timings = res_partial->timings;
|
||||
for (const auto & diff : res_partial->oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (is_thinking) {
|
||||
|
|
@ -137,9 +141,10 @@ struct cli_context {
|
|||
}
|
||||
}
|
||||
}
|
||||
auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
|
||||
auto *res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
|
||||
if (res_final) {
|
||||
out_timings = std::move(res_final->timings);
|
||||
out_timings = res_final->timings;
|
||||
out_result = std::move(result);
|
||||
break;
|
||||
}
|
||||
result = rd.next(should_stop);
|
||||
|
|
@ -160,20 +165,18 @@ struct cli_context {
|
|||
buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
input_files.push_back(std::move(buf));
|
||||
return mtmd_default_marker();
|
||||
} else {
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
return content;
|
||||
}
|
||||
return std::string((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
}
|
||||
|
||||
common_chat_params format_chat() {
|
||||
common_chat_params format_chat() const {
|
||||
auto meta = ctx_server.get_meta();
|
||||
auto & chat_params = meta.chat_params;
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = {}; // TODO
|
||||
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||
inputs.tool_choice = tools.empty() ? COMMON_CHAT_TOOL_CHOICE_NONE : COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
inputs.json_schema = ""; // TODO
|
||||
inputs.grammar = ""; // TODO
|
||||
inputs.use_jinja = chat_params.use_jinja;
|
||||
|
|
@ -229,7 +232,29 @@ int main(int argc, char ** argv) {
|
|||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
console::log("\nLoading model... "); // followed by loading animation
|
||||
// Initialize MCP
|
||||
mcp_context mcp;
|
||||
if (!params.mcp_config.empty()) {
|
||||
if (mcp.load_config(params.mcp_config, params.mcp_servers)) {
|
||||
mcp.set_yolo(params.mcp_yolo);
|
||||
auto mcp_tools = mcp.get_tools();
|
||||
for (const auto & t : mcp_tools) {
|
||||
json tool = {
|
||||
{ "type", "function" },
|
||||
{ "function",
|
||||
{ { "name", t.name }, { "description", t.description }, { "parameters", t.input_schema } } }
|
||||
};
|
||||
ctx_cli.tools.push_back(tool);
|
||||
}
|
||||
if (!ctx_cli.tools.empty()) {
|
||||
console::log("Enabled %d MCP tools\n", (int) ctx_cli.tools.size());
|
||||
}
|
||||
} else {
|
||||
console::error("Failed to load MCP config: %s\n", params.mcp_config.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
console::log("\nLoading model... "); // followed by loading animation
|
||||
console::spinner::start();
|
||||
if (!ctx_cli.ctx_server.load_model(params)) {
|
||||
console::spinner::stop();
|
||||
|
|
@ -337,7 +362,8 @@ int main(int argc, char ** argv) {
|
|||
// process commands
|
||||
if (string_starts_with(buffer, "/exit")) {
|
||||
break;
|
||||
} else if (string_starts_with(buffer, "/regen")) {
|
||||
}
|
||||
if (string_starts_with(buffer, "/regen")) {
|
||||
if (ctx_cli.messages.size() >= 2) {
|
||||
size_t last_idx = ctx_cli.messages.size() - 1;
|
||||
ctx_cli.messages.erase(last_idx);
|
||||
|
|
@ -388,11 +414,85 @@ int main(int argc, char ** argv) {
|
|||
cur_msg.clear();
|
||||
}
|
||||
result_timings timings;
|
||||
std::string assistant_content = ctx_cli.generate_completion(timings);
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "assistant"},
|
||||
{"content", assistant_content}
|
||||
});
|
||||
while (true) {
|
||||
server_task_result_ptr result;
|
||||
std::string assistant_content = ctx_cli.generate_completion(timings, result);
|
||||
auto * res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
|
||||
|
||||
if (res_final && !res_final->oaicompat_msg.tool_calls.empty()) {
|
||||
ctx_cli.messages.push_back(res_final->oaicompat_msg.to_json_oaicompat());
|
||||
|
||||
for (const auto & tc : res_final->oaicompat_msg.tool_calls) {
|
||||
json args;
|
||||
try {
|
||||
if (tc.arguments.empty()) {
|
||||
args = json::object();
|
||||
} else {
|
||||
args = json::parse(tc.arguments);
|
||||
}
|
||||
} catch (...) {
|
||||
json err_msg = {
|
||||
{ "role", "tool" },
|
||||
{ "content", "Error parsing arguments" },
|
||||
{ "tool_call_id", tc.id },
|
||||
{ "name", tc.name }
|
||||
};
|
||||
ctx_cli.messages.push_back(err_msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string request_id = tc.name + tc.arguments;
|
||||
if (!mcp.get_yolo() &&
|
||||
ctx_cli.approved_requests.find(request_id) == ctx_cli.approved_requests.end()) {
|
||||
// Prompt user
|
||||
fprintf(stdout, "\n\n\033[1;33mTool call: %s\033[0m\n", tc.name.c_str());
|
||||
fprintf(stdout, "Arguments: %s\n", args.dump(2).c_str());
|
||||
fprintf(stdout, "Approve? [y]es, [n]o, [A]lways allow feature: ");
|
||||
fflush(stdout);
|
||||
|
||||
char c = ' ';
|
||||
std::string line;
|
||||
|
||||
console::readline(line, false);
|
||||
if (!line.empty()) {
|
||||
c = line[0];
|
||||
}
|
||||
|
||||
if (c == 'y' || c == 'Y') {
|
||||
// approved once
|
||||
} else if (c == 'A') {
|
||||
ctx_cli.approved_requests.insert(request_id);
|
||||
} else {
|
||||
json err_msg = {
|
||||
{ "role", "tool" },
|
||||
{ "content", "User denied tool execution" },
|
||||
{ "tool_call_id", tc.id },
|
||||
{ "name", tc.name }
|
||||
};
|
||||
ctx_cli.messages.push_back(err_msg);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
json res = mcp.call_tool(tc.name, args);
|
||||
json tool_msg = {
|
||||
{ "role", "tool" },
|
||||
{ "content", res.dump() },
|
||||
{ "tool_call_id", tc.id },
|
||||
{ "name", tc.name }
|
||||
};
|
||||
ctx_cli.messages.push_back(tool_msg);
|
||||
}
|
||||
// continue loop to generate with tool results
|
||||
} else {
|
||||
json assistant_msg = {
|
||||
{"role", "assistant"},
|
||||
{"content", assistant_content}
|
||||
};
|
||||
ctx_cli.messages.push_back(assistant_msg);
|
||||
break;
|
||||
}
|
||||
}
|
||||
console::log("\n");
|
||||
|
||||
if (params.show_timings) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,569 @@
|
|||
#pragma once
|
||||
|
||||
#include "../../vendor/sheredom/subprocess.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
struct mcp_tool {
|
||||
std::string name;
|
||||
std::string description;
|
||||
json input_schema;
|
||||
std::string server_name;
|
||||
};
|
||||
|
||||
// Low-level transport layer for subprocess-based MCP communication.
|
||||
// Handles subprocess lifecycle, raw I/O, and line-based message framing.
|
||||
class mcp_transport {
|
||||
std::string name;
|
||||
std::string command;
|
||||
std::vector<std::string> args;
|
||||
std::map<std::string, std::string> env;
|
||||
|
||||
struct subprocess_s process = {};
|
||||
std::atomic<bool> running{ false };
|
||||
std::thread read_thread;
|
||||
std::thread err_thread;
|
||||
std::atomic<bool> stop_read{ false };
|
||||
|
||||
std::string read_buffer;
|
||||
|
||||
std::function<void(const std::string &)> line_handler;
|
||||
|
||||
public:
|
||||
mcp_transport(const std::string & name,
|
||||
const std::string & cmd,
|
||||
const std::vector<std::string> & args,
|
||||
const std::map<std::string, std::string> & env) :
|
||||
name(name),
|
||||
command(cmd),
|
||||
args(args),
|
||||
env(env) {}
|
||||
|
||||
mcp_transport(const mcp_transport &) = delete;
|
||||
mcp_transport & operator=(const mcp_transport &) = delete;
|
||||
|
||||
mcp_transport(mcp_transport && other) noexcept :
|
||||
name(std::move(other.name)),
|
||||
command(std::move(other.command)),
|
||||
args(std::move(other.args)),
|
||||
env(std::move(other.env)),
|
||||
process(other.process),
|
||||
running(other.running.load()),
|
||||
read_thread(std::move(other.read_thread)),
|
||||
err_thread(std::move(other.err_thread)),
|
||||
stop_read(other.stop_read.load()),
|
||||
read_buffer(std::move(other.read_buffer)),
|
||||
line_handler(std::move(other.line_handler)) {
|
||||
other.process = {};
|
||||
other.running = false;
|
||||
}
|
||||
|
||||
mcp_transport & operator=(mcp_transport && other) noexcept {
|
||||
if (this != &other) {
|
||||
stop();
|
||||
|
||||
name = std::move(other.name);
|
||||
command = std::move(other.command);
|
||||
args = std::move(other.args);
|
||||
env = std::move(other.env);
|
||||
process = other.process;
|
||||
running = other.running.load();
|
||||
read_thread = std::move(other.read_thread);
|
||||
err_thread = std::move(other.err_thread);
|
||||
stop_read = other.stop_read.load();
|
||||
read_buffer = std::move(other.read_buffer);
|
||||
line_handler = std::move(other.line_handler);
|
||||
|
||||
other.process = {};
|
||||
other.running = false;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
~mcp_transport() { stop(); }
|
||||
|
||||
void set_line_handler(std::function<void(const std::string &)> handler) {
|
||||
line_handler = std::move(handler);
|
||||
}
|
||||
|
||||
bool is_running() const { return running; }
|
||||
|
||||
bool start() {
|
||||
std::vector<const char *> cmd_args;
|
||||
cmd_args.push_back(command.c_str());
|
||||
for (const auto & arg : args) {
|
||||
cmd_args.push_back(arg.c_str());
|
||||
}
|
||||
cmd_args.push_back(nullptr);
|
||||
|
||||
std::vector<const char *> env_vars;
|
||||
std::vector<std::string> env_strings; // keep strings alive
|
||||
if (!env.empty()) {
|
||||
for (const auto & kv : env) {
|
||||
env_strings.push_back(kv.first + "=" + kv.second);
|
||||
}
|
||||
for (const auto & s : env_strings) {
|
||||
env_vars.push_back(s.c_str());
|
||||
}
|
||||
env_vars.push_back(nullptr);
|
||||
}
|
||||
|
||||
int options = subprocess_option_search_user_path;
|
||||
int result;
|
||||
if (env.empty()) {
|
||||
options |= subprocess_option_inherit_environment;
|
||||
result = subprocess_create(cmd_args.data(), options, &process);
|
||||
} else {
|
||||
result = subprocess_create_ex(cmd_args.data(), options, env_vars.data(), &process);
|
||||
}
|
||||
|
||||
if (result != 0) {
|
||||
LOG_ERR("Failed to start MCP server %s: error %d (%s)\n", name.c_str(), errno, strerror(errno));
|
||||
return false;
|
||||
}
|
||||
|
||||
running = true;
|
||||
read_thread = std::thread(&mcp_transport::read_loop, this);
|
||||
err_thread = std::thread(&mcp_transport::err_loop, this);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void stop() {
|
||||
if (!running) {
|
||||
return;
|
||||
}
|
||||
LOG_INF("Stopping MCP server %s...\n", name.c_str());
|
||||
stop_read = true;
|
||||
|
||||
// 1. Close stdin to signal EOF
|
||||
if (process.stdin_file) {
|
||||
fclose(process.stdin_file);
|
||||
process.stdin_file = nullptr;
|
||||
}
|
||||
|
||||
// 2. Wait for 10 seconds for normal termination
|
||||
bool terminated = false;
|
||||
for (int i = 0; i < 100; ++i) { // 100 * 100ms = 10s
|
||||
if (subprocess_alive(&process) == 0) {
|
||||
terminated = true;
|
||||
break;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
|
||||
// 3. Terminate if still running
|
||||
if (!terminated) {
|
||||
LOG_WRN("MCP server %s did not exit gracefully, terminating...\n", name.c_str());
|
||||
subprocess_terminate(&process);
|
||||
}
|
||||
|
||||
// 4. Join threads
|
||||
if (read_thread.joinable()) {
|
||||
read_thread.join();
|
||||
}
|
||||
if (err_thread.joinable()) {
|
||||
err_thread.join();
|
||||
}
|
||||
|
||||
// 5. Cleanup
|
||||
if (running) {
|
||||
subprocess_destroy(&process);
|
||||
running = false;
|
||||
}
|
||||
LOG_INF("MCP server %s stopped.\n", name.c_str());
|
||||
}
|
||||
|
||||
bool send(const std::string & data) {
|
||||
FILE * f = subprocess_stdin(&process);
|
||||
if (!f) {
|
||||
return false;
|
||||
}
|
||||
fwrite(data.c_str(), 1, data.size(), f);
|
||||
fflush(f);
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
void read_loop() {
|
||||
LOG_DBG("MCP read_loop started for %s\n", name.c_str());
|
||||
char buffer[4096];
|
||||
while (!stop_read && running) {
|
||||
unsigned bytes_read = subprocess_read_stdout(&process, buffer, sizeof(buffer));
|
||||
|
||||
if (bytes_read == 0) {
|
||||
// If blocking read returns 0, it means EOF (process exited or pipe closed).
|
||||
// We should NOT call subprocess_alive() here because it calls subprocess_join()
|
||||
// which modifies the process struct (closes stdin) and causes race conditions/double-free
|
||||
// when stop() is called concurrently.
|
||||
// Just break the loop. The process is likely dead or dying.
|
||||
if (!stop_read) {
|
||||
LOG_ERR("MCP server %s: read_loop exiting (stdout closed/EOF)\n", name.c_str());
|
||||
}
|
||||
running = false;
|
||||
break;
|
||||
}
|
||||
|
||||
read_buffer.append(buffer, bytes_read);
|
||||
size_t pos;
|
||||
while ((pos = read_buffer.find('\n')) != std::string::npos) {
|
||||
std::string line = read_buffer.substr(0, pos);
|
||||
read_buffer.erase(0, pos + 1);
|
||||
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (line_handler) {
|
||||
line_handler(line);
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG_DBG("MCP read_loop exiting for %s (stop_read=%d, running=%d)\n", name.c_str(), stop_read.load(),
|
||||
running.load());
|
||||
}
|
||||
|
||||
void err_loop() {
|
||||
char buffer[1024];
|
||||
while (!stop_read && running) {
|
||||
unsigned bytes_read = subprocess_read_stderr(&process, buffer, sizeof(buffer));
|
||||
if (bytes_read > 0) {
|
||||
if (stop_read) {
|
||||
break; // Don't log stderr during shutdown
|
||||
}
|
||||
std::string err_str(buffer, bytes_read);
|
||||
LOG_WRN("[%s stderr] %s", name.c_str(), err_str.c_str());
|
||||
} else {
|
||||
// EOF
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// JSON-RPC layer for MCP server communication.
|
||||
// Handles request/response correlation, message formatting, and MCP protocol methods.
|
||||
class mcp_server {
|
||||
std::string name;
|
||||
mcp_transport transport;
|
||||
|
||||
std::mutex mutex;
|
||||
int next_id = 1;
|
||||
std::map<int, std::promise<json>> pending_requests;
|
||||
|
||||
public:
|
||||
mcp_server(const std::string & name,
|
||||
const std::string & cmd,
|
||||
const std::vector<std::string> & args,
|
||||
const std::map<std::string, std::string> & env) :
|
||||
name(name),
|
||||
transport(name, cmd, args, env) {}
|
||||
|
||||
mcp_server(const mcp_server &) = delete;
|
||||
mcp_server & operator=(const mcp_server &) = delete;
|
||||
|
||||
mcp_server(mcp_server && other) noexcept :
|
||||
name(std::move(other.name)),
|
||||
transport(std::move(other.transport)),
|
||||
next_id(other.next_id),
|
||||
pending_requests(std::move(other.pending_requests)) {}
|
||||
|
||||
mcp_server & operator=(mcp_server && other) noexcept {
|
||||
if (this != &other) {
|
||||
stop();
|
||||
|
||||
name = std::move(other.name);
|
||||
transport = std::move(other.transport);
|
||||
next_id = other.next_id;
|
||||
pending_requests = std::move(other.pending_requests);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
~mcp_server() { stop(); }
|
||||
|
||||
bool start() {
|
||||
transport.set_line_handler([this](const std::string & line) {
|
||||
handle_line(line);
|
||||
});
|
||||
return transport.start();
|
||||
}
|
||||
|
||||
void stop() { transport.stop(); }
|
||||
|
||||
json send_request(const std::string & method, const json & params = json::object()) {
|
||||
if (!transport.is_running()) {
|
||||
LOG_ERR("Cannot send request to %s: server not running\n", name.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int id;
|
||||
std::future<json> future;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
id = next_id++;
|
||||
future = pending_requests[id].get_future();
|
||||
}
|
||||
|
||||
json req = {
|
||||
{ "jsonrpc", "2.0" },
|
||||
{ "id", id },
|
||||
{ "method", method },
|
||||
{ "params", params }
|
||||
};
|
||||
|
||||
std::string req_str = req.dump() + "\n";
|
||||
LOG_DBG("MCP request to %s [id=%d]: %s\n", name.c_str(), id, method.c_str());
|
||||
if (!transport.send(req_str)) {
|
||||
LOG_ERR("Cannot send request to %s: stdin is null\n", name.c_str());
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
pending_requests.erase(id);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Wait for response with timeout
|
||||
if (future.wait_for(std::chrono::seconds(30)) == std::future_status::timeout) {
|
||||
LOG_ERR("Timeout waiting for response from %s (method: %s, id: %d)\n", name.c_str(), method.c_str(), id);
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
pending_requests.erase(id);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return future.get();
|
||||
}
|
||||
|
||||
void send_notification(const std::string & method, const json & params = json::object()) {
|
||||
json req = {
|
||||
{ "jsonrpc", "2.0" },
|
||||
{ "method", method },
|
||||
{ "params", params }
|
||||
};
|
||||
|
||||
transport.send(req.dump() + "\n");
|
||||
}
|
||||
|
||||
// Initialize handshake
|
||||
bool initialize() {
|
||||
// Send initialize
|
||||
json init_params = {
|
||||
{ "protocolVersion", "2024-11-05" },
|
||||
{ "capabilities", { { "roots", { { "listChanged", false } } }, { "sampling", json::object() } } },
|
||||
{ "clientInfo",
|
||||
{
|
||||
{ "name", "llama.cpp-cli" }, { "version", "0.1.0" } // TODO: use real version
|
||||
} }
|
||||
};
|
||||
|
||||
json res = send_request("initialize", init_params);
|
||||
if (res.is_null() || res.contains("error")) {
|
||||
LOG_ERR("Failed to initialize MCP server %s\n", name.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// Send initialized notification
|
||||
send_notification("notifications/initialized");
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<mcp_tool> list_tools() {
|
||||
std::vector<mcp_tool> tools;
|
||||
json res = send_request("tools/list");
|
||||
if (res.is_null() || res.contains("error")) {
|
||||
LOG_ERR("Failed to list tools from %s\n", name.c_str());
|
||||
return tools;
|
||||
}
|
||||
|
||||
if (res.contains("result") && res["result"].contains("tools")) {
|
||||
for (const auto & t : res["result"]["tools"]) {
|
||||
mcp_tool tool;
|
||||
tool.name = t["name"].get<std::string>();
|
||||
if (t.contains("description")) {
|
||||
tool.description = t["description"].get<std::string>();
|
||||
}
|
||||
if (t.contains("inputSchema")) {
|
||||
tool.input_schema = t["inputSchema"];
|
||||
}
|
||||
tool.server_name = name;
|
||||
tools.push_back(tool);
|
||||
}
|
||||
}
|
||||
return tools;
|
||||
}
|
||||
|
||||
json call_tool(const std::string & tool_name, const json & args) {
|
||||
json params = {
|
||||
{ "name", tool_name },
|
||||
{ "arguments", args }
|
||||
};
|
||||
json res = send_request("tools/call", params);
|
||||
if (res.is_null() || res.contains("error")) {
|
||||
return {
|
||||
{ "error", res.contains("error") ? res["error"] : "Unknown error" }
|
||||
};
|
||||
}
|
||||
if (res.contains("result")) {
|
||||
return res["result"];
|
||||
}
|
||||
return {
|
||||
{ "error", "No result returned" }
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
void handle_line(const std::string & line) {
|
||||
try {
|
||||
json msg = json::parse(line);
|
||||
if (msg.contains("id") && !msg["id"].is_null()) {
|
||||
// Response - handle both int and string IDs (JSON-RPC allows both)
|
||||
int id;
|
||||
if (msg["id"].is_string()) {
|
||||
id = std::stoi(msg["id"].get<std::string>());
|
||||
} else {
|
||||
id = msg["id"].get<int>();
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (pending_requests.count(id)) {
|
||||
LOG_DBG("MCP response received from %s [id=%d]\n", name.c_str(), id);
|
||||
pending_requests[id].set_value(msg);
|
||||
pending_requests.erase(id);
|
||||
} else {
|
||||
LOG_WRN("MCP response for unknown id %d from %s: %s\n", id, name.c_str(), line.c_str());
|
||||
}
|
||||
} else {
|
||||
// Notification or request from server -> ignore for now or log
|
||||
// MCP servers might send notifications (e.g. logging)
|
||||
LOG_INF("MCP Notification from %s: %s\n", name.c_str(), line.c_str());
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
// Not a full JSON yet? Or invalid?
|
||||
// If it was a line, it should be valid JSON-RPC
|
||||
LOG_WRN("Failed to parse JSON from %s: %s (line: %s)\n", name.c_str(), e.what(), line.c_str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class mcp_context {
|
||||
std::map<std::string, std::shared_ptr<mcp_server>> servers;
|
||||
std::vector<mcp_tool> tools;
|
||||
bool yolo = false;
|
||||
|
||||
public:
|
||||
void set_yolo(bool y) { yolo = y; }
|
||||
|
||||
bool load_config(const std::string & config_path, const std::string & enabled_servers_str) {
|
||||
std::ifstream f(config_path);
|
||||
if (!f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
json config;
|
||||
try {
|
||||
f >> config;
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> enabled_list;
|
||||
std::stringstream ss(enabled_servers_str);
|
||||
std::string item;
|
||||
while (std::getline(ss, item, ',')) {
|
||||
if (!item.empty()) {
|
||||
enabled_list.push_back(item);
|
||||
}
|
||||
}
|
||||
|
||||
if (config.contains("mcpServers")) {
|
||||
std::string server_list;
|
||||
for (auto & [key, val] : config["mcpServers"].items()) {
|
||||
if (!server_list.empty()) {
|
||||
server_list += ", ";
|
||||
}
|
||||
server_list += key;
|
||||
}
|
||||
LOG_INF("MCP configuration found with servers: %s\n", server_list.c_str());
|
||||
|
||||
for (auto & [key, val] : config["mcpServers"].items()) {
|
||||
bool enabled = true;
|
||||
if (!enabled_list.empty()) {
|
||||
bool found = false;
|
||||
for (const auto & s : enabled_list) {
|
||||
if (s == key) {
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
enabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (enabled) {
|
||||
std::string cmd = val["command"].get<std::string>();
|
||||
std::vector<std::string> args = val.value("args", std::vector<std::string>{});
|
||||
std::map<std::string, std::string> env;
|
||||
if (val.contains("env")) {
|
||||
for (auto & [ek, ev] : val["env"].items()) {
|
||||
env[ek] = ev.get<std::string>();
|
||||
}
|
||||
}
|
||||
|
||||
auto server = std::make_shared<mcp_server>(key, cmd, args, env);
|
||||
LOG_INF("Trying to start MCP server: %s...\n", key.c_str());
|
||||
if (server->start()) {
|
||||
if (server->initialize()) {
|
||||
servers[key] = server;
|
||||
LOG_INF("MCP Server '%s' started and initialized.\n", key.c_str());
|
||||
|
||||
auto server_tools = server->list_tools();
|
||||
tools.insert(tools.end(), server_tools.begin(), server_tools.end());
|
||||
} else {
|
||||
LOG_ERR("MCP Server '%s' failed to initialize.\n", key.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<mcp_tool> get_tools() const { return tools; }
|
||||
|
||||
bool get_yolo() const { return yolo; }
|
||||
|
||||
json call_tool(const std::string & tool_name, const json & args) {
|
||||
// Find which server has this tool
|
||||
std::string server_name;
|
||||
for (const auto & t : tools) {
|
||||
if (t.name == tool_name) {
|
||||
server_name = t.server_name;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (server_name.empty()) {
|
||||
return {
|
||||
{ "error", "Tool not found" }
|
||||
};
|
||||
}
|
||||
|
||||
if (servers.count(server_name)) {
|
||||
return servers[server_name]->call_tool(tool_name, args);
|
||||
}
|
||||
return {
|
||||
{ "error", "Server not found" }
|
||||
};
|
||||
}
|
||||
};
|
||||
Loading…
Reference in New Issue