diff --git a/common/hf-cache.cpp b/common/hf-cache.cpp index 41661c3d94..8aab2d117c 100644 --- a/common/hf-cache.cpp +++ b/common/hf-cache.cpp @@ -13,6 +13,7 @@ #include // migration only #include #include +#include namespace nl = nlohmann; @@ -106,16 +107,6 @@ static fs::path get_repo_path(const std::string & repo_id) { return get_cache_directory() / repo_to_folder_name(repo_id); } -static void set_default_headers(httplib::Client & cli, const std::string & bearer_token) { - httplib::Headers headers; - headers.emplace("User-Agent", "llama-cpp/" + build_info); - headers.emplace("Accept", "application/json"); - if (!bearer_token.empty()) { - headers.emplace("Authorization", "Bearer " + bearer_token); - } - cli.set_default_headers(headers); -} - static void write_ref(const std::string & repo_id, const std::string & ref, const std::string & commit) { @@ -135,36 +126,59 @@ static void write_ref(const std::string & repo_id, std::error_code rename_ec; fs::rename(ref_path_tmp, ref_path, rename_ec); if (rename_ec) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, ref_path_tmp.c_str(), ref_path.c_str()); + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, + ref_path_tmp.string().c_str(), ref_path.string().c_str()); fs::remove(ref_path_tmp, ec); } } +static nl::json api_get(const std::string & url, + const std::string & bearer_token) { + auto [cli, parts] = common_http_client(url); + + httplib::Headers headers = { + {"User-Agent", "llama-cpp/" + build_info}, + {"Accept", "application/json"} + }; + if (!bearer_token.empty()) { + headers.emplace("Authorization", "Bearer " + bearer_token); + } + + if (auto res = cli.Get(parts.path, headers)) { + auto body = res->body; + + if (res->status == 200) { + return nl::json::parse(res->body); + } + try { + body = nl::json::parse(res->body)["error"].get(); + } catch (...) { } + + throw std::runtime_error("GET failed (" + std::to_string(res->status) + "): " + body); + } else { + throw std::runtime_error("HTTPLIB failed: " + httplib::to_string(res.error())); + } +} + static std::string get_repo_ref(const std::string & repo_id, const std::string & bearer_token) { - std::string url = get_model_endpoint() + "api/models/" + repo_id + "/refs"; - auto [cli, parts] = common_http_client(url); - - set_default_headers(cli, bearer_token); - - auto res = cli.Get(parts.path); - if (!res || res->status != 200) { - LOG_WRN("%s: API request failed for %s, status: %d\n", __func__, url.c_str(), res ? res->status : -1); - return {}; - } - try { - auto j = nl::json::parse(res->body); + auto endpoint = get_model_endpoint(); + auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", bearer_token); - if (!j.contains("branches") || !j["branches"].is_array()) { + if (!json.is_object() || + !json.contains("branches") || !json["branches"].is_array()) { + LOG_WRN("%s: missing 'branches' for '%s'\n", __func__, repo_id.c_str()); return {}; } std::string name; std::string commit; - for (const auto & branch : j["branches"]) { - if (!branch.contains("name") || !branch.contains("targetCommit")) { + for (const auto & branch : json["branches"]) { + if (!branch.is_object() || + !branch.contains("name") || !branch["name"].is_string() || + !branch.contains("targetCommit") || !branch["targetCommit"].is_string()) { continue; } std::string _name = branch["name"].get(); @@ -182,53 +196,45 @@ static std::string get_repo_ref(const std::string & repo_id, } } - if (!name.empty() && !commit.empty()) { - write_ref(repo_id, name, commit); + if (name.empty() || commit.empty()) { + LOG_WRN("%s: no valid branch for '%s'\n", __func__, repo_id.c_str()); + return {}; } + + write_ref(repo_id, name, commit); return commit; + + } catch (const nl::json::exception & e) { + LOG_ERR("%s: JSON error '%s': %s\n", __func__, repo_id.c_str(), e.what()); } catch (const std::exception & e) { - LOG_WRN("%s: failed to parse API response: %s\n", __func__, e.what()); - return {}; + LOG_ERR("%s: API error '%s': %s\n", __func__, repo_id.c_str(), e.what()); } + return {}; } hf_files get_repo_files(const std::string & repo_id, const std::string & bearer_token) { + hf_files files; std::string rev = get_repo_ref(repo_id, bearer_token); + if (rev.empty()) { LOG_WRN("%s: failed to resolve commit hash for %s\n", __func__, repo_id.c_str()); return {}; } - std::string url = get_model_endpoint() + "api/models/" + repo_id + "/tree/" + rev + "?recursive=true"; - - auto [cli, parts] = common_http_client(url); - - set_default_headers(cli, bearer_token); - - auto res = cli.Get(parts.path); - if (!res || res->status != 200) { - LOG_WRN("%s: API request failed for %s, status: %d\n", __func__, url.c_str(), res ? res->status : -1); - return {}; - } - - std::string endpoint = get_model_endpoint(); // TODO - bool use_symlinks = symlinks_supported(); - hf_files files; - try { - auto j = nl::json::parse(res->body); + auto endpoint = get_model_endpoint(); + auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + rev + "?recursive=true", bearer_token); - if (!j.is_array()) { - LOG_DBG("%s: response is not an array\n", __func__); + if (!json.is_array()) { + LOG_WRN("%s: response is not an array for '%s'\n", __func__, repo_id.c_str()); return files; } - for (const auto & item : j) { - if (!item.contains("type") || item["type"] != "file") { - continue; - } - if (!item.contains("path")) { + for (const auto & item : json) { + if (!item.is_object() || + !item.contains("type") || !item["type"].is_string() || item["type"] != "file" || + !item.contains("path") || !item["path"].is_string()) { continue; } @@ -251,7 +257,7 @@ hf_files get_repo_files(const std::string & repo_id, fs::path snapshots_path = repo_path / "snapshots" / rev / path; fs::path blobs_path = repo_path / "blobs" / file.oid; - if (use_symlinks) { + if (symlinks_supported()) { file.local_path = blobs_path.string(); file.link_path = snapshots_path.string(); } else { // degraded mode @@ -260,11 +266,11 @@ hf_files get_repo_files(const std::string & repo_id, files.push_back(file); } + } catch (const nl::json::exception & e) { + LOG_ERR("%s: JSON error '%s': %s\n", __func__, repo_id.c_str(), e.what()); } catch (const std::exception & e) { - LOG_WRN("%s: failed to parse API response: %s\n", __func__, e.what()); - return {}; + LOG_ERR("%s: API error '%s': %s\n", __func__, repo_id.c_str(), e.what()); } - return files; } @@ -369,13 +375,9 @@ static std::pair parse_manifest_name(std::string & fil static std::string make_old_cache_filename(const std::string & owner, const std::string & repo, const std::string & filename) { - std::string name = owner + "_" + repo + "_" + filename; - for (char & c : name) { - if (c == '/') { - c = '_'; - } - } - return name; + auto result = owner + "_" + repo + "_" + filename; + string_replace_all(result, "/", "_"); + return result; } static bool migrate_single_file(const fs::path & old_cache, @@ -447,8 +449,8 @@ static bool migrate_single_file(const fs::path & old_cache, } fs::remove(etag_path, ec); - std::string snapshot_path = finalize_file(*file_info); - LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), snapshot_path.c_str()); + std::string snapshot_file = finalize_file(*file_info); + LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), snapshot_file.c_str()); return true; }