فهرست منبع

Cleanup ggml_delta_net

Piotr Wilkin 3 ماه پیش
والد
کامیت
2b0673c315
1فایلهای تغییر یافته به همراه33 افزوده شده و 118 حذف شده
  1. 33 118
      ggml/src/ggml-delta.c

+ 33 - 118
ggml/src/ggml-delta.c

@@ -52,10 +52,8 @@ struct ggml_tensor * ggml_delta_net(
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[3] == n_tokens);
        
     GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
-    
-    struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
-    report_tensor_size("beta_sigmoid", beta_sigmoid);
-    
+       
+    // Merge q, k, v into qkv
     struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q, k, 1);
     report_tensor_size("mixed_qkv_qk", mixed_qkv);
     mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
@@ -68,6 +66,7 @@ struct ggml_tensor * ggml_delta_net(
     struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, conv_weight->ne[0] - 1, 0, 0, 0);
     report_tensor_size("mixed_qkv_padded", mixed_qkv_padded);
 
+    // Apply convolution
     struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
     report_tensor_size("conv_out", conv_out);
 
@@ -85,68 +84,36 @@ struct ggml_tensor * ggml_delta_net(
     conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
     report_tensor_size("conv_out_transposed", conv_out);
 
-    struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out,
-                                               S_k,                  // ne0
-                                               H_k,                  // ne1
-                                               conv_out->ne[1],      // ne2 = sequence length (1)
-                                               conv_out->ne[2],      // ne3 = batch (1)
-                                               H_k * sizeof(float),  // nb1 = stride along H_k
-                                               conv_out->nb[1],      // nb2 = stride along sequence dim
-                                               conv_out->nb[2],      // nb3 = stride along batch dim
-                                               0                     // offset in bytes
-    );
+    // Beta sigmoid
+    struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
+    report_tensor_size("beta_sigmoid", beta_sigmoid);
+
+    // Gate calculations are done elsewhere in llama-model.cpp
+
+    // Re-split the qkv tensors
+    struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out, S_k, H_k, conv_out->ne[1], conv_out->ne[2], 
+                                               H_k * sizeof(float), conv_out->nb[1], conv_out->nb[2], 0);
     report_tensor_size("q_conv_view", q_conv);
 
-    // k projection view
-    struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out,
-                                               S_k,                       // ne0
-                                               H_k,                       // ne1
-                                               conv_out->ne[1],           // ne2
-                                               conv_out->ne[2],           // ne3
-                                               H_k * sizeof(float),       // nb1
-                                               conv_out->nb[1],           // nb2
-                                               conv_out->nb[2],           // nb3
-                                               S_k * H_k * sizeof(q->type)  // offset = skip q_out
-    );
+    struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out, S_k, H_k, conv_out->ne[1], conv_out->ne[2],
+                                               H_k * sizeof(float), conv_out->nb[1], conv_out->nb[2], S_k * H_k * sizeof(q->type));
     report_tensor_size("k_conv_view", k_conv);
 
-    // v projection view
-    struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out,
-                                               S_v,                             // ne0
-                                               H_v,                             // ne1
-                                               conv_out->ne[1],                 // ne2
-                                               conv_out->ne[2],                 // ne3
-                                               H_v * sizeof(float),             // nb1
-                                               conv_out->nb[1],                 // nb2
-                                               conv_out->nb[2],                 // nb3
-                                               (2 * S_k * H_k) * sizeof(q->type)// offset = skip q_out + k_out
-    );
+    struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out, S_v, H_v, conv_out->ne[1], conv_out->ne[2], H_v * sizeof(float),
+                                               conv_out->nb[1], conv_out->nb[2], (2 * S_k * H_k) * sizeof(q->type));
     report_tensor_size("v_conv_view", v_conv);
 
-    q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3);
-    report_tensor_size("q_conv_permuted", q_conv);
-    k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
-    report_tensor_size("k_conv_permuted", k_conv);
-    v_conv = ggml_permute(ctx, v_conv, 0, 2, 1, 3);
-    report_tensor_size("v_conv_permuted", v_conv);
-
-    q_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, q_conv), S_k * H_k, batch_size, n_tokens);
-    report_tensor_size("q_conv_reshaped", q_conv);
-    k_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, k_conv), S_k * H_k, batch_size, n_tokens);
-    report_tensor_size("k_conv_reshaped", k_conv);
-    v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, batch_size, n_tokens);
-    report_tensor_size("v_conv_reshaped", v_conv);
-    
     struct ggml_tensor * q_broadcast = q_conv;
     struct ggml_tensor * k_broadcast = k_conv;
     
