common : add minimalist multi-thread progress bar (#17602)

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-12-12 12:44:35 +01:00 committed by GitHub
parent 2eaa2c65cb
commit b8ee22cfde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 69 additions and 25 deletions

View File

@ -12,6 +12,8 @@
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <future> #include <future>
#include <map>
#include <mutex>
#include <regex> #include <regex>
#include <string> #include <string>
#include <thread> #include <thread>
@ -472,15 +474,35 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
#elif defined(LLAMA_USE_HTTPLIB) #elif defined(LLAMA_USE_HTTPLIB)
static bool is_output_a_tty() { class ProgressBar {
static inline std::mutex mutex;
static inline std::map<const ProgressBar *, int> lines;
static inline int max_line = 0;
static void cleanup(const ProgressBar * line) {
lines.erase(line);
if (lines.empty()) {
max_line = 0;
}
}
static bool is_output_a_tty() {
#if defined(_WIN32) #if defined(_WIN32)
return _isatty(_fileno(stdout)); return _isatty(_fileno(stdout));
#else #else
return isatty(1); return isatty(1);
#endif #endif
} }
static void print_progress(size_t current, size_t total) { public:
ProgressBar() = default;
~ProgressBar() {
std::lock_guard<std::mutex> lock(mutex);
cleanup(this);
}
void update(size_t current, size_t total) {
if (!is_output_a_tty()) { if (!is_output_a_tty()) {
return; return;
} }
@ -489,19 +511,42 @@ static void print_progress(size_t current, size_t total) {
return; return;
} }
std::lock_guard<std::mutex> lock(mutex);
if (lines.find(this) == lines.end()) {
lines[this] = max_line++;
std::cout << "\n";
}
int lines_up = max_line - lines[this];
size_t width = 50; size_t width = 50;
size_t pct = (100 * current) / total; size_t pct = (100 * current) / total;
size_t pos = (width * current) / total; size_t pos = (width * current) / total;
std::cout << "[" std::cout << "\033[s";
if (lines_up > 0) {
std::cout << "\033[" << lines_up << "A";
}
std::cout << "\033[2K\r["
<< std::string(pos, '=') << std::string(pos, '=')
<< (pos < width ? ">" : "") << (pos < width ? ">" : "")
<< std::string(width - pos, ' ') << std::string(width - pos, ' ')
<< "] " << std::setw(3) << pct << "% (" << "] " << std::setw(3) << pct << "% ("
<< current / (1024 * 1024) << " MB / " << current / (1024 * 1024) << " MB / "
<< total / (1024 * 1024) << " MB)\r"; << total / (1024 * 1024) << " MB) "
<< "\033[u";
std::cout.flush(); std::cout.flush();
}
if (current == total) {
cleanup(this);
}
}
ProgressBar(const ProgressBar &) = delete;
ProgressBar & operator=(const ProgressBar &) = delete;
};
static bool common_pull_file(httplib::Client & cli, static bool common_pull_file(httplib::Client & cli,
const std::string & resolve_path, const std::string & resolve_path,
@ -523,6 +568,7 @@ static bool common_pull_file(httplib::Client & cli,
const char * func = __func__; // avoid __func__ inside a lambda const char * func = __func__; // avoid __func__ inside a lambda
size_t downloaded = existing_size; size_t downloaded = existing_size;
size_t progress_step = 0; size_t progress_step = 0;
ProgressBar bar;
auto res = cli.Get(resolve_path, headers, auto res = cli.Get(resolve_path, headers,
[&](const httplib::Response &response) { [&](const httplib::Response &response) {
@ -554,7 +600,7 @@ static bool common_pull_file(httplib::Client & cli,
progress_step += len; progress_step += len;
if (progress_step >= total_size / 1000 || downloaded == total_size) { if (progress_step >= total_size / 1000 || downloaded == total_size) {
print_progress(downloaded, total_size); bar.update(downloaded, total_size);
progress_step = 0; progress_step = 0;
} }
return true; return true;
@ -562,8 +608,6 @@ static bool common_pull_file(httplib::Client & cli,
nullptr nullptr
); );
std::cout << "\n";
if (!res) { if (!res) {
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1); LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
return false; return false;