common : add standard Hugging Face cache support
- Use HF API to find all files - Migrate all manifests to hugging face cache at startup Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
312cf03328
commit
6fd16ba05c
|
|
@ -63,6 +63,8 @@ add_library(${TARGET} STATIC
|
|||
debug.h
|
||||
download.cpp
|
||||
download.h
|
||||
hf-cache.cpp
|
||||
hf-cache.h
|
||||
http.h
|
||||
json-partial.cpp
|
||||
json-partial.h
|
||||
|
|
|
|||
100
common/arg.cpp
100
common/arg.cpp
|
|
@ -3,6 +3,7 @@
|
|||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
#include "hf-cache.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
|
|
@ -326,60 +327,48 @@ struct handle_model_result {
|
|||
common_params_model mmproj;
|
||||
};
|
||||
|
||||
static handle_model_result common_params_handle_model(
|
||||
struct common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
bool offline) {
|
||||
static handle_model_result common_params_handle_model(struct common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
bool offline) {
|
||||
handle_model_result result;
|
||||
// handle pre-fill default model path and url based on hf_repo and hf_file
|
||||
{
|
||||
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
model.name = model.docker_repo; // set name for consistency
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||
if (model.hf_file.empty()) {
|
||||
if (model.path.empty()) {
|
||||
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline);
|
||||
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
|
||||
exit(1); // error message already printed
|
||||
}
|
||||
model.name = model.hf_repo; // repo name with tag
|
||||
model.hf_repo = auto_detected.repo; // repo name without tag
|
||||
model.hf_file = auto_detected.ggufFile;
|
||||
if (!auto_detected.mmprojFile.empty()) {
|
||||
result.found_mmproj = true;
|
||||
result.mmproj.hf_repo = model.hf_repo;
|
||||
result.mmproj.hf_file = auto_detected.mmprojFile;
|
||||
}
|
||||
} else {
|
||||
model.hf_file = model.path;
|
||||
}
|
||||
}
|
||||
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
|
||||
// make sure model path is present (for caching purposes)
|
||||
if (model.path.empty()) {
|
||||
// this is to avoid different repo having same file name, or same file name in different subdirs
|
||||
std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file);
|
||||
model.path = fs_get_cache_file(filename);
|
||||
}
|
||||
|
||||
} else if (!model.url.empty()) {
|
||||
if (model.path.empty()) {
|
||||
auto f = string_split<std::string>(model.url, '#').front();
|
||||
f = string_split<std::string>(f, '?').front();
|
||||
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
|
||||
if (!model.docker_repo.empty()) {
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
model.name = model.docker_repo;
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
// If -m was used with -hf, treat the model "path" as the hf_file to download
|
||||
if (model.hf_file.empty() && !model.path.empty()) {
|
||||
model.hf_file = model.path;
|
||||
model.path = "";
|
||||
}
|
||||
}
|
||||
common_download_model_opts opts;
|
||||
opts.download_mmproj = true;
|
||||
opts.offline = offline;
|
||||
auto download_result = common_download_model(model, bearer_token, opts);
|
||||
|
||||
// then, download it if needed
|
||||
if (!model.url.empty()) {
|
||||
bool ok = common_download_model(model, bearer_token, offline);
|
||||
if (!ok) {
|
||||
if (download_result.model_path.empty()) {
|
||||
LOG_ERR("error: failed to download model from Hugging Face\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
model.name = model.hf_repo;
|
||||
model.path = download_result.model_path;
|
||||
|
||||
if (!download_result.mmproj_path.empty()) {
|
||||
result.found_mmproj = true;
|
||||
result.mmproj.path = download_result.mmproj_path;
|
||||
}
|
||||
} else if (!model.url.empty()) {
|
||||
if (model.path.empty()) {
|
||||
auto f = string_split<std::string>(model.url, '#').front();
|
||||
f = string_split<std::string>(f, '?').front();
|
||||
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
|
||||
common_download_model_opts opts;
|
||||
opts.offline = offline;
|
||||
auto download_result = common_download_model(model, bearer_token, opts);
|
||||
if (download_result.model_path.empty()) {
|
||||
LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
|
@ -539,6 +528,13 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
|||
// parse the first time to get -hf option (used for remote preset)
|
||||
parse_cli_args();
|
||||
|
||||
// TODO: Remove later
|
||||
try {
|
||||
hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("HF cache migration failed: %s\n", e.what());
|
||||
}
|
||||
|
||||
// maybe handle remote preset
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
std::string cli_hf_repo = params.model.hf_repo;
|
||||
|
|
@ -1061,12 +1057,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"-cl", "--cache-list"},
|
||||
"show list of models in cache",
|
||||
[](common_params &) {
|
||||
printf("model cache directory: %s\n", fs_get_cache_directory().c_str());
|
||||
auto models = common_list_cached_models();
|
||||
printf("number of models in cache: %zu\n", models.size());
|
||||
for (size_t i = 0; i < models.size(); i++) {
|
||||
auto & model = models[i];
|
||||
printf("%4d. %s\n", (int) i + 1, model.to_string().c_str());
|
||||
printf("%4zu. %s\n", i + 1, models[i].c_str());
|
||||
}
|
||||
exit(0);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
#include "arg.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "gguf.h" // for reading GGUF splits
|
||||
#include "log.h"
|
||||
#include "download.h"
|
||||
#include "hf-cache.h"
|
||||
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
|
@ -15,6 +15,7 @@
|
|||
#include <map>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
|
@ -35,8 +36,6 @@
|
|||
#endif
|
||||
#endif
|
||||
|
||||
#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
||||
|
||||
// isatty
|
||||
#if defined(_WIN32)
|
||||
#include <io.h>
|
||||
|
|
@ -51,31 +50,6 @@ using json = nlohmann::ordered_json;
|
|||
//
|
||||
|
||||
// validate repo name format: owner/repo
|
||||
static bool validate_repo_name(const std::string & repo) {
|
||||
static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
|
||||
return std::regex_match(repo, repo_regex);
|
||||
}
|
||||
|
||||
static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
|
||||
// we use "=" to avoid clashing with other component, while still being allowed on windows
|
||||
std::string fname = "manifest=" + repo + "=" + tag + ".json";
|
||||
if (!validate_repo_name(repo)) {
|
||||
throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
|
||||
}
|
||||
string_replace_all(fname, "/", "=");
|
||||
return fs_get_cache_file(fname);
|
||||
}
|
||||
|
||||
static std::string read_file(const std::string & fname) {
|
||||
std::ifstream file(fname);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
|
||||
}
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
file.close();
|
||||
return content;
|
||||
}
|
||||
|
||||
static void write_file(const std::string & fname, const std::string & content) {
|
||||
const std::string fname_tmp = fname + ".tmp";
|
||||
std::ofstream file(fname_tmp);
|
||||
|
|
@ -132,7 +106,7 @@ static bool is_http_status_ok(int status) {
|
|||
|
||||
std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag) {
|
||||
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
|
||||
std::string tag = parts.size() > 1 ? parts.back() : "latest";
|
||||
std::string tag = parts.size() > 1 ? parts.back() : "";
|
||||
std::string hf_repo = parts[0];
|
||||
if (string_split<std::string>(hf_repo, '/').size() != 2) {
|
||||
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
|
||||
|
|
@ -290,7 +264,8 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
const common_header_list & custom_headers,
|
||||
bool skip_etag = false) {
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
|
||||
|
|
@ -310,6 +285,11 @@ static int common_download_file_single_online(const std::string & url,
|
|||
|
||||
const bool file_exists = std::filesystem::exists(path);
|
||||
|
||||
if (file_exists && skip_etag) {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
|
||||
std::string last_etag;
|
||||
if (file_exists) {
|
||||
last_etag = read_etag(path);
|
||||
|
|
@ -361,6 +341,12 @@ static int common_download_file_single_online(const std::string & url,
|
|||
}
|
||||
}
|
||||
|
||||
{ // silent
|
||||
std::error_code ec;
|
||||
std::filesystem::path p(path);
|
||||
std::filesystem::create_directories(p.parent_path(), ec);
|
||||
}
|
||||
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
int delay = retry_delay_seconds;
|
||||
|
||||
|
|
@ -391,7 +377,7 @@ static int common_download_file_single_online(const std::string & url,
|
|||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return -1;
|
||||
}
|
||||
if (!etag.empty()) {
|
||||
if (!etag.empty() && !skip_etag) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
return head->status;
|
||||
|
|
@ -440,9 +426,10 @@ int common_download_file_single(const std::string & url,
|
|||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers) {
|
||||
const common_header_list & headers,
|
||||
bool skip_etag) {
|
||||
if (!offline) {
|
||||
return common_download_file_single_online(url, path, bearer_token, headers);
|
||||
return common_download_file_single_online(url, path, bearer_token, headers, skip_etag);
|
||||
}
|
||||
|
||||
if (!std::filesystem::exists(path)) {
|
||||
|
|
@ -454,193 +441,234 @@ int common_download_file_single(const std::string & url,
|
|||
return 304; // Not Modified - fake cached response
|
||||
}
|
||||
|
||||
// download multiple files from remote URLs to local paths
|
||||
// the input is a vector of pairs <url, path>
|
||||
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers) {
|
||||
// Prepare download in parallel
|
||||
std::vector<std::future<bool>> futures_download;
|
||||
futures_download.reserve(urls.size());
|
||||
|
||||
for (auto const & item : urls) {
|
||||
futures_download.push_back(
|
||||
std::async(
|
||||
std::launch::async,
|
||||
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
|
||||
const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
|
||||
return is_http_status_ok(http_status);
|
||||
},
|
||||
item
|
||||
)
|
||||
);
|
||||
// "subdir/model-00001-of-00002.gguf" -> "subdir/model", 1, 2
|
||||
static std::tuple<std::string, int, int> get_gguf_split_info(const std::string & path) {
|
||||
if (path.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Wait for all downloads to complete
|
||||
for (auto & f : futures_download) {
|
||||
if (!f.get()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
static const std::regex re(R"(^(.+)-([0-9]+)-of-([0-9]+)\.gguf$)", std::regex::icase);
|
||||
|
||||
return true;
|
||||
std::smatch m;
|
||||
if (std::regex_match(path, m, re)) {
|
||||
return {m[1].str(), std::stoi(m[2].str()), std::stoi(m[3].str())};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
bool common_download_model(const common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers) {
|
||||
// Basic validation of the model.url
|
||||
if (model.url.empty()) {
|
||||
LOG_ERR("%s: invalid model url\n", __func__);
|
||||
return false;
|
||||
}
|
||||
static hf_cache::hf_files get_split_files(const hf_cache::hf_files & all_files,
|
||||
const hf_cache::hf_file & primary_file) {
|
||||
hf_cache::hf_files result;
|
||||
auto [prefix, idx, count] = get_gguf_split_info(primary_file.path);
|
||||
|
||||
const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
|
||||
if (!is_http_status_ok(http_status)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check for additional GGUFs split to download
|
||||
int n_split = 0;
|
||||
{
|
||||
struct gguf_init_params gguf_params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ NULL,
|
||||
};
|
||||
auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
|
||||
if (!ctx_gguf) {
|
||||
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
|
||||
if (key_n_split >= 0) {
|
||||
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
|
||||
}
|
||||
|
||||
gguf_free(ctx_gguf);
|
||||
}
|
||||
|
||||
if (n_split > 1) {
|
||||
char split_prefix[PATH_MAX] = {0};
|
||||
char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
|
||||
|
||||
// Verify the first split file format
|
||||
// and extract split URL and PATH prefixes
|
||||
{
|
||||
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
|
||||
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
|
||||
LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
|
||||
return false;
|
||||
if (count > 1) {
|
||||
for (const auto & f : all_files) {
|
||||
auto [sprefix, sidx, scount] = get_gguf_split_info(f.path);
|
||||
if (scount == count && sprefix == prefix) {
|
||||
result.push_back(f);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> urls;
|
||||
for (int idx = 1; idx < n_split; idx++) {
|
||||
char split_path[PATH_MAX] = {0};
|
||||
llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
|
||||
|
||||
char split_url[LLAMA_MAX_URL_LENGTH] = {0};
|
||||
llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
|
||||
|
||||
if (std::string(split_path) == model.path) {
|
||||
continue; // skip the already downloaded file
|
||||
}
|
||||
|
||||
urls.push_back({split_url, split_path});
|
||||
}
|
||||
|
||||
// Download in parallel
|
||||
common_download_file_multiple(urls, bearer_token, offline, headers);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & custom_headers) {
|
||||
// the returned hf_repo is without tag
|
||||
auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag);
|
||||
|
||||
std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
|
||||
|
||||
// headers
|
||||
common_header_list headers = custom_headers;
|
||||
headers.push_back({"Accept", "application/json"});
|
||||
if (!bearer_token.empty()) {
|
||||
headers.push_back({"Authorization", "Bearer " + bearer_token});
|
||||
}
|
||||
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
|
||||
// User-Agent header is already set in common_remote_get_content, no need to set it here
|
||||
|
||||
// make the request
|
||||
common_remote_params params;
|
||||
params.headers = headers;
|
||||
long res_code = 0;
|
||||
std::string res_str;
|
||||
bool use_cache = false;
|
||||
std::string cached_response_path = get_manifest_path(hf_repo, tag);
|
||||
if (!offline) {
|
||||
try {
|
||||
auto res = common_remote_get_content(url, params);
|
||||
res_code = res.first;
|
||||
res_str = std::string(res.second.data(), res.second.size());
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
|
||||
}
|
||||
}
|
||||
if (res_code == 0) {
|
||||
if (std::filesystem::exists(cached_response_path)) {
|
||||
LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
|
||||
res_str = read_file(cached_response_path);
|
||||
res_code = 200;
|
||||
use_cache = true;
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
offline ? "error: failed to get manifest (offline mode)"
|
||||
: "error: failed to get manifest (check your internet connection)");
|
||||
}
|
||||
}
|
||||
std::string ggufFile;
|
||||
std::string mmprojFile;
|
||||
|
||||
if (res_code == 200 || res_code == 304) {
|
||||
try {
|
||||
auto j = json::parse(res_str);
|
||||
|
||||
if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) {
|
||||
ggufFile = j["ggufFile"]["rfilename"].get<std::string>();
|
||||
}
|
||||
if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
|
||||
mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
|
||||
}
|
||||
if (!use_cache) {
|
||||
// if not using cached response, update the cache file
|
||||
write_file(cached_response_path, res_str);
|
||||
}
|
||||
} else if (res_code == 401) {
|
||||
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
|
||||
} else {
|
||||
throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
|
||||
result.push_back(primary_file);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static hf_cache::hf_files filter_gguf_by_quant(const hf_cache::hf_files & files,
|
||||
const std::string & quant_tag) {
|
||||
hf_cache::hf_files matches;
|
||||
std::regex pattern(quant_tag + "[.-]", std::regex::icase);
|
||||
|
||||
for (const auto & f : files) {
|
||||
if (!string_ends_with(f.path, ".gguf")) {
|
||||
continue;
|
||||
}
|
||||
if (f.path.find("mmproj") != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
if (std::regex_search(f.path, pattern)) {
|
||||
matches.push_back(f);
|
||||
}
|
||||
}
|
||||
return matches;
|
||||
}
|
||||
|
||||
static void list_available_gguf_files(const hf_cache::hf_files & files) {
|
||||
LOG_INF("Available GGUF files:\n");
|
||||
for (const auto & f : files) {
|
||||
if (string_ends_with(f.path, ".gguf")) {
|
||||
LOG_INF(" - %s\n", f.path.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct hf_plan {
|
||||
hf_cache::hf_file primary;
|
||||
hf_cache::hf_file mmproj;
|
||||
bool has_primary = false;
|
||||
bool has_mmproj = false;
|
||||
hf_cache::hf_files files;
|
||||
};
|
||||
|
||||
static hf_plan get_hf_plan(const common_params_model & model,
|
||||
const std::string & token,
|
||||
const common_download_model_opts & opts) {
|
||||
hf_plan plan;
|
||||
auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
|
||||
|
||||
auto all = opts.offline ? hf_cache::get_cached_files(repo)
|
||||
: hf_cache::get_repo_files(repo, token);
|
||||
if (all.empty()) {
|
||||
return plan;
|
||||
}
|
||||
|
||||
// check response
|
||||
if (ggufFile.empty()) {
|
||||
throw std::runtime_error("error: model does not have ggufFile");
|
||||
hf_cache::hf_files candidates;
|
||||
|
||||
if (!model.hf_file.empty()) {
|
||||
const hf_cache::hf_file * found_file = nullptr;
|
||||
for (const auto & f : all) {
|
||||
if (f.path == model.hf_file) {
|
||||
found_file = &f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found_file) {
|
||||
LOG_ERR("%s: --hf-file '%s' not found in repository\n", __func__, model.hf_file.c_str());
|
||||
list_available_gguf_files(all);
|
||||
return plan;
|
||||
}
|
||||
|
||||
plan.primary = *found_file;
|
||||
plan.has_primary = true;
|
||||
candidates = get_split_files(all, *found_file);
|
||||
} else {
|
||||
std::vector<std::string> search_priority = {!tag.empty() ? tag : "Q4_K_M", "Q4_0"};
|
||||
|
||||
for (const auto & q : search_priority) {
|
||||
candidates = filter_gguf_by_quant(all, q);
|
||||
if (!candidates.empty()) {
|
||||
candidates = get_split_files(all, candidates[0]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (candidates.empty()) {
|
||||
for (const auto & f : all) {
|
||||
if (string_ends_with(f.path, ".gguf") &&
|
||||
f.path.find("mmproj") == std::string::npos) {
|
||||
candidates = get_split_files(all, f);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidates.empty()) {
|
||||
LOG_ERR("%s: no GGUF files found in repository %s\n", __func__, repo.c_str());
|
||||
list_available_gguf_files(all);
|
||||
return plan;
|
||||
}
|
||||
|
||||
plan.primary = candidates[0];
|
||||
plan.has_primary = true;
|
||||
}
|
||||
|
||||
return { hf_repo, ggufFile, mmprojFile };
|
||||
for (const auto & f : candidates) {
|
||||
plan.files.push_back(f);
|
||||
}
|
||||
|
||||
if (opts.download_mmproj) {
|
||||
for (const auto & f : all) {
|
||||
if (string_ends_with(f.path, ".gguf") &&
|
||||
f.path.find("mmproj") != std::string::npos) {
|
||||
plan.mmproj = f;
|
||||
plan.has_mmproj = true;
|
||||
plan.files.push_back(f);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
static std::vector<std::pair<std::string, std::string>> get_url_tasks(const common_params_model & model) {
|
||||
auto [prefix_url, idx, count] = get_gguf_split_info(model.url);
|
||||
|
||||
if (count <= 1) {
|
||||
return {{model.url, model.path}};
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> files;
|
||||
|
||||
size_t pos = prefix_url.rfind('/');
|
||||
std::string prefix_filename = (pos != std::string::npos) ? prefix_url.substr(pos + 1) : prefix_url;
|
||||
std::string prefix_path = (std::filesystem::path(model.path).parent_path() / prefix_filename).string();
|
||||
|
||||
for (int i = 1; i <= count; i++) {
|
||||
std::string suffix = string_format("-%05d-of-%05d.gguf", i, count);
|
||||
files.emplace_back(prefix_url + suffix, prefix_path + suffix);
|
||||
}
|
||||
return files;
|
||||
}
|
||||
|
||||
common_download_model_result common_download_model(const common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
const common_download_model_opts & opts,
|
||||
const common_header_list & headers) {
|
||||
common_download_model_result result;
|
||||
std::vector<std::pair<std::string, std::string>> to_download;
|
||||
hf_plan hf;
|
||||
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
hf = get_hf_plan(model, bearer_token, opts);
|
||||
for (const auto & f : hf.files) {
|
||||
to_download.emplace_back(f.url, f.local_path);
|
||||
}
|
||||
} else if (!model.url.empty()) {
|
||||
to_download = get_url_tasks(model);
|
||||
} else {
|
||||
result.model_path = model.path;
|
||||
return result;
|
||||
}
|
||||
|
||||
if (to_download.empty()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::future<bool>> futures;
|
||||
for (const auto & item : to_download) {
|
||||
futures.push_back(std::async(std::launch::async,
|
||||
[u = item.first, p = item.second, &bearer_token, offline = opts.offline, &headers, is_hf]() {
|
||||
int status = common_download_file_single(u, p, bearer_token, offline, headers, is_hf);
|
||||
return is_http_status_ok(status);
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
for (auto & f : futures) {
|
||||
if (!f.get()) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
if (is_hf) {
|
||||
for (const auto & f : hf.files) {
|
||||
hf_cache::finalize_file(f);
|
||||
}
|
||||
if (hf.has_primary) {
|
||||
result.model_path = hf_cache::finalize_file(hf.primary);
|
||||
}
|
||||
if (hf.has_mmproj) {
|
||||
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
|
||||
}
|
||||
} else {
|
||||
result.model_path = model.path;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
|
|
@ -764,29 +792,48 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
std::vector<common_cached_model_info> models;
|
||||
const std::string cache_dir = fs_get_cache_directory();
|
||||
const std::vector<common_file_info> files = fs_list(cache_dir, false);
|
||||
for (const auto & file : files) {
|
||||
if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
|
||||
common_cached_model_info model_info;
|
||||
model_info.manifest_path = file.path;
|
||||
std::string fname = file.name;
|
||||
string_replace_all(fname, ".json", ""); // remove extension
|
||||
auto parts = string_split<std::string>(fname, '=');
|
||||
if (parts.size() == 4) {
|
||||
// expect format: manifest=<user>=<model>=<tag>=<other>
|
||||
model_info.user = parts[1];
|
||||
model_info.model = parts[2];
|
||||
model_info.tag = parts[3];
|
||||
} else {
|
||||
// invalid format
|
||||
continue;
|
||||
std::vector<std::string> common_list_cached_models() {
|
||||
auto files = hf_cache::get_cached_files("");
|
||||
std::set<std::string> models;
|
||||
|
||||
for (const auto & f : files) {
|
||||
std::string tmp = f.path;
|
||||
|
||||
if (!string_remove_suffix(tmp, ".gguf")) {
|
||||
continue;
|
||||
}
|
||||
if (tmp.find("mmproj") != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
auto split_pos = tmp.find("-00001-of-");
|
||||
|
||||
if (split_pos == std::string::npos &&
|
||||
tmp.find("-of-") != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
if (split_pos != std::string::npos) {
|
||||
tmp.erase(split_pos);
|
||||
}
|
||||
auto sep_pos = tmp.find_last_of("-.");
|
||||
|
||||
if (sep_pos == std::string::npos || sep_pos == tmp.size() - 1) {
|
||||
continue;
|
||||
}
|
||||
tmp.erase(0, sep_pos + 1);
|
||||
|
||||
bool is_valid = true;
|
||||
for (char & c : tmp) {
|
||||
unsigned char uc = c;
|
||||
if (!std::isalnum(uc) && uc != '_') {
|
||||
is_valid = false;
|
||||
break;
|
||||
}
|
||||
model_info.size = 0; // TODO: get GGUF size, not manifest size
|
||||
models.push_back(model_info);
|
||||
c = std::toupper(uc);
|
||||
}
|
||||
if (is_valid) {
|
||||
models.insert(f.repo_id + ":" + tmp);
|
||||
}
|
||||
}
|
||||
return models;
|
||||
|
||||
return {models.begin(), models.end()};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "hf-cache.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -23,23 +25,16 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
|
|||
// example: "user/model" -> <"user/model", "latest">
|
||||
std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag);
|
||||
|
||||
struct common_cached_model_info {
|
||||
std::string manifest_path;
|
||||
std::string user;
|
||||
std::string model;
|
||||
std::string tag;
|
||||
size_t size = 0; // GGUF size in bytes
|
||||
// return string representation like "user/model:tag"
|
||||
// if tag is "latest", it will be omitted
|
||||
std::string to_string() const {
|
||||
return user + "/" + model + (tag == "latest" ? "" : ":" + tag);
|
||||
}
|
||||
// Options for common_download_model
|
||||
struct common_download_model_opts {
|
||||
bool download_mmproj = false;
|
||||
bool offline = false;
|
||||
};
|
||||
|
||||
struct common_hf_file_res {
|
||||
std::string repo; // repo name with ":tag" removed
|
||||
std::string ggufFile;
|
||||
std::string mmprojFile;
|
||||
// Result of common_download_model
|
||||
struct common_download_model_result {
|
||||
std::string model_path; // path to downloaded model (empty on failure)
|
||||
std::string mmproj_path; // path to downloaded mmproj (empty if not downloaded)
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -47,37 +42,27 @@ struct common_hf_file_res {
|
|||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
|
||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
|
||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
|
||||
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
|
||||
*
|
||||
* Return pair of <repo, file> (with "repo" already having tag removed)
|
||||
*
|
||||
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
|
||||
* Tag is optional, it checks for Q4_K_M first, then Q4_0, then if not found, return the first GGUF file in repo
|
||||
*/
|
||||
common_hf_file_res common_get_hf_file(
|
||||
const std::string & hf_repo_with_tag,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers = {}
|
||||
);
|
||||
|
||||
// returns true if download succeeded
|
||||
bool common_download_model(
|
||||
common_download_model_result common_download_model(
|
||||
const common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_download_model_opts & opts = {},
|
||||
const common_header_list & headers = {}
|
||||
);
|
||||
|
||||
// returns list of cached models
|
||||
std::vector<common_cached_model_info> common_list_cached_models();
|
||||
std::vector<std::string> common_list_cached_models();
|
||||
|
||||
// download single file from url to local path
|
||||
// returns status code or -1 on error
|
||||
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers = {});
|
||||
const common_header_list & headers = {},
|
||||
bool skip_etag = false);
|
||||
|
||||
// resolve and download model from Docker registry
|
||||
// return local path to downloaded model file
|
||||
|
|
|
|||
|
|
@ -0,0 +1,516 @@
|
|||
#include "hf-cache.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "http.h"
|
||||
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <ctime>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
#include <regex> // migration only
|
||||
#include <string>
|
||||
|
||||
namespace nl = nlohmann;
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
namespace hf_cache {
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
static fs::path get_cache_directory() {
|
||||
const char * hf_hub_cache = std::getenv("HF_HUB_CACHE");
|
||||
if (hf_hub_cache && *hf_hub_cache) {
|
||||
return fs::path(hf_hub_cache); // assume shell-expanded; add expand logic if you want full parity
|
||||
}
|
||||
|
||||
const char * huggingface_hub_cache = std::getenv("HUGGINGFACE_HUB_CACHE");
|
||||
if (huggingface_hub_cache && *huggingface_hub_cache) {
|
||||
return fs::path(huggingface_hub_cache);
|
||||
}
|
||||
|
||||
const char * hf_home = std::getenv("HF_HOME");
|
||||
if (hf_home && *hf_home) {
|
||||
return fs::path(hf_home) / "hub";
|
||||
}
|
||||
|
||||
const char * xdg_cache_home = std::getenv("XDG_CACHE_HOME");
|
||||
if (xdg_cache_home && *xdg_cache_home) {
|
||||
return fs::path(xdg_cache_home) / "huggingface" / "hub";
|
||||
}
|
||||
#if defined(_WIN32)
|
||||
const char * userprofile = std::getenv("USERPROFILE");
|
||||
if (userprofile && *userprofile) {
|
||||
return fs::path(userprofile) / ".cache" / "huggingface" / "hub";
|
||||
}
|
||||
#else
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && *home) {
|
||||
return fs::path(home) / ".cache" / "huggingface" / "hub";
|
||||
}
|
||||
#endif
|
||||
throw std::runtime_error("Failed to determine HF cache directory");
|
||||
}
|
||||
|
||||
static bool symlinks_supported() {
|
||||
#ifdef _WIN32
|
||||
static bool supported = false;
|
||||
static std::once_flag once;
|
||||
std::call_once(once, []() {
|
||||
fs::path link = get_cache_directory() / ("link_" + std::to_string(GetCurrentProcessId()));
|
||||
|
||||
std::error_code ec;
|
||||
fs::create_directory_symlink("..", link, ec);
|
||||
supported = !ec;
|
||||
|
||||
if (!ec) {
|
||||
fs::remove(link, ec);
|
||||
} else if (GetLastError() == ERROR_PRIVILEGE_NOT_HELD) {
|
||||
LOG_WRN("symlink creation requires Developer Mode or admin privileges on Windows\n");
|
||||
}
|
||||
});
|
||||
return supported;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::string folder_name_to_repo(const std::string & folder) {
|
||||
if (folder.size() < 8 || folder.substr(0, 8) != "models--") {
|
||||
return {};
|
||||
}
|
||||
std::string repo_id;
|
||||
for (size_t i = 8; i < folder.size(); ++i) {
|
||||
if (i + 1 < folder.size() && folder[i] == '-' && folder[i+1] == '-') {
|
||||
repo_id += '/';
|
||||
i++;
|
||||
} else {
|
||||
repo_id += folder[i];
|
||||
}
|
||||
}
|
||||
return repo_id;
|
||||
}
|
||||
|
||||
static std::string repo_to_folder_name(const std::string & repo_id) {
|
||||
std::string name = "models--";
|
||||
for (char c : repo_id) {
|
||||
if (c == '/') {
|
||||
name += "--";
|
||||
} else {
|
||||
name += c;
|
||||
}
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
static fs::path get_repo_path(const std::string & repo_id) {
|
||||
return get_cache_directory() / repo_to_folder_name(repo_id);
|
||||
}
|
||||
|
||||
static void write_ref(const std::string & repo_id,
|
||||
const std::string & ref,
|
||||
const std::string & commit) {
|
||||
fs::path refs_path = get_repo_path(repo_id) / "refs";
|
||||
std::error_code ec;
|
||||
fs::create_directories(refs_path, ec);
|
||||
|
||||
fs::path ref_path = refs_path / ref;
|
||||
fs::path ref_path_tmp = refs_path / (ref + ".tmp");
|
||||
{
|
||||
std::ofstream file(ref_path_tmp);
|
||||
if (!file) {
|
||||
throw std::runtime_error("Failed to write ref file: " + ref_path.string());
|
||||
}
|
||||
file << commit;
|
||||
}
|
||||
std::error_code rename_ec;
|
||||
fs::rename(ref_path_tmp, ref_path, rename_ec);
|
||||
if (rename_ec) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, ref_path_tmp.c_str(), ref_path.c_str());
|
||||
fs::remove(ref_path_tmp, ec);
|
||||
}
|
||||
}
|
||||
|
||||
static std::string get_repo_ref(const std::string & repo_id,
|
||||
const std::string & bearer_token) {
|
||||
std::string url = get_model_endpoint() + "api/models/" + repo_id + "/refs";
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers;
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
headers.emplace("Accept", "application/json");
|
||||
if (!bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + bearer_token);
|
||||
}
|
||||
cli.set_default_headers(headers);
|
||||
|
||||
auto res = cli.Get(parts.path);
|
||||
if (!res || res->status != 200) {
|
||||
LOG_WRN("%s: API request failed for %s, status: %d\n", __func__, url.c_str(), res ? res->status : -1);
|
||||
return {};
|
||||
}
|
||||
|
||||
try {
|
||||
auto j = nl::json::parse(res->body);
|
||||
|
||||
if (!j.contains("branches") || !j["branches"].is_array()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string name;
|
||||
std::string commit;
|
||||
|
||||
for (const auto & branch : j["branches"]) {
|
||||
if (!branch.contains("name") || !branch.contains("targetCommit")) {
|
||||
continue;
|
||||
}
|
||||
std::string _name = branch["name"].get<std::string>();
|
||||
std::string _commit = branch["targetCommit"].get<std::string>();
|
||||
|
||||
if (_name == "main") {
|
||||
name = _name;
|
||||
commit = _commit;
|
||||
break;
|
||||
}
|
||||
|
||||
if (name.empty() || commit.empty()) {
|
||||
name = _name;
|
||||
commit = _commit;
|
||||
}
|
||||
}
|
||||
|
||||
if (!name.empty() && !commit.empty()) {
|
||||
write_ref(repo_id, name, commit);
|
||||
}
|
||||
return commit;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("%s: failed to parse API response: %s\n", __func__, e.what());
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
hf_files get_repo_files(const std::string & repo_id,
|
||||
const std::string & bearer_token) {
|
||||
std::string rev = get_repo_ref(repo_id, bearer_token);
|
||||
if (rev.empty()) {
|
||||
LOG_WRN("%s: failed to resolve commit hash for %s\n", __func__, repo_id.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string url = get_model_endpoint() + "api/models/" + repo_id + "/tree/" + rev + "?recursive=true";
|
||||
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers;
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
headers.emplace("Accept", "application/json");
|
||||
if (!bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + bearer_token);
|
||||
}
|
||||
cli.set_default_headers(headers);
|
||||
|
||||
auto res = cli.Get(parts.path);
|
||||
if (!res || res->status != 200) {
|
||||
LOG_WRN("%s: API request failed for %s, status: %d\n", __func__, url.c_str(), res ? res->status : -1);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string endpoint = get_model_endpoint(); // TODO
|
||||
bool use_symlinks = symlinks_supported();
|
||||
hf_files files;
|
||||
|
||||
try {
|
||||
auto j = nl::json::parse(res->body);
|
||||
|
||||
if (!j.is_array()) {
|
||||
LOG_DBG("%s: response is not an array\n", __func__);
|
||||
return files;
|
||||
}
|
||||
|
||||
for (const auto & item : j) {
|
||||
if (!item.contains("type") || item["type"] != "file") {
|
||||
continue;
|
||||
}
|
||||
if (!item.contains("path")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
hf_file file;
|
||||
file.repo_id = repo_id;
|
||||
file.path = item["path"].get<std::string>();
|
||||
|
||||
if (item.contains("lfs") && item["lfs"].is_object()) {
|
||||
if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) {
|
||||
file.oid = item["lfs"]["oid"].get<std::string>();
|
||||
}
|
||||
} else if (item.contains("oid") && item["oid"].is_string()) {
|
||||
file.oid = item["oid"].get<std::string>();
|
||||
}
|
||||
|
||||
file.url = endpoint + repo_id + "/resolve/" + rev + "/" + file.path;
|
||||
|
||||
fs::path path = file.path;
|
||||
fs::path repo_path = get_repo_path(repo_id);
|
||||
fs::path snapshots_path = repo_path / "snapshots" / rev / path;
|
||||
fs::path blobs_path = repo_path / "blobs" / file.oid;
|
||||
|
||||
if (use_symlinks) {
|
||||
file.local_path = blobs_path.string();
|
||||
file.link_path = snapshots_path.string();
|
||||
} else { // degraded mode
|
||||
file.local_path = snapshots_path.string();
|
||||
}
|
||||
|
||||
files.push_back(file);
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("%s: failed to parse API response: %s\n", __func__, e.what());
|
||||
return {};
|
||||
}
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
static std::string get_cached_ref(const fs::path & repo_path) {
|
||||
fs::path refs_path = repo_path / "refs";
|
||||
if (!fs::is_directory(refs_path)) {
|
||||
return {};
|
||||
}
|
||||
for (const auto & entry : fs::directory_iterator(refs_path)) {
|
||||
if (entry.is_regular_file()) {
|
||||
std::ifstream f(entry.path());
|
||||
std::string commit;
|
||||
if (f && std::getline(f, commit) && !commit.empty()) {
|
||||
return commit;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
hf_files get_cached_files(const std::string & repo_id) {
|
||||
fs::path cache_dir = get_cache_directory();
|
||||
if (!fs::exists(cache_dir)) {
|
||||
return {};
|
||||
}
|
||||
hf_files files;
|
||||
|
||||
for (const auto & repo : fs::directory_iterator(cache_dir)) {
|
||||
if (!repo.is_directory()) {
|
||||
continue;
|
||||
}
|
||||
fs::path snapshots_path = repo.path() / "snapshots";
|
||||
|
||||
if (!fs::exists(snapshots_path)) {
|
||||
continue;
|
||||
}
|
||||
std::string _repo_id = folder_name_to_repo(repo.path().filename().string());
|
||||
|
||||
if (_repo_id.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (!repo_id.empty() && _repo_id != repo_id) {
|
||||
continue;
|
||||
}
|
||||
std::string commit = get_cached_ref(repo.path());
|
||||
fs::path rev_path = snapshots_path / commit;
|
||||
|
||||
if (commit.empty() || !fs::is_directory(rev_path)) {
|
||||
continue;
|
||||
}
|
||||
for (const auto & entry : fs::recursive_directory_iterator(rev_path)) {
|
||||
if (!entry.is_regular_file() && !entry.is_symlink()) {
|
||||
continue;
|
||||
}
|
||||
fs::path path = entry.path().lexically_relative(rev_path);
|
||||
|
||||
if (!path.empty()) {
|
||||
hf_file file;
|
||||
file.repo_id = _repo_id;
|
||||
file.path = path.generic_string();
|
||||
file.local_path = entry.path().string();
|
||||
files.push_back(std::move(file));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
std::string finalize_file(const hf_file & file) {
|
||||
if (file.link_path.empty()) {
|
||||
return file.local_path;
|
||||
}
|
||||
|
||||
fs::path link_path(file.link_path);
|
||||
fs::path local_path(file.local_path);
|
||||
|
||||
std::error_code ec;
|
||||
fs::create_directories(link_path.parent_path(), ec);
|
||||
fs::path target_path = fs::relative(local_path, link_path.parent_path(), ec);
|
||||
fs::create_symlink(target_path, link_path, ec);
|
||||
|
||||
if (fs::exists(link_path)) {
|
||||
return file.link_path;
|
||||
}
|
||||
|
||||
LOG_WRN("%s: failed to create symlink: %s\n", __func__, file.link_path.c_str());
|
||||
return file.local_path;
|
||||
}
|
||||
|
||||
// delete everything after this line, one day
|
||||
|
||||
static std::pair<std::string, std::string> parse_manifest_name(std::string & filename) {
|
||||
static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)");
|
||||
std::smatch match;
|
||||
if (std::regex_match(filename, match, re)) {
|
||||
return {match[1].str(), match[2].str()};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::string make_old_cache_filename(const std::string & owner,
|
||||
const std::string & repo,
|
||||
const std::string & filename) {
|
||||
std::string name = owner + "_" + repo + "_" + filename;
|
||||
for (char & c : name) {
|
||||
if (c == '/') {
|
||||
c = '_';
|
||||
}
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
static bool migrate_single_file(const fs::path & old_cache,
|
||||
const std::string & owner,
|
||||
const std::string & repo,
|
||||
const nl::json & node,
|
||||
const hf_files & files) {
|
||||
|
||||
if (!node.contains("rfilename") ||
|
||||
!node.contains("lfs") ||
|
||||
!node["lfs"].contains("sha256")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string path = node["rfilename"];
|
||||
std::string sha256 = node["lfs"]["sha256"];
|
||||
|
||||
const hf_file * file_info = nullptr;
|
||||
for (const auto & f : files) {
|
||||
if (f.path == path) {
|
||||
file_info = &f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::string old_filename = make_old_cache_filename(owner, repo, path);
|
||||
fs::path old_path = old_cache / old_filename;
|
||||
fs::path etag_path = old_path.string() + ".etag";
|
||||
|
||||
if (!fs::exists(old_path)) {
|
||||
if (fs::exists(etag_path)) {
|
||||
LOG_WRN("%s: %s is orphan, deleting...\n", __func__, etag_path.string().c_str());
|
||||
fs::remove(etag_path);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool delete_old_path = false;
|
||||
|
||||
if (!file_info) {
|
||||
LOG_WRN("%s: %s not found in current repo, deleting...\n", __func__, old_filename.c_str());
|
||||
delete_old_path = true;
|
||||
} else if (!sha256.empty() && !file_info->oid.empty() && sha256 != file_info->oid) {
|
||||
LOG_WRN("%s: %s is not up to date (sha256 mismatch), deleting...\n", __func__, old_filename.c_str());
|
||||
delete_old_path = true;
|
||||
}
|
||||
|
||||
std::error_code ec;
|
||||
|
||||
if (delete_old_path) {
|
||||
fs::remove(old_path, ec);
|
||||
fs::remove(etag_path, ec);
|
||||
return true;
|
||||
}
|
||||
|
||||
fs::path new_path(file_info->local_path);
|
||||
fs::create_directories(new_path.parent_path(), ec);
|
||||
|
||||
if (!fs::exists(new_path, ec)) {
|
||||
fs::rename(old_path, new_path, ec);
|
||||
if (ec) {
|
||||
fs::copy_file(old_path, new_path, ec);
|
||||
if (ec) {
|
||||
LOG_WRN("%s: failed to move/copy %s: %s\n", __func__, old_path.string().c_str(), ec.message().c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
fs::remove(old_path, ec);
|
||||
}
|
||||
fs::remove(etag_path, ec);
|
||||
|
||||
std::string snapshot_path = finalize_file(*file_info);
|
||||
LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), snapshot_path.c_str());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offline) {
|
||||
fs::path old_cache = fs_get_cache_directory();
|
||||
if (!fs::exists(old_cache)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (offline) {
|
||||
LOG_WRN("%s: skipping migration in offline mode (will run when online)\n", __func__);
|
||||
return; // -hf is not going to work
|
||||
}
|
||||
|
||||
for (const auto & entry : fs::directory_iterator(old_cache)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
auto filename = entry.path().filename().string();
|
||||
auto [owner, repo] = parse_manifest_name(filename);
|
||||
|
||||
if (owner.empty() || repo.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto repo_id = owner + "/" + repo;
|
||||
auto files = get_repo_files(repo_id, bearer_token);
|
||||
|
||||
if (files.empty()) {
|
||||
LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
std::ifstream manifest_stream(entry.path());
|
||||
std::string content((std::istreambuf_iterator<char>(manifest_stream)), std::istreambuf_iterator<char>());
|
||||
auto j = nl::json::parse(content);
|
||||
for (const char* key : {"ggufFile", "mmprojFile"}) {
|
||||
if (j.contains(key)) {
|
||||
migrate_single_file(old_cache, owner, repo, j[key], files);
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what());
|
||||
continue;
|
||||
}
|
||||
fs::remove(entry.path());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace hf_cache
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
#pragma once
|
||||
|
||||
#include <filesystem>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// Ref: https://huggingface.co/docs/hub/local-cache.md
|
||||
|
||||
namespace hf_cache {
|
||||
|
||||
struct hf_file {
|
||||
std::string path;
|
||||
std::string url;
|
||||
std::string local_path;
|
||||
std::string link_path;
|
||||
std::string oid;
|
||||
std::string repo_id;
|
||||
};
|
||||
|
||||
using hf_files = std::vector<hf_file>;
|
||||
|
||||
// Get files from HF API
|
||||
hf_files get_repo_files(
|
||||
const std::string & repo_id,
|
||||
const std::string & bearer_token
|
||||
);
|
||||
|
||||
hf_files get_cached_files(const std::string & repo_id);
|
||||
|
||||
// Create symlink if link_path is set and returns the snapshot path
|
||||
std::string finalize_file(const hf_file & file);
|
||||
|
||||
// TODO: Remove later
|
||||
void migrate_old_cache_to_hf_cache(const std::string & bearer_token, bool offline = false);
|
||||
|
||||
} // namespace hf_cache
|
||||
|
|
@ -365,8 +365,8 @@ common_presets common_preset_context::load_from_cache() const {
|
|||
auto cached_models = common_list_cached_models();
|
||||
for (const auto & model : cached_models) {
|
||||
common_preset preset;
|
||||
preset.name = model.to_string();
|
||||
preset.set_option(*this, "LLAMA_ARG_HF_REPO", model.to_string());
|
||||
preset.name = model;
|
||||
preset.set_option(*this, "LLAMA_ARG_HF_REPO", model);
|
||||
out[preset.name] = preset;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -979,37 +979,20 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
|||
for (size_t i = 0; i < params.hf_repo.size(); i++) {
|
||||
common_params_model model;
|
||||
|
||||
// step 1: no `-hff` provided, we auto-detect based on the `-hf` flag
|
||||
if (params.hf_file.empty() || params.hf_file[i].empty()) {
|
||||
auto auto_detected = common_get_hf_file(params.hf_repo[i], params.hf_token, false);
|
||||
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
|
||||
exit(1);
|
||||
}
|
||||
|
||||
model.name = params.hf_repo[i];
|
||||
model.hf_repo = auto_detected.repo;
|
||||
model.hf_file = auto_detected.ggufFile;
|
||||
model.hf_repo = params.hf_repo[i];
|
||||
} else {
|
||||
model.hf_repo = params.hf_repo[i];
|
||||
model.hf_file = params.hf_file[i];
|
||||
}
|
||||
|
||||
// step 2: construct the model cache path
|
||||
std::string clean_fname = model.hf_repo + "_" + model.hf_file;
|
||||
string_replace_all(clean_fname, "\\", "_");
|
||||
string_replace_all(clean_fname, "/", "_");
|
||||
model.path = fs_get_cache_file(clean_fname);
|
||||
|
||||
// step 3: download the model if not exists
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
|
||||
|
||||
bool ok = common_download_model(model, params.hf_token, false);
|
||||
if (!ok) {
|
||||
fprintf(stderr, "error: failed to download model from %s\n", model.url.c_str());
|
||||
auto download_result = common_download_model(model, params.hf_token);
|
||||
if (download_result.model_path.empty()) {
|
||||
fprintf(stderr, "error: failed to download model from HuggingFace\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
params.model.push_back(model.path);
|
||||
params.model.push_back(download_result.model_path);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue