common : add minimalist multi-thread progress bar (#17602)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
2eaa2c65cb
commit
b8ee22cfde
|
|
@ -12,6 +12,8 @@
|
|||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <future>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
|
@ -472,36 +474,79 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
|
|||
|
||||
#elif defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
static bool is_output_a_tty() {
|
||||
class ProgressBar {
|
||||
static inline std::mutex mutex;
|
||||
static inline std::map<const ProgressBar *, int> lines;
|
||||
static inline int max_line = 0;
|
||||
|
||||
static void cleanup(const ProgressBar * line) {
|
||||
lines.erase(line);
|
||||
if (lines.empty()) {
|
||||
max_line = 0;
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_output_a_tty() {
|
||||
#if defined(_WIN32)
|
||||
return _isatty(_fileno(stdout));
|
||||
return _isatty(_fileno(stdout));
|
||||
#else
|
||||
return isatty(1);
|
||||
return isatty(1);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void print_progress(size_t current, size_t total) {
|
||||
if (!is_output_a_tty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!total) {
|
||||
return;
|
||||
public:
|
||||
ProgressBar() = default;
|
||||
|
||||
~ProgressBar() {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
cleanup(this);
|
||||
}
|
||||
|
||||
size_t width = 50;
|
||||
size_t pct = (100 * current) / total;
|
||||
size_t pos = (width * current) / total;
|
||||
void update(size_t current, size_t total) {
|
||||
if (!is_output_a_tty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::cout << "["
|
||||
<< std::string(pos, '=')
|
||||
<< (pos < width ? ">" : "")
|
||||
<< std::string(width - pos, ' ')
|
||||
<< "] " << std::setw(3) << pct << "% ("
|
||||
<< current / (1024 * 1024) << " MB / "
|
||||
<< total / (1024 * 1024) << " MB)\r";
|
||||
std::cout.flush();
|
||||
}
|
||||
if (!total) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
|
||||
if (lines.find(this) == lines.end()) {
|
||||
lines[this] = max_line++;
|
||||
std::cout << "\n";
|
||||
}
|
||||
int lines_up = max_line - lines[this];
|
||||
|
||||
size_t width = 50;
|
||||
size_t pct = (100 * current) / total;
|
||||
size_t pos = (width * current) / total;
|
||||
|
||||
std::cout << "\033[s";
|
||||
|
||||
if (lines_up > 0) {
|
||||
std::cout << "\033[" << lines_up << "A";
|
||||
}
|
||||
std::cout << "\033[2K\r["
|
||||
<< std::string(pos, '=')
|
||||
<< (pos < width ? ">" : "")
|
||||
<< std::string(width - pos, ' ')
|
||||
<< "] " << std::setw(3) << pct << "% ("
|
||||
<< current / (1024 * 1024) << " MB / "
|
||||
<< total / (1024 * 1024) << " MB) "
|
||||
<< "\033[u";
|
||||
|
||||
std::cout.flush();
|
||||
|
||||
if (current == total) {
|
||||
cleanup(this);
|
||||
}
|
||||
}
|
||||
|
||||
ProgressBar(const ProgressBar &) = delete;
|
||||
ProgressBar & operator=(const ProgressBar &) = delete;
|
||||
};
|
||||
|
||||
static bool common_pull_file(httplib::Client & cli,
|
||||
const std::string & resolve_path,
|
||||
|
|
@ -523,6 +568,7 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
const char * func = __func__; // avoid __func__ inside a lambda
|
||||
size_t downloaded = existing_size;
|
||||
size_t progress_step = 0;
|
||||
ProgressBar bar;
|
||||
|
||||
auto res = cli.Get(resolve_path, headers,
|
||||
[&](const httplib::Response &response) {
|
||||
|
|
@ -554,7 +600,7 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
progress_step += len;
|
||||
|
||||
if (progress_step >= total_size / 1000 || downloaded == total_size) {
|
||||
print_progress(downloaded, total_size);
|
||||
bar.update(downloaded, total_size);
|
||||
progress_step = 0;
|
||||
}
|
||||
return true;
|
||||
|
|
@ -562,8 +608,6 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
nullptr
|
||||
);
|
||||
|
||||
std::cout << "\n";
|
||||
|
||||
if (!res) {
|
||||
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
|
||||
return false;
|
||||
|
|
|
|||
Loading…
Reference in New Issue