Browse Source

kv-cache : optimize KQ mask construction (#18842)

* kv-cache : optimize KQ mask construction

* cont : add explanation + improve

* cont : fix
Georgi Gerganov 1 week ago
parent
commit
2fbde785bc
4 changed files with 239 additions and 98 deletions
  1. 0 36
      src/llama-hparams.cpp
  2. 38 1
      src/llama-hparams.h
  3. 201 59
      src/llama-kv-cache.cpp
  4. 0 2
      src/llama-kv-cache.h

+ 0 - 36
src/llama-hparams.cpp

@@ -200,42 +200,6 @@ uint32_t llama_hparams::n_layer_kv() const {
     return res;
 }
 
-bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
-    assert(p0 >= 0 && p1 >= 0);
-
-    switch (swa_type) {
-        case LLAMA_SWA_TYPE_NONE:
-            {
-            } break;
-        case LLAMA_SWA_TYPE_STANDARD:
-            {
-                if (p1 - p0 >= (int32_t) n_swa) {
-                    return true;
-                }
-            } break;
-        case LLAMA_SWA_TYPE_CHUNKED:
-            {
-                const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
-
-                if (p0 < pos_chunk_start) {
-                    return true;
-                }
-            } break;
-        case LLAMA_SWA_TYPE_SYMMETRIC:
-            {
-                const int32_t half_n_swa = (int32_t) n_swa / 2;
-                const int32_t pos_diff = p1 - p0;
-
-                // Mask if outside the symmetric window
-                if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
-                    return true;
-                }
-            } break;
-    }
-
-    return false;
-}
-
 bool llama_hparams::use_mrope() const {
     return rope_sections[0] > 0 && rope_sections[1] > 0;
 }

+ 38 - 1
src/llama-hparams.h

@@ -3,6 +3,7 @@
 #include "llama.h"
 
 #include <array>
+#include <cassert>
 
 // bump if necessary
 #define LLAMA_MAX_LAYERS  512
@@ -274,9 +275,45 @@ struct llama_hparams {
     uint32_t n_layer_kv() const;
 
     // note that this function uses different SWA parameters from those in the hparams
+    // note: inlined on purpose for performance reasons
     // TODO: think of a better place for this function
     // TODO: pack the SWA params in a struct?
-    static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
+    static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
+        assert(p0 >= 0 && p1 >= 0);
+
+        switch (swa_type) {
+            case LLAMA_SWA_TYPE_NONE:
+                {
+                } break;
+            case LLAMA_SWA_TYPE_STANDARD:
+                {
+                    if (p1 - p0 >= (int32_t) n_swa) {
+                        return true;
+                    }
+                } break;
+            case LLAMA_SWA_TYPE_CHUNKED:
+                {
+                    const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
+
+                    if (p0 < pos_chunk_start) {
+                        return true;
+                    }
+                } break;
+            case LLAMA_SWA_TYPE_SYMMETRIC:
+                {
+                    const int32_t half_n_swa = (int32_t) n_swa / 2;
+                    const int32_t pos_diff = p1 - p0;
+
+                    // Mask if outside the symmetric window
+                    if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
+                        return true;
+                    }
+                } break;
+        }
+
+        return false;
+    }
+
 
     bool use_mrope() const;
 };

+ 201 - 59
src/llama-kv-cache.cpp

@@ -852,7 +852,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
                         const llama_seq_id seq_id_cell = cells.seq_get(idx);
 
                         // SWA mask
-                        if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
+                        if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
                             can_use = true;
                         }
                     }
@@ -1237,90 +1237,236 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
     }
 }
 
-void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
-    const uint32_t n_tokens = ubatch->n_tokens;
+struct args_set_input_kq_mask {
+    const llama_hparams & hparams;
+    const llama_ubatch  * ubatch;
 
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    float * data = (float *) dst->data;
+    const std::vector<llama_kv_cells> & v_cells;
+    const std::vector<uint32_t>       & seq_to_stream;
 
-    const int64_t n_kv     = dst->ne[0];
-    const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
+    uint32_t       n_swa;
+    llama_swa_type swa_type;
 
-    GGML_ASSERT(n_tokens%n_stream == 0);
+    int64_t n_kv;
+    int64_t n_stream;
+    int64_t n_tps;
+};
 
