Просмотр исходного кода

graph : refactor context to not pass gf explicitly (#14629)

ggml-ci
Georgi Gerganov 6 месяцев назад
Родитель
Сommit
8f974bc1e9
5 измененных файлов с 151 добавлено и 183 удалено
  1. 2 2
      src/llama-context.cpp
  2. 2 2
      src/llama-context.h
  3. 15 19
      src/llama-graph.cpp
  4. 20 44
      src/llama-graph.h
  5. 112 116
      src/llama-model.cpp

+ 2 - 2
src/llama-context.cpp

@@ -694,7 +694,7 @@ bool llama_context::apply_adapter_cvec(
     return cvec.apply(model, data, len, n_embd, il_start, il_end);
 }
 
-llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
+llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
     if (mctx && !mctx->apply()) {
         LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
         ret = GGML_STATUS_FAILED;
@@ -1363,7 +1363,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
 }
 
 llm_graph_params llama_context::graph_params(
-                      llm_graph_result_i * res,
+                        llm_graph_result * res,
                       const llama_ubatch & ubatch,
             const llama_memory_context_i * mctx,
             llm_graph_type   gtype) const {

+ 2 - 2
src/llama-context.h

@@ -94,7 +94,7 @@ struct llama_context {
     // if memory_context is provided, it will be applied first to the context's memory
     // ret contains the status of the graph computation
     // returns nullptr only if ret != GGML_STATUS_SUCCESS
-    llm_graph_result_i * process_ubatch(
+    llm_graph_result * process_ubatch(
                 const llama_ubatch & ubatch,
                     llm_graph_type   gtype,
             llama_memory_context_i * mctx,
@@ -199,7 +199,7 @@ public:
 
 private:
     llm_graph_params graph_params(
-                      llm_graph_result_i * res,
+                        llm_graph_result * res,
                       const llama_ubatch & ubatch,
             const llama_memory_context_i * mctx,
                           llm_graph_type   gtype) const;

+ 15 - 19
src/llama-graph.cpp

@@ -486,6 +486,10 @@ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
     return inputs.back().get();
 }
 
+void llm_graph_result::set_params(const llm_graph_params & params) {
+    this->params = params;
+}
+
 //
 // llm_graph_context
 //
@@ -527,9 +531,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     mctx             (params.mctx),
     cross            (params.cross),
     cb_func          (params.cb),
-    res              (static_cast<llm_graph_result *>(params.res)),
-    ctx0             (res->get_ctx()) {
-        res->params = params;
+    res              (params.res),
+    ctx0             (res->get_ctx()),
+    gf               (res->get_gf()) {
+        res->set_params(params);
     }
 
 void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -1119,7 +1124,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
 }
 
 ggml_tensor * llm_graph_context::build_attn_mha(
-         ggml_cgraph * gf,
          ggml_tensor * q,
          ggml_tensor * k,
          ggml_tensor * v,
@@ -1253,7 +1257,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
 
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_no_cache * inp,
-        ggml_cgraph * gf,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1281,7 +1284,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = k_cur;
     ggml_tensor * v = v_cur;
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1337,7 +1340,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
 
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_kv_unified * inp,
-        ggml_cgraph * gf,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1370,7 +1372,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = mctx_cur->get_k(ctx0, il);
     ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1390,7 +1392,6 @@ ggml_tensor * llm_graph_context::build_attn(
 
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_kv_unified_iswa * inp,
-        ggml_cgraph * gf,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1437,7 +1438,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = mctx_cur->get_k(ctx0, il);
     ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1470,7 +1471,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
 
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_cross * inp,
-        ggml_cgraph * gf,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1492,7 +1492,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = k_cur;
     ggml_tensor * v = v_cur;
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1550,7 +1550,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
 }
 
 ggml_tensor * llm_graph_context::build_rs(
-        ggml_cgraph * gf,
         ggml_tensor * s,
         ggml_tensor * state_copy,
             int32_t   state_size,
@@ -1608,21 +1607,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
 
 ggml_tensor * llm_graph_context::build_rs(
         llm_graph_input_rs * inp,
-        ggml_cgraph * gf,
         ggml_tensor * s,
             int32_t   state_size,
             int32_t   n_seqs,
         const llm_graph_get_rows_fn & get_state_rows) const {
     const auto * kv_state = inp->mctx;
 
-    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
+    return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
 }
 
 ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
     llm_graph_input_rs * inp,
-           ggml_cgraph * gf,
     const llama_ubatch & ubatch,
-                 int   il) const {
+                   int   il) const {
     const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
     const auto token_shift_count = hparams.token_shift_count;
@@ -1632,7 +1629,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
     ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
 
     ggml_tensor * token_shift = build_rs(
-            inp, gf, token_shift_all,
+            inp, token_shift_all,
             hparams.n_embd_r(), n_seqs);
 
     token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1672,7 +1669,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
 }
 
 void llm_graph_context::build_pooling(
-        ggml_cgraph * gf,
         ggml_tensor * cls,
         ggml_tensor * cls_b,
         ggml_tensor * cls_out,

+ 20 - 44
src/llama-graph.h

@@ -371,31 +371,11 @@ public:
 // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
 //   these are used by the llama_context to extact the relevant data, based on the compute parameters
 
-// TODO: this interface seems redundant - remove it
-class llm_graph_result_i {
-public:
-    virtual ~llm_graph_result_i() = default;
-
-    virtual ggml_tensor * get_tokens()      const = 0;
-    virtual ggml_tensor * get_logits()      const = 0;
-    virtual ggml_tensor * get_embd()        const = 0;
-    virtual ggml_tensor * get_embd_pooled() const = 0;
-
-    virtual ggml_cgraph  * get_gf()  = 0;
-    virtual ggml_context * get_ctx() = 0;
-
-    virtual void reset() = 0;
-
-    virtual void set_inputs(const llama_ubatch * ubatch) = 0;
-
-    virtual bool can_reuse(const llm_graph_params & params) = 0;
-};
-
-using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
-
 // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
 using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
 
+class llm_graph_result;
+
 struct llm_graph_params {
     llm_arch arch = LLM_ARCH_UNKNOWN;
 
@@ -418,8 +398,7 @@ struct llm_graph_params {
 
     llm_graph_cb cb;
 
-    // TODO: temporary
-    llm_graph_result_i * res;
+    llm_graph_result * res;
 
     // return true if the "other" params would result in a graph with the same topology as with the current params
     //   having the same topology allows us to reuse the graph in some cases
@@ -464,35 +443,37 @@ struct llm_graph_params {
     }
 };
 
-class llm_graph_result : public llm_graph_result_i {
+class llm_graph_result {
 public:
     llm_graph_result(int64_t max_nodes);
 
     virtual ~llm_graph_result() = default;
 
-    ggml_tensor * get_tokens()      const override { return t_tokens; }
-    ggml_tensor * get_logits()      const override { return t_logits; }
-    ggml_tensor * get_embd()        const override { return t_embd; }
-    ggml_tensor * get_embd_pooled() const override { return t_embd_pooled; }
+    ggml_tensor * get_tokens()      const { return t_tokens; }
+    ggml_tensor * get_logits()      const { return t_logits; }
+    ggml_tensor * get_embd()        const { return t_embd; }
+    ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
 
-    ggml_cgraph  * get_gf()  override { return gf; }
-    ggml_context * get_ctx() override { return ctx_compute.get(); }
+    ggml_cgraph  * get_gf()  const { return gf; }
+    ggml_context * get_ctx() const { return ctx_compute.get(); }
 
     int64_t get_max_nodes() const;
 
-    void reset() override;
+    void reset();
 
-    void set_inputs(const llama_ubatch * ubatch) override;
+    void set_inputs(const llama_ubatch * ubatch);
 
     // try to update the existing graph result using the new graph parameters in order to reuse it
     // this can only be done if we determine that the resulting graph using the new graph parameters
     //   would be identical to the existing graph. in that case, we simply have to update the memory
     //   contexts of the input tensors of the graph and we can reuse it for another computation
     // return true if the graph was updated and can be reused
-    bool can_reuse(const llm_graph_params & params) override;
+    bool can_reuse(const llm_graph_params & params);
 
     llm_graph_input_i * add_input(llm_graph_input_ptr input);
 
+    void set_params(const llm_graph_params & params);
+
     // important graph nodes
     ggml_tensor * t_tokens      = nullptr;
     ggml_tensor * t_logits      = nullptr;
@@ -510,6 +491,7 @@ public:
 
     int64_t max_nodes;
 
+private:
     // keep a copy of the previous graph parameters
     // we will use this to determine whether the graph can be reused by comparing them with the new parameters
     // note: these are updated after constructing the new graph
@@ -519,6 +501,8 @@ public:
     int debug = 0;
 };
 
+using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
+
 //
 // llm_graph_context
 //
@@ -576,6 +560,7 @@ struct llm_graph_context {
     llm_graph_result * res;
 
     ggml_context * ctx0 = nullptr;
+    ggml_cgraph  * gf   = nullptr;
 
     llm_graph_context(const llm_graph_params & params);
     virtual ~llm_graph_context() = default;
@@ -661,7 +646,6 @@ struct llm_graph_context {
     //
 
     ggml_tensor * build_attn_mha(
-             ggml_cgraph * gf,
              ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
              ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
              ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -674,7 +658,6 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn(
             llm_graph_input_attn_no_cache * inp,
-            ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -689,7 +672,6 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn(
             llm_graph_input_attn_kv_unified * inp,
-            ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -705,7 +687,6 @@ struct llm_graph_context {
     // note: if k_cur or v_cur are not provided, they will not be stored in the memory
     ggml_tensor * build_attn(
             llm_graph_input_attn_kv_unified_iswa * inp,
-            ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -720,7 +701,6 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn(
             llm_graph_input_attn_cross * inp,
-            ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -742,7 +722,6 @@ struct llm_graph_context {
     //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
     //         `llama_memory_recurrent`
     ggml_tensor * build_rs(
-            ggml_cgraph * gf,
             ggml_tensor * s,
             ggml_tensor * state_copy,
                 int32_t   state_size,
@@ -757,7 +736,6 @@ struct llm_graph_context {
 
     ggml_tensor * build_rs(
             llm_graph_input_rs * inp,
-            ggml_cgraph * gf,
             ggml_tensor * s,
                 int32_t   state_size,
                 int32_t   n_seqs,
@@ -765,9 +743,8 @@ struct llm_graph_context {
 
     ggml_tensor * build_rwkv_token_shift_load(
         llm_graph_input_rs * inp,
-               ggml_cgraph * gf,
         const llama_ubatch & ubatch,
-                     int   il) const;
+                       int   il) const;
 
     ggml_tensor * build_rwkv_token_shift_store(
              ggml_tensor * token_shift,
@@ -784,7 +761,6 @@ struct llm_graph_context {
     //
 
     void build_pooling(
-            ggml_cgraph * gf,
             ggml_tensor * cls,
             ggml_tensor * cls_b,
             ggml_tensor * cls_out,

Разница между файлами не показана из-за своего большого размера
+ 112 - 116
src/llama-model.cpp


Некоторые файлы не были показаны из-за большого количества измененных файлов