|
|
@@ -248,7 +248,10 @@ llama_context::llama_context(
|
|
|
|
|
|
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
|
|
|
|
|
- const size_t max_nodes = this->graph_max_nodes();
|
|
|
+ const uint32_t n_seqs = cparams.n_seq_max;
|
|
|
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
|
+
|
|
|
+ const size_t max_nodes = this->graph_max_nodes(n_tokens);
|
|
|
|
|
|
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
|
|
|
|
|
@@ -300,9 +303,6 @@ llama_context::llama_context(
|
|
|
|
|
|
cross.v_embd.clear();
|
|
|
|
|
|
- const uint32_t n_seqs = cparams.n_seq_max;
|
|
|
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
|
-
|
|
|
// avoid reserving graphs with zero outputs - assume one output per sequence
|
|
|
n_outputs = n_seqs;
|
|
|
|
|
|
@@ -1386,9 +1386,9 @@ void llama_context::output_reorder() {
|
|
|
// graph
|
|
|
//
|
|
|
|
|
|
-uint32_t llama_context::graph_max_nodes() const {
|
|
|
+uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
|
|
|
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
|
|
- return std::max<uint32_t>(8192u, 32u*model.n_tensors());
|
|
|
+ return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
|
|
|
}
|
|
|
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
|
|
}
|