common : add minimalist multi-thread progress bar (#17602)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
2eaa2c65cb
commit
b8ee22cfde
|
|
@ -12,6 +12,8 @@
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <future>
|
#include <future>
|
||||||
|
#include <map>
|
||||||
|
#include <mutex>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
@ -472,36 +474,79 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
|
||||||
|
|
||||||
#elif defined(LLAMA_USE_HTTPLIB)
|
#elif defined(LLAMA_USE_HTTPLIB)
|
||||||
|
|
||||||
static bool is_output_a_tty() {
|
class ProgressBar {
|
||||||
|
static inline std::mutex mutex;
|
||||||
|
static inline std::map<const ProgressBar *, int> lines;
|
||||||
|
static inline int max_line = 0;
|
||||||
|
|
||||||
|
static void cleanup(const ProgressBar * line) {
|
||||||
|
lines.erase(line);
|
||||||
|
if (lines.empty()) {
|
||||||
|
max_line = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool is_output_a_tty() {
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
return _isatty(_fileno(stdout));
|
return _isatty(_fileno(stdout));
|
||||||
#else
|
#else
|
||||||
return isatty(1);
|
return isatty(1);
|
||||||
#endif
|
#endif
|
||||||
}
|
|
||||||
|
|
||||||
static void print_progress(size_t current, size_t total) {
|
|
||||||
if (!is_output_a_tty()) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!total) {
|
public:
|
||||||
return;
|
ProgressBar() = default;
|
||||||
|
|
||||||
|
~ProgressBar() {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
cleanup(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t width = 50;
|
void update(size_t current, size_t total) {
|
||||||
size_t pct = (100 * current) / total;
|
if (!is_output_a_tty()) {
|
||||||
size_t pos = (width * current) / total;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::cout << "["
|
if (!total) {
|
||||||
<< std::string(pos, '=')
|
return;
|
||||||
<< (pos < width ? ">" : "")
|
}
|
||||||
<< std::string(width - pos, ' ')
|
|
||||||
<< "] " << std::setw(3) << pct << "% ("
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
<< current / (1024 * 1024) << " MB / "
|
|
||||||
<< total / (1024 * 1024) << " MB)\r";
|
if (lines.find(this) == lines.end()) {
|
||||||
std::cout.flush();
|
lines[this] = max_line++;
|
||||||
}
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
int lines_up = max_line - lines[this];
|
||||||
|
|
||||||
|
size_t width = 50;
|
||||||
|
size_t pct = (100 * current) / total;
|
||||||
|
size_t pos = (width * current) / total;
|
||||||
|
|
||||||
|
std::cout << "\033[s";
|
||||||
|
|
||||||
|
if (lines_up > 0) {
|
||||||
|
std::cout << "\033[" << lines_up << "A";
|
||||||
|
}
|
||||||
|
std::cout << "\033[2K\r["
|
||||||
|
<< std::string(pos, '=')
|
||||||
|
<< (pos < width ? ">" : "")
|
||||||
|
<< std::string(width - pos, ' ')
|
||||||
|
<< "] " << std::setw(3) << pct << "% ("
|
||||||
|
<< current / (1024 * 1024) << " MB / "
|
||||||
|
<< total / (1024 * 1024) << " MB) "
|
||||||
|
<< "\033[u";
|
||||||
|
|
||||||
|
std::cout.flush();
|
||||||
|
|
||||||
|
if (current == total) {
|
||||||
|
cleanup(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ProgressBar(const ProgressBar &) = delete;
|
||||||
|
ProgressBar & operator=(const ProgressBar &) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
static bool common_pull_file(httplib::Client & cli,
|
static bool common_pull_file(httplib::Client & cli,
|
||||||
const std::string & resolve_path,
|
const std::string & resolve_path,
|
||||||
|
|
@ -523,6 +568,7 @@ static bool common_pull_file(httplib::Client & cli,
|
||||||
const char * func = __func__; // avoid __func__ inside a lambda
|
const char * func = __func__; // avoid __func__ inside a lambda
|
||||||
size_t downloaded = existing_size;
|
size_t downloaded = existing_size;
|
||||||
size_t progress_step = 0;
|
size_t progress_step = 0;
|
||||||
|
ProgressBar bar;
|
||||||
|
|
||||||
auto res = cli.Get(resolve_path, headers,
|
auto res = cli.Get(resolve_path, headers,
|
||||||
[&](const httplib::Response &response) {
|
[&](const httplib::Response &response) {
|
||||||
|
|
@ -554,7 +600,7 @@ static bool common_pull_file(httplib::Client & cli,
|
||||||
progress_step += len;
|
progress_step += len;
|
||||||
|
|
||||||
if (progress_step >= total_size / 1000 || downloaded == total_size) {
|
if (progress_step >= total_size / 1000 || downloaded == total_size) {
|
||||||
print_progress(downloaded, total_size);
|
bar.update(downloaded, total_size);
|
||||||
progress_step = 0;
|
progress_step = 0;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -562,8 +608,6 @@ static bool common_pull_file(httplib::Client & cli,
|
||||||
nullptr
|
nullptr
|
||||||
);
|
);
|
||||||
|
|
||||||
std::cout << "\n";
|
|
||||||
|
|
||||||
if (!res) {
|
if (!res) {
|
||||||
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
|
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
|
||||||
return false;
|
return false;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue