From 6fd16ba05c88846ba5e6d8fd053c6230509d5ffc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Tue, 17 Mar 2026 13:56:48 +0000 Subject: [PATCH] common : add standard Hugging Face cache support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use HF API to find all files - Migrate all manifests to hugging face cache at startup Signed-off-by: Adrien Gallouët --- common/CMakeLists.txt | 2 + common/arg.cpp | 100 +++--- common/download.cpp | 505 ++++++++++++++++------------- common/download.h | 49 +-- common/hf-cache.cpp | 516 ++++++++++++++++++++++++++++++ common/hf-cache.h | 37 +++ common/preset.cpp | 4 +- tools/llama-bench/llama-bench.cpp | 29 +- 8 files changed, 903 insertions(+), 339 deletions(-) create mode 100644 common/hf-cache.cpp create mode 100644 common/hf-cache.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 75c6366c7f..b313a7320e 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -63,6 +63,8 @@ add_library(${TARGET} STATIC debug.h download.cpp download.h + hf-cache.cpp + hf-cache.h http.h json-partial.cpp json-partial.h diff --git a/common/arg.cpp b/common/arg.cpp index 666339a094..0fc5fae498 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3,6 +3,7 @@ #include "chat.h" #include "common.h" #include "download.h" +#include "hf-cache.h" #include "json-schema-to-grammar.h" #include "log.h" #include "sampling.h" @@ -326,60 +327,48 @@ struct handle_model_result { common_params_model mmproj; }; -static handle_model_result common_params_handle_model( - struct common_params_model & model, - const std::string & bearer_token, - bool offline) { +static handle_model_result common_params_handle_model(struct common_params_model & model, + const std::string & bearer_token, + bool offline) { handle_model_result result; - // handle pre-fill default model path and url based on hf_repo and hf_file - { - if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths - model.path = common_docker_resolve_model(model.docker_repo); - model.name = model.docker_repo; // set name for consistency - } else if (!model.hf_repo.empty()) { - // short-hand to avoid specifying --hf-file -> default it to --model - if (model.hf_file.empty()) { - if (model.path.empty()) { - auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline); - if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { - exit(1); // error message already printed - } - model.name = model.hf_repo; // repo name with tag - model.hf_repo = auto_detected.repo; // repo name without tag - model.hf_file = auto_detected.ggufFile; - if (!auto_detected.mmprojFile.empty()) { - result.found_mmproj = true; - result.mmproj.hf_repo = model.hf_repo; - result.mmproj.hf_file = auto_detected.mmprojFile; - } - } else { - model.hf_file = model.path; - } - } - - std::string model_endpoint = get_model_endpoint(); - model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; - // make sure model path is present (for caching purposes) - if (model.path.empty()) { - // this is to avoid different repo having same file name, or same file name in different subdirs - std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file); - model.path = fs_get_cache_file(filename); - } - - } else if (!model.url.empty()) { - if (model.path.empty()) { - auto f = string_split(model.url, '#').front(); - f = string_split(f, '?').front(); - model.path = fs_get_cache_file(string_split(f, '/').back()); - } + if (!model.docker_repo.empty()) { + model.path = common_docker_resolve_model(model.docker_repo); + model.name = model.docker_repo; + } else if (!model.hf_repo.empty()) { + // If -m was used with -hf, treat the model "path" as the hf_file to download + if (model.hf_file.empty() && !model.path.empty()) { + model.hf_file = model.path; + model.path = ""; } - } + common_download_model_opts opts; + opts.download_mmproj = true; + opts.offline = offline; + auto download_result = common_download_model(model, bearer_token, opts); - // then, download it if needed - if (!model.url.empty()) { - bool ok = common_download_model(model, bearer_token, offline); - if (!ok) { + if (download_result.model_path.empty()) { + LOG_ERR("error: failed to download model from Hugging Face\n"); + exit(1); + } + + model.name = model.hf_repo; + model.path = download_result.model_path; + + if (!download_result.mmproj_path.empty()) { + result.found_mmproj = true; + result.mmproj.path = download_result.mmproj_path; + } + } else if (!model.url.empty()) { + if (model.path.empty()) { + auto f = string_split(model.url, '#').front(); + f = string_split(f, '?').front(); + model.path = fs_get_cache_file(string_split(f, '/').back()); + } + + common_download_model_opts opts; + opts.offline = offline; + auto download_result = common_download_model(model, bearer_token, opts); + if (download_result.model_path.empty()) { LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); exit(1); } @@ -539,6 +528,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // parse the first time to get -hf option (used for remote preset) parse_cli_args(); + // TODO: Remove later + try { + hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline); + } catch (const std::exception & e) { + LOG_WRN("HF cache migration failed: %s\n", e.what()); + } + // maybe handle remote preset if (!params.model.hf_repo.empty()) { std::string cli_hf_repo = params.model.hf_repo; @@ -1061,12 +1057,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-cl", "--cache-list"}, "show list of models in cache", [](common_params &) { - printf("model cache directory: %s\n", fs_get_cache_directory().c_str()); auto models = common_list_cached_models(); printf("number of models in cache: %zu\n", models.size()); for (size_t i = 0; i < models.size(); i++) { - auto & model = models[i]; - printf("%4d. %s\n", (int) i + 1, model.to_string().c_str()); + printf("%4zu. %s\n", i + 1, models[i].c_str()); } exit(0); } diff --git a/common/download.cpp b/common/download.cpp index 5ef60a4208..ee37b84d5d 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -1,9 +1,9 @@ #include "arg.h" #include "common.h" -#include "gguf.h" // for reading GGUF splits #include "log.h" #include "download.h" +#include "hf-cache.h" #define JSON_ASSERT GGML_ASSERT #include @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -35,8 +36,6 @@ #endif #endif -#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 - // isatty #if defined(_WIN32) #include @@ -51,31 +50,6 @@ using json = nlohmann::ordered_json; // // validate repo name format: owner/repo -static bool validate_repo_name(const std::string & repo) { - static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)"); - return std::regex_match(repo, repo_regex); -} - -static std::string get_manifest_path(const std::string & repo, const std::string & tag) { - // we use "=" to avoid clashing with other component, while still being allowed on windows - std::string fname = "manifest=" + repo + "=" + tag + ".json"; - if (!validate_repo_name(repo)) { - throw std::runtime_error("error: repo name must be in the format 'owner/repo'"); - } - string_replace_all(fname, "/", "="); - return fs_get_cache_file(fname); -} - -static std::string read_file(const std::string & fname) { - std::ifstream file(fname); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); - } - std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); - file.close(); - return content; -} - static void write_file(const std::string & fname, const std::string & content) { const std::string fname_tmp = fname + ".tmp"; std::ofstream file(fname_tmp); @@ -132,7 +106,7 @@ static bool is_http_status_ok(int status) { std::pair common_download_split_repo_tag(const std::string & hf_repo_with_tag) { auto parts = string_split(hf_repo_with_tag, ':'); - std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string tag = parts.size() > 1 ? parts.back() : ""; std::string hf_repo = parts[0]; if (string_split(hf_repo, '/').size() != 2) { throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); @@ -290,7 +264,8 @@ static bool common_pull_file(httplib::Client & cli, static int common_download_file_single_online(const std::string & url, const std::string & path, const std::string & bearer_token, - const common_header_list & custom_headers) { + const common_header_list & custom_headers, + bool skip_etag = false) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; @@ -310,6 +285,11 @@ static int common_download_file_single_online(const std::string & url, const bool file_exists = std::filesystem::exists(path); + if (file_exists && skip_etag) { + LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); + return 304; // 304 Not Modified - fake cached response + } + std::string last_etag; if (file_exists) { last_etag = read_etag(path); @@ -361,6 +341,12 @@ static int common_download_file_single_online(const std::string & url, } } + { // silent + std::error_code ec; + std::filesystem::path p(path); + std::filesystem::create_directories(p.parent_path(), ec); + } + const std::string path_temporary = path + ".downloadInProgress"; int delay = retry_delay_seconds; @@ -391,7 +377,7 @@ static int common_download_file_single_online(const std::string & url, LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); return -1; } - if (!etag.empty()) { + if (!etag.empty() && !skip_etag) { write_etag(path, etag); } return head->status; @@ -440,9 +426,10 @@ int common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline, - const common_header_list & headers) { + const common_header_list & headers, + bool skip_etag) { if (!offline) { - return common_download_file_single_online(url, path, bearer_token, headers); + return common_download_file_single_online(url, path, bearer_token, headers, skip_etag); } if (!std::filesystem::exists(path)) { @@ -454,193 +441,234 @@ int common_download_file_single(const std::string & url, return 304; // Not Modified - fake cached response } -// download multiple files from remote URLs to local paths -// the input is a vector of pairs -static bool common_download_file_multiple(const std::vector> & urls, - const std::string & bearer_token, - bool offline, - const common_header_list & headers) { - // Prepare download in parallel - std::vector> futures_download; - futures_download.reserve(urls.size()); - - for (auto const & item : urls) { - futures_download.push_back( - std::async( - std::launch::async, - [&bearer_token, offline, &headers](const std::pair & it) -> bool { - const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers); - return is_http_status_ok(http_status); - }, - item - ) - ); +// "subdir/model-00001-of-00002.gguf" -> "subdir/model", 1, 2 +static std::tuple get_gguf_split_info(const std::string & path) { + if (path.empty()) { + return {}; } - // Wait for all downloads to complete - for (auto & f : futures_download) { - if (!f.get()) { - return false; - } - } + static const std::regex re(R"(^(.+)-([0-9]+)-of-([0-9]+)\.gguf$)", std::regex::icase); - return true; + std::smatch m; + if (std::regex_match(path, m, re)) { + return {m[1].str(), std::stoi(m[2].str()), std::stoi(m[3].str())}; + } + return {}; } -bool common_download_model(const common_params_model & model, - const std::string & bearer_token, - bool offline, - const common_header_list & headers) { - // Basic validation of the model.url - if (model.url.empty()) { - LOG_ERR("%s: invalid model url\n", __func__); - return false; - } +static hf_cache::hf_files get_split_files(const hf_cache::hf_files & all_files, + const hf_cache::hf_file & primary_file) { + hf_cache::hf_files result; + auto [prefix, idx, count] = get_gguf_split_info(primary_file.path); - const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers); - if (!is_http_status_ok(http_status)) { - return false; - } - - // check for additional GGUFs split to download - int n_split = 0; - { - struct gguf_init_params gguf_params = { - /*.no_alloc = */ true, - /*.ctx = */ NULL, - }; - auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params); - if (!ctx_gguf) { - LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str()); - return false; - } - - auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT); - if (key_n_split >= 0) { - n_split = gguf_get_val_u16(ctx_gguf, key_n_split); - } - - gguf_free(ctx_gguf); - } - - if (n_split > 1) { - char split_prefix[PATH_MAX] = {0}; - char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0}; - - // Verify the first split file format - // and extract split URL and PATH prefixes - { - if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split); - return false; - } - - if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split); - return false; + if (count > 1) { + for (const auto & f : all_files) { + auto [sprefix, sidx, scount] = get_gguf_split_info(f.path); + if (scount == count && sprefix == prefix) { + result.push_back(f); } } - - std::vector> urls; - for (int idx = 1; idx < n_split; idx++) { - char split_path[PATH_MAX] = {0}; - llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split); - - char split_url[LLAMA_MAX_URL_LENGTH] = {0}; - llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split); - - if (std::string(split_path) == model.path) { - continue; // skip the already downloaded file - } - - urls.push_back({split_url, split_path}); - } - - // Download in parallel - common_download_file_multiple(urls, bearer_token, offline, headers); - } - - return true; -} - -common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, - const std::string & bearer_token, - bool offline, - const common_header_list & custom_headers) { - // the returned hf_repo is without tag - auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag); - - std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; - - // headers - common_header_list headers = custom_headers; - headers.push_back({"Accept", "application/json"}); - if (!bearer_token.empty()) { - headers.push_back({"Authorization", "Bearer " + bearer_token}); - } - // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response - // User-Agent header is already set in common_remote_get_content, no need to set it here - - // make the request - common_remote_params params; - params.headers = headers; - long res_code = 0; - std::string res_str; - bool use_cache = false; - std::string cached_response_path = get_manifest_path(hf_repo, tag); - if (!offline) { - try { - auto res = common_remote_get_content(url, params); - res_code = res.first; - res_str = std::string(res.second.data(), res.second.size()); - } catch (const std::exception & e) { - LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what()); - } - } - if (res_code == 0) { - if (std::filesystem::exists(cached_response_path)) { - LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str()); - res_str = read_file(cached_response_path); - res_code = 200; - use_cache = true; - } else { - throw std::runtime_error( - offline ? "error: failed to get manifest (offline mode)" - : "error: failed to get manifest (check your internet connection)"); - } - } - std::string ggufFile; - std::string mmprojFile; - - if (res_code == 200 || res_code == 304) { - try { - auto j = json::parse(res_str); - - if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) { - ggufFile = j["ggufFile"]["rfilename"].get(); - } - if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) { - mmprojFile = j["mmprojFile"]["rfilename"].get(); - } - } catch (const std::exception & e) { - throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what()); - } - if (!use_cache) { - // if not using cached response, update the cache file - write_file(cached_response_path, res_str); - } - } else if (res_code == 401) { - throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); } else { - throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str())); + result.push_back(primary_file); + } + return result; +} + +static hf_cache::hf_files filter_gguf_by_quant(const hf_cache::hf_files & files, + const std::string & quant_tag) { + hf_cache::hf_files matches; + std::regex pattern(quant_tag + "[.-]", std::regex::icase); + + for (const auto & f : files) { + if (!string_ends_with(f.path, ".gguf")) { + continue; + } + if (f.path.find("mmproj") != std::string::npos) { + continue; + } + if (std::regex_search(f.path, pattern)) { + matches.push_back(f); + } + } + return matches; +} + +static void list_available_gguf_files(const hf_cache::hf_files & files) { + LOG_INF("Available GGUF files:\n"); + for (const auto & f : files) { + if (string_ends_with(f.path, ".gguf")) { + LOG_INF(" - %s\n", f.path.c_str()); + } + } +} + +struct hf_plan { + hf_cache::hf_file primary; + hf_cache::hf_file mmproj; + bool has_primary = false; + bool has_mmproj = false; + hf_cache::hf_files files; +}; + +static hf_plan get_hf_plan(const common_params_model & model, + const std::string & token, + const common_download_model_opts & opts) { + hf_plan plan; + auto [repo, tag] = common_download_split_repo_tag(model.hf_repo); + + auto all = opts.offline ? hf_cache::get_cached_files(repo) + : hf_cache::get_repo_files(repo, token); + if (all.empty()) { + return plan; } - // check response - if (ggufFile.empty()) { - throw std::runtime_error("error: model does not have ggufFile"); + hf_cache::hf_files candidates; + + if (!model.hf_file.empty()) { + const hf_cache::hf_file * found_file = nullptr; + for (const auto & f : all) { + if (f.path == model.hf_file) { + found_file = &f; + break; + } + } + + if (!found_file) { + LOG_ERR("%s: --hf-file '%s' not found in repository\n", __func__, model.hf_file.c_str()); + list_available_gguf_files(all); + return plan; + } + + plan.primary = *found_file; + plan.has_primary = true; + candidates = get_split_files(all, *found_file); + } else { + std::vector search_priority = {!tag.empty() ? tag : "Q4_K_M", "Q4_0"}; + + for (const auto & q : search_priority) { + candidates = filter_gguf_by_quant(all, q); + if (!candidates.empty()) { + candidates = get_split_files(all, candidates[0]); + break; + } + } + + if (candidates.empty()) { + for (const auto & f : all) { + if (string_ends_with(f.path, ".gguf") && + f.path.find("mmproj") == std::string::npos) { + candidates = get_split_files(all, f); + break; + } + } + } + + if (candidates.empty()) { + LOG_ERR("%s: no GGUF files found in repository %s\n", __func__, repo.c_str()); + list_available_gguf_files(all); + return plan; + } + + plan.primary = candidates[0]; + plan.has_primary = true; } - return { hf_repo, ggufFile, mmprojFile }; + for (const auto & f : candidates) { + plan.files.push_back(f); + } + + if (opts.download_mmproj) { + for (const auto & f : all) { + if (string_ends_with(f.path, ".gguf") && + f.path.find("mmproj") != std::string::npos) { + plan.mmproj = f; + plan.has_mmproj = true; + plan.files.push_back(f); + break; + } + } + } + + return plan; +} + +static std::vector> get_url_tasks(const common_params_model & model) { + auto [prefix_url, idx, count] = get_gguf_split_info(model.url); + + if (count <= 1) { + return {{model.url, model.path}}; + } + + std::vector> files; + + size_t pos = prefix_url.rfind('/'); + std::string prefix_filename = (pos != std::string::npos) ? prefix_url.substr(pos + 1) : prefix_url; + std::string prefix_path = (std::filesystem::path(model.path).parent_path() / prefix_filename).string(); + + for (int i = 1; i <= count; i++) { + std::string suffix = string_format("-%05d-of-%05d.gguf", i, count); + files.emplace_back(prefix_url + suffix, prefix_path + suffix); + } + return files; +} + +common_download_model_result common_download_model(const common_params_model & model, + const std::string & bearer_token, + const common_download_model_opts & opts, + const common_header_list & headers) { + common_download_model_result result; + std::vector> to_download; + hf_plan hf; + + bool is_hf = !model.hf_repo.empty(); + + if (is_hf) { + hf = get_hf_plan(model, bearer_token, opts); + for (const auto & f : hf.files) { + to_download.emplace_back(f.url, f.local_path); + } + } else if (!model.url.empty()) { + to_download = get_url_tasks(model); + } else { + result.model_path = model.path; + return result; + } + + if (to_download.empty()) { + return result; + } + + std::vector> futures; + for (const auto & item : to_download) { + futures.push_back(std::async(std::launch::async, + [u = item.first, p = item.second, &bearer_token, offline = opts.offline, &headers, is_hf]() { + int status = common_download_file_single(u, p, bearer_token, offline, headers, is_hf); + return is_http_status_ok(status); + } + )); + } + + for (auto & f : futures) { + if (!f.get()) { + return {}; + } + } + + if (is_hf) { + for (const auto & f : hf.files) { + hf_cache::finalize_file(f); + } + if (hf.has_primary) { + result.model_path = hf_cache::finalize_file(hf.primary); + } + if (hf.has_mmproj) { + result.mmproj_path = hf_cache::finalize_file(hf.mmproj); + } + } else { + result.model_path = model.path; + } + + return result; } // @@ -764,29 +792,48 @@ std::string common_docker_resolve_model(const std::string & docker) { } } -std::vector common_list_cached_models() { - std::vector models; - const std::string cache_dir = fs_get_cache_directory(); - const std::vector files = fs_list(cache_dir, false); - for (const auto & file : files) { - if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) { - common_cached_model_info model_info; - model_info.manifest_path = file.path; - std::string fname = file.name; - string_replace_all(fname, ".json", ""); // remove extension - auto parts = string_split(fname, '='); - if (parts.size() == 4) { - // expect format: manifest==== - model_info.user = parts[1]; - model_info.model = parts[2]; - model_info.tag = parts[3]; - } else { - // invalid format - continue; +std::vector common_list_cached_models() { + auto files = hf_cache::get_cached_files(""); + std::set models; + + for (const auto & f : files) { + std::string tmp = f.path; + + if (!string_remove_suffix(tmp, ".gguf")) { + continue; + } + if (tmp.find("mmproj") != std::string::npos) { + continue; + } + auto split_pos = tmp.find("-00001-of-"); + + if (split_pos == std::string::npos && + tmp.find("-of-") != std::string::npos) { + continue; + } + if (split_pos != std::string::npos) { + tmp.erase(split_pos); + } + auto sep_pos = tmp.find_last_of("-."); + + if (sep_pos == std::string::npos || sep_pos == tmp.size() - 1) { + continue; + } + tmp.erase(0, sep_pos + 1); + + bool is_valid = true; + for (char & c : tmp) { + unsigned char uc = c; + if (!std::isalnum(uc) && uc != '_') { + is_valid = false; + break; } - model_info.size = 0; // TODO: get GGUF size, not manifest size - models.push_back(model_info); + c = std::toupper(uc); + } + if (is_valid) { + models.insert(f.repo_id + ":" + tmp); } } - return models; + + return {models.begin(), models.end()}; } diff --git a/common/download.h b/common/download.h index 1c1d8e6db5..db91ae7399 100644 --- a/common/download.h +++ b/common/download.h @@ -1,5 +1,7 @@ #pragma once +#include "hf-cache.h" + #include #include @@ -23,23 +25,16 @@ std::pair> common_remote_get_content(const std::string & // example: "user/model" -> <"user/model", "latest"> std::pair common_download_split_repo_tag(const std::string & hf_repo_with_tag); -struct common_cached_model_info { - std::string manifest_path; - std::string user; - std::string model; - std::string tag; - size_t size = 0; // GGUF size in bytes - // return string representation like "user/model:tag" - // if tag is "latest", it will be omitted - std::string to_string() const { - return user + "/" + model + (tag == "latest" ? "" : ":" + tag); - } +// Options for common_download_model +struct common_download_model_opts { + bool download_mmproj = false; + bool offline = false; }; -struct common_hf_file_res { - std::string repo; // repo name with ":tag" removed - std::string ggufFile; - std::string mmprojFile; +// Result of common_download_model +struct common_download_model_result { + std::string model_path; // path to downloaded model (empty on failure) + std::string mmproj_path; // path to downloaded mmproj (empty if not downloaded) }; /** @@ -47,37 +42,27 @@ struct common_hf_file_res { * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s - * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) - * - * Return pair of (with "repo" already having tag removed) - * - * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. + * Tag is optional, it checks for Q4_K_M first, then Q4_0, then if not found, return the first GGUF file in repo */ -common_hf_file_res common_get_hf_file( - const std::string & hf_repo_with_tag, - const std::string & bearer_token, - bool offline, - const common_header_list & headers = {} -); - -// returns true if download succeeded -bool common_download_model( +common_download_model_result common_download_model( const common_params_model & model, const std::string & bearer_token, - bool offline, + const common_download_model_opts & opts = {}, const common_header_list & headers = {} ); // returns list of cached models -std::vector common_list_cached_models(); +std::vector common_list_cached_models(); // download single file from url to local path // returns status code or -1 on error +// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash) int common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline, - const common_header_list & headers = {}); + const common_header_list & headers = {}, + bool skip_etag = false); // resolve and download model from Docker registry // return local path to downloaded model file diff --git a/common/hf-cache.cpp b/common/hf-cache.cpp new file mode 100644 index 0000000000..968e95567a --- /dev/null +++ b/common/hf-cache.cpp @@ -0,0 +1,516 @@ +#include "hf-cache.h" + +#include "common.h" +#include "log.h" +#include "http.h" + +#define JSON_ASSERT GGML_ASSERT +#include + +#include +#include +#include +#include +#include // migration only +#include + +namespace nl = nlohmann; + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif + +namespace hf_cache { + +namespace fs = std::filesystem; + +static fs::path get_cache_directory() { + const char * hf_hub_cache = std::getenv("HF_HUB_CACHE"); + if (hf_hub_cache && *hf_hub_cache) { + return fs::path(hf_hub_cache); // assume shell-expanded; add expand logic if you want full parity + } + + const char * huggingface_hub_cache = std::getenv("HUGGINGFACE_HUB_CACHE"); + if (huggingface_hub_cache && *huggingface_hub_cache) { + return fs::path(huggingface_hub_cache); + } + + const char * hf_home = std::getenv("HF_HOME"); + if (hf_home && *hf_home) { + return fs::path(hf_home) / "hub"; + } + + const char * xdg_cache_home = std::getenv("XDG_CACHE_HOME"); + if (xdg_cache_home && *xdg_cache_home) { + return fs::path(xdg_cache_home) / "huggingface" / "hub"; + } +#if defined(_WIN32) + const char * userprofile = std::getenv("USERPROFILE"); + if (userprofile && *userprofile) { + return fs::path(userprofile) / ".cache" / "huggingface" / "hub"; + } +#else + const char * home = std::getenv("HOME"); + if (home && *home) { + return fs::path(home) / ".cache" / "huggingface" / "hub"; + } +#endif + throw std::runtime_error("Failed to determine HF cache directory"); +} + +static bool symlinks_supported() { +#ifdef _WIN32 + static bool supported = false; + static std::once_flag once; + std::call_once(once, []() { + fs::path link = get_cache_directory() / ("link_" + std::to_string(GetCurrentProcessId())); + + std::error_code ec; + fs::create_directory_symlink("..", link, ec); + supported = !ec; + + if (!ec) { + fs::remove(link, ec); + } else if (GetLastError() == ERROR_PRIVILEGE_NOT_HELD) { + LOG_WRN("symlink creation requires Developer Mode or admin privileges on Windows\n"); + } + }); + return supported; +#else + return true; +#endif +} + +static std::string folder_name_to_repo(const std::string & folder) { + if (folder.size() < 8 || folder.substr(0, 8) != "models--") { + return {}; + } + std::string repo_id; + for (size_t i = 8; i < folder.size(); ++i) { + if (i + 1 < folder.size() && folder[i] == '-' && folder[i+1] == '-') { + repo_id += '/'; + i++; + } else { + repo_id += folder[i]; + } + } + return repo_id; +} + +static std::string repo_to_folder_name(const std::string & repo_id) { + std::string name = "models--"; + for (char c : repo_id) { + if (c == '/') { + name += "--"; + } else { + name += c; + } + } + return name; +} + +static fs::path get_repo_path(const std::string & repo_id) { + return get_cache_directory() / repo_to_folder_name(repo_id); +} + +static void write_ref(const std::string & repo_id, + const std::string & ref, + const std::string & commit) { + fs::path refs_path = get_repo_path(repo_id) / "refs"; + std::error_code ec; + fs::create_directories(refs_path, ec); + + fs::path ref_path = refs_path / ref; + fs::path ref_path_tmp = refs_path / (ref + ".tmp"); + { + std::ofstream file(ref_path_tmp); + if (!file) { + throw std::runtime_error("Failed to write ref file: " + ref_path.string()); + } + file << commit; + } + std::error_code rename_ec; + fs::rename(ref_path_tmp, ref_path, rename_ec); + if (rename_ec) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, ref_path_tmp.c_str(), ref_path.c_str()); + fs::remove(ref_path_tmp, ec); + } +} + +static std::string get_repo_ref(const std::string & repo_id, + const std::string & bearer_token) { + std::string url = get_model_endpoint() + "api/models/" + repo_id + "/refs"; + auto [cli, parts] = common_http_client(url); + + httplib::Headers headers; + headers.emplace("User-Agent", "llama-cpp/" + build_info); + headers.emplace("Accept", "application/json"); + if (!bearer_token.empty()) { + headers.emplace("Authorization", "Bearer " + bearer_token); + } + cli.set_default_headers(headers); + + auto res = cli.Get(parts.path); + if (!res || res->status != 200) { + LOG_WRN("%s: API request failed for %s, status: %d\n", __func__, url.c_str(), res ? res->status : -1); + return {}; + } + + try { + auto j = nl::json::parse(res->body); + + if (!j.contains("branches") || !j["branches"].is_array()) { + return {}; + } + + std::string name; + std::string commit; + + for (const auto & branch : j["branches"]) { + if (!branch.contains("name") || !branch.contains("targetCommit")) { + continue; + } + std::string _name = branch["name"].get(); + std::string _commit = branch["targetCommit"].get(); + + if (_name == "main") { + name = _name; + commit = _commit; + break; + } + + if (name.empty() || commit.empty()) { + name = _name; + commit = _commit; + } + } + + if (!name.empty() && !commit.empty()) { + write_ref(repo_id, name, commit); + } + return commit; + } catch (const std::exception & e) { + LOG_WRN("%s: failed to parse API response: %s\n", __func__, e.what()); + return {}; + } +} + +hf_files get_repo_files(const std::string & repo_id, + const std::string & bearer_token) { + std::string rev = get_repo_ref(repo_id, bearer_token); + if (rev.empty()) { + LOG_WRN("%s: failed to resolve commit hash for %s\n", __func__, repo_id.c_str()); + return {}; + } + + std::string url = get_model_endpoint() + "api/models/" + repo_id + "/tree/" + rev + "?recursive=true"; + + auto [cli, parts] = common_http_client(url); + + httplib::Headers headers; + headers.emplace("User-Agent", "llama-cpp/" + build_info); + headers.emplace("Accept", "application/json"); + if (!bearer_token.empty()) { + headers.emplace("Authorization", "Bearer " + bearer_token); + } + cli.set_default_headers(headers); + + auto res = cli.Get(parts.path); + if (!res || res->status != 200) { + LOG_WRN("%s: API request failed for %s, status: %d\n", __func__, url.c_str(), res ? res->status : -1); + return {}; + } + + std::string endpoint = get_model_endpoint(); // TODO + bool use_symlinks = symlinks_supported(); + hf_files files; + + try { + auto j = nl::json::parse(res->body); + + if (!j.is_array()) { + LOG_DBG("%s: response is not an array\n", __func__); + return files; + } + + for (const auto & item : j) { + if (!item.contains("type") || item["type"] != "file") { + continue; + } + if (!item.contains("path")) { + continue; + } + + hf_file file; + file.repo_id = repo_id; + file.path = item["path"].get(); + + if (item.contains("lfs") && item["lfs"].is_object()) { + if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) { + file.oid = item["lfs"]["oid"].get(); + } + } else if (item.contains("oid") && item["oid"].is_string()) { + file.oid = item["oid"].get(); + } + + file.url = endpoint + repo_id + "/resolve/" + rev + "/" + file.path; + + fs::path path = file.path; + fs::path repo_path = get_repo_path(repo_id); + fs::path snapshots_path = repo_path / "snapshots" / rev / path; + fs::path blobs_path = repo_path / "blobs" / file.oid; + + if (use_symlinks) { + file.local_path = blobs_path.string(); + file.link_path = snapshots_path.string(); + } else { // degraded mode + file.local_path = snapshots_path.string(); + } + + files.push_back(file); + } + } catch (const std::exception & e) { + LOG_WRN("%s: failed to parse API response: %s\n", __func__, e.what()); + return {}; + } + + return files; +} + +static std::string get_cached_ref(const fs::path & repo_path) { + fs::path refs_path = repo_path / "refs"; + if (!fs::is_directory(refs_path)) { + return {}; + } + for (const auto & entry : fs::directory_iterator(refs_path)) { + if (entry.is_regular_file()) { + std::ifstream f(entry.path()); + std::string commit; + if (f && std::getline(f, commit) && !commit.empty()) { + return commit; + } + } + } + return {}; +} + +hf_files get_cached_files(const std::string & repo_id) { + fs::path cache_dir = get_cache_directory(); + if (!fs::exists(cache_dir)) { + return {}; + } + hf_files files; + + for (const auto & repo : fs::directory_iterator(cache_dir)) { + if (!repo.is_directory()) { + continue; + } + fs::path snapshots_path = repo.path() / "snapshots"; + + if (!fs::exists(snapshots_path)) { + continue; + } + std::string _repo_id = folder_name_to_repo(repo.path().filename().string()); + + if (_repo_id.empty()) { + continue; + } + if (!repo_id.empty() && _repo_id != repo_id) { + continue; + } + std::string commit = get_cached_ref(repo.path()); + fs::path rev_path = snapshots_path / commit; + + if (commit.empty() || !fs::is_directory(rev_path)) { + continue; + } + for (const auto & entry : fs::recursive_directory_iterator(rev_path)) { + if (!entry.is_regular_file() && !entry.is_symlink()) { + continue; + } + fs::path path = entry.path().lexically_relative(rev_path); + + if (!path.empty()) { + hf_file file; + file.repo_id = _repo_id; + file.path = path.generic_string(); + file.local_path = entry.path().string(); + files.push_back(std::move(file)); + } + } + } + + return files; +} + +std::string finalize_file(const hf_file & file) { + if (file.link_path.empty()) { + return file.local_path; + } + + fs::path link_path(file.link_path); + fs::path local_path(file.local_path); + + std::error_code ec; + fs::create_directories(link_path.parent_path(), ec); + fs::path target_path = fs::relative(local_path, link_path.parent_path(), ec); + fs::create_symlink(target_path, link_path, ec); + + if (fs::exists(link_path)) { + return file.link_path; + } + + LOG_WRN("%s: failed to create symlink: %s\n", __func__, file.link_path.c_str()); + return file.local_path; +} + +// delete everything after this line, one day + +static std::pair parse_manifest_name(std::string & filename) { + static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)"); + std::smatch match; + if (std::regex_match(filename, match, re)) { + return {match[1].str(), match[2].str()}; + } + return {}; +} + +static std::string make_old_cache_filename(const std::string & owner, + const std::string & repo, + const std::string & filename) { + std::string name = owner + "_" + repo + "_" + filename; + for (char & c : name) { + if (c == '/') { + c = '_'; + } + } + return name; +} + +static bool migrate_single_file(const fs::path & old_cache, + const std::string & owner, + const std::string & repo, + const nl::json & node, + const hf_files & files) { + + if (!node.contains("rfilename") || + !node.contains("lfs") || + !node["lfs"].contains("sha256")) { + return false; + } + + std::string path = node["rfilename"]; + std::string sha256 = node["lfs"]["sha256"]; + + const hf_file * file_info = nullptr; + for (const auto & f : files) { + if (f.path == path) { + file_info = &f; + break; + } + } + + std::string old_filename = make_old_cache_filename(owner, repo, path); + fs::path old_path = old_cache / old_filename; + fs::path etag_path = old_path.string() + ".etag"; + + if (!fs::exists(old_path)) { + if (fs::exists(etag_path)) { + LOG_WRN("%s: %s is orphan, deleting...\n", __func__, etag_path.string().c_str()); + fs::remove(etag_path); + } + return false; + } + + bool delete_old_path = false; + + if (!file_info) { + LOG_WRN("%s: %s not found in current repo, deleting...\n", __func__, old_filename.c_str()); + delete_old_path = true; + } else if (!sha256.empty() && !file_info->oid.empty() && sha256 != file_info->oid) { + LOG_WRN("%s: %s is not up to date (sha256 mismatch), deleting...\n", __func__, old_filename.c_str()); + delete_old_path = true; + } + + std::error_code ec; + + if (delete_old_path) { + fs::remove(old_path, ec); + fs::remove(etag_path, ec); + return true; + } + + fs::path new_path(file_info->local_path); + fs::create_directories(new_path.parent_path(), ec); + + if (!fs::exists(new_path, ec)) { + fs::rename(old_path, new_path, ec); + if (ec) { + fs::copy_file(old_path, new_path, ec); + if (ec) { + LOG_WRN("%s: failed to move/copy %s: %s\n", __func__, old_path.string().c_str(), ec.message().c_str()); + return false; + } + } + fs::remove(old_path, ec); + } + fs::remove(etag_path, ec); + + std::string snapshot_path = finalize_file(*file_info); + LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), snapshot_path.c_str()); + + return true; +} + +void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offline) { + fs::path old_cache = fs_get_cache_directory(); + if (!fs::exists(old_cache)) { + return; + } + + if (offline) { + LOG_WRN("%s: skipping migration in offline mode (will run when online)\n", __func__); + return; // -hf is not going to work + } + + for (const auto & entry : fs::directory_iterator(old_cache)) { + if (!entry.is_regular_file()) { + continue; + } + auto filename = entry.path().filename().string(); + auto [owner, repo] = parse_manifest_name(filename); + + if (owner.empty() || repo.empty()) { + continue; + } + + auto repo_id = owner + "/" + repo; + auto files = get_repo_files(repo_id, bearer_token); + + if (files.empty()) { + LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str()); + continue; + } + + try { + std::ifstream manifest_stream(entry.path()); + std::string content((std::istreambuf_iterator(manifest_stream)), std::istreambuf_iterator()); + auto j = nl::json::parse(content); + for (const char* key : {"ggufFile", "mmprojFile"}) { + if (j.contains(key)) { + migrate_single_file(old_cache, owner, repo, j[key], files); + } + } + } catch (const std::exception & e) { + LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what()); + continue; + } + fs::remove(entry.path()); + } +} + +} // namespace hf_cache diff --git a/common/hf-cache.h b/common/hf-cache.h new file mode 100644 index 0000000000..1d417111ac --- /dev/null +++ b/common/hf-cache.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include +#include + +// Ref: https://huggingface.co/docs/hub/local-cache.md + +namespace hf_cache { + +struct hf_file { + std::string path; + std::string url; + std::string local_path; + std::string link_path; + std::string oid; + std::string repo_id; +}; + +using hf_files = std::vector; + +// Get files from HF API +hf_files get_repo_files( + const std::string & repo_id, + const std::string & bearer_token +); + +hf_files get_cached_files(const std::string & repo_id); + +// Create symlink if link_path is set and returns the snapshot path +std::string finalize_file(const hf_file & file); + +// TODO: Remove later +void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offline = false); + +} // namespace hf_cache diff --git a/common/preset.cpp b/common/preset.cpp index 57ccd000b5..6bbd591c64 100644 --- a/common/preset.cpp +++ b/common/preset.cpp @@ -365,8 +365,8 @@ common_presets common_preset_context::load_from_cache() const { auto cached_models = common_list_cached_models(); for (const auto & model : cached_models) { common_preset preset; - preset.name = model.to_string(); - preset.set_option(*this, "LLAMA_ARG_HF_REPO", model.to_string()); + preset.name = model; + preset.set_option(*this, "LLAMA_ARG_HF_REPO", model); out[preset.name] = preset; } diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index b0f1d6b936..8396492acc 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -979,37 +979,20 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { for (size_t i = 0; i < params.hf_repo.size(); i++) { common_params_model model; - // step 1: no `-hff` provided, we auto-detect based on the `-hf` flag if (params.hf_file.empty() || params.hf_file[i].empty()) { - auto auto_detected = common_get_hf_file(params.hf_repo[i], params.hf_token, false); - if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { - exit(1); - } - - model.name = params.hf_repo[i]; - model.hf_repo = auto_detected.repo; - model.hf_file = auto_detected.ggufFile; + model.hf_repo = params.hf_repo[i]; } else { + model.hf_repo = params.hf_repo[i]; model.hf_file = params.hf_file[i]; } - // step 2: construct the model cache path - std::string clean_fname = model.hf_repo + "_" + model.hf_file; - string_replace_all(clean_fname, "\\", "_"); - string_replace_all(clean_fname, "/", "_"); - model.path = fs_get_cache_file(clean_fname); - - // step 3: download the model if not exists - std::string model_endpoint = get_model_endpoint(); - model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; - - bool ok = common_download_model(model, params.hf_token, false); - if (!ok) { - fprintf(stderr, "error: failed to download model from %s\n", model.url.c_str()); + auto download_result = common_download_model(model, params.hf_token); + if (download_result.model_path.empty()) { + fprintf(stderr, "error: failed to download model from HuggingFace\n"); exit(1); } - params.model.push_back(model.path); + params.model.push_back(download_result.model_path); } }