Просмотр исходного кода

common : add HF arg helpers (#6234)

* common : add HF arg helpers

* common : remove defaults
Georgi Gerganov 1 год назад
Родитель
Сommit
80bd33bc2c
2 измененных файлов с 78 добавлено и 17 удалено
  1. 72 13
      common/common.cpp
  2. 6 4
      common/common.h

+ 72 - 13
common/common.cpp

@@ -647,6 +647,22 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
         params.model = argv[i];
         params.model = argv[i];
         return true;
         return true;
     }
     }
+    if (arg == "-md" || arg == "--model-draft") {
+        if (++i >= argc) {
+            invalid_param = true;
+            return true;
+        }
+        params.model_draft = argv[i];
+        return true;
+    }
+    if (arg == "-a" || arg == "--alias") {
+        if (++i >= argc) {
+            invalid_param = true;
+            return true;
+        }
+        params.model_alias = argv[i];
+        return true;
+    }
     if (arg == "-mu" || arg == "--model-url") {
     if (arg == "-mu" || arg == "--model-url") {
         if (++i >= argc) {
         if (++i >= argc) {
             invalid_param = true;
             invalid_param = true;
@@ -655,20 +671,20 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
         params.model_url = argv[i];
         params.model_url = argv[i];
         return true;
         return true;
     }
     }
-    if (arg == "-md" || arg == "--model-draft") {
+    if (arg == "-hfr" || arg == "--hf-repo") {
         if (++i >= argc) {
         if (++i >= argc) {
             invalid_param = true;
             invalid_param = true;
             return true;
             return true;
         }
         }
-        params.model_draft = argv[i];
+        params.hf_repo = argv[i];
         return true;
         return true;
     }
     }
-    if (arg == "-a" || arg == "--alias") {
+    if (arg == "-hff" || arg == "--hf-file") {
         if (++i >= argc) {
         if (++i >= argc) {
             invalid_param = true;
             invalid_param = true;
             return true;
             return true;
         }
         }
-        params.model_alias = argv[i];
+        params.hf_file = argv[i];
         return true;
         return true;
     }
     }
     if (arg == "--lora") {
     if (arg == "--lora") {
@@ -1403,10 +1419,14 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("                        layer range to apply the control vector(s) to, start and end inclusive\n");
     printf("                        layer range to apply the control vector(s) to, start and end inclusive\n");
     printf("  -m FNAME, --model FNAME\n");
     printf("  -m FNAME, --model FNAME\n");
     printf("                        model path (default: %s)\n", params.model.c_str());
     printf("                        model path (default: %s)\n", params.model.c_str());
-    printf("  -mu MODEL_URL, --model-url MODEL_URL\n");
-    printf("                        model download url (default: %s)\n", params.model_url.c_str());
     printf("  -md FNAME, --model-draft FNAME\n");
     printf("  -md FNAME, --model-draft FNAME\n");
-    printf("                        draft model for speculative decoding\n");
+    printf("                        draft model for speculative decoding (default: unused)\n");
+    printf("  -mu MODEL_URL, --model-url MODEL_URL\n");
+    printf("                        model download url (default: unused)\n");
+    printf("  -hfr REPO, --hf-repo REPO\n");
+    printf("                        Hugging Face model repository (default: unused)\n");
+    printf("  -hff FILE, --hf-file FILE\n");
+    printf("                        Hugging Face model file (default: unused)\n");
     printf("  -ld LOGDIR, --logdir LOGDIR\n");
     printf("  -ld LOGDIR, --logdir LOGDIR\n");
     printf("                        path under which to save YAML logs (no logging if unset)\n");
     printf("                        path under which to save YAML logs (no logging if unset)\n");
     printf("  --override-kv KEY=TYPE:VALUE\n");
     printf("  --override-kv KEY=TYPE:VALUE\n");
@@ -1655,8 +1675,10 @@ void llama_batch_add(
 
 
 #ifdef LLAMA_USE_CURL
 #ifdef LLAMA_USE_CURL
 
 
-struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
-                                              struct llama_model_params params) {
+struct llama_model * llama_load_model_from_url(
+        const char * model_url,
+        const char * path_model,
+        const struct llama_model_params & params) {
     // Basic validation of the model_url
     // Basic validation of the model_url
     if (!model_url || strlen(model_url) == 0) {
     if (!model_url || strlen(model_url) == 0) {
         fprintf(stderr, "%s: invalid model_url\n", __func__);
         fprintf(stderr, "%s: invalid model_url\n", __func__);
@@ -1850,25 +1872,62 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
     return llama_load_model_from_file(path_model, params);
     return llama_load_model_from_file(path_model, params);
 }
 }
 
 
+struct llama_model * llama_load_model_from_hf(
+        const char * repo,
+        const char * model,
+        const char * path_model,
+        const struct llama_model_params & params) {
+    // construct hugging face model url:
+    //
+    //  --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
+    //    https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
+    //
+    //  --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
+    //    https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
+    //
+
+    std::string model_url = "https://huggingface.co/";
+    model_url += repo;
+    model_url += "/resolve/main/";
+    model_url += model;
+
+    return llama_load_model_from_url(model_url.c_str(), path_model, params);
+}
+
 #else
 #else
 
 
-struct llama_model * llama_load_model_from_url(const char * /*model_url*/, const char * /*path_model*/,
-                                              struct llama_model_params /*params*/) {
+struct llama_model * llama_load_model_from_url(
+        const char * /*model_url*/,
+        const char * /*path_model*/,
+        const struct llama_model_params & /*params*/) {
     fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
     fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
     return nullptr;
     return nullptr;
 }
 }
 
 
+struct llama_model * llama_load_model_from_hf(
+        const char * /*repo*/,
+        const char * /*model*/,
+        const char * /*path_model*/,
+        const struct llama_model_params & /*params*/) {
+    fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
+    return nullptr;
+}
+
 #endif // LLAMA_USE_CURL
 #endif // LLAMA_USE_CURL
 
 
 std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
 std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
     auto mparams = llama_model_params_from_gpt_params(params);
     auto mparams = llama_model_params_from_gpt_params(params);
 
 
     llama_model * model = nullptr;
     llama_model * model = nullptr;
-    if (!params.model_url.empty()) {
+
+    if (!params.hf_repo.empty() && !params.hf_file.empty()) {
+        model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
+    } else if (!params.model_url.empty()) {
         model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
         model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
     } else {
     } else {
         model = llama_load_model_from_file(params.model.c_str(), mparams);
         model = llama_load_model_from_file(params.model.c_str(), mparams);
     }
     }
+
     if (model == NULL) {
     if (model == NULL) {
         fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
         fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
         return std::make_tuple(nullptr, nullptr);
         return std::make_tuple(nullptr, nullptr);
@@ -1908,7 +1967,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
     }
     }
 
 
     for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
     for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
-        const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
+        const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
         float lora_scale = std::get<1>(params.lora_adapter[i]);
         float lora_scale = std::get<1>(params.lora_adapter[i]);
         int err = llama_model_apply_lora_from_file(model,
         int err = llama_model_apply_lora_from_file(model,
                                              lora_adapter.c_str(),
                                              lora_adapter.c_str(),

+ 6 - 4
common/common.h

@@ -89,9 +89,11 @@ struct gpt_params {
     struct llama_sampling_params sparams;
     struct llama_sampling_params sparams;
 
 
     std::string model             = "models/7B/ggml-model-f16.gguf"; // model path
     std::string model             = "models/7B/ggml-model-f16.gguf"; // model path
-    std::string model_url         = ""; // model url to download
-    std::string model_draft       = "";                              // draft model for speculative decoding
+    std::string model_draft       = "";  // draft model for speculative decoding
     std::string model_alias       = "unknown"; // model alias
     std::string model_alias       = "unknown"; // model alias
+    std::string model_url         = "";  // model url to download
+    std::string hf_repo           = "";  // HF repo
+    std::string hf_file           = "";  // HF file
     std::string prompt            = "";
     std::string prompt            = "";
     std::string prompt_file       = "";  // store the external prompt file name
     std::string prompt_file       = "";  // store the external prompt file name
     std::string path_prompt_cache = "";  // path to file for saving/loading prompt eval state
     std::string path_prompt_cache = "";  // path to file for saving/loading prompt eval state
@@ -192,8 +194,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
 struct llama_model_params   llama_model_params_from_gpt_params  (const gpt_params & params);
 struct llama_model_params   llama_model_params_from_gpt_params  (const gpt_params & params);
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
 
 
-struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
-                                                         struct llama_model_params     params);
+struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
+struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
 
 
 // Batch utils
 // Batch utils