+    // if head keys and value keys are different, repeat to force tensors into matching shapes
     if (H_k != H_v) {
         GGML_ASSERT(H_v % H_k == 0);
         int64_t repeat_factor = H_v / H_k;
         
-        q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, batch_size, H_k, n_tokens);
+        q_broadcast = ggml_cont_4d(ctx, q_conv, S_k, batch_size, H_k, n_tokens);
         report_tensor_size("q_broadcast_reshape1", q_broadcast);
-        k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, batch_size, H_k, n_tokens);
+        k_broadcast = ggml_cont_4d(ctx, k_conv, S_k, batch_size, H_k, n_tokens);
         report_tensor_size("k_broadcast_reshape1", k_broadcast);
         
         q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
@@ -160,24 +127,14 @@ struct ggml_tensor * ggml_delta_net(
         report_tensor_size("k_broadcast_reshape2", k_broadcast);
     }
 
-    struct ggml_tensor * v_reshape = ggml_reshape_4d(ctx, v_conv, S_v, H_v, n_tokens, batch_size);
+    struct ggml_tensor * v_reshape = ggml_cont_4d(ctx, v_conv, S_v, H_v, n_tokens, batch_size);
     report_tensor_size("v_reshape", v_reshape);
-    struct ggml_tensor * v_broadcast = ggml_repeat_4d(ctx, v_reshape, S_v, H_v, n_tokens, batch_size);
-    report_tensor_size("v_broadcast", v_broadcast);
-    struct ggml_tensor * g_reshape = g;
-    report_tensor_size("g_reshape", g_reshape);
-    q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
-    report_tensor_size("q_broadcast_final", q_broadcast);
-    k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, H_v, n_tokens, batch_size);
-    report_tensor_size("k_broadcast_final", k_broadcast);
-    struct ggml_tensor * beta_reshape = ggml_reshape_4d(ctx, beta_sigmoid, 1, H_v, n_tokens, batch_size);
-    report_tensor_size("beta_reshape", beta_reshape);
-    struct ggml_tensor * beta_broadcast = ggml_repeat_4d(ctx, beta_reshape, 1, H_v, n_tokens, batch_size);
+    struct ggml_tensor * beta_broadcast = ggml_cont_4d(ctx, beta, 1, H_v, n_tokens, batch_size);
     report_tensor_size("beta_broadcast", beta_broadcast);
     struct ggml_tensor * state_broadcast = ggml_cont(ctx, state);
     report_tensor_size("state_broadcast", state_broadcast);
     
-    return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_broadcast, g_reshape, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
+    return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_reshape, g, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
 }
 
 struct ggml_tensor * ggml_delta_net_op(
@@ -212,9 +169,10 @@ struct ggml_tensor * ggml_delta_net_op(
     const int64_t batch_size = q->ne[3];  
     
     const int64_t S_v = v->ne[0];  
-    const int64_t H_v = v->ne[1];  
+    const int64_t H_v = v->ne[1];
 
     GGML_LOG_INFO("S_k = %ld, S_v = %ld, H_k = %ld, H_v = %ld\n", S_k, S_v, H_k, H_v);
+    GGML_ASSERT(H_k == H_v); // we broadcasted the tensors in the main function to guarantee this
     
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_tokens && k->ne[3] == batch_size);
     GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == batch_size);
@@ -289,71 +247,28 @@ struct ggml_tensor * ggml_delta_net_op(
         struct ggml_tensor * state_t = state_2d;
         report_tensor_size("state_t", state_t);
         
-        struct ggml_tensor * state_t_transposed = ggml_transpose(ctx, state_t);
+        struct ggml_tensor * state_t_transposed = ggml_cont(ctx, ggml_transpose(ctx, state_t));
         report_tensor_size("state_t_transposed", state_t_transposed);
-        
+       
         struct ggml_tensor * k_t_final_reshaped = ggml_reshape_4d(ctx, k_t_final, H_v, S_k, batch_size, 1);
         report_tensor_size("k_t_final_reshaped", k_t_final_reshaped);
         
-        struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, k_t_final_reshaped, state_t_transposed);
+        struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, state_t_transposed, k_t_final_reshaped);
         report_tensor_size("kv_mem", kv_mem);
                 
         struct ggml_tensor * v_t_final = v_t_reshaped;
         struct ggml_tensor * beta_t_final = beta_t_reshaped;
