change arg to --tools all

This commit is contained in:
Xuan Son Nguyen 2026-03-21 01:11:57 +01:00
parent 3c5dac1ddf
commit b0a1b31477
5 changed files with 94 additions and 79 deletions

View File

@ -2849,11 +2849,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY"));
add_opt(common_arg(
{"--tools"},
{"--no-tools"},
string_format("experimental: whether to enable tools for AI agents - do not enable in untrusted environments (default: %s)", params.server_tools ? "enabled" : "disabled"),
[](common_params & params, bool value) {
params.server_tools = value;
{"--tools"}, "TOOL1,TOOL2,...",
"experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)\n"
"specify \"all\" to enable all tools\n"
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff",
[](common_params & params, const std::string & value) {
params.server_tools = parse_csv_row(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
add_opt(common_arg(

View File

@ -614,7 +614,7 @@ struct common_params {
bool endpoint_metrics = false;
// enable built-in tools
bool server_tools = false;
std::vector<std::string> server_tools;
// router server configs
std::string models_dir = ""; // directory containing models for the router server

View File

@ -29,8 +29,8 @@ static std::vector<char *> to_cstr_vec(const std::vector<std::string> & v) {
struct run_proc_result {
std::string output;
int exit_code = -1;
bool timed_out = false;
int exit_code = -1;
bool timed_out = false;
};
static run_proc_result run_process(
@ -133,31 +133,17 @@ static bool glob_match(const std::string & pattern, const std::string & str) {
return glob_match(pattern.c_str(), str.c_str());
}
//
// base struct
//
struct server_tool {
std::string name;
std::string display_name;
bool permission_write = false;
virtual ~server_tool() = default;
virtual json get_definition() = 0;
virtual json invoke(json params) = 0;
json to_json() {
return {
{"display_name", display_name},
{"tool", name},
{"type", "builtin"},
{"permissions", json{
{"write", permission_write}
}},
{"definition", get_definition()},
};
}
};
json server_tool::to_json() {
return {
{"display_name", display_name},
{"tool", name},
{"type", "builtin"},
{"permissions", json{
{"write", permission_write}
}},
{"definition", get_definition()},
};
}
//
// read_file: read a file with optional line range and line-number prefix
@ -533,7 +519,7 @@ struct server_tool_edit_file : server_tool {
{"type", "function"},
{"function", {
{"name", name},
{"description",
{"description",
"Edit a file by applying a list of line-based changes. "
"Each change targets a 1-based inclusive line range and has a mode: "
"\"replace\" (replace lines with content), "
@ -755,17 +741,56 @@ static std::vector<std::unique_ptr<server_tool>> build_tools() {
return tools;
}
static json server_tools_list() {
auto tools = build_tools();
json result = json::array();
for (const auto & t : tools) {
result.push_back(t->to_json());
void server_tools::setup(const std::vector<std::string> & enabled_tools) {
if (!enabled_tools.empty()) {
std::unordered_set<std::string> enabled_set(enabled_tools.begin(), enabled_tools.end());
auto all_tools = build_tools();
tools.clear();
for (auto & t : all_tools) {
if (enabled_set.count(t->name) > 0 || enabled_set.count("all") > 0) {
tools.push_back(std::move(t));
}
}
}
return result;
handle_get = [this](const server_http_req &) -> server_http_res_ptr {
auto res = std::make_unique<server_http_res>();
try {
json result = json::array();
for (const auto & t : tools) {
result.push_back(t->to_json());
}
res->data = safe_json_to_str(result);
} catch (const std::exception & e) {
SRV_ERR("got exception: %s\n", e.what());
res->status = 500;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
}
return res;
};
handle_post = [this](const server_http_req & req) -> server_http_res_ptr {
auto res = std::make_unique<server_http_res>();
try {
json body = json::parse(req.body);
std::string tool_name = body.at("tool").get<std::string>();
json params = body.value("params", json::object());
json result = invoke(tool_name, params);
res->data = safe_json_to_str(result);
} catch (const json::exception & e) {
res->status = 400;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
} catch (const std::exception & e) {
SRV_ERR("got exception: %s\n", e.what());
res->status = 500;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
}
return res;
};
}
static json server_tool_call(const std::string & name, const json & params) {
auto tools = build_tools();
json server_tools::invoke(const std::string & name, const json & params) {
for (auto & t : tools) {
if (t->name == name) {
return t->invoke(params);
@ -773,35 +798,3 @@ static json server_tool_call(const std::string & name, const json & params) {
}
return {{"error", "unknown tool: " + name}};
}
server_http_context::handler_t server_tools_get = [](const server_http_req &) -> server_http_res_ptr {
auto res = std::make_unique<server_http_res>();
try {
json tools = server_tools_list();
res->data = safe_json_to_str(tools);
} catch (const std::exception & e) {
SRV_ERR("got exception: %s\n", e.what());
res->status = 500;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
}
return res;
};
server_http_context::handler_t server_tools_post = [](const server_http_req & req) -> server_http_res_ptr {
auto res = std::make_unique<server_http_res>();
try {
json body = json::parse(req.body);
std::string tool_name = body.at("tool").get<std::string>();
json params = body.value("params", json::object());
json result = server_tool_call(tool_name, params);
res->data = safe_json_to_str(result);
} catch (const json::exception & e) {
res->status = 400;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
} catch (const std::exception & e) {
SRV_ERR("got exception: %s\n", e.what());
res->status = 500;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
}
return res;
};

View File

@ -3,5 +3,24 @@
#include "server-common.h"
#include "server-http.h"
extern server_http_context::handler_t server_tools_get;
extern server_http_context::handler_t server_tools_post;
struct server_tool {
std::string name;
std::string display_name;
bool permission_write = false;
virtual ~server_tool() = default;
virtual json get_definition() = 0;
virtual json invoke(json params) = 0;
json to_json();
};
struct server_tools {
std::vector<std::unique_ptr<server_tool>> tools;
void setup(const std::vector<std::string> & enabled_tools);
json invoke(const std::string & name, const json & params);
server_http_context::handler_t handle_get;
server_http_context::handler_t handle_post;
};

View File

@ -125,6 +125,7 @@ int main(int argc, char ** argv) {
// register API routes
server_routes routes(params, ctx_server);
server_tools tools;
bool is_router_server = params.model.path.empty();
std::optional<server_models_routes> models_routes{};
@ -213,13 +214,14 @@ int main(int argc, char ** argv) {
ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post));
}
// EXPERIMENTAL built-in tools
if (params.server_tools) {
if (!params.server_tools.empty()) {
tools.setup(params.server_tools);
SRV_WRN("%s", "-----------------\n");
SRV_WRN("%s", "Built-in tools are enabled, do not expose server to untrusted environments\n");
SRV_WRN("%s", "This feature is EXPERIMENTAL and may be changed in the future\n");
SRV_WRN("%s", "-----------------\n");
ctx_http.get ("/tools", ex_wrapper(server_tools_get));
ctx_http.post("/tools", ex_wrapper(server_tools_post));
ctx_http.get ("/tools", ex_wrapper(tools.handle_get));
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
}
//