597 lines
20 KiB
C++
597 lines
20 KiB
C++
#include "server-tools.h"
|
|
|
|
#include <sheredom/subprocess.h>
|
|
|
|
#include <filesystem>
|
|
#include <fstream>
|
|
#include <regex>
|
|
#include <thread>
|
|
#include <chrono>
|
|
#include <atomic>
|
|
#include <cstring>
|
|
#include <climits>
|
|
|
|
namespace fs = std::filesystem;
|
|
|
|
//
|
|
// internal helpers
|
|
//
|
|
|
|
static std::vector<char *> to_cstr_vec(const std::vector<std::string> & v) {
|
|
std::vector<char *> r;
|
|
r.reserve(v.size() + 1);
|
|
for (const auto & s : v) {
|
|
r.push_back(const_cast<char *>(s.c_str()));
|
|
}
|
|
r.push_back(nullptr);
|
|
return r;
|
|
}
|
|
|
|
struct run_proc_result {
|
|
std::string output;
|
|
int exit_code = -1;
|
|
bool timed_out = false;
|
|
};
|
|
|
|
static run_proc_result run_process(
|
|
const std::vector<std::string> & args,
|
|
size_t max_output,
|
|
int timeout_secs) {
|
|
run_proc_result res;
|
|
|
|
subprocess_s proc;
|
|
auto argv = to_cstr_vec(args);
|
|
|
|
int options = subprocess_option_no_window
|
|
| subprocess_option_combined_stdout_stderr
|
|
| subprocess_option_inherit_environment
|
|
| subprocess_option_search_user_path;
|
|
|
|
if (subprocess_create(argv.data(), options, &proc) != 0) {
|
|
res.output = "failed to spawn process";
|
|
return res;
|
|
}
|
|
|
|
std::atomic<bool> done{false};
|
|
std::atomic<bool> timed_out{false};
|
|
|
|
std::thread timeout_thread([&]() {
|
|
auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_secs);
|
|
while (!done.load()) {
|
|
if (std::chrono::steady_clock::now() >= deadline) {
|
|
timed_out.store(true);
|
|
subprocess_terminate(&proc);
|
|
return;
|
|
}
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
|
}
|
|
});
|
|
|
|
FILE * f = subprocess_stdout(&proc);
|
|
std::string output;
|
|
bool truncated = false;
|
|
if (f) {
|
|
char buf[4096];
|
|
while (fgets(buf, sizeof(buf), f) != nullptr) {
|
|
if (!truncated) {
|
|
size_t len = strlen(buf);
|
|
if (output.size() + len <= max_output) {
|
|
output.append(buf, len);
|
|
} else {
|
|
output.append(buf, max_output - output.size());
|
|
truncated = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
done.store(true);
|
|
if (timeout_thread.joinable()) {
|
|
timeout_thread.join();
|
|
}
|
|
|
|
subprocess_join(&proc, &res.exit_code);
|
|
subprocess_destroy(&proc);
|
|
|
|
res.output = output;
|
|
res.timed_out = timed_out.load();
|
|
if (truncated) {
|
|
res.output += "\n[output truncated]";
|
|
}
|
|
return res;
|
|
}
|
|
|
|
// simple glob: * matches non-/ chars, ** matches anything including /
|
|
static bool glob_match(const char * pattern, const char * str) {
|
|
if (*pattern == '\0') {
|
|
return *str == '\0';
|
|
}
|
|
if (pattern[0] == '*' && pattern[1] == '*') {
|
|
const char * p = pattern + 2;
|
|
if (*p == '/') p++;
|
|
if (glob_match(p, str)) return true;
|
|
if (*str != '\0') return glob_match(pattern, str + 1);
|
|
return false;
|
|
}
|
|
if (*pattern == '*') {
|
|
const char * p = pattern + 1;
|
|
for (; *str != '\0' && *str != '/'; str++) {
|
|
if (glob_match(p, str)) return true;
|
|
}
|
|
return glob_match(p, str);
|
|
}
|
|
if (*pattern == '?' && *str != '\0' && *str != '/') {
|
|
return glob_match(pattern + 1, str + 1);
|
|
}
|
|
if (*pattern == *str) {
|
|
return glob_match(pattern + 1, str + 1);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static bool glob_match(const std::string & pattern, const std::string & str) {
|
|
return glob_match(pattern.c_str(), str.c_str());
|
|
}
|
|
|
|
//
|
|
// base struct
|
|
//
|
|
|
|
struct server_tool {
|
|
std::string name;
|
|
json definition;
|
|
bool permission_write = false;
|
|
virtual ~server_tool() = default;
|
|
virtual json to_json() = 0;
|
|
virtual json invoke(json params) = 0;
|
|
};
|
|
|
|
//
|
|
// read_file: read a file with optional line range and line-number prefix
|
|
//
|
|
|
|
static constexpr size_t SERVER_TOOL_READ_FILE_MAX_SIZE = 16 * 1024; // 16 KB
|
|
|
|
struct server_tool_read_file : server_tool {
|
|
server_tool_read_file() { name = "read_file"; permission_write = false; }
|
|
|
|
json to_json() override {
|
|
return {
|
|
{"type", "function"},
|
|
{"function", {
|
|
{"name", name},
|
|
{"description", "Read the contents of a file. Optionally specify a 1-based line range. "
|
|
"If append_loc is true, each line is prefixed with its line number (e.g. \"1\u2192 ...\")."},
|
|
{"parameters", {
|
|
{"type", "object"},
|
|
{"properties", {
|
|
{"path", {{"type", "string"}, {"description", "Path to the file"}}},
|
|
{"start_line", {{"type", "integer"}, {"description", "First line to read, 1-based (default: 1)"}}},
|
|
{"end_line", {{"type", "integer"}, {"description", "Last line to read, 1-based inclusive (default: end of file)"}}},
|
|
{"append_loc", {{"type", "boolean"}, {"description", "Prefix each line with its line number"}}},
|
|
}},
|
|
{"required", json::array({"path"})},
|
|
}},
|
|
}},
|
|
};
|
|
}
|
|
|
|
json invoke(json params) override {
|
|
std::string path = params.at("path").get<std::string>();
|
|
int start_line = json_value(params, "start_line", 1);
|
|
int end_line = json_value(params, "end_line", -1); // -1 = no limit
|
|
bool append_loc = json_value(params, "append_loc", false);
|
|
|
|
std::error_code ec;
|
|
uintmax_t file_size = fs::file_size(path, ec);
|
|
if (ec) {
|
|
return {{"error", "cannot stat file: " + ec.message()}};
|
|
}
|
|
if (file_size > SERVER_TOOL_READ_FILE_MAX_SIZE && end_line == -1) {
|
|
return {{"error", string_format(
|
|
"file too large (%zu bytes, max %zu). Use start_line/end_line to read a portion.",
|
|
(size_t)file_size, SERVER_TOOL_READ_FILE_MAX_SIZE)}};
|
|
}
|
|
|
|
std::ifstream f(path);
|
|
if (!f) {
|
|
return {{"error", "failed to open file: " + path}};
|
|
}
|
|
|
|
std::string result;
|
|
std::string line;
|
|
int lineno = 0;
|
|
|
|
while (std::getline(f, line)) {
|
|
lineno++;
|
|
if (lineno < start_line) continue;
|
|
if (end_line != -1 && lineno > end_line) break;
|
|
|
|
std::string out_line;
|
|
if (append_loc) {
|
|
out_line = std::to_string(lineno) + "\u2192 " + line + "\n";
|
|
} else {
|
|
out_line = line + "\n";
|
|
}
|
|
|
|
if (result.size() + out_line.size() > SERVER_TOOL_READ_FILE_MAX_SIZE) {
|
|
result += "[output truncated]";
|
|
break;
|
|
}
|
|
result += out_line;
|
|
}
|
|
|
|
return {{"content", result}};
|
|
}
|
|
};
|
|
|
|
//
|
|
// file_glob_search: find files matching a glob pattern under a base directory
|
|
//
|
|
|
|
static constexpr size_t SERVER_TOOL_FILE_SEARCH_MAX_RESULTS = 100;
|
|
|
|
struct server_tool_file_glob_search : server_tool {
|
|
server_tool_file_glob_search() { name = "file_glob_search"; permission_write = false; }
|
|
|
|
json to_json() override {
|
|
return {
|
|
{"type", "function"},
|
|
{"function", {
|
|
{"name", name},
|
|
{"description", "Recursively search for files matching a glob pattern under a directory."},
|
|
{"parameters", {
|
|
{"type", "object"},
|
|
{"properties", {
|
|
{"path", {{"type", "string"}, {"description", "Base directory to search in"}}},
|
|
{"include", {{"type", "string"}, {"description", "Glob pattern for files to include (e.g. \"**/*.cpp\"). Default: **"}}},
|
|
{"exclude", {{"type", "string"}, {"description", "Glob pattern for files to exclude"}}},
|
|
}},
|
|
{"required", json::array({"path"})},
|
|
}},
|
|
}},
|
|
};
|
|
}
|
|
|
|
json invoke(json params) override {
|
|
std::string base = params.at("path").get<std::string>();
|
|
std::string include = json_value(params, "include", std::string("**"));
|
|
std::string exclude = json_value(params, "exclude", std::string(""));
|
|
|
|
json files = json::array();
|
|
|
|
std::error_code ec;
|
|
for (const auto & entry : fs::recursive_directory_iterator(base,
|
|
fs::directory_options::skip_permission_denied, ec)) {
|
|
if (!entry.is_regular_file()) continue;
|
|
|
|
std::string rel = fs::relative(entry.path(), base, ec).string();
|
|
if (ec) continue;
|
|
std::replace(rel.begin(), rel.end(), '\\', '/');
|
|
|
|
if (!glob_match(include, rel)) continue;
|
|
if (!exclude.empty() && glob_match(exclude, rel)) continue;
|
|
|
|
files.push_back(entry.path().string());
|
|
if (files.size() >= SERVER_TOOL_FILE_SEARCH_MAX_RESULTS) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
return {{"files", files}, {"count", files.size()}};
|
|
}
|
|
};
|
|
|
|
//
|
|
// grep_search: search for a regex pattern in files
|
|
//
|
|
|
|
static constexpr size_t SERVER_TOOL_GREP_SEARCH_MAX_RESULTS = 100;
|
|
|
|
struct server_tool_grep_search : server_tool {
|
|
server_tool_grep_search() { name = "grep_search"; permission_write = false; }
|
|
|
|
json to_json() override {
|
|
return {
|
|
{"type", "function"},
|
|
{"function", {
|
|
{"name", name},
|
|
{"description", "Search for a regex pattern in files under a path. Returns matching lines."},
|
|
{"parameters", {
|
|
{"type", "object"},
|
|
{"properties", {
|
|
{"path", {{"type", "string"}, {"description", "File or directory to search in"}}},
|
|
{"pattern", {{"type", "string"}, {"description", "Regular expression pattern to search for"}}},
|
|
{"include", {{"type", "string"}, {"description", "Glob pattern to filter files (default: **)"}}},
|
|
{"exclude", {{"type", "string"}, {"description", "Glob pattern to exclude files"}}},
|
|
{"return_line_numbers", {{"type", "boolean"}, {"description", "If true, include line numbers in results"}}},
|
|
}},
|
|
{"required", json::array({"path", "pattern"})},
|
|
}},
|
|
}},
|
|
};
|
|
}
|
|
|
|
json invoke(json params) override {
|
|
std::string path = params.at("path").get<std::string>();
|
|
std::string pat_str = params.at("pattern").get<std::string>();
|
|
std::string include = json_value(params, "include", std::string("**"));
|
|
std::string exclude = json_value(params, "exclude", std::string(""));
|
|
bool show_lineno = json_value(params, "return_line_numbers", false);
|
|
|
|
std::regex pattern;
|
|
try {
|
|
pattern = std::regex(pat_str);
|
|
} catch (const std::regex_error & e) {
|
|
return {{"error", std::string("invalid regex: ") + e.what()}};
|
|
}
|
|
|
|
json matches = json::array();
|
|
size_t total = 0;
|
|
|
|
auto search_file = [&](const fs::path & fpath) {
|
|
std::ifstream f(fpath);
|
|
if (!f) return;
|
|
std::string line;
|
|
int lineno = 0;
|
|
while (std::getline(f, line) && total < SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) {
|
|
lineno++;
|
|
if (std::regex_search(line, pattern)) {
|
|
json match = {{"file", fpath.string()}, {"content", line}};
|
|
if (show_lineno) {
|
|
match["line"] = lineno;
|
|
}
|
|
matches.push_back(match);
|
|
total++;
|
|
}
|
|
}
|
|
};
|
|
|
|
std::error_code ec;
|
|
if (fs::is_regular_file(path, ec)) {
|
|
search_file(path);
|
|
} else if (fs::is_directory(path, ec)) {
|
|
for (const auto & entry : fs::recursive_directory_iterator(path,
|
|
fs::directory_options::skip_permission_denied, ec)) {
|
|
if (!entry.is_regular_file()) continue;
|
|
if (total >= SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) break;
|
|
|
|
std::string rel = fs::relative(entry.path(), path, ec).string();
|
|
if (ec) continue;
|
|
std::replace(rel.begin(), rel.end(), '\\', '/');
|
|
|
|
if (!glob_match(include, rel)) continue;
|
|
if (!exclude.empty() && glob_match(exclude, rel)) continue;
|
|
|
|
search_file(entry.path());
|
|
}
|
|
} else {
|
|
return {{"error", "path does not exist: " + path}};
|
|
}
|
|
|
|
return {{"matches", matches}, {"count", total}};
|
|
}
|
|
};
|
|
|
|
//
|
|
// exec_shell_command: run an arbitrary shell command
|
|
//
|
|
|
|
static constexpr size_t SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE = 16 * 1024; // 16 KB
|
|
static constexpr int SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT = 60; // seconds
|
|
|
|
struct server_tool_exec_shell_command : server_tool {
|
|
server_tool_exec_shell_command() { name = "exec_shell_command"; permission_write = true; }
|
|
|
|
json to_json() override {
|
|
return {
|
|
{"type", "function"},
|
|
{"function", {
|
|
{"name", name},
|
|
{"description", "Execute a shell command and return its output (stdout and stderr combined)."},
|
|
{"parameters", {
|
|
{"type", "object"},
|
|
{"properties", {
|
|
{"command", {{"type", "string"}, {"description", "Shell command to execute"}}},
|
|
{"timeout", {{"type", "integer"}, {"description", string_format("Timeout in seconds (default 10, max %d)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT)}}},
|
|
{"max_output_size", {{"type", "integer"}, {"description", string_format("Maximum output size in bytes (default %zu)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE)}}},
|
|
}},
|
|
{"required", json::array({"command"})},
|
|
}},
|
|
}},
|
|
};
|
|
}
|
|
|
|
json invoke(json params) override {
|
|
std::string command = params.at("command").get<std::string>();
|
|
int timeout = json_value(params, "timeout", 10);
|
|
size_t max_output = (size_t) json_value(params, "max_output_size", (int) SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
|
|
|
|
timeout = std::min(timeout, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT);
|
|
max_output = std::min(max_output, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
|
|
|
|
#ifdef _WIN32
|
|
std::vector<std::string> args = {"cmd", "/c", command};
|
|
#else
|
|
std::vector<std::string> args = {"sh", "-c", command};
|
|
#endif
|
|
|
|
auto res = run_process(args, max_output, timeout);
|
|
|
|
json out = {{"output", res.output}, {"exit_code", res.exit_code}};
|
|
if (res.timed_out) {
|
|
out["timed_out"] = true;
|
|
}
|
|
return out;
|
|
}
|
|
};
|
|
|
|
//
|
|
// write_file: create or overwrite a file
|
|
//
|
|
|
|
struct server_tool_write_file : server_tool {
|
|
server_tool_write_file() { name = "write_file"; permission_write = true; }
|
|
|
|
json to_json() override {
|
|
return {
|
|
{"type", "function"},
|
|
{"function", {
|
|
{"name", name},
|
|
{"description", "Write content to a file, creating it (including parent directories) if it does not exist."},
|
|
{"parameters", {
|
|
{"type", "object"},
|
|
{"properties", {
|
|
{"path", {{"type", "string"}, {"description", "Path of the file to write"}}},
|
|
{"content", {{"type", "string"}, {"description", "Content to write"}}},
|
|
}},
|
|
{"required", json::array({"path", "content"})},
|
|
}},
|
|
}},
|
|
};
|
|
}
|
|
|
|
json invoke(json params) override {
|
|
std::string path = params.at("path").get<std::string>();
|
|
std::string content = params.at("content").get<std::string>();
|
|
|
|
std::error_code ec;
|
|
fs::path fpath(path);
|
|
if (fpath.has_parent_path()) {
|
|
fs::create_directories(fpath.parent_path(), ec);
|
|
if (ec) {
|
|
return {{"error", "failed to create directories: " + ec.message()}};
|
|
}
|
|
}
|
|
|
|
std::ofstream f(path, std::ios::binary);
|
|
if (!f) {
|
|
return {{"error", "failed to open file for writing: " + path}};
|
|
}
|
|
f << content;
|
|
if (!f) {
|
|
return {{"error", "failed to write file: " + path}};
|
|
}
|
|
|
|
return {{"result", "file written successfully"}, {"path", path}, {"bytes", content.size()}};
|
|
}
|
|
};
|
|
|
|
//
|
|
// edit_file: apply a unified diff via git apply
|
|
//
|
|
|
|
struct server_tool_edit_file : server_tool {
|
|
server_tool_edit_file() { name = "edit_file"; permission_write = true; }
|
|
|
|
json to_json() override {
|
|
return {
|
|
{"type", "function"},
|
|
{"function", {
|
|
{"name", name},
|
|
{"description", "Apply a unified diff to edit one or more files using git apply."},
|
|
{"parameters", {
|
|
{"type", "object"},
|
|
{"properties", {
|
|
{"diff", {{"type", "string"}, {"description", "Unified diff content in git diff format"}}},
|
|
}},
|
|
{"required", json::array({"diff"})},
|
|
}},
|
|
}},
|
|
};
|
|
}
|
|
|
|
json invoke(json params) override {
|
|
std::string diff = params.at("diff").get<std::string>();
|
|
|
|
// write diff to a temporary file
|
|
static std::atomic<int> counter{0};
|
|
std::string tmp_path = (fs::temp_directory_path() /
|
|
("llama_patch_" + std::to_string(++counter) + ".patch")).string();
|
|
|
|
{
|
|
std::ofstream f(tmp_path, std::ios::binary);
|
|
if (!f) {
|
|
return {{"error", "failed to create temp patch file"}};
|
|
}
|
|
f << diff;
|
|
}
|
|
|
|
auto res = run_process({"git", "apply", tmp_path}, 4096, 10);
|
|
|
|
std::error_code ec;
|
|
fs::remove(tmp_path, ec);
|
|
|
|
if (res.exit_code != 0) {
|
|
return {{"error", "git apply failed (exit " + std::to_string(res.exit_code) + "): " + res.output}};
|
|
}
|
|
return {{"result", "patch applied successfully"}};
|
|
}
|
|
};
|
|
|
|
//
|
|
// public API
|
|
//
|
|
|
|
static std::vector<std::unique_ptr<server_tool>> build_tools() {
|
|
std::vector<std::unique_ptr<server_tool>> tools;
|
|
tools.push_back(std::make_unique<server_tool_read_file>());
|
|
tools.push_back(std::make_unique<server_tool_file_glob_search>());
|
|
tools.push_back(std::make_unique<server_tool_grep_search>());
|
|
tools.push_back(std::make_unique<server_tool_exec_shell_command>());
|
|
tools.push_back(std::make_unique<server_tool_write_file>());
|
|
tools.push_back(std::make_unique<server_tool_edit_file>());
|
|
return tools;
|
|
}
|
|
|
|
static json server_tools_list() {
|
|
auto tools = build_tools();
|
|
json result = json::array();
|
|
for (const auto & t : tools) {
|
|
result.push_back(t->to_json());
|
|
}
|
|
return result;
|
|
}
|
|
|
|
static json server_tool_call(const std::string & name, const json & params) {
|
|
auto tools = build_tools();
|
|
for (auto & t : tools) {
|
|
if (t->name == name) {
|
|
return t->invoke(params);
|
|
}
|
|
}
|
|
return {{"error", "unknown tool: " + name}};
|
|
}
|
|
|
|
server_http_context::handler_t server_tools_get = [](const server_http_req &) -> server_http_res_ptr {
|
|
auto res = std::make_unique<server_http_res>();
|
|
try {
|
|
json tools = server_tools_list();
|
|
res->data = safe_json_to_str(tools);
|
|
} catch (const std::exception & e) {
|
|
SRV_ERR("got exception: %s\n", e.what());
|
|
res->status = 500;
|
|
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
|
}
|
|
return res;
|
|
};
|
|
|
|
server_http_context::handler_t server_tools_post = [](const server_http_req & req) -> server_http_res_ptr {
|
|
auto res = std::make_unique<server_http_res>();
|
|
try {
|
|
json body = json::parse(req.body);
|
|
std::string tool_name = body.at("tool").get<std::string>();
|
|
json params = body.value("params", json::object());
|
|
json result = server_tool_call(tool_name, params);
|
|
res->data = safe_json_to_str(result);
|
|
} catch (const json::exception & e) {
|
|
res->status = 400;
|
|
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
|
} catch (const std::exception & e) {
|
|
SRV_ERR("got exception: %s\n", e.what());
|
|
res->status = 500;
|
|
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
|
}
|
|
return res;
|
|
};
|