Просмотр исходного кода

graph : normalize Q, K, V shapes + sync cross attention (#12449)

* graph : normalize Q, K, V shapes and add comments

ggml-ci

* context : synchronize before getting cross attention data

* model : fix command-r attention norm check
Georgi Gerganov 10 месяцев назад
Родитель
Сommit
75422e8bc4
4 измененных файлов с 299 добавлено и 197 удалено
  1. 2 0
      src/llama-context.cpp
  2. 1 1
      src/llama-graph.cpp
  3. 12 12
      src/llama-graph.h
  4. 284 184
      src/llama-model.cpp

+ 2 - 0
src/llama-context.cpp

@@ -1143,6 +1143,8 @@ int llama_context::encode(llama_batch & inp_batch) {
     if (model.arch == LLM_ARCH_T5 && t_embd) {
         //cross.t_embd = t_embd;
 
+        synchronize();
+
         cross.n_embd = t_embd->ne[0];
         cross.n_enc  = t_embd->ne[1];
         cross.v_embd.resize(cross.n_embd*cross.n_enc);

+ 1 - 1
src/llama-graph.cpp

@@ -1378,7 +1378,7 @@ ggml_tensor * llm_graph_context::build_attn(
         // note: storing RoPE-ed version of K in the KV cache
         ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
 
-        assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
+        v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
 
         ggml_tensor * v_cache_view = nullptr;
 

+ 12 - 12
src/llama-graph.h

@@ -487,9 +487,9 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn_mha(
              ggml_cgraph * gf,
-             ggml_tensor * q,
-             ggml_tensor * k,
-             ggml_tensor * v,
+             ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
+             ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
+             ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
              ggml_tensor * kq_b,
              ggml_tensor * kq_mask,
                     bool   v_trans,
@@ -502,9 +502,9 @@ struct llm_graph_context {
             ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
-            ggml_tensor * q_cur,
-            ggml_tensor * k_cur,
-            ggml_tensor * v_cur,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
                   float   kq_scale,
                     int   il) const;
@@ -516,9 +516,9 @@ struct llm_graph_context {
             ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
-            ggml_tensor * q_cur,
-            ggml_tensor * k_cur,
-            ggml_tensor * v_cur,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
                   float   kq_scale,
                     int   il) const;
@@ -530,9 +530,9 @@ struct llm_graph_context {
             ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
-            ggml_tensor * q_cur,
-            ggml_tensor * k_cur,
-            ggml_tensor * v_cur,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
                   float   kq_scale,
                     int   il) const;

Разница между файлами не показана из-за своего большого размера
+ 284 - 184
src/llama-model.cpp


Некоторые файлы не были показаны из-за большого количества измененных файлов