diff --git a/common/hf-cache.cpp b/common/hf-cache.cpp index 4bde42cf86..ad68c55674 100644 --- a/common/hf-cache.cpp +++ b/common/hf-cache.cpp @@ -77,41 +77,127 @@ static fs::path get_repo_path(const std::string & repo_id) { return get_cache_directory() / repo_to_folder_name(repo_id); } -static void write_ref(const std::string & repo_id, - const std::string & ref, - const std::string & commit) { - fs::path refs_path = get_repo_path(repo_id) / "refs"; - std::error_code ec; - fs::create_directories(refs_path, ec); +static bool is_hex_char(const char c) { + return (c >= 'A' && c <= 'F') || + (c >= 'a' && c <= 'f') || + (c >= '0' && c <= '9'); +} - fs::path ref_path = refs_path / ref; - fs::path ref_path_tmp = refs_path / (ref + ".tmp"); - { - std::ofstream file(ref_path_tmp); - if (!file) { - throw std::runtime_error("Failed to write ref file: " + ref_path.string()); - } - file << commit; +static bool is_hex_string(const std::string & s, size_t expected_len) { + if (s.length() != expected_len) { + return false; } - std::error_code rename_ec; - fs::rename(ref_path_tmp, ref_path, rename_ec); - if (rename_ec) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, - ref_path_tmp.string().c_str(), ref_path.string().c_str()); - fs::remove(ref_path_tmp, ec); + for (const char c : s) { + if (!is_hex_char(c)) { + return false; + } + } + return true; +} + +static bool is_alphanum(const char c) { + return (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9'); +} + +static bool is_special_char(char c) { + return c == '/' || c == '.' || c == '-'; +} + +// base chars [A-Za-z0-9_] are always valid +// special chars [/.-] must be surrounded by base chars +// exactly one '/' required +static bool is_valid_repo_id(const std::string & repo_id) { + if (repo_id.empty() || repo_id.length() > 256) { + return false; + } + int slash = 0; + bool special = true; + + for (const char c : repo_id) { + if (is_alphanum(c) || c == '_') { + special = false; + } else if (is_special_char(c)) { + if (special) { + return false; + } + slash += (c == '/'); + special = true; + } else { + return false; + } + } + return !special && slash == 1; +} + +static bool is_valid_hf_token(const std::string & token) { + if (token.length() < 37 || token.length() > 256 || + !string_starts_with(token, "hf_")) { + return false; + } + for (size_t i = 3; i < token.length(); ++i) { + if (!is_alphanum(token[i])) { + return false; + } + } + return true; +} + +static bool is_valid_commit(const std::string & hash) { + return is_hex_string(hash, 40); +} + +static bool is_valid_oid(const std::string & oid) { + return is_hex_string(oid, 40) || is_hex_string(oid, 64); +} + +static bool is_valid_subpath(const fs::path & path, const fs::path & subpath) { + if (subpath.is_absolute()) { + return false; // never do a / b with b absolute + } + auto b = fs::absolute(path).lexically_normal(); + auto t = (b / subpath).lexically_normal(); + auto [b_end, _] = std::mismatch(b.begin(), b.end(), t.begin(), t.end()); + + return b_end == b.end(); +} + +static void safe_write_file(const fs::path & path, const std::string & data) { + fs::path path_tmp = path.string() + ".tmp"; + + if (path.has_parent_path()) { + fs::create_directories(path.parent_path()); + } + + std::ofstream file(path_tmp); + file << data; + file.close(); + + std::error_code ec; + + if (!file.fail()) { + fs::rename(path_tmp, path, ec); + } + if (file.fail() || ec) { + fs::remove(path_tmp, ec); + throw std::runtime_error("failed to write file: " + path.string()); } } static nl::json api_get(const std::string & url, - const std::string & bearer_token) { + const std::string & token) { auto [cli, parts] = common_http_client(url); httplib::Headers headers = { {"User-Agent", "llama-cpp/" + build_info}, {"Accept", "application/json"} }; - if (!bearer_token.empty()) { - headers.emplace("Authorization", "Bearer " + bearer_token); + + if (is_valid_hf_token(token)) { + headers.emplace("Authorization", "Bearer " + token); + } else if (!token.empty()) { + LOG_WRN("%s: invalid token, authentication disabled\n", __func__); } if (auto res = cli.Get(parts.path, headers)) { @@ -130,11 +216,11 @@ static nl::json api_get(const std::string & url, } } -static std::string get_repo_ref(const std::string & repo_id, - const std::string & bearer_token) { +static std::string get_repo_commit(const std::string & repo_id, + const std::string & token) { try { auto endpoint = get_model_endpoint(); - auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", bearer_token); + auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token); if (!json.is_object() || !json.contains("branches") || !json["branches"].is_array()) { @@ -142,6 +228,7 @@ static std::string get_repo_ref(const std::string & repo_id, return {}; } + fs::path refs_path = get_repo_path(repo_id) / "refs"; std::string name; std::string commit; @@ -154,6 +241,15 @@ static std::string get_repo_ref(const std::string & repo_id, std::string _name = branch["name"].get(); std::string _commit = branch["targetCommit"].get(); + if (!is_valid_subpath(refs_path, _name)) { + LOG_WRN("%s: skip invalid branch: %s\n", __func__, _name.c_str()); + continue; + } + if (!is_valid_commit(_commit)) { + LOG_WRN("%s: skip invalid commit: %s\n", __func__, _commit.c_str()); + continue; + } + if (_name == "main") { name = _name; commit = _commit; @@ -171,34 +267,42 @@ static std::string get_repo_ref(const std::string & repo_id, return {}; } - write_ref(repo_id, name, commit); + safe_write_file(refs_path / name, commit); return commit; } catch (const nl::json::exception & e) { - LOG_ERR("%s: JSON error '%s': %s\n", __func__, repo_id.c_str(), e.what()); + LOG_ERR("%s: JSON error: %s\n", __func__, e.what()); } catch (const std::exception & e) { - LOG_ERR("%s: API error '%s': %s\n", __func__, repo_id.c_str(), e.what()); + LOG_ERR("%s: error: %s\n", __func__, e.what()); } return {}; } hf_files get_repo_files(const std::string & repo_id, - const std::string & bearer_token) { - hf_files files; - std::string rev = get_repo_ref(repo_id, bearer_token); - - if (rev.empty()) { - LOG_WRN("%s: failed to resolve commit hash for %s\n", __func__, repo_id.c_str()); + const std::string & token) { + if (!is_valid_repo_id(repo_id)) { + LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str()); return {}; } + std::string commit = get_repo_commit(repo_id, token); + if (commit.empty()) { + LOG_WRN("%s: failed to resolve commit for %s\n", __func__, repo_id.c_str()); + return {}; + } + + fs::path blobs_path = get_repo_path(repo_id) / "blobs"; + fs::path commit_path = get_repo_path(repo_id) / "snapshots" / commit; + + hf_files files; + try { auto endpoint = get_model_endpoint(); - auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + rev + "?recursive=true", bearer_token); + auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token); if (!json.is_array()) { LOG_WRN("%s: response is not an array for '%s'\n", __func__, repo_id.c_str()); - return files; + return {}; } for (const auto & item : json) { @@ -212,6 +316,11 @@ hf_files get_repo_files(const std::string & repo_id, file.repo_id = repo_id; file.path = item["path"].get(); + if (!is_valid_subpath(commit_path, file.path)) { + LOG_WRN("%s: skip invalid path: %s\n", __func__, file.path.c_str()); + continue; + } + if (item.contains("lfs") && item["lfs"].is_object()) { if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) { file.oid = item["lfs"]["oid"].get(); @@ -220,26 +329,29 @@ hf_files get_repo_files(const std::string & repo_id, file.oid = item["oid"].get(); } - file.url = endpoint + repo_id + "/resolve/" + rev + "/" + file.path; + if (!file.oid.empty() && !is_valid_oid(file.oid)) { + LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str()); + continue; + } - fs::path path = file.path; - fs::path repo_path = get_repo_path(repo_id); - fs::path snapshots_path = repo_path / "snapshots" / rev / path; + file.url = endpoint + repo_id + "/resolve/" + commit + "/" + file.path; - file.final_path = snapshots_path.string(); - file.local_path = file.final_path; + fs::path final_path = commit_path / file.path; + file.final_path = final_path.string(); - if (!file.oid.empty() && !fs::exists(snapshots_path)) { - fs::path blob_path = repo_path / "blobs" / file.oid; - file.local_path = blob_path.string(); + if (!file.oid.empty() && !fs::exists(final_path)) { + fs::path local_path = blobs_path / file.oid; + file.local_path = local_path.string(); + } else { + file.local_path = file.final_path; } files.push_back(file); } } catch (const nl::json::exception & e) { - LOG_ERR("%s: JSON error '%s': %s\n", __func__, repo_id.c_str(), e.what()); + LOG_ERR("%s: JSON error: %s\n", __func__, e.what()); } catch (const std::exception & e) { - LOG_ERR("%s: API error '%s': %s\n", __func__, repo_id.c_str(), e.what()); + LOG_ERR("%s: error: %s\n", __func__, e.what()); } return files; } @@ -260,6 +372,10 @@ static std::string get_cached_ref(const fs::path & repo_path) { if (!f || !std::getline(f, commit) || commit.empty()) { continue; } + if (!is_valid_commit(commit)) { + LOG_WRN("%s: skip invalid commit: %s\n", __func__, commit.c_str()); + continue; + } if (entry.path().filename() == "main") { return commit; } @@ -275,6 +391,12 @@ hf_files get_cached_files(const std::string & repo_id) { if (!fs::exists(cache_dir)) { return {}; } + + if (!repo_id.empty() && !is_valid_repo_id(repo_id)) { + LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str()); + return {}; + } + hf_files files; for (const auto & repo : fs::directory_iterator(cache_dir)) { @@ -288,23 +410,23 @@ hf_files get_cached_files(const std::string & repo_id) { } std::string _repo_id = folder_name_to_repo(repo.path().filename().string()); - if (_repo_id.empty()) { + if (!is_valid_repo_id(_repo_id)) { continue; } if (!repo_id.empty() && _repo_id != repo_id) { continue; } std::string commit = get_cached_ref(repo.path()); - fs::path rev_path = snapshots_path / commit; + fs::path commit_path = snapshots_path / commit; - if (commit.empty() || !fs::is_directory(rev_path)) { + if (commit.empty() || !fs::is_directory(commit_path)) { continue; } - for (const auto & entry : fs::recursive_directory_iterator(rev_path)) { + for (const auto & entry : fs::recursive_directory_iterator(commit_path)) { if (!entry.is_regular_file() && !entry.is_symlink()) { continue; } - fs::path path = entry.path().lexically_relative(rev_path); + fs::path path = entry.path().lexically_relative(commit_path); if (!path.empty()) { hf_file file; @@ -324,23 +446,23 @@ std::string finalize_file(const hf_file & file) { static std::atomic symlinks_disabled{false}; std::error_code ec; - fs::path blob_path(file.local_path); - fs::path snapshot_path(file.final_path); + fs::path local_path(file.local_path); + fs::path final_path(file.final_path); - if (blob_path == snapshot_path || fs::exists(snapshot_path, ec)) { + if (local_path == final_path || fs::exists(final_path, ec)) { return file.final_path; } - if (!fs::exists(blob_path, ec)) { + if (!fs::exists(local_path, ec)) { return file.final_path; } - fs::create_directories(snapshot_path.parent_path(), ec); + fs::create_directories(final_path.parent_path(), ec); if (!symlinks_disabled) { - fs::path target = fs::relative(blob_path, snapshot_path.parent_path(), ec); + fs::path target = fs::relative(local_path, final_path.parent_path(), ec); if (!ec) { - fs::create_symlink(target, snapshot_path, ec); + fs::create_symlink(target, final_path, ec); } if (!ec) { return file.final_path; @@ -352,10 +474,10 @@ std::string finalize_file(const hf_file & file) { LOG_WRN("%s: switching to degraded mode\n", __func__); } - fs::rename(blob_path, snapshot_path, ec); + fs::rename(local_path, final_path, ec); if (ec) { LOG_WRN("%s: failed to move file to snapshots: %s\n", __func__, ec.message().c_str()); - fs::copy(blob_path, snapshot_path, ec); + fs::copy(local_path, final_path, ec); if (ec) { LOG_ERR("%s: failed to copy file to snapshots: %s\n", __func__, ec.message().c_str()); } @@ -451,13 +573,13 @@ static bool migrate_single_file(const fs::path & old_cache, } fs::remove(etag_path, ec); - std::string snapshot_file = finalize_file(*file_info); - LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), snapshot_file.c_str()); + std::string filename = finalize_file(*file_info); + LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), filename.c_str()); return true; } -void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offline) { +void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) { fs::path old_cache = fs_get_cache_directory(); if (!fs::exists(old_cache)) { return; @@ -480,7 +602,7 @@ void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offlin } auto repo_id = owner + "/" + repo; - auto files = get_repo_files(repo_id, bearer_token); + auto files = get_repo_files(repo_id, token); if (files.empty()) { LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str()); @@ -488,12 +610,12 @@ void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offlin } try { - std::ifstream manifest_stream(entry.path()); - std::string content((std::istreambuf_iterator(manifest_stream)), std::istreambuf_iterator()); - auto j = nl::json::parse(content); - for (const char* key : {"ggufFile", "mmprojFile"}) { - if (j.contains(key)) { - migrate_single_file(old_cache, owner, repo, j[key], files); + std::ifstream manifest(entry.path()); + auto json = nl::json::parse(manifest); + + for (const char * key : {"ggufFile", "mmprojFile"}) { + if (json.contains(key)) { + migrate_single_file(old_cache, owner, repo, json[key], files); } } } catch (const std::exception & e) { diff --git a/common/hf-cache.h b/common/hf-cache.h index 7934ec7970..ee2e98494a 100644 --- a/common/hf-cache.h +++ b/common/hf-cache.h @@ -21,7 +21,7 @@ using hf_files = std::vector; // Get files from HF API hf_files get_repo_files( const std::string & repo_id, - const std::string & bearer_token + const std::string & token ); hf_files get_cached_files(const std::string & repo_id = {}); @@ -30,6 +30,6 @@ hf_files get_cached_files(const std::string & repo_id = {}); std::string finalize_file(const hf_file & file); // TODO: Remove later -void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offline = false); +void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false); } // namespace hf_cache