Преглед изворни кода

graph : reuse SSM graphs (#16490)

* graph : reuse hybrid graphs

* graph : reuse recurrent graphs

* graph : fix reuse check for recurrent inputs

* memory : move the recurrent state into the memory context

* Revert "memory : move the recurrent state into the memory context"

This reverts commit 00f115fe810815d4a22a6dee0acc346131e970e1.

* cont : fix build
Georgi Gerganov пре 1 месец
родитељ
комит
c560316440
3 измењених фајлова са 78 додато и 7 уклоњено
  1. 63 4
      src/llama-graph.cpp
  2. 14 2
      src/llama-graph.h
  3. 1 1
      src/llama-memory-hybrid.cpp

+ 63 - 4
src/llama-graph.cpp

@@ -254,6 +254,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
+
+    this->mctx = mctx;
+
+    bool res = true;
+
+    res &= s_copy->ne[0] == mctx->get_n_rs();
+
+    res &= s_copy_main->ne[0]  == params.ubatch.n_seqs;
+    res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
+
+    res &= head == mctx->get_head();
+    res &= rs_z == mctx->get_rs_z();
+
+    return res;
+}
+
 void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
     GGML_UNUSED(ubatch);
 
@@ -461,8 +479,46 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
 }
 
 void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
-    inp_attn->set_input(ubatch);
-    inp_rs->set_input(ubatch);
+    mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
+    mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
+
+    mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
+
+    const int64_t n_rs = mctx->get_recr()->get_n_rs();
+
+    if (inp_rs->s_copy) {
+        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
+        int32_t * data = (int32_t *) inp_rs->s_copy->data;
+
+        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+        for (uint32_t i = 0; i < n_rs; ++i) {
+            data[i] = mctx->get_recr()->s_copy(i);
+        }
+    }
+}
+
+bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
+
+    this->mctx = mctx;
+
+    bool res = true;
+
+    res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
+  //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+    res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
+    res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
+
+    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
+
+    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
+    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
+
+    res &= inp_rs->head == mctx->get_recr()->get_head();
+    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
+
+    return res;
 }
 
 //
@@ -1850,6 +1906,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
     inp->s_copy_main  = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
     inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
 
+    inp->head = mctx_cur->get_head();
+    inp->rs_z = mctx_cur->get_rs_z();
+
     return inp;
 }
 
@@ -1918,10 +1977,10 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
 llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
     const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
 
-    auto inp_rs   = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
+    auto inp_rs   = build_rs_inp_impl     (ctx0, ubatch, mctx_cur->get_recr());
     auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
 
-    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
+    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
 
     return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
 }

+ 14 - 2
src/llama-graph.h

@@ -225,6 +225,8 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * s_copy;  // I32 [n_rs]
 
     // views of s_copy, computed once per graph
@@ -233,6 +235,10 @@ public:
     ggml_tensor * s_copy_extra;  // I32 [n_rs - n_seqs]
 
     const llama_memory_recurrent_context * mctx;
+
+    // used in view offsets, need to match for valid graph reuse
+    uint32_t head;
+    int32_t rs_z;
 };
 
 class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -365,22 +371,28 @@ public:
 class llm_graph_input_mem_hybrid : public llm_graph_input_i {
 public:
     llm_graph_input_mem_hybrid(
+            const llama_cparams & cparams,
             std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
-            std::unique_ptr<llm_graph_input_rs>              inp_rs,
-            const llama_memory_hybrid_context *              mctx) :
+            std::unique_ptr<llm_graph_input_rs>      inp_rs,
+            const llama_memory_hybrid_context *      mctx) :
         inp_attn(std::move(inp_attn)),
         inp_rs(std::move(inp_rs)),
+        cparams(cparams),
         mctx(mctx) { }
     virtual ~llm_graph_input_mem_hybrid() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
     std::unique_ptr<llm_graph_input_rs>      inp_rs;
 
     llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
     llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
 
+    const llama_cparams cparams;
+
     const llama_memory_hybrid_context * mctx;
 };
 

+ 1 - 1
src/llama-memory-hybrid.cpp

@@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
     ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
-    ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(),                        this->ubatches)),
+    ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
     status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }