Bläddra i källkod

arg: add --cache-list argument to list cached models (#17073)

* arg: add --cache-list argument to list cached models

* new manifest naming format

* improve naming

* Update common/arg.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Xuan-Son Nguyen 2 månader sedan
förälder
incheckning
aa3b7a90b4
5 ändrade filer med 117 tillägg och 9 borttagningar
  1. 14 0
      common/arg.cpp
  2. 33 0
      common/common.cpp
  3. 7 0
      common/common.h
  4. 45 5
      common/download.cpp
  5. 18 4
      common/download.h

+ 14 - 0
common/arg.cpp

@@ -740,6 +740,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             exit(0);
             exit(0);
         }
         }
     ));
     ));
+    add_opt(common_arg(
+        {"-cl", "--cache-list"},
+        "show list of models in cache",
+        [](common_params &) {
+            printf("model cache directory: %s\n", fs_get_cache_directory().c_str());
+            auto models = common_list_cached_models();
+            printf("number of models in cache: %zu\n", models.size());
+            for (size_t i = 0; i < models.size(); i++) {
+                auto & model = models[i];
+                printf("%4d. %s\n", (int) i + 1, model.to_string().c_str());
+            }
+            exit(0);
+        }
+    ));
     add_opt(common_arg(
     add_opt(common_arg(
         {"--completion-bash"},
         {"--completion-bash"},
         "print source-able bash completion script for llama.cpp",
         "print source-able bash completion script for llama.cpp",

+ 33 - 0
common/common.cpp

@@ -908,6 +908,39 @@ std::string fs_get_cache_file(const std::string & filename) {
     return cache_directory + filename;
     return cache_directory + filename;
 }
 }
 
 
+std::vector<common_file_info> fs_list_files(const std::string & path) {
+    std::vector<common_file_info> files;
+    if (path.empty()) return files;
+
+    std::filesystem::path dir(path);
+    if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
+        return files;
+    }
+
+    for (const auto & entry : std::filesystem::directory_iterator(dir)) {
+        try {
+            // Only include regular files (skip directories)
+            const auto & p = entry.path();
+            if (std::filesystem::is_regular_file(p)) {
+                common_file_info info;
+                info.path = p.string();
+                info.name = p.filename().string();
+                try {
+                    info.size = static_cast<size_t>(std::filesystem::file_size(p));
+                } catch (const std::filesystem::filesystem_error &) {
+                    info.size = 0;
+                }
+                files.push_back(std::move(info));
+            }
+        } catch (const std::filesystem::filesystem_error &) {
+            // skip entries we cannot inspect
+            continue;
+        }
+    }
+
+    return files;
+}
+
 
 
 //
 //
 // Model utils
 // Model utils

+ 7 - 0
common/common.h

@@ -611,6 +611,13 @@ bool fs_create_directory_with_parents(const std::string & path);
 std::string fs_get_cache_directory();
 std::string fs_get_cache_directory();
 std::string fs_get_cache_file(const std::string & filename);
 std::string fs_get_cache_file(const std::string & filename);
 
 
+struct common_file_info {
+    std::string path;
+    std::string name;
+    size_t      size = 0; // in bytes
+};
+std::vector<common_file_info> fs_list_files(const std::string & path);
+
 //
 //
 // Model utils
 // Model utils
 //
 //

+ 45 - 5
common/download.cpp

@@ -50,6 +50,22 @@ using json = nlohmann::ordered_json;
 // downloader
 // downloader
 //
 //
 
 
+// validate repo name format: owner/repo
+static bool validate_repo_name(const std::string & repo) {
+    static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
+    return std::regex_match(repo, repo_regex);
+}
+
+static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
+    // we use "=" to avoid clashing with other component, while still being allowed on windows
+    std::string fname = "manifest=" + repo + "=" + tag + ".json";
+    if (!validate_repo_name(repo)) {
+        throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
+    }
+    string_replace_all(fname, "/", "=");
+    return fs_get_cache_file(fname);
+}
+
 static std::string read_file(const std::string & fname) {
 static std::string read_file(const std::string & fname) {
     std::ifstream file(fname);
     std::ifstream file(fname);
     if (!file) {
     if (!file) {
@@ -829,17 +845,13 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
     // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
     // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
     // User-Agent header is already set in common_remote_get_content, no need to set it here
     // User-Agent header is already set in common_remote_get_content, no need to set it here
 
 
-    // we use "=" to avoid clashing with other component, while still being allowed on windows
-    std::string cached_response_fname = "manifest=" + hf_repo + "=" + tag + ".json";
-    string_replace_all(cached_response_fname, "/", "_");
-    std::string cached_response_path = fs_get_cache_file(cached_response_fname);
-
     // make the request
     // make the request
     common_remote_params params;
     common_remote_params params;
     params.headers = headers;
     params.headers = headers;
     long res_code = 0;
     long res_code = 0;
     std::string res_str;
     std::string res_str;
     bool use_cache = false;
     bool use_cache = false;
+    std::string cached_response_path = get_manifest_path(hf_repo, tag);
     if (!offline) {
     if (!offline) {
         try {
         try {
             auto res = common_remote_get_content(url, params);
             auto res = common_remote_get_content(url, params);
@@ -895,6 +907,33 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
     return { hf_repo, ggufFile, mmprojFile };
     return { hf_repo, ggufFile, mmprojFile };
 }
 }
 
 
+std::vector<common_cached_model_info> common_list_cached_models() {
+    std::vector<common_cached_model_info> models;
+    const std::string cache_dir = fs_get_cache_directory();
+    const std::vector<common_file_info> files = fs_list_files(cache_dir);
+    for (const auto & file : files) {
+        if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
+            common_cached_model_info model_info;
+            model_info.manifest_path = file.path;
+            std::string fname = file.name;
+            string_replace_all(fname, ".json", ""); // remove extension
+            auto parts = string_split<std::string>(fname, '=');
+            if (parts.size() == 4) {
+                // expect format: manifest=<user>=<model>=<tag>=<other>
+                model_info.user  = parts[1];
+                model_info.model = parts[2];
+                model_info.tag   = parts[3];
+            } else {
+                // invalid format
+                continue;
+            }
+            model_info.size = 0; // TODO: get GGUF size, not manifest size
+            models.push_back(model_info);
+        }
+    }
+    return models;
+}
+
 //
 //
 // Docker registry functions
 // Docker registry functions
 //
 //
@@ -959,6 +998,7 @@ std::string common_docker_resolve_model(const std::string & docker) {
         std::string token = common_docker_get_token(repo);  // Get authentication token
         std::string token = common_docker_get_token(repo);  // Get authentication token
 
 
         // Get manifest
         // Get manifest
+        // TODO: cache the manifest response so that it appears in the model list
         const std::string    url_prefix = "https://registry-1.docker.io/v2/" + repo;
         const std::string    url_prefix = "https://registry-1.docker.io/v2/" + repo;
         std::string          manifest_url = url_prefix + "/manifests/" + tag;
         std::string          manifest_url = url_prefix + "/manifests/" + tag;
         common_remote_params manifest_params;
         common_remote_params manifest_params;

+ 18 - 4
common/download.h

@@ -8,16 +8,23 @@ struct common_params_model;
 // download functionalities
 // download functionalities
 //
 //
 
 
+struct common_cached_model_info {
+    std::string manifest_path;
+    std::string user;
+    std::string model;
+    std::string tag;
+    size_t      size = 0; // GGUF size in bytes
+    std::string to_string() const {
+        return user + "/" + model + ":" + tag;
+    }
+};
+
 struct common_hf_file_res {
 struct common_hf_file_res {
     std::string repo; // repo name with ":tag" removed
     std::string repo; // repo name with ":tag" removed
     std::string ggufFile;
     std::string ggufFile;
     std::string mmprojFile;
     std::string mmprojFile;
 };
 };
 
 
-// resolve and download model from Docker registry
-// return local path to downloaded model file
-std::string common_docker_resolve_model(const std::string & docker);
-
 /**
 /**
  * Allow getting the HF file from the HF repo with tag (like ollama), for example:
  * Allow getting the HF file from the HF repo with tag (like ollama), for example:
  * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
  * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
@@ -39,3 +46,10 @@ bool common_download_model(
     const common_params_model & model,
     const common_params_model & model,
     const std::string & bearer_token,
     const std::string & bearer_token,
     bool offline);
     bool offline);
+
+// returns list of cached models
+std::vector<common_cached_model_info> common_list_cached_models();
+
+// resolve and download model from Docker registry
+// return local path to downloaded model file
+std::string common_docker_resolve_model(const std::string & docker);