Check all inputs
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
6ab630f5f8
commit
3645fee1ed
|
|
@ -77,41 +77,127 @@ static fs::path get_repo_path(const std::string & repo_id) {
|
|||
return get_cache_directory() / repo_to_folder_name(repo_id);
|
||||
}
|
||||
|
||||
static void write_ref(const std::string & repo_id,
|
||||
const std::string & ref,
|
||||
const std::string & commit) {
|
||||
fs::path refs_path = get_repo_path(repo_id) / "refs";
|
||||
std::error_code ec;
|
||||
fs::create_directories(refs_path, ec);
|
||||
static bool is_hex_char(const char c) {
|
||||
return (c >= 'A' && c <= 'F') ||
|
||||
(c >= 'a' && c <= 'f') ||
|
||||
(c >= '0' && c <= '9');
|
||||
}
|
||||
|
||||
fs::path ref_path = refs_path / ref;
|
||||
fs::path ref_path_tmp = refs_path / (ref + ".tmp");
|
||||
{
|
||||
std::ofstream file(ref_path_tmp);
|
||||
if (!file) {
|
||||
throw std::runtime_error("Failed to write ref file: " + ref_path.string());
|
||||
}
|
||||
file << commit;
|
||||
static bool is_hex_string(const std::string & s, size_t expected_len) {
|
||||
if (s.length() != expected_len) {
|
||||
return false;
|
||||
}
|
||||
std::error_code rename_ec;
|
||||
fs::rename(ref_path_tmp, ref_path, rename_ec);
|
||||
if (rename_ec) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__,
|
||||
ref_path_tmp.string().c_str(), ref_path.string().c_str());
|
||||
fs::remove(ref_path_tmp, ec);
|
||||
for (const char c : s) {
|
||||
if (!is_hex_char(c)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool is_alphanum(const char c) {
|
||||
return (c >= 'A' && c <= 'Z') ||
|
||||
(c >= 'a' && c <= 'z') ||
|
||||
(c >= '0' && c <= '9');
|
||||
}
|
||||
|
||||
static bool is_special_char(char c) {
|
||||
return c == '/' || c == '.' || c == '-';
|
||||
}
|
||||
|
||||
// base chars [A-Za-z0-9_] are always valid
|
||||
// special chars [/.-] must be surrounded by base chars
|
||||
// exactly one '/' required
|
||||
static bool is_valid_repo_id(const std::string & repo_id) {
|
||||
if (repo_id.empty() || repo_id.length() > 256) {
|
||||
return false;
|
||||
}
|
||||
int slash = 0;
|
||||
bool special = true;
|
||||
|
||||
for (const char c : repo_id) {
|
||||
if (is_alphanum(c) || c == '_') {
|
||||
special = false;
|
||||
} else if (is_special_char(c)) {
|
||||
if (special) {
|
||||
return false;
|
||||
}
|
||||
slash += (c == '/');
|
||||
special = true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return !special && slash == 1;
|
||||
}
|
||||
|
||||
static bool is_valid_hf_token(const std::string & token) {
|
||||
if (token.length() < 37 || token.length() > 256 ||
|
||||
!string_starts_with(token, "hf_")) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 3; i < token.length(); ++i) {
|
||||
if (!is_alphanum(token[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool is_valid_commit(const std::string & hash) {
|
||||
return is_hex_string(hash, 40);
|
||||
}
|
||||
|
||||
static bool is_valid_oid(const std::string & oid) {
|
||||
return is_hex_string(oid, 40) || is_hex_string(oid, 64);
|
||||
}
|
||||
|
||||
static bool is_valid_subpath(const fs::path & path, const fs::path & subpath) {
|
||||
if (subpath.is_absolute()) {
|
||||
return false; // never do a / b with b absolute
|
||||
}
|
||||
auto b = fs::absolute(path).lexically_normal();
|
||||
auto t = (b / subpath).lexically_normal();
|
||||
auto [b_end, _] = std::mismatch(b.begin(), b.end(), t.begin(), t.end());
|
||||
|
||||
return b_end == b.end();
|
||||
}
|
||||
|
||||
static void safe_write_file(const fs::path & path, const std::string & data) {
|
||||
fs::path path_tmp = path.string() + ".tmp";
|
||||
|
||||
if (path.has_parent_path()) {
|
||||
fs::create_directories(path.parent_path());
|
||||
}
|
||||
|
||||
std::ofstream file(path_tmp);
|
||||
file << data;
|
||||
file.close();
|
||||
|
||||
std::error_code ec;
|
||||
|
||||
if (!file.fail()) {
|
||||
fs::rename(path_tmp, path, ec);
|
||||
}
|
||||
if (file.fail() || ec) {
|
||||
fs::remove(path_tmp, ec);
|
||||
throw std::runtime_error("failed to write file: " + path.string());
|
||||
}
|
||||
}
|
||||
|
||||
static nl::json api_get(const std::string & url,
|
||||
const std::string & bearer_token) {
|
||||
const std::string & token) {
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers = {
|
||||
{"User-Agent", "llama-cpp/" + build_info},
|
||||
{"Accept", "application/json"}
|
||||
};
|
||||
if (!bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + bearer_token);
|
||||
|
||||
if (is_valid_hf_token(token)) {
|
||||
headers.emplace("Authorization", "Bearer " + token);
|
||||
} else if (!token.empty()) {
|
||||
LOG_WRN("%s: invalid token, authentication disabled\n", __func__);
|
||||
}
|
||||
|
||||
if (auto res = cli.Get(parts.path, headers)) {
|
||||
|
|
@ -130,11 +216,11 @@ static nl::json api_get(const std::string & url,
|
|||
}
|
||||
}
|
||||
|
||||
static std::string get_repo_ref(const std::string & repo_id,
|
||||
const std::string & bearer_token) {
|
||||
static std::string get_repo_commit(const std::string & repo_id,
|
||||
const std::string & token) {
|
||||
try {
|
||||
auto endpoint = get_model_endpoint();
|
||||
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", bearer_token);
|
||||
auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token);
|
||||
|
||||
if (!json.is_object() ||
|
||||
!json.contains("branches") || !json["branches"].is_array()) {
|
||||
|
|
@ -142,6 +228,7 @@ static std::string get_repo_ref(const std::string & repo_id,
|
|||
return {};
|
||||
}
|
||||
|
||||
fs::path refs_path = get_repo_path(repo_id) / "refs";
|
||||
std::string name;
|
||||
std::string commit;
|
||||
|
||||
|
|
@ -154,6 +241,15 @@ static std::string get_repo_ref(const std::string & repo_id,
|
|||
std::string _name = branch["name"].get<std::string>();
|
||||
std::string _commit = branch["targetCommit"].get<std::string>();
|
||||
|
||||
if (!is_valid_subpath(refs_path, _name)) {
|
||||
LOG_WRN("%s: skip invalid branch: %s\n", __func__, _name.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!is_valid_commit(_commit)) {
|
||||
LOG_WRN("%s: skip invalid commit: %s\n", __func__, _commit.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (_name == "main") {
|
||||
name = _name;
|
||||
commit = _commit;
|
||||
|
|
@ -171,34 +267,42 @@ static std::string get_repo_ref(const std::string & repo_id,
|
|||
return {};
|
||||
}
|
||||
|
||||
write_ref(repo_id, name, commit);
|
||||
safe_write_file(refs_path / name, commit);
|
||||
return commit;
|
||||
|
||||
} catch (const nl::json::exception & e) {
|
||||
LOG_ERR("%s: JSON error '%s': %s\n", __func__, repo_id.c_str(), e.what());
|
||||
LOG_ERR("%s: JSON error: %s\n", __func__, e.what());
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: API error '%s': %s\n", __func__, repo_id.c_str(), e.what());
|
||||
LOG_ERR("%s: error: %s\n", __func__, e.what());
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
hf_files get_repo_files(const std::string & repo_id,
|
||||
const std::string & bearer_token) {
|
||||
hf_files files;
|
||||
std::string rev = get_repo_ref(repo_id, bearer_token);
|
||||
|
||||
if (rev.empty()) {
|
||||
LOG_WRN("%s: failed to resolve commit hash for %s\n", __func__, repo_id.c_str());
|
||||
const std::string & token) {
|
||||
if (!is_valid_repo_id(repo_id)) {
|
||||
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string commit = get_repo_commit(repo_id, token);
|
||||
if (commit.empty()) {
|
||||
LOG_WRN("%s: failed to resolve commit for %s\n", __func__, repo_id.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
fs::path blobs_path = get_repo_path(repo_id) / "blobs";
|
||||
fs::path commit_path = get_repo_path(repo_id) / "snapshots" / commit;
|
||||
|
||||
hf_files files;
|
||||
|
||||
try {
|
||||
auto endpoint = get_model_endpoint();
|
||||
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + rev + "?recursive=true", bearer_token);
|
||||
auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token);
|
||||
|
||||
if (!json.is_array()) {
|
||||
LOG_WRN("%s: response is not an array for '%s'\n", __func__, repo_id.c_str());
|
||||
return files;
|
||||
return {};
|
||||
}
|
||||
|
||||
for (const auto & item : json) {
|
||||
|
|
@ -212,6 +316,11 @@ hf_files get_repo_files(const std::string & repo_id,
|
|||
file.repo_id = repo_id;
|
||||
file.path = item["path"].get<std::string>();
|
||||
|
||||
if (!is_valid_subpath(commit_path, file.path)) {
|
||||
LOG_WRN("%s: skip invalid path: %s\n", __func__, file.path.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.contains("lfs") && item["lfs"].is_object()) {
|
||||
if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) {
|
||||
file.oid = item["lfs"]["oid"].get<std::string>();
|
||||
|
|
@ -220,26 +329,29 @@ hf_files get_repo_files(const std::string & repo_id,
|
|||
file.oid = item["oid"].get<std::string>();
|
||||
}
|
||||
|
||||
file.url = endpoint + repo_id + "/resolve/" + rev + "/" + file.path;
|
||||
if (!file.oid.empty() && !is_valid_oid(file.oid)) {
|
||||
LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
fs::path path = file.path;
|
||||
fs::path repo_path = get_repo_path(repo_id);
|
||||
fs::path snapshots_path = repo_path / "snapshots" / rev / path;
|
||||
file.url = endpoint + repo_id + "/resolve/" + commit + "/" + file.path;
|
||||
|
||||
file.final_path = snapshots_path.string();
|
||||
file.local_path = file.final_path;
|
||||
fs::path final_path = commit_path / file.path;
|
||||
file.final_path = final_path.string();
|
||||
|
||||
if (!file.oid.empty() && !fs::exists(snapshots_path)) {
|
||||
fs::path blob_path = repo_path / "blobs" / file.oid;
|
||||
file.local_path = blob_path.string();
|
||||
if (!file.oid.empty() && !fs::exists(final_path)) {
|
||||
fs::path local_path = blobs_path / file.oid;
|
||||
file.local_path = local_path.string();
|
||||
} else {
|
||||
file.local_path = file.final_path;
|
||||
}
|
||||
|
||||
files.push_back(file);
|
||||
}
|
||||
} catch (const nl::json::exception & e) {
|
||||
LOG_ERR("%s: JSON error '%s': %s\n", __func__, repo_id.c_str(), e.what());
|
||||
LOG_ERR("%s: JSON error: %s\n", __func__, e.what());
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: API error '%s': %s\n", __func__, repo_id.c_str(), e.what());
|
||||
LOG_ERR("%s: error: %s\n", __func__, e.what());
|
||||
}
|
||||
return files;
|
||||
}
|
||||
|
|
@ -260,6 +372,10 @@ static std::string get_cached_ref(const fs::path & repo_path) {
|
|||
if (!f || !std::getline(f, commit) || commit.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (!is_valid_commit(commit)) {
|
||||
LOG_WRN("%s: skip invalid commit: %s\n", __func__, commit.c_str());
|
||||
continue;
|
||||
}
|
||||
if (entry.path().filename() == "main") {
|
||||
return commit;
|
||||
}
|
||||
|
|
@ -275,6 +391,12 @@ hf_files get_cached_files(const std::string & repo_id) {
|
|||
if (!fs::exists(cache_dir)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (!repo_id.empty() && !is_valid_repo_id(repo_id)) {
|
||||
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
hf_files files;
|
||||
|
||||
for (const auto & repo : fs::directory_iterator(cache_dir)) {
|
||||
|
|
@ -288,23 +410,23 @@ hf_files get_cached_files(const std::string & repo_id) {
|
|||
}
|
||||
std::string _repo_id = folder_name_to_repo(repo.path().filename().string());
|
||||
|
||||
if (_repo_id.empty()) {
|
||||
if (!is_valid_repo_id(_repo_id)) {
|
||||
continue;
|
||||
}
|
||||
if (!repo_id.empty() && _repo_id != repo_id) {
|
||||
continue;
|
||||
}
|
||||
std::string commit = get_cached_ref(repo.path());
|
||||
fs::path rev_path = snapshots_path / commit;
|
||||
fs::path commit_path = snapshots_path / commit;
|
||||
|
||||
if (commit.empty() || !fs::is_directory(rev_path)) {
|
||||
if (commit.empty() || !fs::is_directory(commit_path)) {
|
||||
continue;
|
||||
}
|
||||
for (const auto & entry : fs::recursive_directory_iterator(rev_path)) {
|
||||
for (const auto & entry : fs::recursive_directory_iterator(commit_path)) {
|
||||
if (!entry.is_regular_file() && !entry.is_symlink()) {
|
||||
continue;
|
||||
}
|
||||
fs::path path = entry.path().lexically_relative(rev_path);
|
||||
fs::path path = entry.path().lexically_relative(commit_path);
|
||||
|
||||
if (!path.empty()) {
|
||||
hf_file file;
|
||||
|
|
@ -324,23 +446,23 @@ std::string finalize_file(const hf_file & file) {
|
|||
static std::atomic<bool> symlinks_disabled{false};
|
||||
|
||||
std::error_code ec;
|
||||
fs::path blob_path(file.local_path);
|
||||
fs::path snapshot_path(file.final_path);
|
||||
fs::path local_path(file.local_path);
|
||||
fs::path final_path(file.final_path);
|
||||
|
||||
if (blob_path == snapshot_path || fs::exists(snapshot_path, ec)) {
|
||||
if (local_path == final_path || fs::exists(final_path, ec)) {
|
||||
return file.final_path;
|
||||
}
|
||||
|
||||
if (!fs::exists(blob_path, ec)) {
|
||||
if (!fs::exists(local_path, ec)) {
|
||||
return file.final_path;
|
||||
}
|
||||
|
||||
fs::create_directories(snapshot_path.parent_path(), ec);
|
||||
fs::create_directories(final_path.parent_path(), ec);
|
||||
|
||||
if (!symlinks_disabled) {
|
||||
fs::path target = fs::relative(blob_path, snapshot_path.parent_path(), ec);
|
||||
fs::path target = fs::relative(local_path, final_path.parent_path(), ec);
|
||||
if (!ec) {
|
||||
fs::create_symlink(target, snapshot_path, ec);
|
||||
fs::create_symlink(target, final_path, ec);
|
||||
}
|
||||
if (!ec) {
|
||||
return file.final_path;
|
||||
|
|
@ -352,10 +474,10 @@ std::string finalize_file(const hf_file & file) {
|
|||
LOG_WRN("%s: switching to degraded mode\n", __func__);
|
||||
}
|
||||
|
||||
fs::rename(blob_path, snapshot_path, ec);
|
||||
fs::rename(local_path, final_path, ec);
|
||||
if (ec) {
|
||||
LOG_WRN("%s: failed to move file to snapshots: %s\n", __func__, ec.message().c_str());
|
||||
fs::copy(blob_path, snapshot_path, ec);
|
||||
fs::copy(local_path, final_path, ec);
|
||||
if (ec) {
|
||||
LOG_ERR("%s: failed to copy file to snapshots: %s\n", __func__, ec.message().c_str());
|
||||
}
|
||||
|
|
@ -451,13 +573,13 @@ static bool migrate_single_file(const fs::path & old_cache,
|
|||
}
|
||||
fs::remove(etag_path, ec);
|
||||
|
||||
std::string snapshot_file = finalize_file(*file_info);
|
||||
LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), snapshot_file.c_str());
|
||||
std::string filename = finalize_file(*file_info);
|
||||
LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), filename.c_str());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offline) {
|
||||
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) {
|
||||
fs::path old_cache = fs_get_cache_directory();
|
||||
if (!fs::exists(old_cache)) {
|
||||
return;
|
||||
|
|
@ -480,7 +602,7 @@ void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offlin
|
|||
}
|
||||
|
||||
auto repo_id = owner + "/" + repo;
|
||||
auto files = get_repo_files(repo_id, bearer_token);
|
||||
auto files = get_repo_files(repo_id, token);
|
||||
|
||||
if (files.empty()) {
|
||||
LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str());
|
||||
|
|
@ -488,12 +610,12 @@ void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offlin
|
|||
}
|
||||
|
||||
try {
|
||||
std::ifstream manifest_stream(entry.path());
|
||||
std::string content((std::istreambuf_iterator<char>(manifest_stream)), std::istreambuf_iterator<char>());
|
||||
auto j = nl::json::parse(content);
|
||||
for (const char* key : {"ggufFile", "mmprojFile"}) {
|
||||
if (j.contains(key)) {
|
||||
migrate_single_file(old_cache, owner, repo, j[key], files);
|
||||
std::ifstream manifest(entry.path());
|
||||
auto json = nl::json::parse(manifest);
|
||||
|
||||
for (const char * key : {"ggufFile", "mmprojFile"}) {
|
||||
if (json.contains(key)) {
|
||||
migrate_single_file(old_cache, owner, repo, json[key], files);
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ using hf_files = std::vector<hf_file>;
|
|||
// Get files from HF API
|
||||
hf_files get_repo_files(
|
||||
const std::string & repo_id,
|
||||
const std::string & bearer_token
|
||||
const std::string & token
|
||||
);
|
||||
|
||||
hf_files get_cached_files(const std::string & repo_id = {});
|
||||
|
|
@ -30,6 +30,6 @@ hf_files get_cached_files(const std::string & repo_id = {});
|
|||
std::string finalize_file(const hf_file & file);
|
||||
|
||||
// TODO: Remove later
|
||||
void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offline = false);
|
||||
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false);
|
||||
|
||||
} // namespace hf_cache
|
||||
|
|
|
|||
Loading…
Reference in New Issue