Просмотр исходного кода

added support for Authorization Bearer tokens when downloading model (#8307)

* added support for Authorization Bearer tokens

* removed auth_token, removed set_ function, other small fixes

* Update common/common.cpp

---------

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Derrick T. Woolworth 1 год назад
Родитель
Сommit
86e7299ef5
2 измененных файлов с 41 добавлено и 9 удалено
  1. 37 7
      common/common.cpp
  2. 4 2
      common/common.h

+ 37 - 7
common/common.cpp

@@ -190,6 +190,12 @@ int32_t cpu_get_num_math() {
 // CLI argument parsing
 //
 
+void gpt_params_handle_hf_token(gpt_params & params) {
+    if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
+        params.hf_token = std::getenv("HF_TOKEN");
+    }
+}
+
 void gpt_params_handle_model_default(gpt_params & params) {
     if (!params.hf_repo.empty()) {
         // short-hand to avoid specifying --hf-file -> default it to --model
@@ -237,6 +243,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
 
     gpt_params_handle_model_default(params);
 
+    gpt_params_handle_hf_token(params);
+
     if (params.escape) {
         string_process_escapes(params.prompt);
         string_process_escapes(params.input_prefix);
@@ -652,6 +660,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         params.model_url = argv[i];
         return true;
     }
+    if (arg == "-hft" || arg == "--hf-token") {
+        if (++i >= argc) {
+          invalid_param = true;
+          return true;
+        }
+        params.hf_token = argv[i];
+        return true;
+    }
     if (arg == "-hfr" || arg == "--hf-repo") {
         CHECK_ARG
         params.hf_repo = argv[i];
@@ -1576,6 +1592,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
     options.push_back({ "*",           "-mu,   --model-url MODEL_URL",  "model download url (default: unused)" });
     options.push_back({ "*",           "-hfr,  --hf-repo REPO",         "Hugging Face model repository (default: unused)" });
     options.push_back({ "*",           "-hff,  --hf-file FILE",         "Hugging Face model file (default: unused)" });
+    options.push_back({ "*",           "-hft,  --hf-token TOKEN",       "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
 
     options.push_back({ "retrieval" });
     options.push_back({ "retrieval",   "       --context-file FNAME",   "file to load context from (repeat to specify multiple files)" });
@@ -2015,9 +2032,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
     llama_model * model = nullptr;
 
     if (!params.hf_repo.empty() && !params.hf_file.empty()) {
-        model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
+        model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
     } else if (!params.model_url.empty()) {
-        model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
+        model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
     } else {
         model = llama_load_model_from_file(params.model.c_str(), mparams);
     }
@@ -2205,7 +2222,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
     return str.rfind(prefix, 0) == 0;
 }
 
-static bool llama_download_file(const std::string & url, const std::string & path) {
+static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
 
     // Initialize libcurl
     std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
@@ -2220,6 +2237,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
     curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
     curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
 
+    // Check if hf-token or bearer-token was specified
+    if (!hf_token.empty()) {
+      std::string auth_header = "Authorization: Bearer ";
+      auth_header += hf_token.c_str();
+      struct curl_slist *http_headers = NULL;
+      http_headers = curl_slist_append(http_headers, auth_header.c_str());
+      curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
+    }
+
 #if defined(_WIN32)
     // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
     //   operating system. Currently implemented under MS-Windows.
@@ -2415,6 +2441,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat
 struct llama_model * llama_load_model_from_url(
         const char * model_url,
         const char * path_model,
+        const char * hf_token,
         const struct llama_model_params & params) {
     // Basic validation of the model_url
     if (!model_url || strlen(model_url) == 0) {
@@ -2422,7 +2449,7 @@ struct llama_model * llama_load_model_from_url(
         return NULL;
     }
 
-    if (!llama_download_file(model_url, path_model)) {
+    if (!llama_download_file(model_url, path_model, hf_token)) {
         return NULL;
     }
 
@@ -2470,14 +2497,14 @@ struct llama_model * llama_load_model_from_url(
         // Prepare download in parallel
         std::vector<std::future<bool>> futures_download;
         for (int idx = 1; idx < n_split; idx++) {
-            futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool {
+            futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
                 char split_path[PATH_MAX] = {0};
                 llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
 
                 char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
                 llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
 
-                return llama_download_file(split_url, split_path);
+                return llama_download_file(split_url, split_path, hf_token);
             }, idx));
         }
 
@@ -2496,6 +2523,7 @@ struct llama_model * llama_load_model_from_hf(
         const char * repo,
         const char * model,
         const char * path_model,
+        const char * hf_token,
         const struct llama_model_params & params) {
     // construct hugging face model url:
     //
@@ -2511,7 +2539,7 @@ struct llama_model * llama_load_model_from_hf(
     model_url += "/resolve/main/";
     model_url += model;
 
-    return llama_load_model_from_url(model_url.c_str(), path_model, params);
+    return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params);
 }
 
 #else
@@ -2519,6 +2547,7 @@ struct llama_model * llama_load_model_from_hf(
 struct llama_model * llama_load_model_from_url(
         const char * /*model_url*/,
         const char * /*path_model*/,
+        const char * /*hf_token*/,
         const struct llama_model_params & /*params*/) {
     fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
     return nullptr;
@@ -2528,6 +2557,7 @@ struct llama_model * llama_load_model_from_hf(
         const char * /*repo*/,
         const char * /*model*/,
         const char * /*path_model*/,
+        const char * /*hf_token*/,
         const struct llama_model_params & /*params*/) {
     fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
     return nullptr;

+ 4 - 2
common/common.h

@@ -108,6 +108,7 @@ struct gpt_params {
     std::string model_draft          = ""; // draft model for speculative decoding
     std::string model_alias          = "unknown"; // model alias
     std::string model_url            = ""; // model url to download
+    std::string hf_token             = ""; // HF token
     std::string hf_repo              = ""; // HF repo
     std::string hf_file              = ""; // HF file
     std::string prompt               = "";
@@ -256,6 +257,7 @@ struct gpt_params {
     bool spm_infill = false; // suffix/prefix/middle pattern for infill
 };
 
+void gpt_params_handle_hf_token(gpt_params & params);
 void gpt_params_handle_model_default(gpt_params & params);
 
 bool gpt_params_parse_ex   (int argc, char ** argv, gpt_params & params);
@@ -311,8 +313,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
 struct llama_model_params   llama_model_params_from_gpt_params  (const gpt_params & params);
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
 
-struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
-struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
+struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
+struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
 
 // Batch utils