From 05b3caaa485bd242ad431447e35c988285365e9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Fri, 10 Apr 2026 22:17:00 +0200 Subject: [PATCH] common : add callback interface for download progress (#21735) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- common/arg.cpp | 17 ++-- common/download.cpp | 146 ++++++++++++++++++------------ common/download.h | 34 +++++-- tools/llama-bench/llama-bench.cpp | 4 +- 4 files changed, 126 insertions(+), 75 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index c0cc576f29..3d0183ed70 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -291,14 +291,16 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa hf_tag = "default"; } - const bool offline = params.offline; std::string model_endpoint = get_model_endpoint(); auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini"; // prepare local path for caching auto preset_fname = clean_file_name(hf_repo + "_preset.ini"); auto preset_path = fs_get_cache_file(preset_fname); - const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline); + common_download_opts opts; + opts.bearer_token = params.hf_token; + opts.offline = params.offline; + const int status = common_download_file_single(preset_url, preset_path, opts); const bool has_preset = status >= 200 && status < 400; // remote preset is optional, so we don't error out if not found @@ -341,10 +343,10 @@ static handle_model_result common_params_handle_model(struct common_params_model model.hf_file = model.path; model.path = ""; } - common_download_model_opts opts; - opts.download_mmproj = true; + common_download_opts opts; + opts.bearer_token = bearer_token; opts.offline = offline; - auto download_result = common_download_model(model, bearer_token, opts); + auto download_result = common_download_model(model, opts, true); if (download_result.model_path.empty()) { LOG_ERR("error: failed to download model from Hugging Face\n"); @@ -365,9 +367,10 @@ static handle_model_result common_params_handle_model(struct common_params_model model.path = fs_get_cache_file(string_split(f, '/').back()); } - common_download_model_opts opts; + common_download_opts opts; + opts.bearer_token = bearer_token; opts.offline = offline; - auto download_result = common_download_model(model, bearer_token, opts); + auto download_result = common_download_model(model, opts); if (download_result.model_path.empty()) { LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); exit(1); diff --git a/common/download.cpp b/common/download.cpp index b9e5097123..ccf6fb6867 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -114,7 +114,7 @@ std::pair common_download_split_repo_tag(const std::st return {hf_repo, tag}; } -class ProgressBar { +class ProgressBar : public common_download_callback { static inline std::mutex mutex; static inline std::map lines; static inline int max_line = 0; @@ -138,7 +138,11 @@ class ProgressBar { } public: - ProgressBar(const std::string & url = "") : filename(url) { + ProgressBar() = default; + + void on_start(const common_download_progress & p) override { + filename = p.url; + if (auto pos = filename.rfind('/'); pos != std::string::npos) { filename = filename.substr(pos + 1); } @@ -156,13 +160,13 @@ public: } } - ~ProgressBar() { + void on_done(const common_download_progress &, bool) override { std::lock_guard lock(mutex); cleanup(this); } - void update(size_t current, size_t total) { - if (!total || !is_output_a_tty()) { + void on_update(const common_download_progress & p) override { + if (!p.total || !is_output_a_tty()) { return; } @@ -175,8 +179,8 @@ public: int lines_up = max_line - lines[this]; size_t bar = (55 - len) * 2; - size_t pct = (100 * current) / total; - size_t pos = (bar * current) / total; + size_t pct = (100 * p.downloaded) / p.total; + size_t pos = (bar * p.downloaded) / p.total; if (lines_up > 0) { std::cout << "\033[" << lines_up << "A"; @@ -193,7 +197,7 @@ public: } std::cout << '\r' << std::flush; - if (current == total) { + if (p.downloaded == p.total) { cleanup(this); } } @@ -206,8 +210,8 @@ static bool common_pull_file(httplib::Client & cli, const std::string & resolve_path, const std::string & path_tmp, bool supports_ranges, - size_t existing_size, - size_t & total_size) { + common_download_progress & p, + common_download_callback * callback) { std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app); if (!ofs.is_open()) { LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str()); @@ -215,29 +219,27 @@ static bool common_pull_file(httplib::Client & cli, } httplib::Headers headers; - if (supports_ranges && existing_size > 0) { - headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-"); + if (supports_ranges && p.downloaded > 0) { + headers.emplace("Range", "bytes=" + std::to_string(p.downloaded) + "-"); } const char * func = __func__; // avoid __func__ inside a lambda - size_t downloaded = existing_size; size_t progress_step = 0; - ProgressBar bar(resolve_path); auto res = cli.Get(resolve_path, headers, [&](const httplib::Response &response) { - if (existing_size > 0 && response.status != 206) { + if (p.downloaded > 0 && response.status != 206) { LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status); return false; } - if (existing_size == 0 && response.status != 200) { + if (p.downloaded == 0 && response.status != 200) { LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status); return false; } - if (total_size == 0 && response.has_header("Content-Length")) { + if (p.total == 0 && response.has_header("Content-Length")) { try { size_t content_length = std::stoull(response.get_header_value("Content-Length")); - total_size = existing_size + content_length; + p.total = p.downloaded + content_length; } catch (const std::exception &e) { LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what()); } @@ -250,11 +252,13 @@ static bool common_pull_file(httplib::Client & cli, LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str()); return false; } - downloaded += len; + p.downloaded += len; progress_step += len; - if (progress_step >= total_size / 1000 || downloaded == total_size) { - bar.update(downloaded, total_size); + if (progress_step >= p.total / 1000 || p.downloaded == p.total) { + if (callback) { + callback->on_update(p); + } progress_step = 0; } return true; @@ -275,11 +279,10 @@ static bool common_pull_file(httplib::Client & cli, // download one single file from remote URL to local path // returns status code or -1 on error -static int common_download_file_single_online(const std::string & url, - const std::string & path, - const std::string & bearer_token, - const common_header_list & custom_headers, - bool skip_etag = false) { +static int common_download_file_single_online(const std::string & url, + const std::string & path, + const common_download_opts & opts, + bool skip_etag) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; @@ -293,14 +296,14 @@ static int common_download_file_single_online(const std::string & url, auto [cli, parts] = common_http_client(url); httplib::Headers headers; - for (const auto & h : custom_headers) { + for (const auto & h : opts.headers) { headers.emplace(h.first, h.second); } if (headers.find("User-Agent") == headers.end()) { headers.emplace("User-Agent", "llama-cpp/" + build_info); } - if (!bearer_token.empty()) { - headers.emplace("Authorization", "Bearer " + bearer_token); + if (!opts.bearer_token.empty()) { + headers.emplace("Authorization", "Bearer " + opts.bearer_token); } cli.set_default_headers(headers); @@ -326,10 +329,11 @@ static int common_download_file_single_online(const std::string & url, etag = head->get_header_value("ETag"); } - size_t total_size = 0; + common_download_progress p; + p.url = url; if (head->has_header("Content-Length")) { try { - total_size = std::stoull(head->get_header_value("Content-Length")); + p.total = std::stoull(head->get_header_value("Content-Length")); } catch (const std::exception& e) { LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what()); } @@ -357,13 +361,17 @@ static int common_download_file_single_online(const std::string & url, { // silent std::error_code ec; - std::filesystem::path p(path); - std::filesystem::create_directories(p.parent_path(), ec); + std::filesystem::create_directories(std::filesystem::path(path).parent_path(), ec); } + bool success = false; const std::string path_temporary = path + ".downloadInProgress"; int delay = retry_delay_seconds; + if (opts.callback) { + opts.callback->on_start(p); + } + for (int i = 0; i < max_attempts; ++i) { if (i) { LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay); @@ -378,28 +386,38 @@ static int common_download_file_single_online(const std::string & url, existing_size = std::filesystem::file_size(path_temporary); } else if (remove(path_temporary.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); - return -1; + break; } } + p.downloaded = existing_size; + LOG_DBG("%s: downloading from %s to %s (etag:%s)...\n", __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); - if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) { + if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, p, opts.callback)) { if (std::rename(path_temporary.c_str(), path.c_str()) != 0) { LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return -1; + break; } if (!etag.empty() && !skip_etag) { write_etag(path, etag); } - return head->status; + success = true; + break; } } - LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); - return -1; // max attempts reached + if (opts.callback) { + opts.callback->on_done(p, success); + } + if (!success) { + LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); + return -1; // max attempts reached + } + + return head->status; } std::pair> common_remote_get_content(const std::string & url, @@ -438,12 +456,15 @@ std::pair> common_remote_get_content(const std::string int common_download_file_single(const std::string & url, const std::string & path, - const std::string & bearer_token, - bool offline, - const common_header_list & headers, + const common_download_opts & opts, bool skip_etag) { - if (!offline) { - return common_download_file_single_online(url, path, bearer_token, headers, skip_etag); + if (!opts.offline) { + ProgressBar tty_cb; + common_download_opts online_opts = opts; + if (!online_opts.callback) { + online_opts.callback = &tty_cb; + } + return common_download_file_single_online(url, path, online_opts, skip_etag); } if (!std::filesystem::exists(path)) { @@ -452,6 +473,16 @@ int common_download_file_single(const std::string & url, } LOG_DBG("%s: using cached file (offline mode): %s\n", __func__, path.c_str()); + + // notify the callback that the file was cached + if (opts.callback) { + common_download_progress p; + p.url = url; + p.cached = true; + opts.callback->on_start(p); + opts.callback->on_done(p, true); + } + return 304; // Not Modified - fake cached response } @@ -631,16 +662,16 @@ struct hf_plan { hf_cache::hf_file mmproj; }; -static hf_plan get_hf_plan(const common_params_model & model, - const std::string & token, - const common_download_model_opts & opts) { +static hf_plan get_hf_plan(const common_params_model & model, + const common_download_opts & opts, + bool download_mmproj) { hf_plan plan; hf_cache::hf_files all; auto [repo, tag] = common_download_split_repo_tag(model.hf_repo); if (!opts.offline) { - all = hf_cache::get_repo_files(repo, token); + all = hf_cache::get_repo_files(repo, opts.bearer_token); } if (all.empty()) { all = hf_cache::get_cached_files(repo); @@ -675,7 +706,7 @@ static hf_plan get_hf_plan(const common_params_model & model, plan.primary = primary; plan.model_files = get_split_files(all, primary); - if (opts.download_mmproj) { + if (download_mmproj) { plan.mmproj = find_best_mmproj(all, primary.path); } @@ -710,10 +741,9 @@ static std::vector get_url_tasks(const common_params_model & mode return tasks; } -common_download_model_result common_download_model(const common_params_model & model, - const std::string & bearer_token, - const common_download_model_opts & opts, - const common_header_list & headers) { +common_download_model_result common_download_model(const common_params_model & model, + const common_download_opts & opts, + bool download_mmproj) { common_download_model_result result; std::vector tasks; hf_plan hf; @@ -721,7 +751,7 @@ common_download_model_result common_download_model(const common_params_model bool is_hf = !model.hf_repo.empty(); if (is_hf) { - hf = get_hf_plan(model, bearer_token, opts); + hf = get_hf_plan(model, opts, download_mmproj); for (const auto & f : hf.model_files) { tasks.push_back({f.url, f.local_path}); } @@ -742,8 +772,8 @@ common_download_model_result common_download_model(const common_params_model std::vector> futures; for (const auto & task : tasks) { futures.push_back(std::async(std::launch::async, - [&task, &bearer_token, offline = opts.offline, &headers, is_hf]() { - int status = common_download_file_single(task.url, task.path, bearer_token, offline, headers, is_hf); + [&task, &opts, is_hf]() { + int status = common_download_file_single(task.url, task.path, opts, is_hf); return is_http_status_ok(status); } )); @@ -879,7 +909,9 @@ std::string common_docker_resolve_model(const std::string & docker) { std::string local_path = fs_get_cache_file(model_filename); const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; - const int http_status = common_download_file_single(blob_url, local_path, token, false, {}); + common_download_opts opts; + opts.bearer_token = token; + const int http_status = common_download_file_single(blob_url, local_path, opts); if (!is_http_status_ok(http_status)) { throw std::runtime_error("Failed to download Docker Model"); } diff --git a/common/download.h b/common/download.h index 0a933521fa..48d5ff8a01 100644 --- a/common/download.h +++ b/common/download.h @@ -8,6 +8,21 @@ struct common_params_model; using common_header = std::pair; using common_header_list = std::vector; +struct common_download_progress { + std::string url; + size_t downloaded = 0; + size_t total = 0; + bool cached = false; +}; + +class common_download_callback { +public: + virtual ~common_download_callback() = default; + virtual void on_start(const common_download_progress & p) = 0; + virtual void on_update(const common_download_progress & p) = 0; + virtual void on_done(const common_download_progress & p, bool ok) = 0; +}; + struct common_remote_params { common_header_list headers; long timeout = 0; // in seconds, 0 means no timeout @@ -31,10 +46,12 @@ struct common_cached_model_info { } }; -// Options for common_download_model -struct common_download_model_opts { - bool download_mmproj = false; - bool offline = false; +// Options for common_download_model and common_download_file_single +struct common_download_opts { + std::string bearer_token; + common_header_list headers; + bool offline = false; + common_download_callback * callback = nullptr; }; // Result of common_download_model @@ -69,9 +86,8 @@ struct common_download_model_result { // returns result with model_path and mmproj_path (empty on failure) common_download_model_result common_download_model( const common_params_model & model, - const std::string & bearer_token, - const common_download_model_opts & opts = {}, - const common_header_list & headers = {} + const common_download_opts & opts = {}, + bool download_mmproj = false ); // returns list of cached models @@ -82,9 +98,7 @@ std::vector common_list_cached_models(); // skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash) int common_download_file_single(const std::string & url, const std::string & path, - const std::string & bearer_token, - bool offline, - const common_header_list & headers = {}, + const common_download_opts & opts = {}, bool skip_etag = false); // resolve and download model from Docker registry diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 4f0443532b..b15a26a987 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -1014,7 +1014,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { model.hf_file = params.hf_file[i]; } - auto download_result = common_download_model(model, params.hf_token); + common_download_opts opts; + opts.bearer_token = params.hf_token; + auto download_result = common_download_model(model, opts); if (download_result.model_path.empty()) { fprintf(stderr, "error: failed to download model from HuggingFace\n"); exit(1);