فهرست منبع

rpc : early register backend devices (#11262)

Early register RPC devices and do not propagate RPC specifics in the
llama model structures.

ref: #10609
Radoslav Gerganov 1 سال پیش
والد
کامیت
667d72846c
10فایلهای تغییر یافته به همراه61 افزوده شده و 55 حذف شده
  1. 26 1
      common/arg.cpp
  2. 0 1
      common/common.cpp
  3. 0 1
      common/common.h
  4. 33 4
      examples/llama-bench/llama-bench.cpp
  5. 2 0
      ggml/include/ggml-backend.h
  6. 0 1
      ggml/src/ggml-backend-impl.h
  7. 0 3
      include/llama.h
  8. 0 1
      src/llama-model.cpp
  9. 0 2
      src/llama-model.h
  10. 0 41
      src/llama.cpp

+ 26 - 1
common/arg.cpp

@@ -376,6 +376,30 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
     return devices;
 }
 
+static void add_rpc_devices(std::string servers) {
+    auto rpc_servers = string_split<std::string>(servers, ',');
+    if (rpc_servers.empty()) {
+        throw std::invalid_argument("no RPC servers specified");
+    }
+    ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
+    if (!rpc_reg) {
+        throw std::invalid_argument("failed to find RPC backend");
+    }
+    typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
+    ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
+    if (!ggml_backend_rpc_add_device_fn) {
+        throw std::invalid_argument("failed to find RPC device add function");
+    }
+    for (const auto & server : rpc_servers) {
+        ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
+        if (dev) {
+            ggml_backend_device_register(dev);
+        } else {
+            throw std::invalid_argument("failed to register RPC device");
+        }
+    }
+}
+
 bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
     auto ctx_arg = common_params_parser_init(params, ex, print_usage);
     const common_params params_org = ctx_arg.params; // the example can modify the default params
@@ -1385,7 +1409,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             {"--rpc"}, "SERVERS",
             "comma separated list of RPC servers",
             [](common_params & params, const std::string & value) {
-                params.rpc_servers = value;
+                add_rpc_devices(value);
+                GGML_UNUSED(params);
             }
         ).set_env("LLAMA_ARG_RPC"));
     }

+ 0 - 1
common/common.cpp

@@ -1043,7 +1043,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
     if (params.n_gpu_layers != -1) {
         mparams.n_gpu_layers = params.n_gpu_layers;
     }
-    mparams.rpc_servers     = params.rpc_servers.c_str();
     mparams.main_gpu        = params.main_gpu;
     mparams.split_mode      = params.split_mode;
     mparams.tensor_split    = params.tensor_split;

+ 0 - 1
common/common.h

@@ -246,7 +246,6 @@ struct common_params {
     std::string lookup_cache_static  = ""; // path of static ngram cache file for lookup decoding           // NOLINT
     std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding          // NOLINT
     std::string logits_file          = ""; // file for saving *all* logits                                  // NOLINT
-    std::string rpc_servers          = ""; // comma separated list of RPC servers                           // NOLINT
 
     std::vector<std::string> in_files;   // all input files
     std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)

+ 33 - 4
examples/llama-bench/llama-bench.cpp