-    // n_tps == n_tokens_per_stream
-    const int64_t n_tps = n_tokens/n_stream;
+template<bool causal, bool swa, bool is_2d, bool alibi>
+static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
+  //const auto & hparams = args.hparams;
+    const auto & ubatch  = args.ubatch;
 
-    std::fill(data, data + ggml_nelements(dst), -INFINITY);
-
-    // Use only the previous KV cells of the correct sequence for each token of the ubatch.
-    // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
-    // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
-    //   Causal mask:
-    //      xxx-------
-    //      xxxx------
-    //      xxxxx-----
-    //   Non-causal mask:
-    //      xxxxx-----
-    //      xxxxx-----
-    //      xxxxx-----
-    // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
-    // TODO: optimize this section
-    for (uint32_t h = 0; h < 1; ++h) {
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            for (uint32_t ii = 0; ii < n_tps; ++ii) {
-                const uint32_t i = s*n_tps + ii;
+    const auto & v_cells       = args.v_cells;
+    const auto & seq_to_stream = args.seq_to_stream;
 
-                const llama_seq_id seq_id = ubatch->seq_id[i][0];
+    const uint32_t       n_swa    = args.n_swa;
+    const llama_swa_type swa_type = args.swa_type;
 
-                const auto & cells = v_cells[seq_to_stream[seq_id]];
+    const int64_t n_kv     = args.n_kv;
+    const int64_t n_stream = args.n_stream;
+    const int64_t n_tps    = args.n_tps;
 
-                const llama_pos p1 = ubatch->pos[i];
+    // the min position in the batch for each sequence
+    llama_pos seq_pos_min[LLAMA_MAX_SEQ];
+    std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
 
-                // for M-RoPE
-                const bool is_2d = ubatch->is_pos_2d();
-                const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
-                const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens]   : 0;
+    for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
+        const llama_seq_id seq_id = ubatch->seq_id[i][0];
 
-                const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
+        seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
+    }
 
