change arg to --tools all
This commit is contained in:
parent
3c5dac1ddf
commit
b0a1b31477
|
|
@ -2849,11 +2849,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY"));
|
||||
add_opt(common_arg(
|
||||
{"--tools"},
|
||||
{"--no-tools"},
|
||||
string_format("experimental: whether to enable tools for AI agents - do not enable in untrusted environments (default: %s)", params.server_tools ? "enabled" : "disabled"),
|
||||
[](common_params & params, bool value) {
|
||||
params.server_tools = value;
|
||||
{"--tools"}, "TOOL1,TOOL2,...",
|
||||
"experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)\n"
|
||||
"specify \"all\" to enable all tools\n"
|
||||
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.server_tools = parse_csv_row(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
|
||||
add_opt(common_arg(
|
||||
|
|
|
|||
|
|
@ -614,7 +614,7 @@ struct common_params {
|
|||
bool endpoint_metrics = false;
|
||||
|
||||
// enable built-in tools
|
||||
bool server_tools = false;
|
||||
std::vector<std::string> server_tools;
|
||||
|
||||
// router server configs
|
||||
std::string models_dir = ""; // directory containing models for the router server
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ static std::vector<char *> to_cstr_vec(const std::vector<std::string> & v) {
|
|||
|
||||
struct run_proc_result {
|
||||
std::string output;
|
||||
int exit_code = -1;
|
||||
bool timed_out = false;
|
||||
int exit_code = -1;
|
||||
bool timed_out = false;
|
||||
};
|
||||
|
||||
static run_proc_result run_process(
|
||||
|
|
@ -133,31 +133,17 @@ static bool glob_match(const std::string & pattern, const std::string & str) {
|
|||
return glob_match(pattern.c_str(), str.c_str());
|
||||
}
|
||||
|
||||
//
|
||||
// base struct
|
||||
//
|
||||
|
||||
struct server_tool {
|
||||
std::string name;
|
||||
std::string display_name;
|
||||
bool permission_write = false;
|
||||
|
||||
virtual ~server_tool() = default;
|
||||
virtual json get_definition() = 0;
|
||||
virtual json invoke(json params) = 0;
|
||||
|
||||
json to_json() {
|
||||
return {
|
||||
{"display_name", display_name},
|
||||
{"tool", name},
|
||||
{"type", "builtin"},
|
||||
{"permissions", json{
|
||||
{"write", permission_write}
|
||||
}},
|
||||
{"definition", get_definition()},
|
||||
};
|
||||
}
|
||||
};
|
||||
json server_tool::to_json() {
|
||||
return {
|
||||
{"display_name", display_name},
|
||||
{"tool", name},
|
||||
{"type", "builtin"},
|
||||
{"permissions", json{
|
||||
{"write", permission_write}
|
||||
}},
|
||||
{"definition", get_definition()},
|
||||
};
|
||||
}
|
||||
|
||||
//
|
||||
// read_file: read a file with optional line range and line-number prefix
|
||||
|
|
@ -533,7 +519,7 @@ struct server_tool_edit_file : server_tool {
|
|||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description",
|
||||
{"description",
|
||||
"Edit a file by applying a list of line-based changes. "
|
||||
"Each change targets a 1-based inclusive line range and has a mode: "
|
||||
"\"replace\" (replace lines with content), "
|
||||
|
|
@ -755,17 +741,56 @@ static std::vector<std::unique_ptr<server_tool>> build_tools() {
|
|||
return tools;
|
||||
}
|
||||
|
||||
static json server_tools_list() {
|
||||
auto tools = build_tools();
|
||||
json result = json::array();
|
||||
for (const auto & t : tools) {
|
||||
result.push_back(t->to_json());
|
||||
void server_tools::setup(const std::vector<std::string> & enabled_tools) {
|
||||
if (!enabled_tools.empty()) {
|
||||
std::unordered_set<std::string> enabled_set(enabled_tools.begin(), enabled_tools.end());
|
||||
auto all_tools = build_tools();
|
||||
|
||||
tools.clear();
|
||||
for (auto & t : all_tools) {
|
||||
if (enabled_set.count(t->name) > 0 || enabled_set.count("all") > 0) {
|
||||
tools.push_back(std::move(t));
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
||||
handle_get = [this](const server_http_req &) -> server_http_res_ptr {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
try {
|
||||
json result = json::array();
|
||||
for (const auto & t : tools) {
|
||||
result.push_back(t->to_json());
|
||||
}
|
||||
res->data = safe_json_to_str(result);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got exception: %s\n", e.what());
|
||||
res->status = 500;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
handle_post = [this](const server_http_req & req) -> server_http_res_ptr {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
try {
|
||||
json body = json::parse(req.body);
|
||||
std::string tool_name = body.at("tool").get<std::string>();
|
||||
json params = body.value("params", json::object());
|
||||
json result = invoke(tool_name, params);
|
||||
res->data = safe_json_to_str(result);
|
||||
} catch (const json::exception & e) {
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got exception: %s\n", e.what());
|
||||
res->status = 500;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
static json server_tool_call(const std::string & name, const json & params) {
|
||||
auto tools = build_tools();
|
||||
json server_tools::invoke(const std::string & name, const json & params) {
|
||||
for (auto & t : tools) {
|
||||
if (t->name == name) {
|
||||
return t->invoke(params);
|
||||
|
|
@ -773,35 +798,3 @@ static json server_tool_call(const std::string & name, const json & params) {
|
|||
}
|
||||
return {{"error", "unknown tool: " + name}};
|
||||
}
|
||||
|
||||
server_http_context::handler_t server_tools_get = [](const server_http_req &) -> server_http_res_ptr {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
try {
|
||||
json tools = server_tools_list();
|
||||
res->data = safe_json_to_str(tools);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got exception: %s\n", e.what());
|
||||
res->status = 500;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
server_http_context::handler_t server_tools_post = [](const server_http_req & req) -> server_http_res_ptr {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
try {
|
||||
json body = json::parse(req.body);
|
||||
std::string tool_name = body.at("tool").get<std::string>();
|
||||
json params = body.value("params", json::object());
|
||||
json result = server_tool_call(tool_name, params);
|
||||
res->data = safe_json_to_str(result);
|
||||
} catch (const json::exception & e) {
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got exception: %s\n", e.what());
|
||||
res->status = 500;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -3,5 +3,24 @@
|
|||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
|
||||
extern server_http_context::handler_t server_tools_get;
|
||||
extern server_http_context::handler_t server_tools_post;
|
||||
struct server_tool {
|
||||
std::string name;
|
||||
std::string display_name;
|
||||
bool permission_write = false;
|
||||
|
||||
virtual ~server_tool() = default;
|
||||
virtual json get_definition() = 0;
|
||||
virtual json invoke(json params) = 0;
|
||||
|
||||
json to_json();
|
||||
};
|
||||
|
||||
struct server_tools {
|
||||
std::vector<std::unique_ptr<server_tool>> tools;
|
||||
|
||||
void setup(const std::vector<std::string> & enabled_tools);
|
||||
json invoke(const std::string & name, const json & params);
|
||||
|
||||
server_http_context::handler_t handle_get;
|
||||
server_http_context::handler_t handle_post;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -125,6 +125,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// register API routes
|
||||
server_routes routes(params, ctx_server);
|
||||
server_tools tools;
|
||||
|
||||
bool is_router_server = params.model.path.empty();
|
||||
std::optional<server_models_routes> models_routes{};
|
||||
|
|
@ -213,13 +214,14 @@ int main(int argc, char ** argv) {
|
|||
ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post));
|
||||
}
|
||||
// EXPERIMENTAL built-in tools
|
||||
if (params.server_tools) {
|
||||
if (!params.server_tools.empty()) {
|
||||
tools.setup(params.server_tools);
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
SRV_WRN("%s", "Built-in tools are enabled, do not expose server to untrusted environments\n");
|
||||
SRV_WRN("%s", "This feature is EXPERIMENTAL and may be changed in the future\n");
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
ctx_http.get ("/tools", ex_wrapper(server_tools_get));
|
||||
ctx_http.post("/tools", ex_wrapper(server_tools_post));
|
||||
ctx_http.get ("/tools", ex_wrapper(tools.handle_get));
|
||||
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
|
||||
}
|
||||
|
||||
//
|
||||
|
|
|
|||
Loading…
Reference in New Issue