Преглед изворни кода

llama : fix fattn reserve call n_seqs parameter (#15699)

ggml-ci
Diego Devesa пре 4 месеци
родитељ
комит
274966226f
1 измењених фајлова са 7 додато и 6 уклоњено
  1. 7 6
      src/llama-context.cpp

+ 7 - 6
src/llama-context.cpp

@@ -281,9 +281,15 @@ llama_context::llama_context(
         }
         }
 
 
         cross.v_embd.clear();
         cross.v_embd.clear();
+
+        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
+        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+
+        LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
+
         // resolve automatic Flash Attention use
         // resolve automatic Flash Attention use
         if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
         if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
-            auto * gf = graph_reserve(1, 1, 0, mctx.get(), true);
+            auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
             if (!gf) {
             if (!gf) {
                 throw std::runtime_error("failed to split graph for Flash Attention check");
                 throw std::runtime_error("failed to split graph for Flash Attention check");
             }
             }
@@ -324,11 +330,6 @@ llama_context::llama_context(
         }
         }
 
 
         // reserve worst-case graph
         // reserve worst-case graph
-        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
-        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-
-        LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
-
         int n_splits_pp = -1;
         int n_splits_pp = -1;
         int n_nodes_pp  = -1;
         int n_nodes_pp  = -1;