-                for (uint32_t j = 0; j < n_kv; ++j) {
-                    if (cells.is_empty(j)) {
-                        continue;
-                    }
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        // bookeeping of the KQ mask cells that could change for other tokens of the same sequence
+        std::unordered_map<llama_seq_id, uint32_t>              seq_srct;
+        std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
+
+        for (uint32_t ii = 0; ii < n_tps; ++ii) {
+            const uint32_t i = s*n_tps + ii;
+
+            const llama_seq_id seq_id = ubatch->seq_id[i][0];
+
+            const auto & cells = v_cells.at(seq_to_stream[seq_id]);
+
+                  llama_pos p0 = -1;
+            const llama_pos p1 = ubatch->pos[i];
+
+            // for M-RoPE
+            const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
+            const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens]   : 0;
+
+            const uint64_t idst = n_kv*i;
+
+            // for tokens of the same sequence, the mask is mostly the same, so we can reuse it
+            // the only cells that could change are the ones that are with similar positions as the
+            //   ones in the batch (i.e. due to causal masking, SWA, etc.)
+            // keep track of those cells and shortcut the loop to save time
+            // note: this optimization is not compatible with Alibi position encoding
+            // ref:  https://github.com/ggml-org/llama.cpp/pull/18842
+            bool prev = false;
 
-                    // mask the token if not the same sequence
-                    if (!cells.seq_has(j, seq_id)) {
-                        continue;
+            auto & idxs = seq_idxs[seq_id];
+
+            if (!alibi) {
+                if (seq_srct.find(seq_id) != seq_srct.end()) {
+                    const uint32_t srct = seq_srct[seq_id];
+
+                    const uint64_t idst_prev = n_kv*srct;
+
+                    std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
+
+                    prev = true;
+                } else {
+                    idxs.clear();
+                    idxs.reserve(ubatch->n_tokens + n_swa + 32);
+
+                    seq_srct[seq_id] = i;
+                }
+            }
+
+            for (uint32_t jj = 0; jj < n_kv; ++jj) {
+                uint32_t j = jj;
+
+                // we have an exiting mask for this sequence -> update just seq_idxs
+                if (!alibi) {
+                    if (prev) {
+                        if (jj >= idxs.size()) {
+                            break;
+                        }
+
+                        j = idxs[jj];
                     }
+                }
+
+                if (cells.is_empty(j)) {
+                    goto skip;
+                }
+
+                // mask the token if not the same sequence
+                if (!cells.seq_has(j, seq_id)) {
+                    goto skip;
+                }
+
+                p0 = cells.pos_get(j);
 
-                    const llama_pos p0 = cells.pos_get(j);
+                if (!alibi) {
+                    if (!prev) {
+                        // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
+                        if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
+                            idxs.push_back(j);
+                        }
+                    }
+                }
 
+                if (causal) {
                     // mask future tokens
-                    if (causal_attn && p0 > p1) {
-                        continue;
+                    if (p0 > p1) {
+                        goto skip;
                     }
 
                     // M-RoPE causal mask
-                    if (causal_attn && is_2d && p0 == p1) {
-                        const auto & p0_ext = cells.ext_get(j);
-                        if (p0_ext.is_2d_gt(p1_x, p1_y)) {
-                            continue;
+                    if (is_2d) {
+                        if (p0 == p1) {
+                            const auto & p0_ext = cells.ext_get(j);
+
+                            if (p0_ext.is_2d_gt(p1_x, p1_y)) {
+                                goto skip;
+                            }
                         }
                     }
+                }
 
-                    // apply SWA if any
-                    if (is_masked_swa(p0, p1)) {
-                        continue;
+                // apply SWA if any
+                if (swa) {
+                    if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
+                        goto skip;
                     }
+                }
 
-                    data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
+                if (alibi) {
+                    data[idst + j] = -std::abs(p0 - p1);
+                } else {
+                    data[idst + j] = 0.0f;
                 }
+
+                continue;
+skip:
+                data[idst + j] = -INFINITY;
             }
         }
     }
 }
 
+template<bool causal, bool swa, bool is_2d>
+static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
+    const bool alibi = args.hparams.use_alibi;
+    if (alibi) {
+        set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
+    } else {
+        set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
+    }
+}
+
+template<bool causal, bool swa>
+static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
+    const bool is_2d = args.ubatch->is_pos_2d();
+    if (is_2d) {
+        set_input_kq_mask_impl<causal, swa, true> (args, data);
+    } else {
+        set_input_kq_mask_impl<causal, swa, false>(args, data);
+    }
+}
+
+template<bool causal>
+static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
+    const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
+    if (swa) {
+        set_input_kq_mask_impl<causal, true> (args, data);
+    } else {
+        set_input_kq_mask_impl<causal, false>(args, data);
+    }
+}
+
+void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+    const uint32_t n_tokens = ubatch->n_tokens;
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    float * data = (float *) dst->data;
+
+    const int64_t n_kv     = dst->ne[0];
+    const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
+
+    GGML_ASSERT(n_tokens%n_stream == 0);
+
+    // n_tps == n_tokens_per_stream
+    const int64_t n_tps = n_tokens/n_stream;
+
+    //const int64_t t_start = ggml_time_us();
+
+    const args_set_input_kq_mask args = {
+        /*.hparams          =*/ hparams,
+        /*.ubatch           =*/ ubatch,
+        /*.v_cells          =*/ v_cells,
+        /*.seq_to_stream    =*/ seq_to_stream,
+        /*.n_swa            =*/ n_swa,
+        /*.swa_type         =*/ swa_type,
+        /*.n_kv             =*/ n_kv,
+        /*.n_stream         =*/ n_stream,
+        /*.n_tps            =*/ n_tps,
+    };
+
+    if (causal_attn) {
+        set_input_kq_mask_impl<true> (args, data);
+    } else {
+        set_input_kq_mask_impl<false>(args, data);
+    }
+
+    //const int64_t t_end = ggml_time_us();
+
+    //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
+}
+
 void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     const int64_t n_tokens = ubatch->n_tokens;
 
@@ -1483,10 +1629,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
     return gf;
 }
 
-bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
-    return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
-}
-
 void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
     GGML_UNUSED(flags);
 

+ 0 - 2
src/llama-kv-cache.h

@@ -257,8 +257,6 @@ private:
     size_t size_k_bytes() const;
     size_t size_v_bytes() const;
 
-    bool is_masked_swa(llama_pos p0, llama_pos p1) const;
-
     ggml_tensor * build_rope_shift(
             const llama_cparams & cparams,
                    ggml_context * ctx,