beta
This commit is contained in:
parent
b828e18c75
commit
323d23eb1c
|
|
@ -3784,6 +3784,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--mcp-config"}, "FILE",
|
||||
"path to MCP configuration file",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.mcp_config = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_CLI}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--mcp-yolo"},
|
||||
"auto-approve all MCP tool calls (no user confirmation)",
|
||||
[](common_params & params) {
|
||||
params.mcp_yolo = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_CLI}));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1140,6 +1140,9 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
|
|||
return;
|
||||
}
|
||||
|
||||
// DEBUG: trace input
|
||||
// fprintf(stderr, "hermes_parse: input='%s'\n", builder.input().substr(builder.pos()).c_str());
|
||||
|
||||
static const common_regex open_regex(
|
||||
"(?:"
|
||||
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
||||
|
|
@ -1160,6 +1163,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
|
|||
);
|
||||
|
||||
while (auto res = builder.try_find_regex(open_regex)) {
|
||||
// fprintf(stderr, "hermes_parse: found regex match\n");
|
||||
const auto & block_start = res->groups[1];
|
||||
std::string block_end = block_start.empty() ? "" : "```";
|
||||
|
||||
|
|
@ -1172,16 +1176,23 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
|
|||
|
||||
if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
|
||||
if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
|
||||
fprintf(stderr, "hermes: json tool add failed or partial (partial=%d)\n", tool_call->is_partial);
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_spaces();
|
||||
builder.consume_literal(close_tag);
|
||||
// builder.consume_literal(close_tag); // Handle mismatched close tag gracefully?
|
||||
if (!builder.try_consume_literal(close_tag)) {
|
||||
fprintf(stderr, "hermes: failed to consume close tag '%s'. Remaining: '%s'\n", close_tag.c_str(), builder.input().substr(builder.pos()).c_str());
|
||||
// If closing tag is missing, is it partial?
|
||||
throw common_chat_msg_partial_exception("missing close tag");
|
||||
}
|
||||
builder.consume_spaces();
|
||||
if (!block_end.empty()) {
|
||||
builder.consume_literal(block_end);
|
||||
builder.consume_spaces();
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "hermes: failed to consume json\n");
|
||||
throw common_chat_msg_partial_exception("failed to parse tool call");
|
||||
}
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -543,6 +543,11 @@ struct common_params {
|
|||
|
||||
std::map<std::string, std::string> default_template_kwargs;
|
||||
|
||||
// MCP params
|
||||
std::string mcp_config = "";
|
||||
std::string mcp_servers = "";
|
||||
bool mcp_yolo = false;
|
||||
|
||||
// webui configs
|
||||
bool webui = true;
|
||||
std::string webui_config_json;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"mcpServers": {
|
||||
"serena": {
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"--from",
|
||||
"git+https://github.com/oraios/serena",
|
||||
"serena",
|
||||
"start-mcp-server"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
import sys
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
logging.basicConfig(filename='/devel/tools/llama.cpp/mcp_dummy.log', level=logging.DEBUG)
|
||||
|
||||
def main():
|
||||
logging.info("Starting MCP Dummy Server")
|
||||
while True:
|
||||
try:
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
break
|
||||
logging.info(f"Received: {line.strip()}")
|
||||
try:
|
||||
req = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if "method" in req:
|
||||
method = req["method"]
|
||||
req_id = req.get("id")
|
||||
|
||||
resp = {"jsonrpc": "2.0", "id": req_id}
|
||||
|
||||
if method == "initialize":
|
||||
resp["result"] = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"serverInfo": {"name": "dummy", "version": "1.0"}
|
||||
}
|
||||
elif method == "tools/list":
|
||||
resp["result"] = {
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a location",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
elif method == "tools/call":
|
||||
params = req.get("params", {})
|
||||
name = params.get("name")
|
||||
args = params.get("arguments", {})
|
||||
|
||||
logging.info(f"Tool call: {name} with {args}")
|
||||
|
||||
content = [{"type": "text", "text": f"Weather in {args.get('location')} is 25C"}]
|
||||
# For simplicity, return raw content or follow MCP spec?
|
||||
# MCP spec: result: { content: [ {type: "text", text: "..."} ] }
|
||||
# My mcp.hpp returns res["result"].
|
||||
# My cli.cpp dumps res.dump().
|
||||
# So passing full result object is fine.
|
||||
resp["result"] = {
|
||||
"content": content
|
||||
}
|
||||
else:
|
||||
# Ignore notifications or other methods
|
||||
if req_id is not None:
|
||||
resp["error"] = {"code": -32601, "message": "Method not found"}
|
||||
else:
|
||||
continue
|
||||
|
||||
logging.info(f"Sending: {json.dumps(resp)}")
|
||||
if req_id is not None:
|
||||
sys.stdout.write(json.dumps(resp) + "\n\n")
|
||||
sys.stdout.flush()
|
||||
except Exception as e:
|
||||
logging.error(f"Error: {e}")
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
llama_add_compile_flags()
|
||||
|
||||
add_executable(test-mcp-integration test-mcp-integration.cpp)
|
||||
target_link_libraries(test-mcp-integration PRIVATE common)
|
||||
target_include_directories(test-mcp-integration PRIVATE ../tools/cli)
|
||||
|
||||
function(llama_build source)
|
||||
set(TEST_SOURCES ${source} ${ARGN})
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
#include "../tools/cli/mcp.hpp"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
if (argc < 2) {
|
||||
fprintf(stderr, "Usage: %s <mcp_config_file>\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string config_file = argv[1];
|
||||
|
||||
printf("Testing MCP integration with config: %s\n", config_file.c_str());
|
||||
|
||||
mcp_context mcp;
|
||||
std::string tool_to_run;
|
||||
json tool_args = json::object();
|
||||
std::string servers_arg = "";
|
||||
|
||||
if (argc >= 3) {
|
||||
tool_to_run = argv[2];
|
||||
}
|
||||
if (argc >= 4) {
|
||||
try {
|
||||
tool_args = json::parse(argv[3]);
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "Error parsing tool arguments JSON: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (mcp.load_config(config_file, servers_arg)) {
|
||||
printf("MCP config loaded successfully.\n");
|
||||
|
||||
auto tools = mcp.get_tools();
|
||||
printf("Found %zu tools.\n", tools.size());
|
||||
|
||||
for (const auto & tool : tools) {
|
||||
printf("Tool: %s\n", tool.name.c_str());
|
||||
printf(" Description: %s\n", tool.description.c_str());
|
||||
}
|
||||
|
||||
if (!tool_to_run.empty()) {
|
||||
printf("Calling tool %s...\n", tool_to_run.c_str());
|
||||
mcp.set_yolo(true);
|
||||
json res = mcp.call_tool(tool_to_run, tool_args);
|
||||
printf("Result: %s\n", res.dump().c_str());
|
||||
} else if (!tools.empty()) {
|
||||
printf("No tool specified. Calling first tool '%s' with empty args as smoke test...\n", tools[0].name.c_str());
|
||||
json args = json::object();
|
||||
mcp.set_yolo(true);
|
||||
json res = mcp.call_tool(tools[0].name, args);
|
||||
printf("Result: %s\n", res.dump().c_str());
|
||||
}
|
||||
|
||||
} else {
|
||||
printf("Failed to load MCP config.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Allow some time for threads to shutdown if any
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,22 +1,25 @@
|
|||
#include "common.h"
|
||||
#include "arg.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "console.h"
|
||||
// #include "log.h"
|
||||
|
||||
#include "mcp.hpp"
|
||||
#include "server-context.h"
|
||||
#include "server-task.h"
|
||||
|
||||
#include <signal.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <fstream>
|
||||
#include <thread>
|
||||
#include <signal.h>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
#endif
|
||||
|
||||
const char * LLAMA_ASCII_LOGO = R"(
|
||||
|
|
@ -30,11 +33,12 @@ const char * LLAMA_ASCII_LOGO = R"(
|
|||
)";
|
||||
|
||||
static std::atomic<bool> g_is_interrupted = false;
|
||||
|
||||
static bool should_stop() {
|
||||
return g_is_interrupted.load();
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
|
||||
static void signal_handler(int) {
|
||||
if (g_is_interrupted.load()) {
|
||||
// second Ctrl+C - exit immediately
|
||||
|
|
@ -48,10 +52,12 @@ static void signal_handler(int) {
|
|||
#endif
|
||||
|
||||
struct cli_context {
|
||||
server_context ctx_server;
|
||||
json messages = json::array();
|
||||
server_context ctx_server;
|
||||
json messages = json::array();
|
||||
json tools = json::array();
|
||||
std::vector<raw_buffer> input_files;
|
||||
task_params defaults;
|
||||
task_params defaults;
|
||||
std::set<std::string> approved_requests;
|
||||
|
||||
// thread for showing "loading" animation
|
||||
std::atomic<bool> loading_show;
|
||||
|
|
@ -63,12 +69,12 @@ struct cli_context {
|
|||
defaults.n_predict = params.n_predict;
|
||||
defaults.antiprompt = params.antiprompt;
|
||||
|
||||
defaults.stream = true; // make sure we always use streaming mode
|
||||
defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
|
||||
defaults.stream = true; // make sure we always use streaming mode
|
||||
defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
|
||||
// defaults.return_progress = true; // TODO: show progress
|
||||
}
|
||||
|
||||
std::string generate_completion(result_timings & out_timings) {
|
||||
std::string generate_completion(result_timings & out_timings, server_task_result_ptr & out_result) {
|
||||
server_response_reader rd = ctx_server.get_response_reader();
|
||||
auto chat_params = format_chat();
|
||||
{
|
||||
|
|
@ -88,7 +94,7 @@ struct cli_context {
|
|||
task.params.chat_parser_params.parser.load(chat_params.parser);
|
||||
}
|
||||
|
||||
rd.post_task({std::move(task)});
|
||||
rd.post_task({ std::move(task) });
|
||||
}
|
||||
|
||||
// wait for first result
|
||||
|
|
@ -97,7 +103,7 @@ struct cli_context {
|
|||
|
||||
console::spinner::stop();
|
||||
std::string curr_content;
|
||||
bool is_thinking = false;
|
||||
bool is_thinking = false;
|
||||
|
||||
while (result) {
|
||||
if (should_stop()) {
|
||||
|
|
@ -110,6 +116,7 @@ struct cli_context {
|
|||
} else {
|
||||
console::error("Error: %s\n", err_data.dump().c_str());
|
||||
}
|
||||
out_result = std::move(result);
|
||||
return curr_content;
|
||||
}
|
||||
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
|
||||
|
|
@ -140,6 +147,7 @@ struct cli_context {
|
|||
auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
|
||||
if (res_final) {
|
||||
out_timings = std::move(res_final->timings);
|
||||
out_result = std::move(result);
|
||||
break;
|
||||
}
|
||||
result = rd.next(should_stop);
|
||||
|
|
@ -166,14 +174,14 @@ struct cli_context {
|
|||
}
|
||||
}
|
||||
|
||||
common_chat_params format_chat() {
|
||||
common_chat_params format_chat() const {
|
||||
auto meta = ctx_server.get_meta();
|
||||
auto & chat_params = meta.chat_params;
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = {}; // TODO
|
||||
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||
inputs.tool_choice = tools.empty() ? COMMON_CHAT_TOOL_CHOICE_NONE : COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
inputs.json_schema = ""; // TODO
|
||||
inputs.grammar = ""; // TODO
|
||||
inputs.use_jinja = chat_params.use_jinja;
|
||||
|
|
@ -189,7 +197,7 @@ struct cli_context {
|
|||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs
|
||||
params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CLI)) {
|
||||
return 1;
|
||||
|
|
@ -215,21 +223,43 @@ int main(int argc, char ** argv) {
|
|||
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigemptyset(&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
sigaction(SIGTERM, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
#elif defined(_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
console::log("\nLoading model... "); // followed by loading animation
|
||||
// Initialize MCP
|
||||
mcp_context mcp;
|
||||
if (!params.mcp_config.empty()) {
|
||||
if (mcp.load_config(params.mcp_config, params.mcp_servers)) {
|
||||
mcp.set_yolo(params.mcp_yolo);
|
||||
auto mcp_tools = mcp.get_tools();
|
||||
for (const auto & t : mcp_tools) {
|
||||
json tool = {
|
||||
{ "type", "function" },
|
||||
{ "function",
|
||||
{ { "name", t.name }, { "description", t.description }, { "parameters", t.input_schema } } }
|
||||
};
|
||||
ctx_cli.tools.push_back(tool);
|
||||
}
|
||||
if (!ctx_cli.tools.empty()) {
|
||||
console::log("Enabled %d MCP tools\n", (int) ctx_cli.tools.size());
|
||||
}
|
||||
} else {
|
||||
console::error("Failed to load MCP config: %s\n", params.mcp_config.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
console::log("\nLoading model... "); // followed by loading animation
|
||||
console::spinner::start();
|
||||
if (!ctx_cli.ctx_server.load_model(params)) {
|
||||
console::spinner::stop();
|
||||
|
|
@ -240,11 +270,9 @@ int main(int argc, char ** argv) {
|
|||
console::spinner::stop();
|
||||
console::log("\n");
|
||||
|
||||
std::thread inference_thread([&ctx_cli]() {
|
||||
ctx_cli.ctx_server.start_loop();
|
||||
});
|
||||
std::thread inference_thread([&ctx_cli]() { ctx_cli.ctx_server.start_loop(); });
|
||||
|
||||
auto inf = ctx_cli.ctx_server.get_meta();
|
||||
auto inf = ctx_cli.ctx_server.get_meta();
|
||||
std::string modalities = "text";
|
||||
if (inf.has_inp_image) {
|
||||
modalities += ", vision";
|
||||
|
|
@ -255,8 +283,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (!params.system_prompt.empty()) {
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", params.system_prompt}
|
||||
{ "role", "system" },
|
||||
{ "content", params.system_prompt }
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -290,7 +318,7 @@ int main(int argc, char ** argv) {
|
|||
if (params.prompt.empty()) {
|
||||
console::log("\n> ");
|
||||
std::string line;
|
||||
bool another_line = true;
|
||||
bool another_line = true;
|
||||
do {
|
||||
another_line = console::readline(line, params.multiline_input);
|
||||
buffer += line;
|
||||
|
|
@ -312,7 +340,7 @@ int main(int argc, char ** argv) {
|
|||
} else {
|
||||
console::log("\n> %s\n", buffer.c_str());
|
||||
}
|
||||
params.prompt.clear(); // only use it once
|
||||
params.prompt.clear(); // only use it once
|
||||
}
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
console::log("\n");
|
||||
|
|
@ -323,7 +351,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// remove trailing newline
|
||||
if (!buffer.empty() &&buffer.back() == '\n') {
|
||||
if (!buffer.empty() && buffer.back() == '\n') {
|
||||
buffer.pop_back();
|
||||
}
|
||||
|
||||
|
|
@ -351,11 +379,10 @@ int main(int argc, char ** argv) {
|
|||
ctx_cli.input_files.clear();
|
||||
console::log("Chat history cleared.\n");
|
||||
continue;
|
||||
} else if (
|
||||
(string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
|
||||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
|
||||
} else if ((string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
|
||||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
|
||||
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
|
||||
std::string fname = string_strip(buffer.substr(7));
|
||||
std::string fname = string_strip(buffer.substr(7));
|
||||
std::string marker = ctx_cli.load_input_file(fname, true);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
|
|
@ -365,7 +392,7 @@ int main(int argc, char ** argv) {
|
|||
console::log("Loaded media from '%s'\n", fname.c_str());
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/read ")) {
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
std::string marker = ctx_cli.load_input_file(fname, false);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
|
|
@ -382,23 +409,100 @@ int main(int argc, char ** argv) {
|
|||
// generate response
|
||||
if (add_user_msg) {
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", cur_msg}
|
||||
{ "role", "user" },
|
||||
{ "content", cur_msg }
|
||||
});
|
||||
cur_msg.clear();
|
||||
}
|
||||
result_timings timings;
|
||||
std::string assistant_content = ctx_cli.generate_completion(timings);
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "assistant"},
|
||||
{"content", assistant_content}
|
||||
});
|
||||
while (true) {
|
||||
server_task_result_ptr result;
|
||||
std::string assistant_content = ctx_cli.generate_completion(timings, result);
|
||||
auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
|
||||
|
||||
if (res_final && !res_final->oaicompat_msg.tool_calls.empty()) {
|
||||
ctx_cli.messages.push_back(res_final->oaicompat_msg.to_json_oaicompat());
|
||||
|
||||
for (const auto & tc : res_final->oaicompat_msg.tool_calls) {
|
||||
json args;
|
||||
try {
|
||||
if (tc.arguments.empty()) {
|
||||
args = json::object();
|
||||
} else {
|
||||
args = json::parse(tc.arguments);
|
||||
}
|
||||
} catch (...) {
|
||||
json err_msg = {
|
||||
{ "role", "tool" },
|
||||
{ "content", "Error parsing arguments" },
|
||||
{ "tool_call_id", tc.id },
|
||||
{ "name", tc.name }
|
||||
};
|
||||
ctx_cli.messages.push_back(err_msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string request_id = tc.name + tc.arguments;
|
||||
if (!mcp.get_yolo() && ctx_cli.approved_requests.find(request_id) == ctx_cli.approved_requests.end()) {
|
||||
// Prompt user
|
||||
fprintf(stdout, "\n\033[1;33mTool call: %s\033[0m\n", tc.name.c_str());
|
||||
fprintf(stdout, "Arguments: %s\n", args.dump(2).c_str());
|
||||
fprintf(stdout, "Approve? [y]es, [n]o, [A]lways allow feature: ");
|
||||
fflush(stdout);
|
||||
|
||||
char c = ' ';
|
||||
std::string line;
|
||||
// We are in main loop which might have console settings.
|
||||
// Use console::readline or similar?
|
||||
// But for single char input, standard cin/getline might interfere with console lib state if it's separate.
|
||||
// Given `console::readline` is used above, let's use standard input as fallback or just try cin.
|
||||
|
||||
// Simple blocking read
|
||||
std::cin >> c;
|
||||
std::getline(std::cin, line); // consume rest
|
||||
|
||||
if (c == 'y' || c == 'Y') {
|
||||
// approved once
|
||||
} else if (c == 'A') {
|
||||
ctx_cli.approved_requests.insert(request_id);
|
||||
} else {
|
||||
json err_msg = {
|
||||
{ "role", "tool" },
|
||||
{ "content", "User denied tool execution" },
|
||||
{ "tool_call_id", tc.id },
|
||||
{ "name", tc.name }
|
||||
};
|
||||
ctx_cli.messages.push_back(err_msg);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
json res = mcp.call_tool(tc.name, args);
|
||||
json tool_msg = {
|
||||
{ "role", "tool" },
|
||||
{ "content", res.dump() },
|
||||
{ "tool_call_id", tc.id },
|
||||
{ "name", tc.name }
|
||||
};
|
||||
ctx_cli.messages.push_back(tool_msg);
|
||||
}
|
||||
// continue loop to generate with tool results
|
||||
} else {
|
||||
json assistant_msg = {
|
||||
{ "role", "assistant" },
|
||||
{ "content", assistant_content }
|
||||
};
|
||||
ctx_cli.messages.push_back(assistant_msg);
|
||||
break;
|
||||
}
|
||||
}
|
||||
console::log("\n");
|
||||
|
||||
if (params.show_timings) {
|
||||
console::set_display(DISPLAY_TYPE_INFO);
|
||||
console::log("\n");
|
||||
console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
|
||||
console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second,
|
||||
timings.predicted_per_second);
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,499 @@
|
|||
#pragma once
|
||||
|
||||
#include "../../vendor/sheredom/subprocess.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <fstream>
|
||||
#include <future>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
struct mcp_tool {
|
||||
std::string name;
|
||||
std::string description;
|
||||
json input_schema;
|
||||
std::string server_name;
|
||||
};
|
||||
|
||||
class mcp_server {
|
||||
std::string name;
|
||||
std::string command;
|
||||
std::vector<std::string> args;
|
||||
std::map<std::string, std::string> env;
|
||||
|
||||
struct subprocess_s process = {};
|
||||
std::atomic<bool> running{ false };
|
||||
std::thread read_thread;
|
||||
std::thread err_thread;
|
||||
std::atomic<bool> stop_read{ false };
|
||||
|
||||
std::mutex mutex;
|
||||
int next_id = 1;
|
||||
std::map<int, std::promise<json>> pending_requests;
|
||||
|
||||
// Buffer for reading
|
||||
std::string read_buffer;
|
||||
|
||||
public:
|
||||
mcp_server(const std::string & name,
|
||||
const std::string & cmd,
|
||||
const std::vector<std::string> & args,
|
||||
const std::map<std::string, std::string> & env) :
|
||||
name(name),
|
||||
command(cmd),
|
||||
args(args),
|
||||
env(env) {}
|
||||
|
||||
mcp_server(const mcp_server &) = delete;
|
||||
mcp_server & operator=(const mcp_server &) = delete;
|
||||
|
||||
mcp_server(mcp_server && other) noexcept :
|
||||
name(std::move(other.name)),
|
||||
command(std::move(other.command)),
|
||||
args(std::move(other.args)),
|
||||
env(std::move(other.env)),
|
||||
process(other.process),
|
||||
running(other.running.load()),
|
||||
read_thread(std::move(other.read_thread)),
|
||||
err_thread(std::move(other.err_thread)),
|
||||
stop_read(other.stop_read.load()),
|
||||
next_id(other.next_id),
|
||||
pending_requests(std::move(other.pending_requests)),
|
||||
read_buffer(std::move(other.read_buffer)) {
|
||||
// Zero out the source process to prevent double-free
|
||||
other.process = {};
|
||||
other.running = false;
|
||||
}
|
||||
|
||||
mcp_server & operator=(mcp_server && other) noexcept {
|
||||
if (this != &other) {
|
||||
stop(); // Clean up current resources
|
||||
|
||||
name = std::move(other.name);
|
||||
command = std::move(other.command);
|
||||
args = std::move(other.args);
|
||||
env = std::move(other.env);
|
||||
process = other.process;
|
||||
running = other.running.load();
|
||||
read_thread = std::move(other.read_thread);
|
||||
err_thread = std::move(other.err_thread);
|
||||
stop_read = other.stop_read.load();
|
||||
next_id = other.next_id;
|
||||
pending_requests = std::move(other.pending_requests);
|
||||
read_buffer = std::move(other.read_buffer);
|
||||
|
||||
// Zero out source
|
||||
other.process = {};
|
||||
other.running = false;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
~mcp_server() { stop(); }
|
||||
|
||||
bool start() {
|
||||
std::vector<const char *> cmd_args;
|
||||
cmd_args.push_back(command.c_str());
|
||||
for (const auto & arg : args) {
|
||||
cmd_args.push_back(arg.c_str());
|
||||
}
|
||||
cmd_args.push_back(nullptr);
|
||||
|
||||
std::vector<const char *> env_vars;
|
||||
std::vector<std::string> env_strings; // keep strings alive
|
||||
if (!env.empty()) {
|
||||
for (const auto & kv : env) {
|
||||
env_strings.push_back(kv.first + "=" + kv.second);
|
||||
}
|
||||
for (const auto & s : env_strings) {
|
||||
env_vars.push_back(s.c_str());
|
||||
}
|
||||
env_vars.push_back(nullptr);
|
||||
}
|
||||
|
||||
// Blocking I/O is simpler with threads
|
||||
int options = subprocess_option_search_user_path;
|
||||
int result;
|
||||
if (env.empty()) {
|
||||
options |= subprocess_option_inherit_environment;
|
||||
result = subprocess_create(cmd_args.data(), options, &process);
|
||||
} else {
|
||||
result = subprocess_create_ex(cmd_args.data(), options, env_vars.data(), &process);
|
||||
}
|
||||
|
||||
if (result != 0) {
|
||||
LOG_ERR("Failed to start MCP server %s: error %d (%s)\n", name.c_str(), errno, strerror(errno));
|
||||
return false;
|
||||
}
|
||||
|
||||
running = true;
|
||||
read_thread = std::thread(&mcp_server::read_loop, this);
|
||||
err_thread = std::thread(&mcp_server::err_loop, this);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void stop() {
|
||||
if (!running) return;
|
||||
LOG_INF("Stopping MCP server %s...\n", name.c_str());
|
||||
stop_read = true;
|
||||
|
||||
// 1. Close stdin to signal EOF
|
||||
if (process.stdin_file) {
|
||||
fclose(process.stdin_file);
|
||||
process.stdin_file = nullptr;
|
||||
}
|
||||
|
||||
// 2. Wait for 10 seconds for normal termination
|
||||
bool terminated = false;
|
||||
for (int i = 0; i < 100; ++i) { // 100 * 100ms = 10s
|
||||
if (subprocess_alive(&process) == 0) {
|
||||
terminated = true;
|
||||
break;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
|
||||
// 3. Terminate if still running
|
||||
if (!terminated) {
|
||||
LOG_WRN("MCP server %s did not exit gracefully, terminating...\n", name.c_str());
|
||||
subprocess_terminate(&process);
|
||||
}
|
||||
|
||||
// 4. Join threads
|
||||
if (read_thread.joinable()) {
|
||||
read_thread.join();
|
||||
}
|
||||
if (err_thread.joinable()) {
|
||||
err_thread.join();
|
||||
}
|
||||
|
||||
// 5. Cleanup
|
||||
if (running) {
|
||||
subprocess_destroy(&process);
|
||||
running = false;
|
||||
}
|
||||
LOG_INF("MCP server %s stopped.\n", name.c_str());
|
||||
}
|
||||
|
||||
json send_request(const std::string & method, const json & params = json::object()) {
|
||||
int id;
|
||||
std::future<json> future;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
id = next_id++;
|
||||
future = pending_requests[id].get_future();
|
||||
}
|
||||
|
||||
json req = {
|
||||
{ "jsonrpc", "2.0" },
|
||||
{ "id", id },
|
||||
{ "method", method },
|
||||
{ "params", params }
|
||||
};
|
||||
|
||||
std::string req_str = req.dump() + "\n";
|
||||
FILE * stdin_file = subprocess_stdin(&process);
|
||||
if (stdin_file) {
|
||||
fwrite(req_str.c_str(), 1, req_str.size(), stdin_file);
|
||||
fflush(stdin_file);
|
||||
}
|
||||
|
||||
// Wait for response with timeout
|
||||
if (future.wait_for(std::chrono::seconds(10)) == std::future_status::timeout) {
|
||||
LOG_ERR("Timeout waiting for response from %s (method: %s)\n", name.c_str(), method.c_str());
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
pending_requests.erase(id);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return future.get();
|
||||
}
|
||||
|
||||
void send_notification(const std::string & method, const json & params = json::object()) {
|
||||
json req = {
|
||||
{ "jsonrpc", "2.0" },
|
||||
{ "method", method },
|
||||
{ "params", params }
|
||||
};
|
||||
|
||||
std::string req_str = req.dump() + "\n";
|
||||
FILE * stdin_file = subprocess_stdin(&process);
|
||||
if (stdin_file) {
|
||||
fwrite(req_str.c_str(), 1, req_str.size(), stdin_file);
|
||||
fflush(stdin_file);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize handshake
|
||||
bool initialize() {
|
||||
// Send initialize
|
||||
json init_params = {
|
||||
{ "protocolVersion", "2024-11-05" },
|
||||
{ "capabilities", {
|
||||
{ "roots", {
|
||||
{ "listChanged", false }
|
||||
} },
|
||||
{ "sampling", json::object() }
|
||||
} },
|
||||
{ "clientInfo", {
|
||||
{ "name", "llama.cpp-cli" },
|
||||
{ "version", "0.1.0" } // TODO: use real version
|
||||
} }
|
||||
};
|
||||
|
||||
json res = send_request("initialize", init_params);
|
||||
if (res.is_null() || res.contains("error")) {
|
||||
LOG_ERR("Failed to initialize MCP server %s\n", name.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// Send initialized notification
|
||||
send_notification("notifications/initialized");
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<mcp_tool> list_tools() {
|
||||
std::vector<mcp_tool> tools;
|
||||
json res = send_request("tools/list");
|
||||
if (res.is_null() || res.contains("error")) {
|
||||
LOG_ERR("Failed to list tools from %s\n", name.c_str());
|
||||
return tools;
|
||||
}
|
||||
|
||||
if (res.contains("result") && res["result"].contains("tools")) {
|
||||
for (const auto & t : res["result"]["tools"]) {
|
||||
mcp_tool tool;
|
||||
tool.name = t["name"].get<std::string>();
|
||||
if (t.contains("description")) {
|
||||
tool.description = t["description"].get<std::string>();
|
||||
}
|
||||
if (t.contains("inputSchema")) {
|
||||
tool.input_schema = t["inputSchema"];
|
||||
}
|
||||
tool.server_name = name;
|
||||
tools.push_back(tool);
|
||||
}
|
||||
}
|
||||
return tools;
|
||||
}
|
||||
|
||||
json call_tool(const std::string & tool_name, const json & args) {
|
||||
json params = {
|
||||
{ "name", tool_name },
|
||||
{ "arguments", args }
|
||||
};
|
||||
json res = send_request("tools/call", params);
|
||||
if (res.is_null() || res.contains("error")) {
|
||||
return {
|
||||
{ "error", res.contains("error") ? res["error"] : "Unknown error" }
|
||||
};
|
||||
}
|
||||
if (res.contains("result")) {
|
||||
return res["result"];
|
||||
}
|
||||
return {
|
||||
{ "error", "No result returned" }
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
void read_loop() {
|
||||
char buffer[4096];
|
||||
while (!stop_read && running) {
|
||||
unsigned bytes_read = subprocess_read_stdout(&process, buffer, sizeof(buffer));
|
||||
|
||||
if (bytes_read == 0) {
|
||||
// If blocking read returns 0, it means EOF (process exited or pipe closed).
|
||||
// We should NOT call subprocess_alive() here because it calls subprocess_join()
|
||||
// which modifies the process struct (closes stdin) and causes race conditions/double-free
|
||||
// when stop() is called concurrently.
|
||||
// Just break the loop. The process is likely dead or dying.
|
||||
if (!stop_read) {
|
||||
LOG_ERR("MCP process died (stdout closed)\n");
|
||||
}
|
||||
running = false;
|
||||
break;
|
||||
}
|
||||
|
||||
read_buffer.append(buffer, bytes_read);
|
||||
size_t pos;
|
||||
while ((pos = read_buffer.find('\n')) != std::string::npos) {
|
||||
std::string line = read_buffer.substr(0, pos);
|
||||
read_buffer.erase(0, pos + 1);
|
||||
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
json msg = json::parse(line);
|
||||
if (msg.contains("id")) {
|
||||
// Response
|
||||
int id = msg["id"].get<int>();
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (pending_requests.count(id)) {
|
||||
pending_requests[id].set_value(msg);
|
||||
pending_requests.erase(id);
|
||||
} else {
|
||||
// ID not found
|
||||
}
|
||||
} else {
|
||||
// Notification or request from server -> ignore for now or log
|
||||
// MCP servers might send notifications (e.g. logging)
|
||||
LOG_ERR("MCP Notification from %s: %s\n", name.c_str(), line.c_str());
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
// Not a full JSON yet? Or invalid?
|
||||
// If it was a line, it should be valid JSON-RPC
|
||||
LOG_WRN("Failed to parse JSON from %s: %s\n", name.c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void err_loop() {
|
||||
char buffer[1024];
|
||||
while (!stop_read && running) {
|
||||
unsigned bytes_read = subprocess_read_stderr(&process, buffer, sizeof(buffer));
|
||||
if (bytes_read > 0) {
|
||||
if (stop_read) break; // Don't log stderr during shutdown
|
||||
std::string err_str(buffer, bytes_read);
|
||||
// Filter out empty/whitespace-only stderr if desired, or just keep it.
|
||||
// User said "extra logging that passes the stderr here is unnecessary" referring to shutdown.
|
||||
LOG_WRN("[%s stderr] %s", name.c_str(), err_str.c_str());
|
||||
} else {
|
||||
// EOF
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class mcp_context {
|
||||
std::map<std::string, std::shared_ptr<mcp_server>> servers;
|
||||
std::vector<mcp_tool> tools;
|
||||
bool yolo = false;
|
||||
|
||||
public:
|
||||
void set_yolo(bool y) { yolo = y; }
|
||||
|
||||
bool load_config(const std::string & config_path, const std::string & enabled_servers_str) {
|
||||
std::ifstream f(config_path);
|
||||
if (!f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
json config;
|
||||
try {
|
||||
f >> config;
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> enabled_list;
|
||||
std::stringstream ss(enabled_servers_str);
|
||||
std::string item;
|
||||
while (std::getline(ss, item, ',')) {
|
||||
if (!item.empty()) {
|
||||
enabled_list.push_back(item);
|
||||
}
|
||||
}
|
||||
|
||||
if (config.contains("mcpServers")) {
|
||||
std::string server_list;
|
||||
for (auto & [key, val] : config["mcpServers"].items()) {
|
||||
if (!server_list.empty()) server_list += ", ";
|
||||
server_list += key;
|
||||
}
|
||||
LOG_INF("MCP configuration found with servers: %s\n", server_list.c_str());
|
||||
|
||||
for (auto & [key, val] : config["mcpServers"].items()) {
|
||||
// If enabled_servers_str is empty, enable all? User said "possibility to pick which MCP servers to enable".
|
||||
// If the user specifies explicit list, we filter. If not, maybe we shouldn't enable any or enable all?
|
||||
// The prompt says "possibility to pick".
|
||||
// Let's assume if list provided, use it. If not, enable all? Or none?
|
||||
// User provided example implies explicit enabling might be desired.
|
||||
// Let's assume if `enabled_servers_str` is not empty, we filter.
|
||||
|
||||
bool enabled = true;
|
||||
if (!enabled_list.empty()) {
|
||||
bool found = false;
|
||||
for (const auto & s : enabled_list) {
|
||||
if (s == key) {
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
enabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (enabled) {
|
||||
std::string cmd = val["command"].get<std::string>();
|
||||
std::vector<std::string> args = val.value("args", std::vector<std::string>{});
|
||||
std::map<std::string, std::string> env;
|
||||
if (val.contains("env")) {
|
||||
for (auto & [ek, ev] : val["env"].items()) {
|
||||
env[ek] = ev.get<std::string>();
|
||||
}
|
||||
}
|
||||
|
||||
auto server = std::make_shared<mcp_server>(key, cmd, args, env);
|
||||
LOG_INF("Trying to start MCP server: %s...\n", key.c_str());
|
||||
if (server->start()) {
|
||||
if (server->initialize()) {
|
||||
servers[key] = server;
|
||||
LOG_INF("MCP Server '%s' started and initialized.\n", key.c_str());
|
||||
|
||||
auto server_tools = server->list_tools();
|
||||
tools.insert(tools.end(), server_tools.begin(), server_tools.end());
|
||||
} else {
|
||||
LOG_ERR("MCP Server '%s' failed to initialize.\n", key.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<mcp_tool> get_tools() const { return tools; }
|
||||
|
||||
bool get_yolo() const { return yolo; }
|
||||
|
||||
json call_tool(const std::string & tool_name, const json & args) {
|
||||
// Find which server has this tool
|
||||
std::string server_name;
|
||||
for (const auto & t : tools) {
|
||||
if (t.name == tool_name) {
|
||||
server_name = t.server_name;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (server_name.empty()) {
|
||||
return {
|
||||
{ "error", "Tool not found" }
|
||||
};
|
||||
}
|
||||
|
||||
if (servers.count(server_name)) {
|
||||
return servers[server_name]->call_tool(tool_name, args);
|
||||
}
|
||||
return {
|
||||
{ "error", "Server not found" }
|
||||
};
|
||||
}
|
||||
};
|
||||
Loading…
Reference in New Issue