common : add callback interface for download progress (#21735)

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2026-04-10 22:17:00 +02:00 committed by GitHub
parent e62fa13c24
commit 05b3caaa48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 126 additions and 75 deletions

View File

@ -291,14 +291,16 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
hf_tag = "default"; hf_tag = "default";
} }
const bool offline = params.offline;
std::string model_endpoint = get_model_endpoint(); std::string model_endpoint = get_model_endpoint();
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini"; auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
// prepare local path for caching // prepare local path for caching
auto preset_fname = clean_file_name(hf_repo + "_preset.ini"); auto preset_fname = clean_file_name(hf_repo + "_preset.ini");
auto preset_path = fs_get_cache_file(preset_fname); auto preset_path = fs_get_cache_file(preset_fname);
const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline); common_download_opts opts;
opts.bearer_token = params.hf_token;
opts.offline = params.offline;
const int status = common_download_file_single(preset_url, preset_path, opts);
const bool has_preset = status >= 200 && status < 400; const bool has_preset = status >= 200 && status < 400;
// remote preset is optional, so we don't error out if not found // remote preset is optional, so we don't error out if not found
@ -341,10 +343,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
model.hf_file = model.path; model.hf_file = model.path;
model.path = ""; model.path = "";
} }
common_download_model_opts opts; common_download_opts opts;
opts.download_mmproj = true; opts.bearer_token = bearer_token;
opts.offline = offline; opts.offline = offline;
auto download_result = common_download_model(model, bearer_token, opts); auto download_result = common_download_model(model, opts, true);
if (download_result.model_path.empty()) { if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from Hugging Face\n"); LOG_ERR("error: failed to download model from Hugging Face\n");
@ -365,9 +367,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back()); model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
} }
common_download_model_opts opts; common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline; opts.offline = offline;
auto download_result = common_download_model(model, bearer_token, opts); auto download_result = common_download_model(model, opts);
if (download_result.model_path.empty()) { if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
exit(1); exit(1);

View File

@ -114,7 +114,7 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
return {hf_repo, tag}; return {hf_repo, tag};
} }
class ProgressBar { class ProgressBar : public common_download_callback {
static inline std::mutex mutex; static inline std::mutex mutex;
static inline std::map<const ProgressBar *, int> lines; static inline std::map<const ProgressBar *, int> lines;
static inline int max_line = 0; static inline int max_line = 0;
@ -138,7 +138,11 @@ class ProgressBar {
} }
public: public:
ProgressBar(const std::string & url = "") : filename(url) { ProgressBar() = default;
void on_start(const common_download_progress & p) override {
filename = p.url;
if (auto pos = filename.rfind('/'); pos != std::string::npos) { if (auto pos = filename.rfind('/'); pos != std::string::npos) {
filename = filename.substr(pos + 1); filename = filename.substr(pos + 1);
} }
@ -156,13 +160,13 @@ public:
} }
} }
~ProgressBar() { void on_done(const common_download_progress &, bool) override {
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
cleanup(this); cleanup(this);
} }
void update(size_t current, size_t total) { void on_update(const common_download_progress & p) override {
if (!total || !is_output_a_tty()) { if (!p.total || !is_output_a_tty()) {
return; return;
} }
@ -175,8 +179,8 @@ public:
int lines_up = max_line - lines[this]; int lines_up = max_line - lines[this];
size_t bar = (55 - len) * 2; size_t bar = (55 - len) * 2;
size_t pct = (100 * current) / total; size_t pct = (100 * p.downloaded) / p.total;
size_t pos = (bar * current) / total; size_t pos = (bar * p.downloaded) / p.total;
if (lines_up > 0) { if (lines_up > 0) {
std::cout << "\033[" << lines_up << "A"; std::cout << "\033[" << lines_up << "A";
@ -193,7 +197,7 @@ public:
} }
std::cout << '\r' << std::flush; std::cout << '\r' << std::flush;
if (current == total) { if (p.downloaded == p.total) {
cleanup(this); cleanup(this);
} }
} }
@ -206,8 +210,8 @@ static bool common_pull_file(httplib::Client & cli,
const std::string & resolve_path, const std::string & resolve_path,
const std::string & path_tmp, const std::string & path_tmp,
bool supports_ranges, bool supports_ranges,
size_t existing_size, common_download_progress & p,
size_t & total_size) { common_download_callback * callback) {
std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app); std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
if (!ofs.is_open()) { if (!ofs.is_open()) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str()); LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
@ -215,29 +219,27 @@ static bool common_pull_file(httplib::Client & cli,
} }
httplib::Headers headers; httplib::Headers headers;
if (supports_ranges && existing_size > 0) { if (supports_ranges && p.downloaded > 0) {
headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-"); headers.emplace("Range", "bytes=" + std::to_string(p.downloaded) + "-");
} }
const char * func = __func__; // avoid __func__ inside a lambda const char * func = __func__; // avoid __func__ inside a lambda
size_t downloaded = existing_size;
size_t progress_step = 0; size_t progress_step = 0;
ProgressBar bar(resolve_path);
auto res = cli.Get(resolve_path, headers, auto res = cli.Get(resolve_path, headers,
[&](const httplib::Response &response) { [&](const httplib::Response &response) {
if (existing_size > 0 && response.status != 206) { if (p.downloaded > 0 && response.status != 206) {
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status); LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
return false; return false;
} }
if (existing_size == 0 && response.status != 200) { if (p.downloaded == 0 && response.status != 200) {
LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status); LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
return false; return false;
} }
if (total_size == 0 && response.has_header("Content-Length")) { if (p.total == 0 && response.has_header("Content-Length")) {
try { try {
size_t content_length = std::stoull(response.get_header_value("Content-Length")); size_t content_length = std::stoull(response.get_header_value("Content-Length"));
total_size = existing_size + content_length; p.total = p.downloaded + content_length;
} catch (const std::exception &e) { } catch (const std::exception &e) {
LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what()); LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
} }
@ -250,11 +252,13 @@ static bool common_pull_file(httplib::Client & cli,
LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str()); LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
return false; return false;
} }
downloaded += len; p.downloaded += len;
progress_step += len; progress_step += len;
if (progress_step >= total_size / 1000 || downloaded == total_size) { if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
bar.update(downloaded, total_size); if (callback) {
callback->on_update(p);
}
progress_step = 0; progress_step = 0;
} }
return true; return true;
@ -275,11 +279,10 @@ static bool common_pull_file(httplib::Client & cli,
// download one single file from remote URL to local path // download one single file from remote URL to local path
// returns status code or -1 on error // returns status code or -1 on error
static int common_download_file_single_online(const std::string & url, static int common_download_file_single_online(const std::string & url,
const std::string & path, const std::string & path,
const std::string & bearer_token, const common_download_opts & opts,
const common_header_list & custom_headers, bool skip_etag) {
bool skip_etag = false) {
static const int max_attempts = 3; static const int max_attempts = 3;
static const int retry_delay_seconds = 2; static const int retry_delay_seconds = 2;
@ -293,14 +296,14 @@ static int common_download_file_single_online(const std::string & url,
auto [cli, parts] = common_http_client(url); auto [cli, parts] = common_http_client(url);
httplib::Headers headers; httplib::Headers headers;
for (const auto & h : custom_headers) { for (const auto & h : opts.headers) {
headers.emplace(h.first, h.second); headers.emplace(h.first, h.second);
} }
if (headers.find("User-Agent") == headers.end()) { if (headers.find("User-Agent") == headers.end()) {
headers.emplace("User-Agent", "llama-cpp/" + build_info); headers.emplace("User-Agent", "llama-cpp/" + build_info);
} }
if (!bearer_token.empty()) { if (!opts.bearer_token.empty()) {
headers.emplace("Authorization", "Bearer " + bearer_token); headers.emplace("Authorization", "Bearer " + opts.bearer_token);
} }
cli.set_default_headers(headers); cli.set_default_headers(headers);
@ -326,10 +329,11 @@ static int common_download_file_single_online(const std::string & url,
etag = head->get_header_value("ETag"); etag = head->get_header_value("ETag");
} }
size_t total_size = 0; common_download_progress p;
p.url = url;
if (head->has_header("Content-Length")) { if (head->has_header("Content-Length")) {
try { try {
total_size = std::stoull(head->get_header_value("Content-Length")); p.total = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) { } catch (const std::exception& e) {
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what()); LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
} }
@ -357,13 +361,17 @@ static int common_download_file_single_online(const std::string & url,
{ // silent { // silent
std::error_code ec; std::error_code ec;
std::filesystem::path p(path); std::filesystem::create_directories(std::filesystem::path(path).parent_path(), ec);
std::filesystem::create_directories(p.parent_path(), ec);
} }
bool success = false;
const std::string path_temporary = path + ".downloadInProgress"; const std::string path_temporary = path + ".downloadInProgress";
int delay = retry_delay_seconds; int delay = retry_delay_seconds;
if (opts.callback) {
opts.callback->on_start(p);
}
for (int i = 0; i < max_attempts; ++i) { for (int i = 0; i < max_attempts; ++i) {
if (i) { if (i) {
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay); LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
@ -378,28 +386,38 @@ static int common_download_file_single_online(const std::string & url,
existing_size = std::filesystem::file_size(path_temporary); existing_size = std::filesystem::file_size(path_temporary);
} else if (remove(path_temporary.c_str()) != 0) { } else if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
return -1; break;
} }
} }
p.downloaded = existing_size;
LOG_DBG("%s: downloading from %s to %s (etag:%s)...\n", LOG_DBG("%s: downloading from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(), __func__, common_http_show_masked_url(parts).c_str(),
path_temporary.c_str(), etag.c_str()); path_temporary.c_str(), etag.c_str());
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) { if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, p, opts.callback)) {
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) { if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1; break;
} }
if (!etag.empty() && !skip_etag) { if (!etag.empty() && !skip_etag) {
write_etag(path, etag); write_etag(path, etag);
} }
return head->status; success = true;
break;
} }
} }
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); if (opts.callback) {
return -1; // max attempts reached opts.callback->on_done(p, success);
}
if (!success) {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached
}
return head->status;
} }
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
@ -438,12 +456,15 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
int common_download_file_single(const std::string & url, int common_download_file_single(const std::string & url,
const std::string & path, const std::string & path,
const std::string & bearer_token, const common_download_opts & opts,
bool offline,
const common_header_list & headers,
bool skip_etag) { bool skip_etag) {
if (!offline) { if (!opts.offline) {
return common_download_file_single_online(url, path, bearer_token, headers, skip_etag); ProgressBar tty_cb;
common_download_opts online_opts = opts;
if (!online_opts.callback) {
online_opts.callback = &tty_cb;
}
return common_download_file_single_online(url, path, online_opts, skip_etag);
} }
if (!std::filesystem::exists(path)) { if (!std::filesystem::exists(path)) {
@ -452,6 +473,16 @@ int common_download_file_single(const std::string & url,
} }
LOG_DBG("%s: using cached file (offline mode): %s\n", __func__, path.c_str()); LOG_DBG("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
// notify the callback that the file was cached
if (opts.callback) {
common_download_progress p;
p.url = url;
p.cached = true;
opts.callback->on_start(p);
opts.callback->on_done(p, true);
}
return 304; // Not Modified - fake cached response return 304; // Not Modified - fake cached response
} }
@ -631,16 +662,16 @@ struct hf_plan {
hf_cache::hf_file mmproj; hf_cache::hf_file mmproj;
}; };
static hf_plan get_hf_plan(const common_params_model & model, static hf_plan get_hf_plan(const common_params_model & model,
const std::string & token, const common_download_opts & opts,
const common_download_model_opts & opts) { bool download_mmproj) {
hf_plan plan; hf_plan plan;
hf_cache::hf_files all; hf_cache::hf_files all;
auto [repo, tag] = common_download_split_repo_tag(model.hf_repo); auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
if (!opts.offline) { if (!opts.offline) {
all = hf_cache::get_repo_files(repo, token); all = hf_cache::get_repo_files(repo, opts.bearer_token);
} }
if (all.empty()) { if (all.empty()) {
all = hf_cache::get_cached_files(repo); all = hf_cache::get_cached_files(repo);
@ -675,7 +706,7 @@ static hf_plan get_hf_plan(const common_params_model & model,
plan.primary = primary; plan.primary = primary;
plan.model_files = get_split_files(all, primary); plan.model_files = get_split_files(all, primary);
if (opts.download_mmproj) { if (download_mmproj) {
plan.mmproj = find_best_mmproj(all, primary.path); plan.mmproj = find_best_mmproj(all, primary.path);
} }
@ -710,10 +741,9 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode
return tasks; return tasks;
} }
common_download_model_result common_download_model(const common_params_model & model, common_download_model_result common_download_model(const common_params_model & model,
const std::string & bearer_token, const common_download_opts & opts,
const common_download_model_opts & opts, bool download_mmproj) {
const common_header_list & headers) {
common_download_model_result result; common_download_model_result result;
std::vector<download_task> tasks; std::vector<download_task> tasks;
hf_plan hf; hf_plan hf;
@ -721,7 +751,7 @@ common_download_model_result common_download_model(const common_params_model
bool is_hf = !model.hf_repo.empty(); bool is_hf = !model.hf_repo.empty();
if (is_hf) { if (is_hf) {
hf = get_hf_plan(model, bearer_token, opts); hf = get_hf_plan(model, opts, download_mmproj);
for (const auto & f : hf.model_files) { for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path}); tasks.push_back({f.url, f.local_path});
} }
@ -742,8 +772,8 @@ common_download_model_result common_download_model(const common_params_model
std::vector<std::future<bool>> futures; std::vector<std::future<bool>> futures;
for (const auto & task : tasks) { for (const auto & task : tasks) {
futures.push_back(std::async(std::launch::async, futures.push_back(std::async(std::launch::async,
[&task, &bearer_token, offline = opts.offline, &headers, is_hf]() { [&task, &opts, is_hf]() {
int status = common_download_file_single(task.url, task.path, bearer_token, offline, headers, is_hf); int status = common_download_file_single(task.url, task.path, opts, is_hf);
return is_http_status_ok(status); return is_http_status_ok(status);
} }
)); ));
@ -879,7 +909,9 @@ std::string common_docker_resolve_model(const std::string & docker) {
std::string local_path = fs_get_cache_file(model_filename); std::string local_path = fs_get_cache_file(model_filename);
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
const int http_status = common_download_file_single(blob_url, local_path, token, false, {}); common_download_opts opts;
opts.bearer_token = token;
const int http_status = common_download_file_single(blob_url, local_path, opts);
if (!is_http_status_ok(http_status)) { if (!is_http_status_ok(http_status)) {
throw std::runtime_error("Failed to download Docker Model"); throw std::runtime_error("Failed to download Docker Model");
} }

View File

@ -8,6 +8,21 @@ struct common_params_model;
using common_header = std::pair<std::string, std::string>; using common_header = std::pair<std::string, std::string>;
using common_header_list = std::vector<common_header>; using common_header_list = std::vector<common_header>;
struct common_download_progress {
std::string url;
size_t downloaded = 0;
size_t total = 0;
bool cached = false;
};
class common_download_callback {
public:
virtual ~common_download_callback() = default;
virtual void on_start(const common_download_progress & p) = 0;
virtual void on_update(const common_download_progress & p) = 0;
virtual void on_done(const common_download_progress & p, bool ok) = 0;
};
struct common_remote_params { struct common_remote_params {
common_header_list headers; common_header_list headers;
long timeout = 0; // in seconds, 0 means no timeout long timeout = 0; // in seconds, 0 means no timeout
@ -31,10 +46,12 @@ struct common_cached_model_info {
} }
}; };
// Options for common_download_model // Options for common_download_model and common_download_file_single
struct common_download_model_opts { struct common_download_opts {
bool download_mmproj = false; std::string bearer_token;
bool offline = false; common_header_list headers;
bool offline = false;
common_download_callback * callback = nullptr;
}; };
// Result of common_download_model // Result of common_download_model
@ -69,9 +86,8 @@ struct common_download_model_result {
// returns result with model_path and mmproj_path (empty on failure) // returns result with model_path and mmproj_path (empty on failure)
common_download_model_result common_download_model( common_download_model_result common_download_model(
const common_params_model & model, const common_params_model & model,
const std::string & bearer_token, const common_download_opts & opts = {},
const common_download_model_opts & opts = {}, bool download_mmproj = false
const common_header_list & headers = {}
); );
// returns list of cached models // returns list of cached models
@ -82,9 +98,7 @@ std::vector<common_cached_model_info> common_list_cached_models();
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash) // skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
int common_download_file_single(const std::string & url, int common_download_file_single(const std::string & url,
const std::string & path, const std::string & path,
const std::string & bearer_token, const common_download_opts & opts = {},
bool offline,
const common_header_list & headers = {},
bool skip_etag = false); bool skip_etag = false);
// resolve and download model from Docker registry // resolve and download model from Docker registry

View File

@ -1014,7 +1014,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
model.hf_file = params.hf_file[i]; model.hf_file = params.hf_file[i];
} }
auto download_result = common_download_model(model, params.hf_token); common_download_opts opts;
opts.bearer_token = params.hf_token;
auto download_result = common_download_model(model, opts);
if (download_result.model_path.empty()) { if (download_result.model_path.empty()) {
fprintf(stderr, "error: failed to download model from HuggingFace\n"); fprintf(stderr, "error: failed to download model from HuggingFace\n");
exit(1); exit(1);