diff --git a/common/common.cpp b/common/common.cpp index 59d75a3b95..affb3d7bac 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -656,6 +656,38 @@ bool string_parse_kv_override(const char * data, std::vector & values); std::string string_from(const struct llama_context * ctx, const std::vector & tokens); std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch); +bool glob_match(const std::string & pattern, const std::string & str); + // // Filesystem utils // diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index f5b4426f6f..c58fda83e2 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -224,10 +224,11 @@ struct cli_context { }; // TODO?: Make this reusable, enums, docs -static const std::array cmds = { +static const std::array cmds = { "/audio ", "/clear", "/exit", + "/glob ", "/image ", "/read ", "/regen", @@ -258,7 +259,7 @@ static std::vector> auto_completion_callback(std: } } - if (!cmd.empty() && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) { + if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) { const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length())); const std::string path_postfix = std::string(line.substr(cursor_byte_pos)); auto cur_dir = std::filesystem::current_path(); @@ -339,6 +340,8 @@ static std::vector> auto_completion_callback(std: return matches; } +static constexpr size_t FILE_GLOB_MAX_RESULTS = 100; + int main(int argc, char ** argv) { common_params params; @@ -430,7 +433,8 @@ int main(int argc, char ** argv) { console::log(" /exit or Ctrl+C stop or exit\n"); console::log(" /regen regenerate the last response\n"); console::log(" /clear clear the chat history\n"); - console::log(" /read add a text file\n"); + console::log(" /read add a text file\n"); + console::log(" /glob add text files using globbing pattern\n"); if (inf.has_inp_image) { console::log(" /image add an image file\n"); } @@ -441,6 +445,27 @@ int main(int argc, char ** argv) { // interactive loop std::string cur_msg; + + auto add_text_file = [&](const std::string & fname) -> bool { + std::string marker = ctx_cli.load_input_file(fname, false); + if (marker.empty()) { + console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str()); + return false; + } + if (inf.fim_sep_token != LLAMA_TOKEN_NULL) { + cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true); + cur_msg += fname; + cur_msg.push_back('\n'); + } else { + cur_msg += "--- File: "; + cur_msg += fname; + cur_msg += " ---\n"; + } + cur_msg += marker; + console::log("Loaded text from '%s'\n", fname.c_str()); + return true; + }; + while (true) { std::string buffer; console::set_display(DISPLAY_TYPE_USER_INPUT); @@ -525,22 +550,60 @@ int main(int argc, char ** argv) { continue; } else if (string_starts_with(buffer, "/read ")) { std::string fname = string_strip(buffer.substr(6)); - std::string marker = ctx_cli.load_input_file(fname, false); - if (marker.empty()) { - console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str()); - continue; + add_text_file(fname); + continue; + } else if (string_starts_with(buffer, "/glob ")) { + std::error_code ec; + size_t count = 0; + auto curdir = std::filesystem::current_path(); + std::string pattern = string_strip(buffer.substr(6)); + std::filesystem::path rel_path; + + auto startglob = pattern.find_first_of("![*?"); + if (startglob != std::string::npos && startglob != 0) { + auto endpath = pattern.substr(0, startglob).find_last_of('/'); + if (endpath != std::string::npos) { + std::string rel_pattern = pattern.substr(0, endpath); +#if !defined(_WIN32) + if (string_starts_with(rel_pattern, "~")) { + const char * home = std::getenv("HOME"); + if (home && home[0]) { + rel_pattern = std::string(home) + rel_pattern.substr(1); + } + } +#endif + rel_path = rel_pattern; + pattern.erase(0, endpath + 1); + curdir /= rel_path; + } } - if (inf.fim_sep_token != LLAMA_TOKEN_NULL) { - cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true); - cur_msg += fname; - cur_msg.push_back('\n'); - } else { - cur_msg += "--- File: "; - cur_msg += fname; - cur_msg += " ---\n"; + + for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir, + std::filesystem::directory_options::skip_permission_denied, ec)) { + if (!entry.is_regular_file()) { + continue; + } + + std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string(); + if (ec) { + ec.clear(); + continue; + } + std::replace(rel.begin(), rel.end(), '\\', '/'); + + if (!glob_match(pattern, rel)) { + continue; + } + + if (!add_text_file((rel_path / rel).string())) { + continue; + } + + if (++count >= FILE_GLOB_MAX_RESULTS) { + console::error("Maximum number of globbed files allowed (%zu) reached.\n", FILE_GLOB_MAX_RESULTS); + break; + } } - cur_msg += marker; - console::log("Loaded text from '%s'\n", fname.c_str()); continue; } else { // not a command diff --git a/tools/server/server-tools.cpp b/tools/server/server-tools.cpp index 5e89a5668b..81e360de46 100644 --- a/tools/server/server-tools.cpp +++ b/tools/server/server-tools.cpp @@ -101,38 +101,6 @@ static run_proc_result run_process( return res; } -// simple glob: * matches non-/ chars, ** matches anything including / -static bool glob_match(const char * pattern, const char * str) { - if (*pattern == '\0') { - return *str == '\0'; - } - if (pattern[0] == '*' && pattern[1] == '*') { - const char * p = pattern + 2; - if (*p == '/') p++; - if (glob_match(p, str)) return true; - if (*str != '\0') return glob_match(pattern, str + 1); - return false; - } - if (*pattern == '*') { - const char * p = pattern + 1; - for (; *str != '\0' && *str != '/'; str++) { - if (glob_match(p, str)) return true; - } - return glob_match(p, str); - } - if (*pattern == '?' && *str != '\0' && *str != '/') { - return glob_match(pattern + 1, str + 1); - } - if (*pattern == *str) { - return glob_match(pattern + 1, str + 1); - } - return false; -} - -static bool glob_match(const std::string & pattern, const std::string & str) { - return glob_match(pattern.c_str(), str.c_str()); -} - json server_tool::to_json() { return { {"display_name", display_name},