From b0a1b31477d8c68ea7d23bce4b3748de27e88441 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 21 Mar 2026 01:11:57 +0100 Subject: [PATCH] change arg to --tools all --- common/arg.cpp | 11 +-- common/common.h | 2 +- tools/server/server-tools.cpp | 129 ++++++++++++++++------------------ tools/server/server-tools.h | 23 +++++- tools/server/server.cpp | 8 ++- 5 files changed, 94 insertions(+), 79 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index ba2afc77a7..98070d43e2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2849,11 +2849,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY")); add_opt(common_arg( - {"--tools"}, - {"--no-tools"}, - string_format("experimental: whether to enable tools for AI agents - do not enable in untrusted environments (default: %s)", params.server_tools ? "enabled" : "disabled"), - [](common_params & params, bool value) { - params.server_tools = value; + {"--tools"}, "TOOL1,TOOL2,...", + "experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)\n" + "specify \"all\" to enable all tools\n" + "available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff", + [](common_params & params, const std::string & value) { + params.server_tools = parse_csv_row(value); } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS")); add_opt(common_arg( diff --git a/common/common.h b/common/common.h index 9fd1b4dbbe..fde5ba996e 100644 --- a/common/common.h +++ b/common/common.h @@ -614,7 +614,7 @@ struct common_params { bool endpoint_metrics = false; // enable built-in tools - bool server_tools = false; + std::vector server_tools; // router server configs std::string models_dir = ""; // directory containing models for the router server diff --git a/tools/server/server-tools.cpp b/tools/server/server-tools.cpp index d7c116abd8..5e89a5668b 100644 --- a/tools/server/server-tools.cpp +++ b/tools/server/server-tools.cpp @@ -29,8 +29,8 @@ static std::vector to_cstr_vec(const std::vector & v) { struct run_proc_result { std::string output; - int exit_code = -1; - bool timed_out = false; + int exit_code = -1; + bool timed_out = false; }; static run_proc_result run_process( @@ -133,31 +133,17 @@ static bool glob_match(const std::string & pattern, const std::string & str) { return glob_match(pattern.c_str(), str.c_str()); } -// -// base struct -// - -struct server_tool { - std::string name; - std::string display_name; - bool permission_write = false; - - virtual ~server_tool() = default; - virtual json get_definition() = 0; - virtual json invoke(json params) = 0; - - json to_json() { - return { - {"display_name", display_name}, - {"tool", name}, - {"type", "builtin"}, - {"permissions", json{ - {"write", permission_write} - }}, - {"definition", get_definition()}, - }; - } -}; +json server_tool::to_json() { + return { + {"display_name", display_name}, + {"tool", name}, + {"type", "builtin"}, + {"permissions", json{ + {"write", permission_write} + }}, + {"definition", get_definition()}, + }; +} // // read_file: read a file with optional line range and line-number prefix @@ -533,7 +519,7 @@ struct server_tool_edit_file : server_tool { {"type", "function"}, {"function", { {"name", name}, - {"description", + {"description", "Edit a file by applying a list of line-based changes. " "Each change targets a 1-based inclusive line range and has a mode: " "\"replace\" (replace lines with content), " @@ -755,17 +741,56 @@ static std::vector> build_tools() { return tools; } -static json server_tools_list() { - auto tools = build_tools(); - json result = json::array(); - for (const auto & t : tools) { - result.push_back(t->to_json()); +void server_tools::setup(const std::vector & enabled_tools) { + if (!enabled_tools.empty()) { + std::unordered_set enabled_set(enabled_tools.begin(), enabled_tools.end()); + auto all_tools = build_tools(); + + tools.clear(); + for (auto & t : all_tools) { + if (enabled_set.count(t->name) > 0 || enabled_set.count("all") > 0) { + tools.push_back(std::move(t)); + } + } } - return result; + + handle_get = [this](const server_http_req &) -> server_http_res_ptr { + auto res = std::make_unique(); + try { + json result = json::array(); + for (const auto & t : tools) { + result.push_back(t->to_json()); + } + res->data = safe_json_to_str(result); + } catch (const std::exception & e) { + SRV_ERR("got exception: %s\n", e.what()); + res->status = 500; + res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER)); + } + return res; + }; + + handle_post = [this](const server_http_req & req) -> server_http_res_ptr { + auto res = std::make_unique(); + try { + json body = json::parse(req.body); + std::string tool_name = body.at("tool").get(); + json params = body.value("params", json::object()); + json result = invoke(tool_name, params); + res->data = safe_json_to_str(result); + } catch (const json::exception & e) { + res->status = 400; + res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + } catch (const std::exception & e) { + SRV_ERR("got exception: %s\n", e.what()); + res->status = 500; + res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER)); + } + return res; + }; } -static json server_tool_call(const std::string & name, const json & params) { - auto tools = build_tools(); +json server_tools::invoke(const std::string & name, const json & params) { for (auto & t : tools) { if (t->name == name) { return t->invoke(params); @@ -773,35 +798,3 @@ static json server_tool_call(const std::string & name, const json & params) { } return {{"error", "unknown tool: " + name}}; } - -server_http_context::handler_t server_tools_get = [](const server_http_req &) -> server_http_res_ptr { - auto res = std::make_unique(); - try { - json tools = server_tools_list(); - res->data = safe_json_to_str(tools); - } catch (const std::exception & e) { - SRV_ERR("got exception: %s\n", e.what()); - res->status = 500; - res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER)); - } - return res; -}; - -server_http_context::handler_t server_tools_post = [](const server_http_req & req) -> server_http_res_ptr { - auto res = std::make_unique(); - try { - json body = json::parse(req.body); - std::string tool_name = body.at("tool").get(); - json params = body.value("params", json::object()); - json result = server_tool_call(tool_name, params); - res->data = safe_json_to_str(result); - } catch (const json::exception & e) { - res->status = 400; - res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - } catch (const std::exception & e) { - SRV_ERR("got exception: %s\n", e.what()); - res->status = 500; - res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER)); - } - return res; -}; diff --git a/tools/server/server-tools.h b/tools/server/server-tools.h index 141235d799..444ef5f809 100644 --- a/tools/server/server-tools.h +++ b/tools/server/server-tools.h @@ -3,5 +3,24 @@ #include "server-common.h" #include "server-http.h" -extern server_http_context::handler_t server_tools_get; -extern server_http_context::handler_t server_tools_post; \ No newline at end of file +struct server_tool { + std::string name; + std::string display_name; + bool permission_write = false; + + virtual ~server_tool() = default; + virtual json get_definition() = 0; + virtual json invoke(json params) = 0; + + json to_json(); +}; + +struct server_tools { + std::vector> tools; + + void setup(const std::vector & enabled_tools); + json invoke(const std::string & name, const json & params); + + server_http_context::handler_t handle_get; + server_http_context::handler_t handle_post; +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index d1db1ed1ea..2a0cf1bcf9 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -125,6 +125,7 @@ int main(int argc, char ** argv) { // register API routes server_routes routes(params, ctx_server); + server_tools tools; bool is_router_server = params.model.path.empty(); std::optional models_routes{}; @@ -213,13 +214,14 @@ int main(int argc, char ** argv) { ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post)); } // EXPERIMENTAL built-in tools - if (params.server_tools) { + if (!params.server_tools.empty()) { + tools.setup(params.server_tools); SRV_WRN("%s", "-----------------\n"); SRV_WRN("%s", "Built-in tools are enabled, do not expose server to untrusted environments\n"); SRV_WRN("%s", "This feature is EXPERIMENTAL and may be changed in the future\n"); SRV_WRN("%s", "-----------------\n"); - ctx_http.get ("/tools", ex_wrapper(server_tools_get)); - ctx_http.post("/tools", ex_wrapper(server_tools_post)); + ctx_http.get ("/tools", ex_wrapper(tools.handle_get)); + ctx_http.post("/tools", ex_wrapper(tools.handle_post)); } //