diff --git a/common/arg.cpp b/common/arg.cpp index 0096e7a68d..32d8e69f96 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2843,6 +2843,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.webui_mcp_proxy = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY")); + add_opt(common_arg( + {"--tools"}, "TOOL1,TOOL2,...", + "experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)\n" + "specify \"all\" to enable all tools\n" + "available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff", + [](common_params & params, const std::string & value) { + params.server_tools = parse_csv_row(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS")); add_opt(common_arg( {"--webui"}, {"--no-webui"}, diff --git a/common/common.h b/common/common.h index 62201ea1ad..fde5ba996e 100644 --- a/common/common.h +++ b/common/common.h @@ -613,6 +613,9 @@ struct common_params { bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; + // enable built-in tools + std::vector server_tools; + // router server configs std::string models_dir = ""; // directory containing models for the router server std::string models_preset = ""; // directory containing model presets for the router server diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 5621a51b22..fc4cb5dcc8 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -13,6 +13,8 @@ add_library(${TARGET} STATIC server-common.h server-context.cpp server-context.h + server-tools.cpp + server-tools.h ) if (BUILD_SHARED_LIBS) diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md index 5f82e35d6c..1dd00b72e4 100644 --- a/tools/server/README-dev.md +++ b/tools/server/README-dev.md @@ -125,6 +125,61 @@ The framework automatically starts a `llama-server` instance, sends requests, an For detailed instructions, see the [test documentation](./tests/README.md). +### API for tools + +This endpoint is intended to be used internally by the Web UI and subject to change or to be removed in the future. + +**GET /tools** + +Get a list of tools, each tool has these fields: +- `tool` (string): the ID name of the tool, to be used in POST call. Example: `read_file` +- `display_name` (string): the name to be displayed on UI. Example: `Read file` +- `type` (string): always be `"builtin"` for now +- `permissions` (object): a mapping string --> boolean that indicates the permission required by this tool. This is useful for the UI to ask the user before calling the tool. For now, the only permission supported is `"write"` +- `definition` (object): the OAI-compat definition of this tool + +**POST /tools** + +Invoke a tool call, request body is a JSON object with: +- `tool` (string): the name of the tool +- `params` (object): a mapping from argument name (string) to argument value + +Returns JSON object. There are two response formats: + +Format 1: Plain text. The text will be placed into a field called `plain_text_response`, example: + +```json +{ + "plain_text_response": "this is a text response" +} +``` + +The client should extract this value and place it inside message content (note: content is no longer a JSON), example + +```json +{ + "role": "tool", + "content": "this is a text response" +} +``` + +Format 2: Normal JSON response, example: + +```json +{ + "error": "cannot open this file" +} +``` + +That requires `JSON.stringify` when formatted to message content: + +```json +{ + "role": "tool", + "content": "{\"error\":\"cannot open this file\"}" +} +``` + ### Notable Related PRs - Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443 diff --git a/tools/server/README.md b/tools/server/README.md index 5a5d3aa02b..cb53678416 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -194,6 +194,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `--webui-config JSON` | JSON that provides default WebUI settings (overrides WebUI defaults)
(env: LLAMA_ARG_WEBUI_CONFIG) | | `--webui-config-file PATH` | JSON file that provides default WebUI settings (overrides WebUI defaults)
(env: LLAMA_ARG_WEBUI_CONFIG_FILE) | | `--webui-mcp-proxy, --no-webui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_WEBUI_MCP_PROXY) | +| `--tools TOOL1,TOOL2,...` | experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)
specify "all" to enable all tools
available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff
(env: LLAMA_ARG_TOOLS) | | `--webui, --no-webui` | whether to enable the Web UI (default: enabled)
(env: LLAMA_ARG_WEBUI) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | | `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | @@ -293,6 +294,12 @@ It is currently available in the following endpoints: For more details, please refer to [multimodal documentation](../../docs/multimodal.md) +### Built-in tools support + +The server includes a set of built-in tools that enable the LLM to access the local file system directly from the Web UI. + +To use this feature, start the server with `--tools all`. You can also enable only specific tools by passing a comma-separated list: `--tools name1,name2,...`. Run `--help` for the full list of available tool names. + ## Build `llama-server` is built alongside everything else from the root of the project @@ -1438,6 +1445,14 @@ curl http://localhost:8080/v1/messages/count_tokens \ {"input_tokens": 10} ``` +## Server built-in tools + +The server exposes a REST API under `/tools` that allows the Web UI to call built-in tools. This endpoint is intended to be used internally by the Web UI and subject to change or to be removed in the future. + +**Please do NOT use this endpoint in a downstream application** + +For further documentation about this endpoint, please refer to [server internal documentation](./README-dev.md) + ## Using multiple models `llama-server` can be launched in a **router mode** that exposes an API for dynamically loading and unloading models. The main process (the "router") automatically forwards each request to the appropriate model instance. diff --git a/tools/server/server-tools.cpp b/tools/server/server-tools.cpp new file mode 100644 index 0000000000..5e89a5668b --- /dev/null +++ b/tools/server/server-tools.cpp @@ -0,0 +1,800 @@ +#include "server-tools.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +// +// internal helpers +// + +static std::vector to_cstr_vec(const std::vector & v) { + std::vector r; + r.reserve(v.size() + 1); + for (const auto & s : v) { + r.push_back(const_cast(s.c_str())); + } + r.push_back(nullptr); + return r; +} + +struct run_proc_result { + std::string output; + int exit_code = -1; + bool timed_out = false; +}; + +static run_proc_result run_process( + const std::vector & args, + size_t max_output, + int timeout_secs) { + run_proc_result res; + + subprocess_s proc; + auto argv = to_cstr_vec(args); + + int options = subprocess_option_no_window + | subprocess_option_combined_stdout_stderr + | subprocess_option_inherit_environment + | subprocess_option_search_user_path; + + if (subprocess_create(argv.data(), options, &proc) != 0) { + res.output = "failed to spawn process"; + return res; + } + + std::atomic done{false}; + std::atomic timed_out{false}; + + std::thread timeout_thread([&]() { + auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_secs); + while (!done.load()) { + if (std::chrono::steady_clock::now() >= deadline) { + timed_out.store(true); + subprocess_terminate(&proc); + return; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + }); + + FILE * f = subprocess_stdout(&proc); + std::string output; + bool truncated = false; + if (f) { + char buf[4096]; + while (fgets(buf, sizeof(buf), f) != nullptr) { + if (!truncated) { + size_t len = strlen(buf); + if (output.size() + len <= max_output) { + output.append(buf, len); + } else { + output.append(buf, max_output - output.size()); + truncated = true; + } + } + } + } + + done.store(true); + if (timeout_thread.joinable()) { + timeout_thread.join(); + } + + subprocess_join(&proc, &res.exit_code); + subprocess_destroy(&proc); + + res.output = output; + res.timed_out = timed_out.load(); + if (truncated) { + res.output += "\n[output truncated]"; + } + return res; +} + +// simple glob: * matches non-/ chars, ** matches anything including / +static bool glob_match(const char * pattern, const char * str) { + if (*pattern == '\0') { + return *str == '\0'; + } + if (pattern[0] == '*' && pattern[1] == '*') { + const char * p = pattern + 2; + if (*p == '/') p++; + if (glob_match(p, str)) return true; + if (*str != '\0') return glob_match(pattern, str + 1); + return false; + } + if (*pattern == '*') { + const char * p = pattern + 1; + for (; *str != '\0' && *str != '/'; str++) { + if (glob_match(p, str)) return true; + } + return glob_match(p, str); + } + if (*pattern == '?' && *str != '\0' && *str != '/') { + return glob_match(pattern + 1, str + 1); + } + if (*pattern == *str) { + return glob_match(pattern + 1, str + 1); + } + return false; +} + +static bool glob_match(const std::string & pattern, const std::string & str) { + return glob_match(pattern.c_str(), str.c_str()); +} + +json server_tool::to_json() { + return { + {"display_name", display_name}, + {"tool", name}, + {"type", "builtin"}, + {"permissions", json{ + {"write", permission_write} + }}, + {"definition", get_definition()}, + }; +} + +// +// read_file: read a file with optional line range and line-number prefix +// + +static constexpr size_t SERVER_TOOL_READ_FILE_MAX_SIZE = 16 * 1024; // 16 KB + +struct server_tool_read_file : server_tool { + server_tool_read_file() { + name = "read_file"; + display_name = "Read file"; + permission_write = false; + } + + json get_definition() override { + return { + {"type", "function"}, + {"function", { + {"name", name}, + {"description", "Read the contents of a file. Optionally specify a 1-based line range. " + "If append_loc is true, each line is prefixed with its line number (e.g. \"1\u2192 ...\")."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"path", {{"type", "string"}, {"description", "Path to the file"}}}, + {"start_line", {{"type", "integer"}, {"description", "First line to read, 1-based (default: 1)"}}}, + {"end_line", {{"type", "integer"}, {"description", "Last line to read, 1-based inclusive (default: end of file)"}}}, + {"append_loc", {{"type", "boolean"}, {"description", "Prefix each line with its line number"}}}, + }}, + {"required", json::array({"path"})}, + }}, + }}, + }; + } + + json invoke(json params) override { + std::string path = params.at("path").get(); + int start_line = json_value(params, "start_line", 1); + int end_line = json_value(params, "end_line", -1); // -1 = no limit + bool append_loc = json_value(params, "append_loc", false); + + std::error_code ec; + uintmax_t file_size = fs::file_size(path, ec); + if (ec) { + return {{"error", "cannot stat file: " + ec.message()}}; + } + if (file_size > SERVER_TOOL_READ_FILE_MAX_SIZE && end_line == -1) { + return {{"error", string_format( + "file too large (%zu bytes, max %zu). Use start_line/end_line to read a portion.", + (size_t)file_size, SERVER_TOOL_READ_FILE_MAX_SIZE)}}; + } + + std::ifstream f(path); + if (!f) { + return {{"error", "failed to open file: " + path}}; + } + + std::string result; + std::string line; + int lineno = 0; + + while (std::getline(f, line)) { + lineno++; + if (lineno < start_line) continue; + if (end_line != -1 && lineno > end_line) break; + + std::string out_line; + if (append_loc) { + out_line = std::to_string(lineno) + "\u2192 " + line + "\n"; + } else { + out_line = line + "\n"; + } + + if (result.size() + out_line.size() > SERVER_TOOL_READ_FILE_MAX_SIZE) { + result += "[output truncated]"; + break; + } + result += out_line; + } + + return {{"plain_text_response", result}}; + } +}; + +// +// file_glob_search: find files matching a glob pattern under a base directory +// + +static constexpr size_t SERVER_TOOL_FILE_SEARCH_MAX_RESULTS = 100; + +struct server_tool_file_glob_search : server_tool { + server_tool_file_glob_search() { + name = "file_glob_search"; + display_name = "File search"; + permission_write = false; + } + + json get_definition() override { + return { + {"type", "function"}, + {"function", { + {"name", name}, + {"description", "Recursively search for files matching a glob pattern under a directory."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"path", {{"type", "string"}, {"description", "Base directory to search in"}}}, + {"include", {{"type", "string"}, {"description", "Glob pattern for files to include (e.g. \"**/*.cpp\"). Default: **"}}}, + {"exclude", {{"type", "string"}, {"description", "Glob pattern for files to exclude"}}}, + }}, + {"required", json::array({"path"})}, + }}, + }}, + }; + } + + json invoke(json params) override { + std::string base = params.at("path").get(); + std::string include = json_value(params, "include", std::string("**")); + std::string exclude = json_value(params, "exclude", std::string("")); + + std::ostringstream output_text; + size_t count = 0; + + std::error_code ec; + for (const auto & entry : fs::recursive_directory_iterator(base, + fs::directory_options::skip_permission_denied, ec)) { + if (!entry.is_regular_file()) continue; + + std::string rel = fs::relative(entry.path(), base, ec).string(); + if (ec) continue; + std::replace(rel.begin(), rel.end(), '\\', '/'); + + if (!glob_match(include, rel)) continue; + if (!exclude.empty() && glob_match(exclude, rel)) continue; + + output_text << entry.path().string() << "\n"; + if (++count >= SERVER_TOOL_FILE_SEARCH_MAX_RESULTS) { + break; + } + } + + output_text << "\n---\nTotal matches: " << count << "\n"; + + return {{"plain_text_response", output_text.str()}}; + } +}; + +// +// grep_search: search for a regex pattern in files +// + +static constexpr size_t SERVER_TOOL_GREP_SEARCH_MAX_RESULTS = 100; + +struct server_tool_grep_search : server_tool { + server_tool_grep_search() { + name = "grep_search"; + display_name = "Grep search"; + permission_write = false; + } + + json get_definition() override { + return { + {"type", "function"}, + {"function", { + {"name", name}, + {"description", "Search for a regex pattern in files under a path. Returns matching lines."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"path", {{"type", "string"}, {"description", "File or directory to search in"}}}, + {"pattern", {{"type", "string"}, {"description", "Regular expression pattern to search for"}}}, + {"include", {{"type", "string"}, {"description", "Glob pattern to filter files (default: **)"}}}, + {"exclude", {{"type", "string"}, {"description", "Glob pattern to exclude files"}}}, + {"return_line_numbers", {{"type", "boolean"}, {"description", "If true, include line numbers in results"}}}, + }}, + {"required", json::array({"path", "pattern"})}, + }}, + }}, + }; + } + + json invoke(json params) override { + std::string path = params.at("path").get(); + std::string pat_str = params.at("pattern").get(); + std::string include = json_value(params, "include", std::string("**")); + std::string exclude = json_value(params, "exclude", std::string("")); + bool show_lineno = json_value(params, "return_line_numbers", false); + + std::regex pattern; + try { + pattern = std::regex(pat_str); + } catch (const std::regex_error & e) { + return {{"error", std::string("invalid regex: ") + e.what()}}; + } + + std::ostringstream output_text; + size_t total = 0; + + auto search_file = [&](const fs::path & fpath) { + std::ifstream f(fpath); + if (!f) return; + std::string line; + int lineno = 0; + while (std::getline(f, line) && total < SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) { + lineno++; + if (std::regex_search(line, pattern)) { + output_text << fpath.string() << ":"; + if (show_lineno) { + output_text << lineno << ":"; + } + output_text << line << "\n"; + total++; + } + } + }; + + std::error_code ec; + if (fs::is_regular_file(path, ec)) { + search_file(path); + } else if (fs::is_directory(path, ec)) { + for (const auto & entry : fs::recursive_directory_iterator(path, + fs::directory_options::skip_permission_denied, ec)) { + if (!entry.is_regular_file()) continue; + if (total >= SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) break; + + std::string rel = fs::relative(entry.path(), path, ec).string(); + if (ec) continue; + std::replace(rel.begin(), rel.end(), '\\', '/'); + + if (!glob_match(include, rel)) continue; + if (!exclude.empty() && glob_match(exclude, rel)) continue; + + search_file(entry.path()); + } + } else { + return {{"error", "path does not exist: " + path}}; + } + + output_text << "\n\n---\nTotal matches: " << total << "\n"; + + return {{"plain_text_response", output_text.str()}}; + } +}; + +// +// exec_shell_command: run an arbitrary shell command +// + +static constexpr size_t SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE = 16 * 1024; // 16 KB +static constexpr int SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT = 60; // seconds + +struct server_tool_exec_shell_command : server_tool { + server_tool_exec_shell_command() { + name = "exec_shell_command"; + display_name = "Execute shell command"; + permission_write = true; + } + + json get_definition() override { + return { + {"type", "function"}, + {"function", { + {"name", name}, + {"description", "Execute a shell command and return its output (stdout and stderr combined)."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"command", {{"type", "string"}, {"description", "Shell command to execute"}}}, + {"timeout", {{"type", "integer"}, {"description", string_format("Timeout in seconds (default 10, max %d)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT)}}}, + {"max_output_size", {{"type", "integer"}, {"description", string_format("Maximum output size in bytes (default %zu)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE)}}}, + }}, + {"required", json::array({"command"})}, + }}, + }}, + }; + } + + json invoke(json params) override { + std::string command = params.at("command").get(); + int timeout = json_value(params, "timeout", 10); + size_t max_output = (size_t) json_value(params, "max_output_size", (int) SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE); + + timeout = std::min(timeout, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT); + max_output = std::min(max_output, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE); + +#ifdef _WIN32 + std::vector args = {"cmd", "/c", command}; +#else + std::vector args = {"sh", "-c", command}; +#endif + + auto res = run_process(args, max_output, timeout); + + std::string text_output = res.output; + text_output += string_format("\n[exit code: %d]", res.exit_code); + if (res.timed_out) { + text_output += " [exit due to timed out]"; + } + + return {{"plain_text_response", text_output}}; + } +}; + +// +// write_file: create or overwrite a file +// + +struct server_tool_write_file : server_tool { + server_tool_write_file() { + name = "write_file"; + display_name = "Write file"; + permission_write = true; + } + + json get_definition() override { + return { + {"type", "function"}, + {"function", { + {"name", name}, + {"description", "Write content to a file, creating it (including parent directories) if it does not exist. May use with edit_file for more complex edits."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"path", {{"type", "string"}, {"description", "Path of the file to write"}}}, + {"content", {{"type", "string"}, {"description", "Content to write"}}}, + }}, + {"required", json::array({"path", "content"})}, + }}, + }}, + }; + } + + json invoke(json params) override { + std::string path = params.at("path").get(); + std::string content = params.at("content").get(); + + std::error_code ec; + fs::path fpath(path); + if (fpath.has_parent_path()) { + fs::create_directories(fpath.parent_path(), ec); + if (ec) { + return {{"error", "failed to create directories: " + ec.message()}}; + } + } + + std::ofstream f(path, std::ios::binary); + if (!f) { + return {{"error", "failed to open file for writing: " + path}}; + } + f << content; + if (!f) { + return {{"error", "failed to write file: " + path}}; + } + + return {{"result", "file written successfully"}, {"path", path}, {"bytes", content.size()}}; + } +}; + +// +// edit_file: edit file content via line-based changes +// + +struct server_tool_edit_file : server_tool { + server_tool_edit_file() { + name = "edit_file"; + display_name = "Edit file"; + permission_write = true; + } + + json get_definition() override { + return { + {"type", "function"}, + {"function", { + {"name", name}, + {"description", + "Edit a file by applying a list of line-based changes. " + "Each change targets a 1-based inclusive line range and has a mode: " + "\"replace\" (replace lines with content), " + "\"delete\" (remove lines, content must be empty string), " + "\"append\" (insert content after line_end). " + "Set line_start to -1 to target the end of file (line_end is ignored in that case). " + "Changes must not overlap. They are applied in reverse line order automatically."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"path", {{"type", "string"}, {"description", "Path to the file to edit"}}}, + {"changes", { + {"type", "array"}, + {"description", "List of changes to apply"}, + {"items", { + {"type", "object"}, + {"properties", { + {"mode", {{"type", "string"}, {"description", "\"replace\", \"delete\", or \"append\""}}}, + {"line_start", {{"type", "integer"}, {"description", "First line of the range (1-based); use -1 for end of file"}}}, + {"line_end", {{"type", "integer"}, {"description", "Last line of the range (1-based, inclusive); ignored when line_start is -1"}}}, + {"content", {{"type", "string"}, {"description", "Content to insert; must be empty string for delete mode"}}}, + }}, + {"required", json::array({"mode", "line_start", "line_end", "content"})}, + }}, + }}, + }}, + {"required", json::array({"path", "changes"})}, + }}, + }}, + }; + } + + json invoke(json params) override { + std::string path = params.at("path").get(); + const json & changes = params.at("changes"); + + if (!changes.is_array()) { + return {{"error", "\"changes\" must be an array"}}; + } + + // read file into lines + std::ifstream fin(path); + if (!fin) { + return {{"error", "failed to open file: " + path}}; + } + std::vector lines; + { + std::string line; + while (std::getline(fin, line)) { + lines.push_back(line); + } + } + fin.close(); + + // validate and collect changes, then sort descending by line_start + struct change_entry { + std::string mode; + int line_start; // 1-based + int line_end; // 1-based inclusive + std::string content; + }; + std::vector entries; + entries.reserve(changes.size()); + + for (const auto & ch : changes) { + change_entry e; + e.mode = ch.at("mode").get(); + e.line_start = ch.at("line_start").get(); + e.line_end = ch.at("line_end").get(); + e.content = ch.at("content").get(); + + if (e.mode != "replace" && e.mode != "delete" && e.mode != "append") { + return {{"error", "invalid mode \"" + e.mode + "\"; must be replace, delete, or append"}}; + } + if (e.mode == "delete" && !e.content.empty()) { + return {{"error", "content must be empty string for delete mode"}}; + } + int n = (int) lines.size(); + if (e.line_start == -1) { + // -1 means end of file; line_end is ignored — normalize to point past last line + e.line_start = n + 1; + e.line_end = n + 1; + } else { + if (e.line_start < 1 || e.line_end < e.line_start) { + return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}}; + } + if (e.line_end > n) { + return {{"error", string_format("line_end %d exceeds file length %d", e.line_end, n)}}; + } + } + entries.push_back(std::move(e)); + } + + // sort descending so earlier-indexed changes don't shift later ones + std::sort(entries.begin(), entries.end(), [](const change_entry & a, const change_entry & b) { + return a.line_start > b.line_start; + }); + + // apply changes (0-based indices internally) + for (const auto & e : entries) { + int idx_start = e.line_start - 1; // 0-based + int idx_end = e.line_end - 1; // 0-based inclusive + + // split content into lines (preserve trailing newline awareness) + std::vector new_lines; + if (!e.content.empty()) { + std::istringstream ss(e.content); + std::string ln; + while (std::getline(ss, ln)) { + new_lines.push_back(ln); + } + // if content ends with \n, getline consumed it — no extra empty line needed + // if content does NOT end with \n, last line is still captured correctly + } + + if (e.mode == "replace") { + // erase [idx_start, idx_end] and insert new_lines + lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1); + lines.insert(lines.begin() + idx_start, new_lines.begin(), new_lines.end()); + } else if (e.mode == "delete") { + lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1); + } else { // append + // idx_end + 1 may equal lines.size() when line_start == -1 (end of file) + lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end()); + } + } + + // write file back + std::ofstream fout(path, std::ios::binary); + if (!fout) { + return {{"error", "failed to open file for writing: " + path}}; + } + for (size_t i = 0; i < lines.size(); i++) { + fout << lines[i]; + if (i + 1 < lines.size()) { + fout << "\n"; + } + } + if (!lines.empty()) { + fout << "\n"; + } + if (!fout) { + return {{"error", "failed to write file: " + path}}; + } + + return {{"result", "file edited successfully"}, {"path", path}, {"lines", (int) lines.size()}}; + } +}; + +// +// apply_diff: apply a unified diff via git apply +// + +struct server_tool_apply_diff : server_tool { + server_tool_apply_diff() { + name = "apply_diff"; + display_name = "Apply diff"; + permission_write = true; + } + + json get_definition() override { + return { + {"type", "function"}, + {"function", { + {"name", name}, + {"description", "Apply a unified diff to edit one or more files using git apply. Use this instead of edit_file when the changes are complex."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"diff", {{"type", "string"}, {"description", "Unified diff content in git diff format"}}}, + }}, + {"required", json::array({"diff"})}, + }}, + }}, + }; + } + + json invoke(json params) override { + std::string diff = params.at("diff").get(); + + // write diff to a temporary file + static std::atomic counter{0}; + std::string tmp_path = (fs::temp_directory_path() / + ("llama_patch_" + std::to_string(++counter) + ".patch")).string(); + + { + std::ofstream f(tmp_path, std::ios::binary); + if (!f) { + return {{"error", "failed to create temp patch file"}}; + } + f << diff; + } + + auto res = run_process({"git", "apply", tmp_path}, 4096, 10); + + std::error_code ec; + fs::remove(tmp_path, ec); + + if (res.exit_code != 0) { + return {{"error", "git apply failed (exit " + std::to_string(res.exit_code) + "): " + res.output}}; + } + return {{"result", "patch applied successfully"}}; + } +}; + +// +// public API +// + +static std::vector> build_tools() { + std::vector> tools; + tools.push_back(std::make_unique()); + tools.push_back(std::make_unique()); + tools.push_back(std::make_unique()); + tools.push_back(std::make_unique()); + tools.push_back(std::make_unique()); + tools.push_back(std::make_unique()); + tools.push_back(std::make_unique()); + return tools; +} + +void server_tools::setup(const std::vector & enabled_tools) { + if (!enabled_tools.empty()) { + std::unordered_set enabled_set(enabled_tools.begin(), enabled_tools.end()); + auto all_tools = build_tools(); + + tools.clear(); + for (auto & t : all_tools) { + if (enabled_set.count(t->name) > 0 || enabled_set.count("all") > 0) { + tools.push_back(std::move(t)); + } + } + } + + handle_get = [this](const server_http_req &) -> server_http_res_ptr { + auto res = std::make_unique(); + try { + json result = json::array(); + for (const auto & t : tools) { + result.push_back(t->to_json()); + } + res->data = safe_json_to_str(result); + } catch (const std::exception & e) { + SRV_ERR("got exception: %s\n", e.what()); + res->status = 500; + res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER)); + } + return res; + }; + + handle_post = [this](const server_http_req & req) -> server_http_res_ptr { + auto res = std::make_unique(); + try { + json body = json::parse(req.body); + std::string tool_name = body.at("tool").get(); + json params = body.value("params", json::object()); + json result = invoke(tool_name, params); + res->data = safe_json_to_str(result); + } catch (const json::exception & e) { + res->status = 400; + res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + } catch (const std::exception & e) { + SRV_ERR("got exception: %s\n", e.what()); + res->status = 500; + res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER)); + } + return res; + }; +} + +json server_tools::invoke(const std::string & name, const json & params) { + for (auto & t : tools) { + if (t->name == name) { + return t->invoke(params); + } + } + return {{"error", "unknown tool: " + name}}; +} diff --git a/tools/server/server-tools.h b/tools/server/server-tools.h new file mode 100644 index 0000000000..444ef5f809 --- /dev/null +++ b/tools/server/server-tools.h @@ -0,0 +1,26 @@ +#pragma once + +#include "server-common.h" +#include "server-http.h" + +struct server_tool { + std::string name; + std::string display_name; + bool permission_write = false; + + virtual ~server_tool() = default; + virtual json get_definition() = 0; + virtual json invoke(json params) = 0; + + json to_json(); +}; + +struct server_tools { + std::vector> tools; + + void setup(const std::vector & enabled_tools); + json invoke(const std::string & name, const json & params); + + server_http_context::handler_t handle_get; + server_http_context::handler_t handle_post; +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 73e3aa44e0..ef54a46b19 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2,6 +2,7 @@ #include "server-http.h" #include "server-models.h" #include "server-cors-proxy.h" +#include "server-tools.h" #include "arg.h" #include "common.h" @@ -124,6 +125,7 @@ int main(int argc, char ** argv) { // register API routes server_routes routes(params, ctx_server); + server_tools tools; bool is_router_server = params.model.path.empty(); std::optional models_routes{}; @@ -211,6 +213,16 @@ int main(int argc, char ** argv) { ctx_http.get ("/cors-proxy", ex_wrapper(proxy_handler_get)); ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post)); } + // EXPERIMENTAL built-in tools + if (!params.server_tools.empty()) { + tools.setup(params.server_tools); + SRV_WRN("%s", "-----------------\n"); + SRV_WRN("%s", "Built-in tools are enabled, do not expose server to untrusted environments\n"); + SRV_WRN("%s", "This feature is EXPERIMENTAL and may be changed in the future\n"); + SRV_WRN("%s", "-----------------\n"); + ctx_http.get ("/tools", ex_wrapper(tools.handle_get)); + ctx_http.post("/tools", ex_wrapper(tools.handle_post)); + } // // Start the server