Parcourir la source

Make graph_max_nodes vary by ubatch size (#17794)

* Make graph_max_nodes vary by ubatch size for models where chunking might explode the graph

* Update src/llama-context.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Add missing const

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Piotr Wilkin (ilintar) il y a 1 mois
Parent
commit
e4e9c4329c
2 fichiers modifiés avec 7 ajouts et 7 suppressions
  1. 6 6
      src/llama-context.cpp
  2. 1 1
      src/llama-context.h

+ 6 - 6
src/llama-context.cpp

@@ -248,7 +248,10 @@ llama_context::llama_context(
 
         LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
 
-        const size_t max_nodes = this->graph_max_nodes();
+        const uint32_t n_seqs = cparams.n_seq_max;
+        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+
+        const size_t max_nodes = this->graph_max_nodes(n_tokens);
 
         LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
 
@@ -300,9 +303,6 @@ llama_context::llama_context(
 
         cross.v_embd.clear();
 
-        const uint32_t n_seqs = cparams.n_seq_max;
-        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-
         // avoid reserving graphs with zero outputs - assume one output per sequence
         n_outputs = n_seqs;
 
@@ -1386,9 +1386,9 @@ void llama_context::output_reorder() {
 // graph
 //
 
-uint32_t llama_context::graph_max_nodes() const {
+uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
     if (model.arch == LLM_ARCH_QWEN3NEXT) {
-        return std::max<uint32_t>(8192u, 32u*model.n_tensors());
+        return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
     }
     return std::max<uint32_t>(1024u, 8u*model.n_tensors());
 }

+ 1 - 1
src/llama-context.h

@@ -197,7 +197,7 @@ private:
     //
 
 public:
-    uint32_t graph_max_nodes() const;
+    uint32_t graph_max_nodes(uint32_t n_tokens) const;
 
     // can reuse the llm_graph_result instance of the context (for example to update a memory module)
     llm_graph_result * get_gf_res_reserve() const;