Bläddra i källkod

Load all MoE experts during warmup (#11571)

* llama : introduce llama_set_warmup() API call that controls warmup mode; use all MoE experts during warmup

* common : use new API to enable warmup mode during model warmup

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
fairydreaming 10 månader sedan
förälder
incheckning
8fcb563613
6 ändrade filer med 22 tillägg och 2 borttagningar
  1. 3 0
      common/common.cpp
  2. 4 0
      include/llama.h
  3. 12 1
      src/llama-context.cpp
  4. 1 0
      src/llama-context.h
  5. 1 0
      src/llama-cparams.h
  6. 1 1
      src/llama-graph.cpp

+ 3 - 0
common/common.cpp

@@ -1033,6 +1033,8 @@ struct common_init_result common_init_from_params(common_params & params) {
     if (params.warmup) {
         LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
 
+        llama_set_warmup(lctx, true);
+
         std::vector<llama_token> tmp;
         llama_token bos = llama_vocab_bos(vocab);
         llama_token eos = llama_vocab_eos(vocab);
@@ -1063,6 +1065,7 @@ struct common_init_result common_init_from_params(common_params & params) {
         llama_kv_self_clear(lctx);
         llama_synchronize(lctx);
         llama_perf_context_reset(lctx);
+        llama_set_warmup(lctx, false);
     }
 
     iparams.model.reset(model);

+ 4 - 0
include/llama.h

@@ -945,6 +945,10 @@ extern "C" {
     // If set to true, the model will only attend to the past tokens
     LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
 
+    // Set whether the model is in warmup mode or not
+    // If true, all model tensors are activated during llama_decode() to load and cache their weights.
+    LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
+
     // Set abort callback
     LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
 

+ 12 - 1
src/llama-context.cpp

@@ -39,6 +39,7 @@ llama_context::llama_context(
     cparams.flash_attn       = params.flash_attn;
     cparams.no_perf          = params.no_perf;
     cparams.pooling_type     = params.pooling_type;
+    cparams.warmup           = false;
 
     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
     cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
@@ -948,6 +949,12 @@ void llama_context::set_causal_attn(bool value) {
     cparams.causal_attn = value;
 }
 
+void llama_context::set_warmup(bool value) {
+    LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
+
+    cparams.warmup = value;
+}
+
 void llama_context::set_adapter_lora(
             llama_adapter_lora * adapter,
             float scale) {
@@ -1594,7 +1601,7 @@ void llama_context::output_reorder() {
 //
 
 int32_t llama_context::graph_max_nodes() const {
-    return std::max<int32_t>(8192, 5*model.n_tensors());
+    return std::max<int32_t>(65536, 5*model.n_tensors());
 }
 
 ggml_cgraph * llama_context::graph_init() {
@@ -2372,6 +2379,10 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
     ctx->set_causal_attn(causal_attn);
 }
 
+void llama_set_warmup(llama_context * ctx, bool warmup) {
+    ctx->set_warmup(warmup);
+}
+
 void llama_synchronize(llama_context * ctx) {
     ctx->synchronize();
 }

+ 1 - 0
src/llama-context.h

@@ -64,6 +64,7 @@ struct llama_context {
 
     void set_embeddings (bool value);
     void set_causal_attn(bool value);
+    void set_warmup(bool value);
 
     void set_adapter_lora(
             llama_adapter_lora * adapter,

+ 1 - 0
src/llama-cparams.h

@@ -29,6 +29,7 @@ struct llama_cparams {
     bool offload_kqv;
     bool flash_attn;
     bool no_perf;
+    bool warmup;
 
     enum llama_pooling_type pooling_type;
 

+ 1 - 1
src/llama-graph.cpp

@@ -577,7 +577,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     n_embd_head_v    (hparams.n_embd_head_v),
     n_embd_v_gqa     (hparams.n_embd_v_gqa()),
     n_expert         (hparams.n_expert),
-    n_expert_used    (hparams.n_expert_used),
+    n_expert_used    (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
     freq_base        (cparams.rope_freq_base),
     freq_scale       (cparams.rope_freq_scale),
     ext_factor       (cparams.yarn_ext_factor),