Check all inputs

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2026-03-22 18:28:54 +00:00
parent 6ab630f5f8
commit 3645fee1ed
No known key found for this signature in database
2 changed files with 196 additions and 74 deletions

View File

@ -77,41 +77,127 @@ static fs::path get_repo_path(const std::string & repo_id) {
return get_cache_directory() / repo_to_folder_name(repo_id);
}
static void write_ref(const std::string & repo_id,
const std::string & ref,
const std::string & commit) {
fs::path refs_path = get_repo_path(repo_id) / "refs";
std::error_code ec;
fs::create_directories(refs_path, ec);
static bool is_hex_char(const char c) {
return (c >= 'A' && c <= 'F') ||
(c >= 'a' && c <= 'f') ||
(c >= '0' && c <= '9');
}
fs::path ref_path = refs_path / ref;
fs::path ref_path_tmp = refs_path / (ref + ".tmp");
{
std::ofstream file(ref_path_tmp);
if (!file) {
throw std::runtime_error("Failed to write ref file: " + ref_path.string());
}
file << commit;
static bool is_hex_string(const std::string & s, size_t expected_len) {
if (s.length() != expected_len) {
return false;
}
std::error_code rename_ec;
fs::rename(ref_path_tmp, ref_path, rename_ec);
if (rename_ec) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__,
ref_path_tmp.string().c_str(), ref_path.string().c_str());
fs::remove(ref_path_tmp, ec);
for (const char c : s) {
if (!is_hex_char(c)) {
return false;
}
}
return true;
}
static bool is_alphanum(const char c) {
return (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9');
}
static bool is_special_char(char c) {
return c == '/' || c == '.' || c == '-';
}
// base chars [A-Za-z0-9_] are always valid
// special chars [/.-] must be surrounded by base chars
// exactly one '/' required
static bool is_valid_repo_id(const std::string & repo_id) {
if (repo_id.empty() || repo_id.length() > 256) {
return false;
}
int slash = 0;
bool special = true;
for (const char c : repo_id) {
if (is_alphanum(c) || c == '_') {
special = false;
} else if (is_special_char(c)) {
if (special) {
return false;
}
slash += (c == '/');
special = true;
} else {
return false;
}
}
return !special && slash == 1;
}
static bool is_valid_hf_token(const std::string & token) {
if (token.length() < 37 || token.length() > 256 ||
!string_starts_with(token, "hf_")) {
return false;
}
for (size_t i = 3; i < token.length(); ++i) {
if (!is_alphanum(token[i])) {
return false;
}
}
return true;
}
static bool is_valid_commit(const std::string & hash) {
return is_hex_string(hash, 40);
}
static bool is_valid_oid(const std::string & oid) {
return is_hex_string(oid, 40) || is_hex_string(oid, 64);
}
static bool is_valid_subpath(const fs::path & path, const fs::path & subpath) {
if (subpath.is_absolute()) {
return false; // never do a / b with b absolute
}
auto b = fs::absolute(path).lexically_normal();
auto t = (b / subpath).lexically_normal();
auto [b_end, _] = std::mismatch(b.begin(), b.end(), t.begin(), t.end());
return b_end == b.end();
}
static void safe_write_file(const fs::path & path, const std::string & data) {
fs::path path_tmp = path.string() + ".tmp";
if (path.has_parent_path()) {
fs::create_directories(path.parent_path());
}
std::ofstream file(path_tmp);
file << data;
file.close();
std::error_code ec;
if (!file.fail()) {
fs::rename(path_tmp, path, ec);
}
if (file.fail() || ec) {
fs::remove(path_tmp, ec);
throw std::runtime_error("failed to write file: " + path.string());
}
}
static nl::json api_get(const std::string & url,
const std::string & bearer_token) {
const std::string & token) {
auto [cli, parts] = common_http_client(url);
httplib::Headers headers = {
{"User-Agent", "llama-cpp/" + build_info},
{"Accept", "application/json"}
};
if (!bearer_token.empty()) {
headers.emplace("Authorization", "Bearer " + bearer_token);
if (is_valid_hf_token(token)) {
headers.emplace("Authorization", "Bearer " + token);
} else if (!token.empty()) {
LOG_WRN("%s: invalid token, authentication disabled\n", __func__);
}
if (auto res = cli.Get(parts.path, headers)) {
@ -130,11 +216,11 @@ static nl::json api_get(const std::string & url,
}
}
static std::string get_repo_ref(const std::string & repo_id,
const std::string & bearer_token) {
static std::string get_repo_commit(const std::string & repo_id,
const std::string & token) {
try {
auto endpoint = get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", bearer_token);
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);
if (!json.is_object() ||
!json.contains("branches") || !json["branches"].is_array()) {
@ -142,6 +228,7 @@ static std::string get_repo_ref(const std::string & repo_id,
return {};
}
fs::path refs_path = get_repo_path(repo_id) / "refs";
std::string name;
std::string commit;
@ -154,6 +241,15 @@ static std::string get_repo_ref(const std::string & repo_id,
std::string _name = branch["name"].get<std::string>();
std::string _commit = branch["targetCommit"].get<std::string>();
if (!is_valid_subpath(refs_path, _name)) {
LOG_WRN("%s: skip invalid branch: %s\n", __func__, _name.c_str());
continue;
}
if (!is_valid_commit(_commit)) {
LOG_WRN("%s: skip invalid commit: %s\n", __func__, _commit.c_str());
continue;
}
if (_name == "main") {
name = _name;
commit = _commit;
@ -171,34 +267,42 @@ static std::string get_repo_ref(const std::string & repo_id,
return {};
}
write_ref(repo_id, name, commit);
safe_write_file(refs_path / name, commit);
return commit;
} catch (const nl::json::exception & e) {
LOG_ERR("%s: JSON error '%s': %s\n", __func__, repo_id.c_str(), e.what());
LOG_ERR("%s: JSON error: %s\n", __func__, e.what());
} catch (const std::exception & e) {
LOG_ERR("%s: API error '%s': %s\n", __func__, repo_id.c_str(), e.what());
LOG_ERR("%s: error: %s\n", __func__, e.what());
}
return {};
}
hf_files get_repo_files(const std::string & repo_id,
const std::string & bearer_token) {
hf_files files;
std::string rev = get_repo_ref(repo_id, bearer_token);
if (rev.empty()) {
LOG_WRN("%s: failed to resolve commit hash for %s\n", __func__, repo_id.c_str());
const std::string & token) {
if (!is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return {};
}
std::string commit = get_repo_commit(repo_id, token);
if (commit.empty()) {
LOG_WRN("%s: failed to resolve commit for %s\n", __func__, repo_id.c_str());
return {};
}
fs::path blobs_path = get_repo_path(repo_id) / "blobs";
fs::path commit_path = get_repo_path(repo_id) / "snapshots" / commit;
hf_files files;
try {
auto endpoint = get_model_endpoint();
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + rev + "?recursive=true", bearer_token);
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);
if (!json.is_array()) {
LOG_WRN("%s: response is not an array for '%s'\n", __func__, repo_id.c_str());
return files;
return {};
}
for (const auto & item : json) {
@ -212,6 +316,11 @@ hf_files get_repo_files(const std::string & repo_id,
file.repo_id = repo_id;
file.path = item["path"].get<std::string>();
if (!is_valid_subpath(commit_path, file.path)) {
LOG_WRN("%s: skip invalid path: %s\n", __func__, file.path.c_str());
continue;
}
if (item.contains("lfs") && item["lfs"].is_object()) {
if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) {
file.oid = item["lfs"]["oid"].get<std::string>();
@ -220,26 +329,29 @@ hf_files get_repo_files(const std::string & repo_id,
file.oid = item["oid"].get<std::string>();
}
file.url = endpoint + repo_id + "/resolve/" + rev + "/" + file.path;
if (!file.oid.empty() && !is_valid_oid(file.oid)) {
LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str());
continue;
}
fs::path path = file.path;
fs::path repo_path = get_repo_path(repo_id);
fs::path snapshots_path = repo_path / "snapshots" / rev / path;
file.url = endpoint + repo_id + "/resolve/" + commit + "/" + file.path;
file.final_path = snapshots_path.string();
file.local_path = file.final_path;
fs::path final_path = commit_path / file.path;
file.final_path = final_path.string();
if (!file.oid.empty() && !fs::exists(snapshots_path)) {
fs::path blob_path = repo_path / "blobs" / file.oid;
file.local_path = blob_path.string();
if (!file.oid.empty() && !fs::exists(final_path)) {
fs::path local_path = blobs_path / file.oid;
file.local_path = local_path.string();
} else {
file.local_path = file.final_path;
}
files.push_back(file);
}
} catch (const nl::json::exception & e) {
LOG_ERR("%s: JSON error '%s': %s\n", __func__, repo_id.c_str(), e.what());
LOG_ERR("%s: JSON error: %s\n", __func__, e.what());
} catch (const std::exception & e) {
LOG_ERR("%s: API error '%s': %s\n", __func__, repo_id.c_str(), e.what());
LOG_ERR("%s: error: %s\n", __func__, e.what());
}
return files;
}
@ -260,6 +372,10 @@ static std::string get_cached_ref(const fs::path & repo_path) {
if (!f || !std::getline(f, commit) || commit.empty()) {
continue;
}
if (!is_valid_commit(commit)) {
LOG_WRN("%s: skip invalid commit: %s\n", __func__, commit.c_str());
continue;
}
if (entry.path().filename() == "main") {
return commit;
}
@ -275,6 +391,12 @@ hf_files get_cached_files(const std::string & repo_id) {
if (!fs::exists(cache_dir)) {
return {};
}
if (!repo_id.empty() && !is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return {};
}
hf_files files;
for (const auto & repo : fs::directory_iterator(cache_dir)) {
@ -288,23 +410,23 @@ hf_files get_cached_files(const std::string & repo_id) {
}
std::string _repo_id = folder_name_to_repo(repo.path().filename().string());
if (_repo_id.empty()) {
if (!is_valid_repo_id(_repo_id)) {
continue;
}
if (!repo_id.empty() && _repo_id != repo_id) {
continue;
}
std::string commit = get_cached_ref(repo.path());
fs::path rev_path = snapshots_path / commit;
fs::path commit_path = snapshots_path / commit;
if (commit.empty() || !fs::is_directory(rev_path)) {
if (commit.empty() || !fs::is_directory(commit_path)) {
continue;
}
for (const auto & entry : fs::recursive_directory_iterator(rev_path)) {
for (const auto & entry : fs::recursive_directory_iterator(commit_path)) {
if (!entry.is_regular_file() && !entry.is_symlink()) {
continue;
}
fs::path path = entry.path().lexically_relative(rev_path);
fs::path path = entry.path().lexically_relative(commit_path);
if (!path.empty()) {
hf_file file;
@ -324,23 +446,23 @@ std::string finalize_file(const hf_file & file) {
static std::atomic<bool> symlinks_disabled{false};
std::error_code ec;
fs::path blob_path(file.local_path);
fs::path snapshot_path(file.final_path);
fs::path local_path(file.local_path);
fs::path final_path(file.final_path);
if (blob_path == snapshot_path || fs::exists(snapshot_path, ec)) {
if (local_path == final_path || fs::exists(final_path, ec)) {
return file.final_path;
}
if (!fs::exists(blob_path, ec)) {
if (!fs::exists(local_path, ec)) {
return file.final_path;
}
fs::create_directories(snapshot_path.parent_path(), ec);
fs::create_directories(final_path.parent_path(), ec);
if (!symlinks_disabled) {
fs::path target = fs::relative(blob_path, snapshot_path.parent_path(), ec);
fs::path target = fs::relative(local_path, final_path.parent_path(), ec);
if (!ec) {
fs::create_symlink(target, snapshot_path, ec);
fs::create_symlink(target, final_path, ec);
}
if (!ec) {
return file.final_path;
@ -352,10 +474,10 @@ std::string finalize_file(const hf_file & file) {
LOG_WRN("%s: switching to degraded mode\n", __func__);
}
fs::rename(blob_path, snapshot_path, ec);
fs::rename(local_path, final_path, ec);
if (ec) {
LOG_WRN("%s: failed to move file to snapshots: %s\n", __func__, ec.message().c_str());
fs::copy(blob_path, snapshot_path, ec);
fs::copy(local_path, final_path, ec);
if (ec) {
LOG_ERR("%s: failed to copy file to snapshots: %s\n", __func__, ec.message().c_str());
}
@ -451,13 +573,13 @@ static bool migrate_single_file(const fs::path & old_cache,
}
fs::remove(etag_path, ec);
std::string snapshot_file = finalize_file(*file_info);
LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), snapshot_file.c_str());
std::string filename = finalize_file(*file_info);
LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), filename.c_str());
return true;
}
void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offline) {
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) {
fs::path old_cache = fs_get_cache_directory();
if (!fs::exists(old_cache)) {
return;
@ -480,7 +602,7 @@ void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offlin
}
auto repo_id = owner + "/" + repo;
auto files = get_repo_files(repo_id, bearer_token);
auto files = get_repo_files(repo_id, token);
if (files.empty()) {
LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str());
@ -488,12 +610,12 @@ void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offlin
}
try {
std::ifstream manifest_stream(entry.path());
std::string content((std::istreambuf_iterator<char>(manifest_stream)), std::istreambuf_iterator<char>());
auto j = nl::json::parse(content);
for (const char* key : {"ggufFile", "mmprojFile"}) {
if (j.contains(key)) {
migrate_single_file(old_cache, owner, repo, j[key], files);
std::ifstream manifest(entry.path());
auto json = nl::json::parse(manifest);
for (const char * key : {"ggufFile", "mmprojFile"}) {
if (json.contains(key)) {
migrate_single_file(old_cache, owner, repo, json[key], files);
}
}
} catch (const std::exception & e) {

View File

@ -21,7 +21,7 @@ using hf_files = std::vector<hf_file>;
// Get files from HF API
hf_files get_repo_files(
const std::string & repo_id,
const std::string & bearer_token
const std::string & token
);
hf_files get_cached_files(const std::string & repo_id = {});
@ -30,6 +30,6 @@ hf_files get_cached_files(const std::string & repo_id = {});
std::string finalize_file(const hf_file & file);
// TODO: Remove later
void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offline = false);
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false);
} // namespace hf_cache