Просмотр исходного кода

common : add -hfd option for the draft model (#11318)

* common : add -hfd option for the draft model

* cont : fix env var

* cont : more fixes
Georgi Gerganov 1 год назад
Родитель
Сommit
80d0d6b4b7
3 измененных файлов с 24 добавлено и 6 удалено
  1. 13 4
      common/arg.cpp
  2. 7 1
      common/common.h
  3. 4 1
      examples/server/server.cpp

+ 13 - 4
common/arg.cpp

@@ -133,7 +133,8 @@ static void common_params_handle_model_default(
         const std::string & model_url,
         std::string & hf_repo,
         std::string & hf_file,
-        const std::string & hf_token) {
+        const std::string & hf_token,
+        const std::string & model_default) {
     if (!hf_repo.empty()) {
         // short-hand to avoid specifying --hf-file -> default it to --model
         if (hf_file.empty()) {
@@ -163,7 +164,7 @@ static void common_params_handle_model_default(
             model = fs_get_cache_file(string_split<std::string>(f, '/').back());
         }
     } else if (model.empty()) {
-        model = DEFAULT_MODEL_PATH;
+        model = model_default;
     }
 }
 
@@ -299,8 +300,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
     }
 
     // TODO: refactor model params in a common struct
-    common_params_handle_model_default(params.model,         params.model_url,         params.hf_repo,         params.hf_file,         params.hf_token);
-    common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
+    common_params_handle_model_default(params.model,             params.model_url,             params.hf_repo,             params.hf_file,             params.hf_token, DEFAULT_MODEL_PATH);
+    common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, "");
+    common_params_handle_model_default(params.vocoder.model,     params.vocoder.model_url,     params.vocoder.hf_repo,     params.vocoder.hf_file,     params.hf_token, "");
 
     if (params.escape) {
         string_process_escapes(params.prompt);
@@ -1629,6 +1631,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.hf_repo = value;
         }
     ).set_env("LLAMA_ARG_HF_REPO"));
+    add_opt(common_arg(
+        {"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
+        "Same as --hf-repo, but for the draft model (default: unused)",
+        [](common_params & params, const std::string & value) {
+            params.speculative.hf_repo = value;
+        }
+    ).set_env("LLAMA_ARG_HFD_REPO"));
     add_opt(common_arg(
         {"-hff", "--hf-file"}, "FILE",
         "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",

+ 7 - 1
common/common.h

@@ -175,7 +175,11 @@ struct common_params_speculative {
     struct cpu_params cpuparams;
     struct cpu_params cpuparams_batch;
 
-    std::string model = ""; // draft model for speculative decoding                          // NOLINT
+    std::string hf_repo = ""; // HF repo                                                     // NOLINT
+    std::string hf_file = ""; // HF file                                                     // NOLINT
+
+    std::string model = "";     // draft model for speculative decoding                      // NOLINT
+    std::string model_url = ""; // model url to download                                     // NOLINT
 };
 
 struct common_params_vocoder {
@@ -508,12 +512,14 @@ struct llama_model * common_load_model_from_url(
     const std::string & local_path,
     const std::string & hf_token,
     const struct llama_model_params & params);
+
 struct llama_model * common_load_model_from_hf(
     const std::string & repo,
     const std::string & remote_path,
     const std::string & local_path,
     const std::string & hf_token,
     const struct llama_model_params & params);
+
 std::pair<std::string, std::string> common_get_hf_file(
     const std::string & hf_repo_with_tag,
     const std::string & hf_token);

+ 4 - 1
examples/server/server.cpp

@@ -1728,13 +1728,16 @@ struct server_context {
         add_bos_token = llama_vocab_get_add_bos(vocab);
         has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
 
-        if (!params_base.speculative.model.empty()) {
+        if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) {
             SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
 
             auto params_dft = params_base;
 
             params_dft.devices      = params_base.speculative.devices;
+            params_dft.hf_file      = params_base.speculative.hf_file;
+            params_dft.hf_repo      = params_base.speculative.hf_repo;
             params_dft.model        = params_base.speculative.model;
+            params_dft.model_url    = params_base.speculative.model_url;
             params_dft.n_ctx        = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
             params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
             params_dft.n_parallel   = 1;