-        
-        if (H_k != H_v) {
-            struct ggml_tensor * v_t_4d = ggml_reshape_4d(ctx, v_t_reshaped, S_v, H_k, 1, batch_size);
-            struct ggml_tensor * v_t_repeated = ggml_repeat_4d(ctx, v_t_4d, S_v, H_v, 1, batch_size);
-            v_t_final = ggml_reshape_2d(ctx, v_t_repeated, S_v, H_v * batch_size);
-            
-            struct ggml_tensor * beta_t_4d = ggml_reshape_4d(ctx, beta_t_reshaped, 1, H_k, 1, batch_size);
-            struct ggml_tensor * beta_t_repeated = ggml_repeat_4d(ctx, beta_t_4d, 1, H_v, 1, batch_size);
-            beta_t_final = ggml_reshape_2d(ctx, beta_t_repeated, 1, H_v * batch_size);
-        }
-        
-        struct ggml_tensor * kv_mem_reshaped;
-        if (kv_mem->ne[0] == S_v && kv_mem->ne[1] == H_v * batch_size) {
-            kv_mem_reshaped = kv_mem;
-        } else if (kv_mem->ne[0] == S_v) {
-            kv_mem_reshaped = ggml_view_2d(ctx, kv_mem, S_v, H_v * batch_size, kv_mem->nb[1], 0);
-        } else {
-            report_tensor_size("kv_mem_before_reshape", kv_mem);
-            kv_mem_reshaped = ggml_reshape_2d(ctx, kv_mem, S_v, H_v * batch_size);
-        }
-        kv_mem_reshaped = ggml_cont(ctx, kv_mem_reshaped);
+                
+        struct ggml_tensor * kv_mem_reshaped = ggml_transpose(ctx, kv_mem);
         report_tensor_size("kv_mem_reshaped", kv_mem_reshaped);
-        
-        struct ggml_tensor * kv_mem_final;
-        if (kv_mem_reshaped->ne[0] == v_t_final->ne[0] && kv_mem_reshaped->ne[1] == v_t_final->ne[1]) {
-            kv_mem_final = kv_mem_reshaped;
-        } else {
-            kv_mem_final = ggml_repeat(ctx, kv_mem_reshaped, v_t_final);
-        }
-        report_tensor_size("kv_mem_final", kv_mem_final);
-        
-        struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, v_t_final, kv_mem_final), beta_t_final);
+                
+        struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, v_t_final, kv_mem_reshaped), beta_t_final);
         report_tensor_size("delta", delta);
         
         struct ggml_tensor * delta_reshaped = ggml_reshape_2d(ctx, delta, S_v, H_v * batch_size);
         report_tensor_size("delta_reshaped", delta_reshaped);
-        
-        if (H_k == H_v) {
-            k_t_final = k_t_reshaped;
-        } else {
-            int64_t repeat_factor = H_v / H_k;
-            GGML_ASSERT(H_v % H_k == 0);
-            
-            k_t_final = ggml_reshape_3d(ctx, k_t_reshaped, S_k, 1, H_k * batch_size);
-            report_tensor_size("k_t_final_reshape1", k_t_final);
-            
-            k_t_final = ggml_repeat_4d(ctx, k_t_final, S_k, repeat_factor, H_k, batch_size);
-            report_tensor_size("k_t_final_repeat", k_t_final);
-            
-            k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
-            report_tensor_size("k_t_final_reshape2", k_t_final);
-        }
-        
-        k_t_final = ggml_cont(ctx, k_t_final);
+                
+        k_t_final = ggml_cont(ctx, k_t_reshaped);
         report_tensor_size("k_t_final_cont", k_t_final);
         
         struct ggml_tensor * k_t_for_outer;