Ver Fonte

common : add minimalist multi-thread progress bar (#17602)

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
Adrien Gallouët há 1 mês atrás
pai
commit
b8ee22cfde
1 ficheiros alterados com 69 adições e 25 exclusões
  1. 69 25
      common/download.cpp

+ 69 - 25
common/download.cpp

@@ -12,6 +12,8 @@
 #include <filesystem>
 #include <fstream>
 #include <future>
+#include <map>
+#include <mutex>
 #include <regex>
 #include <string>
 #include <thread>
@@ -472,36 +474,79 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
 
 #elif defined(LLAMA_USE_HTTPLIB)
 
-static bool is_output_a_tty() {
+class ProgressBar {
+    static inline std::mutex mutex;
+    static inline std::map<const ProgressBar *, int> lines;
+    static inline int max_line = 0;
+
+    static void cleanup(const ProgressBar * line) {
+        lines.erase(line);
+        if (lines.empty()) {
+            max_line = 0;
+        }
+    }
+
+    static bool is_output_a_tty() {
 #if defined(_WIN32)
-    return _isatty(_fileno(stdout));
+        return _isatty(_fileno(stdout));
 #else
-    return isatty(1);
+        return isatty(1);
 #endif
-}
+    }
 
-static void print_progress(size_t current, size_t total) {
-    if (!is_output_a_tty()) {
-        return;
+public:
+    ProgressBar() = default;
+
+    ~ProgressBar() {
+        std::lock_guard<std::mutex> lock(mutex);
+        cleanup(this);
     }
 
-    if (!total) {
-        return;
+    void update(size_t current, size_t total) {
+        if (!is_output_a_tty()) {
+            return;
+        }
+
+        if (!total) {
+            return;
+        }
+
+        std::lock_guard<std::mutex> lock(mutex);
+
+        if (lines.find(this) == lines.end()) {
+            lines[this] = max_line++;
+            std::cout << "\n";
+        }
+        int lines_up = max_line - lines[this];
+
+        size_t width = 50;
+        size_t pct = (100 * current) / total;
+        size_t pos = (width * current) / total;
+
+        std::cout << "\033[s";
+
+        if (lines_up > 0) {
+            std::cout << "\033[" << lines_up << "A";
+        }
+        std::cout << "\033[2K\r["
+            << std::string(pos, '=')
+            << (pos < width ? ">" : "")
+            << std::string(width - pos, ' ')
+            << "] " << std::setw(3) << pct << "%  ("
+            << current / (1024 * 1024) << " MB / "
+            << total / (1024 * 1024) << " MB) "
+            << "\033[u";
+
+        std::cout.flush();
+
+        if (current == total) {
+             cleanup(this);
+        }
     }
 
-    size_t width = 50;
-    size_t pct = (100 * current) / total;
-    size_t pos = (width * current) / total;
-
-    std::cout << "["
-              << std::string(pos, '=')
-              << (pos < width ? ">" : "")
-              << std::string(width - pos, ' ')
-              << "] " << std::setw(3) << pct << "%  ("
-              << current / (1024 * 1024) << " MB / "
-              << total / (1024 * 1024) << " MB)\r";
-    std::cout.flush();
-}
+    ProgressBar(const ProgressBar &) = delete;
+    ProgressBar & operator=(const ProgressBar &) = delete;
+};
 
 static bool common_pull_file(httplib::Client & cli,
                              const std::string & resolve_path,
@@ -523,6 +568,7 @@ static bool common_pull_file(httplib::Client & cli,
     const char * func = __func__; // avoid __func__ inside a lambda
     size_t downloaded = existing_size;
     size_t progress_step = 0;
+    ProgressBar bar;
 
     auto res = cli.Get(resolve_path, headers,
         [&](const httplib::Response &response) {
@@ -554,7 +600,7 @@ static bool common_pull_file(httplib::Client & cli,
             progress_step += len;
 
             if (progress_step >= total_size / 1000 || downloaded == total_size) {
-                print_progress(downloaded, total_size);
+                bar.update(downloaded, total_size);
                 progress_step = 0;
             }
             return true;
@@ -562,8 +608,6 @@ static bool common_pull_file(httplib::Client & cli,
         nullptr
     );
 
-    std::cout << "\n";
-
     if (!res) {
         LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
         return false;