@@ -683,7 +683,7 @@ struct cmd_params_instance {
     bool               cpu_strict;
     int                poll;
     int                n_gpu_layers;
-    std::string        rpc_servers;
+    std::string        rpc_servers_str;
     llama_split_mode   split_mode;
     int                main_gpu;
     bool               no_kv_offload;
@@ -696,8 +696,37 @@ struct cmd_params_instance {
         llama_model_params mparams = llama_model_default_params();
 
         mparams.n_gpu_layers = n_gpu_layers;
-        if (!rpc_servers.empty()) {
-            mparams.rpc_servers = rpc_servers.c_str();
+        if (!rpc_servers_str.empty()) {
+            auto rpc_servers = string_split<std::string>(rpc_servers_str, ',');
+
+            // add RPC devices
+            if (!rpc_servers.empty()) {
+                ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
+                if (!rpc_reg) {
+                    fprintf(stderr, "%s: failed to find RPC backend\n", __func__);
+                    exit(1);
+                }
+
+                typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
+                ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
+                if (!ggml_backend_rpc_add_device_fn) {
+                    fprintf(stderr, "%s: failed to find RPC device add function\n", __func__);
+                    exit(1);
+                }
+                static std::vector<ggml_backend_dev_t> devices;
+                devices.clear();
+                for (const std::string & server : rpc_servers) {
+                    ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
+                    if (dev) {
+                        devices.push_back(dev);
+                    } else {
+                        fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
+                        exit(1);
+                    }
+                }
+                devices.push_back(nullptr);
+                mparams.devices = devices.data();
+            }
         }
         mparams.split_mode   = split_mode;
         mparams.main_gpu     = main_gpu;
@@ -708,7 +737,7 @@ struct cmd_params_instance {
     }
 
     bool equal_mparams(const cmd_params_instance & other) const {
-        return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers == other.rpc_servers &&
+        return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
                split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
                tensor_split == other.tensor_split;
     }

+ 2 - 0
ggml/include/ggml-backend.h

@@ -203,6 +203,8 @@ extern "C" {
     // Backend registry
     //
 
+    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
+
     // Backend (reg) enumeration
     GGML_API size_t             ggml_backend_reg_count(void);
     GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);

+ 0 - 1
ggml/src/ggml-backend-impl.h

@@ -208,7 +208,6 @@ extern "C" {
 
     // Internal backend registry API
     GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
-    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
 
     // Add backend dynamic loading support to the backend
 

+ 0 - 3
include/llama.h

@@ -288,9 +288,6 @@ extern "C" {
         // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
         const float * tensor_split;
 
-        // comma separated list of RPC servers to use for offloading
-        const char * rpc_servers;
-
         // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
         // If the provided progress_callback returns true, model loading continues.
         // If it returns false, model loading is immediately aborted.

+ 0 - 1
src/llama-model.cpp

@@ -3717,7 +3717,6 @@ struct llama_model_params llama_model_default_params() {
         /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ nullptr,
-        /*.rpc_servers                 =*/ nullptr,
         /*.progress_callback           =*/ nullptr,
         /*.progress_callback_user_data =*/ nullptr,
         /*.kv_overrides                =*/ nullptr,

+ 0 - 2
src/llama-model.h

@@ -323,8 +323,6 @@ struct llama_model {
     // gguf metadata
     std::unordered_map<std::string, std::string> gguf_kv;
 
-    std::vector<std::string> rpc_servers;
-
     // list of devices used in this model
     std::vector<ggml_backend_dev_t> devices;
 

+ 0 - 41
src/llama.cpp

@@ -9399,47 +9399,6 @@ static struct llama_model * llama_model_load_from_file_impl(
         };
     }
 
-    if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
-        // split the servers set them into model->rpc_servers
-        std::string servers(params.rpc_servers);
-        size_t pos = 0;
-        while ((pos = servers.find(',')) != std::string::npos) {
-            std::string server = servers.substr(0, pos);
-            model->rpc_servers.push_back(server);
-            servers.erase(0, pos + 1);
-        }
-        model->rpc_servers.push_back(servers);
-    }
-
-    // add RPC devices
-    if (!model->rpc_servers.empty()) {
-        ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
-        if (!rpc_reg) {
-            LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
-            llama_model_free(model);
-            return nullptr;
-        }
-
-        typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
-        ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
-        if (!ggml_backend_rpc_add_device_fn) {
-            LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
-            llama_model_free(model);
-            return nullptr;
-        }
-
-        for (const std::string & server : model->rpc_servers) {
-            ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
-            if (dev) {
-                model->devices.push_back(dev);
-            } else {
-                LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
-                llama_model_free(model);
-                return nullptr;
-            }
-        }
-    }
-
     // create list of devices to use with this model
     if (params.devices) {
         for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {