Browse Source

lora: make sure model keep track of associated adapters (#18490)

* lora: make sure model keep track of associated adapters

* deprecate llama_adapter_lora_free

* minor : std::unordered_set over std::set

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Xuan-Son Nguyen 2 weeks ago
parent
commit
a7e6ddb8bd
7 changed files with 24 additions and 22 deletions
  1. 3 1
      include/llama-cpp.h
  2. 2 1
      include/llama.h
  3. 7 13
      src/llama-adapter.cpp
  4. 1 3
      src/llama-adapter.h
  5. 3 1
      src/llama-context.cpp
  6. 5 1
      src/llama-model.cpp
  7. 3 2
      src/llama-model.h

+ 3 - 1
include/llama-cpp.h

@@ -21,7 +21,9 @@ struct llama_sampler_deleter {
 };
 
 struct llama_adapter_lora_deleter {
-    void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
+    void operator()(llama_adapter_lora *) {
+        // llama_adapter_lora_free is deprecated
+    }
 };
 
 typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;

+ 2 - 1
include/llama.h

@@ -646,7 +646,8 @@ extern "C" {
 
     // Manually free a LoRA adapter
     // NOTE: loaded adapters will be free when the associated model is deleted
-    LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
+    LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter),
+            "adapters are now freed together with the associated model");
 
     // Get the invocation tokens if the current lora is an alora
     LLAMA_API uint64_t            llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);

+ 7 - 13
src/llama-adapter.cpp

@@ -146,11 +146,9 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
     return nullptr;
 }
 
-static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_lora & adapter) {
+static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
     LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
 
-    llama_model & model = adapter.model;
-
     ggml_context * ctx_init;
     gguf_init_params meta_gguf_params = {
         /* .no_alloc = */ true,
@@ -413,17 +411,17 @@ static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_l
         }
     }
 
-    // update number of nodes used
-    model.n_lora_nodes += adapter.get_n_nodes();
+    // register adapter with model
+    model.loras.insert(&adapter);
 
     LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
 }
 
 llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
-    llama_adapter_lora * adapter = new llama_adapter_lora(*model);
+    llama_adapter_lora * adapter = new llama_adapter_lora();
 
     try {
-        llama_adapter_lora_init_impl(path_lora, *adapter);
+        llama_adapter_lora_init_impl(*model, path_lora, *adapter);
         return adapter;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
@@ -473,12 +471,8 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter,
     return snprintf(buf, buf_size, "%s", it->second.c_str());
 }
 
-void llama_adapter_lora_free(llama_adapter_lora * adapter) {
-    // update number of nodes used
-    GGML_ASSERT(adapter->model.n_lora_nodes >= adapter->get_n_nodes());
-    adapter->model.n_lora_nodes -= adapter->get_n_nodes();
-
-    delete adapter;
+void llama_adapter_lora_free(llama_adapter_lora *) {
+    // deprecated: adapters are freed by llama_model's destructor
 }
 
 uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) {

+ 1 - 3
src/llama-adapter.h

@@ -59,8 +59,6 @@ struct llama_adapter_lora_weight {
 };
 
 struct llama_adapter_lora {
-    llama_model & model;
-
     // map tensor name to lora_a_b
     std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
 
@@ -75,7 +73,7 @@ struct llama_adapter_lora {
     // activated lora (aLoRA)
     std::vector<llama_token> alora_invocation_tokens;
 
-    llama_adapter_lora(llama_model & model) : model(model) {}
+    llama_adapter_lora() = default;
     ~llama_adapter_lora() = default;
 
     llama_adapter_lora_weight * get_weight(ggml_tensor * w);

+ 3 - 1
src/llama-context.cpp

@@ -1955,7 +1955,9 @@ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
         return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
     }
     uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
-    res += model.n_lora_nodes;
+    for (const auto & lora : model.loras) {
+        res += lora->get_n_nodes();
+    }
     return res;
 }
 

+ 5 - 1
src/llama-model.cpp

@@ -468,7 +468,11 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi
     pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
 }
 
-llama_model::~llama_model() = default;
+llama_model::~llama_model() {
+    for (auto * lora : loras) {
+        delete lora;
+    }
+}
 
 void llama_model::load_stats(llama_model_loader & ml) {
     pimpl->n_elements = ml.n_elements;

+ 3 - 2
src/llama-model.h

@@ -11,6 +11,7 @@
 #include <memory>
 #include <string>
 #include <unordered_map>
+#include <unordered_set>
 #include <vector>
 
 struct llama_cparams;
@@ -476,8 +477,8 @@ struct llama_model {
     // for quantize-stats only
     std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
 
-    // for keeping track of extra nodes used by lora adapters
-    uint32_t n_lora_nodes = 0;
+    // for keeping track of associated LoRA adapters
+    std::unordered_set<llama_adapter_lora *> loras;
 
     int64_t t_load_us  = 0;
     int64_t t_start_us = 0;