Просмотр исходного кода

kv-cache : better estimate of n_kv for multi-sequence batches (#15610)

ggml-ci
Georgi Gerganov 4 месяцев назад
Родитель
Сommit
1bded5a3b3
2 измененных файлов с 15 добавлено и 16 удалено
  1. 12 13
      src/llama-kv-cache.cpp
  2. 3 3
      src/llama-kv-cache.h

+ 12 - 13
src/llama-kv-cache.cpp

@@ -771,8 +771,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
             GGML_ASSERT(ubatch.seq_id  [s*n_tokens][0] == seq_id);
         }
 
-        res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
-        res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
+        res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
+        res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
 
         res.strm[s] = seq_to_stream[seq_id];
         res.idxs[s].reserve(n_tokens);
@@ -964,11 +964,11 @@ bool llama_kv_cache::get_has_shift() const {
     return result;
 }
 
-uint32_t llama_kv_cache::get_n_kv() const {
+uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
     uint32_t result = 0;
 
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        const auto & cells = v_cells[s];
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        const auto & cells = v_cells[sinfo.strm[s]];
 
         result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
     }
@@ -1017,18 +1017,18 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
         // note: v->nb[1] <= v->nb[2]
         return ggml_view_4d(ctx, v,
                 hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
-                ggml_row_size(v->type, hparams.n_embd_head_v),            // v->nb[1]
-                ggml_row_size(v->type, n_embd_v_gqa),         // v->nb[2]
-                ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
+                ggml_row_size(v->type, hparams.n_embd_head_v),          // v->nb[1]
+                ggml_row_size(v->type, n_embd_v_gqa),                   // v->nb[2]
+                ggml_row_size(v->type, n_embd_v_gqa*kv_size),           // v->nb[3]
                 ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
     }
 
     // note: v->nb[1] > v->nb[2]
     return ggml_view_4d(ctx, v,
             n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
-            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),    // v->nb[1]
-            ggml_row_size(v->type, kv_size),                          // v->nb[2]
-            ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
+            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),  // v->nb[1]
+            ggml_row_size(v->type, kv_size),                        // v->nb[2]
+            ggml_row_size(v->type, kv_size*n_embd_v_gqa),           // v->nb[3]
             ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
 }
 
@@ -1985,8 +1985,7 @@ bool llama_kv_cache_context::apply() {
     }
 
     kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
-
-    n_kv = kv->get_n_kv();
+    n_kv = kv->get_n_kv(sinfos[i_cur]);
 
     return true;
 }

+ 3 - 3
src/llama-kv-cache.h

@@ -38,8 +38,8 @@ public:
         using idx_vec_t = std::vector<uint32_t>;
 
         // number of streams: ns = s1 - s0 + 1
-        llama_seq_id s0;
-        llama_seq_id s1;
+        uint32_t s0;
+        uint32_t s1;
 
         std::vector<llama_seq_id> strm; // [ns]
         std::vector<idx_vec_t>    idxs; // [ns]
@@ -139,7 +139,7 @@ public:
     // graph_build API
     //
 
-    uint32_t get_n_kv() const;
+    uint32_t get_n_kv(const slot_info & sinfo) const;
 
     // TODO: temporary
     bool get_supports_set_rows() const;