瀏覽代碼

Remove comments as half of them are wrong anyways

Piotr Wilkin 4 月之前
父節點
當前提交
9832f2934a
共有 1 個文件被更改,包括 5 次插入223 次删除
  1. 5 223
      ggml/src/ggml-delta.c

+ 5 - 223
ggml/src/ggml-delta.c

@@ -51,14 +51,11 @@ struct ggml_tensor * ggml_delta_net(
     GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[3] == n_tokens);
     GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[3] == n_tokens);
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[3] == n_tokens);
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[3] == n_tokens);
        
        
-    // Validate g dimensions - g should be [S_v, H_v, n_tokens, batch_size] based on actual tensor layout
     GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
     GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
     
     
-    // Apply sigmoid to beta for gating
     struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
     struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
     report_tensor_size("beta_sigmoid", beta_sigmoid);
     report_tensor_size("beta_sigmoid", beta_sigmoid);
     
     
-    // Concatenate q, k, v for convolution processing
     struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q, k, 1);
     struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q, k, 1);
     report_tensor_size("mixed_qkv_qk", mixed_qkv);
     report_tensor_size("mixed_qkv_qk", mixed_qkv);
     mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
     mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
@@ -71,29 +68,23 @@ struct ggml_tensor * ggml_delta_net(
     struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, conv_weight->ne[0] - 1, 0, 0, 0);
     struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, conv_weight->ne[0] - 1, 0, 0, 0);
     report_tensor_size("mixed_qkv_padded", mixed_qkv_padded);
     report_tensor_size("mixed_qkv_padded", mixed_qkv_padded);
 
 
-    // Apply SSM convolution
     struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
     struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
     report_tensor_size("conv_out", conv_out);
     report_tensor_size("conv_out", conv_out);
 
 
-    // Apply bias if provided
     if (conv_bias) {
     if (conv_bias) {
         conv_out = ggml_add(ctx, conv_out, conv_bias);
         conv_out = ggml_add(ctx, conv_out, conv_bias);
         report_tensor_size("conv_out_bias", conv_out);
         report_tensor_size("conv_out_bias", conv_out);
     }
     }
 
 
-    // Apply SiLU activation
     conv_out = ggml_silu(ctx, conv_out);
     conv_out = ggml_silu(ctx, conv_out);
     report_tensor_size("conv_out_silu", conv_out);
     report_tensor_size("conv_out_silu", conv_out);
 
 
-    // Reshape back to 4D: [dim, n_tokens, 1] -> [dim, n_tokens, 1, 1]
     conv_out = ggml_reshape_4d(ctx, conv_out, dim, n_tokens, batch_size, 1);
     conv_out = ggml_reshape_4d(ctx, conv_out, dim, n_tokens, batch_size, 1);
     report_tensor_size("conv_out_reshaped", conv_out);
     report_tensor_size("conv_out_reshaped", conv_out);
 
 
-    // Transpose to get the right layout: [dim, n_tokens, 1] -> [dim, 1, n_tokens, 1]
     conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
     conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
     report_tensor_size("conv_out_transposed", conv_out);
     report_tensor_size("conv_out_transposed", conv_out);
 
 
-    // q projection view
     struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out,
     struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out,
                                                S_k,                  // ne0
                                                S_k,                  // ne0
                                                H_k,                  // ne1
                                                H_k,                  // ne1
@@ -132,7 +123,6 @@ struct ggml_tensor * ggml_delta_net(
     );
     );
     report_tensor_size("v_conv_view", v_conv);
     report_tensor_size("v_conv_view", v_conv);
 
 
-    // Transpose each component back to original layout: [S_v, 1, token_split_size, 1] -> [S_v, token_split_size, 1, 1]
     q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3);
     q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3);
     report_tensor_size("q_conv_permuted", q_conv);
     report_tensor_size("q_conv_permuted", q_conv);
     k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
     k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
@@ -147,29 +137,23 @@ struct ggml_tensor * ggml_delta_net(
     v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, batch_size, n_tokens);
     v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, batch_size, n_tokens);
     report_tensor_size("v_conv_reshaped", v_conv);
     report_tensor_size("v_conv_reshaped", v_conv);
     
     
-    // NOW we repeat query and key to match value head dimensions if needed (after convolution)
     struct ggml_tensor * q_broadcast = q_conv;
     struct ggml_tensor * q_broadcast = q_conv;
     struct ggml_tensor * k_broadcast = k_conv;
     struct ggml_tensor * k_broadcast = k_conv;
     
     
     if (H_k != H_v) {
     if (H_k != H_v) {
-        // Calculate the repeat factor: H_v / H_k
         GGML_ASSERT(H_v % H_k == 0);
         GGML_ASSERT(H_v % H_k == 0);
         int64_t repeat_factor = H_v / H_k;
         int64_t repeat_factor = H_v / H_k;
         
         
-        // Repeat query and key along the head dimension
-        // First reshape to separate the repeat dimension: [S_k, H_k, n_tokens, 1] -> [S_k, 1, H_k, n_tokens]
         q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, batch_size, H_k, n_tokens);
         q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, batch_size, H_k, n_tokens);
         report_tensor_size("q_broadcast_reshape1", q_broadcast);
         report_tensor_size("q_broadcast_reshape1", q_broadcast);
         k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, batch_size, H_k, n_tokens);
         k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, batch_size, H_k, n_tokens);
         report_tensor_size("k_broadcast_reshape1", k_broadcast);
         report_tensor_size("k_broadcast_reshape1", k_broadcast);
         
         
-        // Repeat along the new dimension: [S_k, repeat_factor, H_k, n_tokens]
         q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
         q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
         report_tensor_size("q_broadcast_repeat", q_broadcast);
         report_tensor_size("q_broadcast_repeat", q_broadcast);
         k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
         k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, batch_size * repeat_factor, H_k, n_tokens);
         report_tensor_size("k_broadcast_repeat", k_broadcast);
         report_tensor_size("k_broadcast_repeat", k_broadcast);
         
         
-        // Reshape back to original dimensions but with H_v heads: [S_k, H_v, n_tokens, 1]
         q_broadcast = ggml_reshape_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
         q_broadcast = ggml_reshape_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
         report_tensor_size("q_broadcast_reshape2", q_broadcast);
         report_tensor_size("q_broadcast_reshape2", q_broadcast);
         k_broadcast = ggml_reshape_4d(ctx, k_broadcast, S_k, H_v, n_tokens, batch_size);
         k_broadcast = ggml_reshape_4d(ctx, k_broadcast, S_k, H_v, n_tokens, batch_size);
@@ -180,7 +164,6 @@ struct ggml_tensor * ggml_delta_net(
     report_tensor_size("v_reshape", v_reshape);
     report_tensor_size("v_reshape", v_reshape);
     struct ggml_tensor * v_broadcast = ggml_repeat_4d(ctx, v_reshape, S_v, H_v, n_tokens, batch_size);
     struct ggml_tensor * v_broadcast = ggml_repeat_4d(ctx, v_reshape, S_v, H_v, n_tokens, batch_size);
     report_tensor_size("v_broadcast", v_broadcast);
     report_tensor_size("v_broadcast", v_broadcast);
-    // g already has correct dimensions [S_v, H_v, n_tokens, batch_size], no need to reshape
     struct ggml_tensor * g_reshape = g;
     struct ggml_tensor * g_reshape = g;
     report_tensor_size("g_reshape", g_reshape);
     report_tensor_size("g_reshape", g_reshape);
     q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
     q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, H_v, n_tokens, batch_size);
@@ -191,13 +174,9 @@ struct ggml_tensor * ggml_delta_net(
     report_tensor_size("beta_reshape", beta_reshape);
     report_tensor_size("beta_reshape", beta_reshape);
     struct ggml_tensor * beta_broadcast = ggml_repeat_4d(ctx, beta_reshape, 1, H_v, n_tokens, batch_size);
     struct ggml_tensor * beta_broadcast = ggml_repeat_4d(ctx, beta_reshape, 1, H_v, n_tokens, batch_size);
     report_tensor_size("beta_broadcast", beta_broadcast);
     report_tensor_size("beta_broadcast", beta_broadcast);
-    // The state should be repeated along the sequence dimension only
-    // Original state: [S_v, S_v, H_v, 1] -> should become [S_v, S_v, H_v, n_seqs]
-    // Use ggml_cont to ensure the state is contiguous, not ggml_repeat_4d which would repeat along all dimensions
     struct ggml_tensor * state_broadcast = ggml_cont(ctx, state);
     struct ggml_tensor * state_broadcast = ggml_cont(ctx, state);
     report_tensor_size("state_broadcast", state_broadcast);
     report_tensor_size("state_broadcast", state_broadcast);
     
     
-    // Call tensor-level kernel with convolved and processed tensors
     return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_broadcast, g_reshape, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
     return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_broadcast, g_reshape, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
 }
 }
 
 
@@ -220,7 +199,6 @@ struct ggml_tensor * ggml_delta_net_op(
     report_tensor_size("beta_input", beta);
     report_tensor_size("beta_input", beta);
     report_tensor_size("state_input", state);
     report_tensor_size("state_input", state);
     
     
-    // Validate dimensions
     GGML_ASSERT(ggml_is_contiguous(q));
     GGML_ASSERT(ggml_is_contiguous(q));
     GGML_ASSERT(ggml_is_contiguous(k));
     GGML_ASSERT(ggml_is_contiguous(k));
     GGML_ASSERT(ggml_is_contiguous(v));
     GGML_ASSERT(ggml_is_contiguous(v));
@@ -228,17 +206,16 @@ struct ggml_tensor * ggml_delta_net_op(
     GGML_ASSERT(ggml_is_contiguous(beta));
     GGML_ASSERT(ggml_is_contiguous(beta));
     GGML_ASSERT(ggml_is_contiguous(state));
     GGML_ASSERT(ggml_is_contiguous(state));
     
     
-    const int64_t S_k = q->ne[0];  // head dimension for q/k
-    const int64_t H_k = q->ne[1];  // number of heads (already processed to match v)
+    const int64_t S_k = q->ne[0];  
+    const int64_t H_k = q->ne[1];  
     const int64_t n_tokens = q->ne[2];
     const int64_t n_tokens = q->ne[2];
-    const int64_t batch_size = q->ne[3];  // batch size, not n_seqs
+    const int64_t batch_size = q->ne[3];  
     
     
-    const int64_t S_v = v->ne[0];  // head dimension for v
-    const int64_t H_v = v->ne[1];  // head dimension for v
+    const int64_t S_v = v->ne[0];  
+    const int64_t H_v = v->ne[1];  
 
 
     GGML_LOG_INFO("S_k = %ld, S_v = %ld, H_k = %ld, H_v = %ld\n", S_k, S_v, H_k, H_v);
     GGML_LOG_INFO("S_k = %ld, S_v = %ld, H_k = %ld, H_v = %ld\n", S_k, S_v, H_k, H_v);
     
     
-    // Validate dimensions - match Python implementation layout
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_tokens && k->ne[3] == batch_size);
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_tokens && k->ne[3] == batch_size);
     GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == batch_size);
     GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == batch_size);
     GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
     GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[3] == n_tokens && g->ne[2] == batch_size);
@@ -250,13 +227,9 @@ struct ggml_tensor * ggml_delta_net_op(
     
     
     struct ggml_tensor * new_state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S_v * S_v, H_v, 1, n_tokens);
     struct ggml_tensor * new_state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S_v * S_v, H_v, 1, n_tokens);
     
     
-    // Copy initial state to new_state
     new_state = ggml_cpy(ctx, state, new_state);
     new_state = ggml_cpy(ctx, state, new_state);
     report_tensor_size("new_state_copied", new_state);
     report_tensor_size("new_state_copied", new_state);
     
     
-    // Process all sequences and heads together using tensor operations
-    
-    // Apply L2 normalization if requested - per head, token, and sequence
     if (use_qk_l2norm) {
     if (use_qk_l2norm) {
         q = ggml_l2_norm(ctx, q, 1e-6f);
         q = ggml_l2_norm(ctx, q, 1e-6f);
         report_tensor_size("q_l2norm", q);
         report_tensor_size("q_l2norm", q);
@@ -264,20 +237,13 @@ struct ggml_tensor * ggml_delta_net_op(
         report_tensor_size("k_l2norm", k);
         report_tensor_size("k_l2norm", k);
     }
     }
     
     
-    // Apply scaling to query - across all tokens, sequences and heads
     q = ggml_scale(ctx, q, scale);
     q = ggml_scale(ctx, q, scale);
     report_tensor_size("q_scaled", q);
     report_tensor_size("q_scaled", q);
     
     
-    // Process the gated delta rule using tensor operations
-    
-    // Reshape for matrix operations: [S_v, S_v, H_v, 1] -> [S_v * S_v, H_v]
     struct ggml_tensor * state_flat = ggml_reshape_2d(ctx, new_state, S_v * S_v, H_v);
     struct ggml_tensor * state_flat = ggml_reshape_2d(ctx, new_state, S_v * S_v, H_v);
     report_tensor_size("state_flat", state_flat);
     report_tensor_size("state_flat", state_flat);
     
     
-    // Process each token sequentially due to recurrent nature
     for (int64_t t = 0; t < n_tokens; ++t) {
     for (int64_t t = 0; t < n_tokens; ++t) {
-        // Extract current token's data across all batches and heads
-        // q, k, v are [S_k, H_k, n_tokens, batch_size] layout in GGML
         struct ggml_tensor * q_t = ggml_view_3d(ctx, q, S_k, H_k, batch_size,
         struct ggml_tensor * q_t = ggml_view_3d(ctx, q, S_k, H_k, batch_size,
                                                q->nb[1], q->nb[2], t * q->nb[2]);
                                                q->nb[1], q->nb[2], t * q->nb[2]);
         report_tensor_size("q_t_view", q_t);
         report_tensor_size("q_t_view", q_t);
@@ -291,403 +257,222 @@ struct ggml_tensor * ggml_delta_net_op(
                                                   beta->nb[1], beta->nb[2], t * beta->nb[2]);
                                                   beta->nb[1], beta->nb[2], t * beta->nb[2]);
         report_tensor_size("beta_t_view", beta_t);
         report_tensor_size("beta_t_view", beta_t);
                 
                 
-        // Simplified approach: follow Python implementation exactly
-        // In Python: kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
-        // This means: for each batch and head, multiply state by k_t and sum over the last dimension
-        
-        // First, reshape tensors to match GGML layout for head-wise processing
-        // q_t: [S_k, H_k, batch_size] -> reshape to [S_k, H_k * batch_size]
         struct ggml_tensor * q_t_reshaped = ggml_reshape_2d(ctx, q_t, S_k, H_k * batch_size);
         struct ggml_tensor * q_t_reshaped = ggml_reshape_2d(ctx, q_t, S_k, H_k * batch_size);
         report_tensor_size("q_t_reshaped", q_t_reshaped);
         report_tensor_size("q_t_reshaped", q_t_reshaped);
         
         
-        // k_t: [S_k, H_k, batch_size] -> reshape to [S_k, H_k * batch_size]
         struct ggml_tensor * k_t_reshaped = ggml_reshape_2d(ctx, k_t, S_k, H_k * batch_size);
         struct ggml_tensor * k_t_reshaped = ggml_reshape_2d(ctx, k_t, S_k, H_k * batch_size);
         report_tensor_size("k_t_reshaped", k_t_reshaped);
         report_tensor_size("k_t_reshaped", k_t_reshaped);
         
         
-        // v_t: [S_v, H_v, batch_size] -> reshape to [S_v, H_v * batch_size]
         struct ggml_tensor * v_t_reshaped = ggml_reshape_2d(ctx, v_t, S_v, H_v * batch_size);
         struct ggml_tensor * v_t_reshaped = ggml_reshape_2d(ctx, v_t, S_v, H_v * batch_size);
         report_tensor_size("v_t_reshaped", v_t_reshaped);
         report_tensor_size("v_t_reshaped", v_t_reshaped);
         
         
-        // beta_t: [1, H_v, batch_size] -> reshape to [1, H_v * batch_size]
         struct ggml_tensor * beta_t_reshaped = ggml_reshape_2d(ctx, beta_t, 1, H_v * batch_size);
         struct ggml_tensor * beta_t_reshaped = ggml_reshape_2d(ctx, beta_t, 1, H_v * batch_size);
         report_tensor_size("beta_t_reshaped", beta_t_reshaped);
         report_tensor_size("beta_t_reshaped", beta_t_reshaped);
         
         
-        // Handle head dimension mismatch - repeat k_t if needed
         struct ggml_tensor * k_t_final = k_t_reshaped;
         struct ggml_tensor * k_t_final = k_t_reshaped;
         if (H_k != H_v) {
         if (H_k != H_v) {
             GGML_ASSERT(H_v % H_k == 0);
             GGML_ASSERT(H_v % H_k == 0);
             
             
-            // Reshape k_t to separate head and batch dimensions: [S_k, H_k, batch_size, 1]
             struct ggml_tensor * k_t_4d = ggml_reshape_4d(ctx, k_t_reshaped, S_k, H_k, 1, batch_size);
             struct ggml_tensor * k_t_4d = ggml_reshape_4d(ctx, k_t_reshaped, S_k, H_k, 1, batch_size);
             report_tensor_size("k_t_4d", k_t_4d);
             report_tensor_size("k_t_4d", k_t_4d);
             
             
-            // Repeat along head dimension: [S_k, H_v, batch_size, 1]
             k_t_final = ggml_repeat_4d(ctx, k_t_4d, S_k, H_v, 1, batch_size);
             k_t_final = ggml_repeat_4d(ctx, k_t_4d, S_k, H_v, 1, batch_size);
             report_tensor_size("k_t_final_repeated", k_t_final);
             report_tensor_size("k_t_final_repeated", k_t_final);
             
             
-            // Reshape back to 2D: [S_k, H_v * batch_size]
             k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
             k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
             report_tensor_size("k_t_final_2d", k_t_final);
             report_tensor_size("k_t_final_2d", k_t_final);
         }
         }
         
         
-        // Simplified kv_mem computation: state @ k_t^T for each head
-        // For now, let's use a simpler approach that matches the Python logic more closely
-        // kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2)
-        
-        // Reshape state to [S_v * S_v, H_v] for easier processing
         struct ggml_tensor * state_2d = ggml_reshape_2d(ctx, new_state, S_v * S_v, H_v);
         struct ggml_tensor * state_2d = ggml_reshape_2d(ctx, new_state, S_v * S_v, H_v);
         report_tensor_size("state_2d", state_2d);
         report_tensor_size("state_2d", state_2d);
         
         
-        // The state is already in the correct format for matrix operations
         struct ggml_tensor * state_t = state_2d;
         struct ggml_tensor * state_t = state_2d;
         report_tensor_size("state_t", state_t);
         report_tensor_size("state_t", state_t);
         
         
-        // Simple kv_mem computation for this token
-        // kv_mem = (state_t * k_t.unsqueeze(-1)).sum(dim=-2)
-        // In GGML, we need to implement: (state_t * k_t_broadcast).sum(dim=1)
-        // state_t: [S_v * S_v, H_v], k_t_final: [S_k, H_v * batch_size]
-        
-        // For the correct matrix multiplication, we need:
-        // state_t: [S_v * S_v, H_v]
-        // k_t_final: [S_k, H_v * batch_size]
-        // We want: state_t @ k_t_transposed where k_t_transposed is [H_v * batch_size, S_k]
-        
-        // But first, let's check if we can do a simpler approach
-        // Since we have H_v = 16 and batch_size = 1, we have:
-        // state_t: [16384, 16] and k_t_final: [128, 16]
-        
-        // For matrix multiplication, we need: [16384, 16] @ [16, 128] = [16384, 128]
-        // So we need to transpose k_t_final to get [16, 128]
-        
-        // For GGML matrix multiplication, we need to satisfy ggml_can_mul_mat requirements:
-        // t0->ne[0] == t1->ne[0] (first dimensions must be equal)
-        // t1->ne[2]%t0->ne[2] == 0 (broadcastable along 3rd dimension)
-        // t1->ne[3]%t0->ne[3] == 0 (broadcastable along 4th dimension)
-        
-        // We need to reshape state_t from [S_v * S_v, H_v, 1, 1] to [H_v, S_v * S_v, 1, 1]
-        // and k_t_final from [S_k, H_v * batch_size] to [H_v, S_k, 1, 1]
-        
-        // First, transpose state_t to get [H_v, S_v * S_v, 1, 1]
         struct ggml_tensor * state_t_transposed = ggml_transpose(ctx, state_t);
         struct ggml_tensor * state_t_transposed = ggml_transpose(ctx, state_t);
         report_tensor_size("state_t_transposed", state_t_transposed);
         report_tensor_size("state_t_transposed", state_t_transposed);
         
         
-        // Reshape k_t_final from [S_k, H_v * batch_size] to [H_v, S_k, 1, 1]
         struct ggml_tensor * k_t_final_reshaped = ggml_reshape_4d(ctx, k_t_final, H_v, S_k, batch_size, 1);
         struct ggml_tensor * k_t_final_reshaped = ggml_reshape_4d(ctx, k_t_final, H_v, S_k, batch_size, 1);
         report_tensor_size("k_t_final_reshaped", k_t_final_reshaped);
         report_tensor_size("k_t_final_reshaped", k_t_final_reshaped);
         
         
-        // Now we can do matrix multiplication: k_t_final_reshaped^T @ state_t_transposed^T
-        // But GGML doesn't allow transposed first argument, so we need to swap the order
-        // and transpose the result if needed
         struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, k_t_final_reshaped, state_t_transposed);
         struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, k_t_final_reshaped, state_t_transposed);
         report_tensor_size("kv_mem", kv_mem);
         report_tensor_size("kv_mem", kv_mem);
                 
                 
-        // Compute delta = (v_t - kv_mem) * beta_t
-        // kv_mem: [batch_size, S_v] (result of state @ k_t^T)
-        // v_t: [batch_size, H_v, S_v] -> reshape to [batch_size * H_v, S_v]
-        // beta_t: [batch_size, H_v, 1] -> reshape to [batch_size * H_v, 1]
-        
-        // Handle head dimension mismatch for v_t and beta_t
         struct ggml_tensor * v_t_final = v_t_reshaped;
         struct ggml_tensor * v_t_final = v_t_reshaped;
         struct ggml_tensor * beta_t_final = beta_t_reshaped;
         struct ggml_tensor * beta_t_final = beta_t_reshaped;
         
         
         if (H_k != H_v) {
         if (H_k != H_v) {
-            // Repeat v_t and beta_t along head dimension to match H_v
-            // v_t: [S_v, H_k, batch_size] -> [S_v, H_k, batch_size, 1] -> repeat -> [S_v, H_v, batch_size, 1]
             struct ggml_tensor * v_t_4d = ggml_reshape_4d(ctx, v_t_reshaped, S_v, H_k, 1, batch_size);
             struct ggml_tensor * v_t_4d = ggml_reshape_4d(ctx, v_t_reshaped, S_v, H_k, 1, batch_size);
             struct ggml_tensor * v_t_repeated = ggml_repeat_4d(ctx, v_t_4d, S_v, H_v, 1, batch_size);
             struct ggml_tensor * v_t_repeated = ggml_repeat_4d(ctx, v_t_4d, S_v, H_v, 1, batch_size);
             v_t_final = ggml_reshape_2d(ctx, v_t_repeated, S_v, H_v * batch_size);
             v_t_final = ggml_reshape_2d(ctx, v_t_repeated, S_v, H_v * batch_size);
             
             
-            // beta_t: [1, H_k, batch_size] -> [1, H_k, batch_size, 1] -> repeat -> [1, H_v, batch_size, 1]
             struct ggml_tensor * beta_t_4d = ggml_reshape_4d(ctx, beta_t_reshaped, 1, H_k, 1, batch_size);
             struct ggml_tensor * beta_t_4d = ggml_reshape_4d(ctx, beta_t_reshaped, 1, H_k, 1, batch_size);
             struct ggml_tensor * beta_t_repeated = ggml_repeat_4d(ctx, beta_t_4d, 1, H_v, 1, batch_size);
             struct ggml_tensor * beta_t_repeated = ggml_repeat_4d(ctx, beta_t_4d, 1, H_v, 1, batch_size);
             beta_t_final = ggml_reshape_2d(ctx, beta_t_repeated, 1, H_v * batch_size);
             beta_t_final = ggml_reshape_2d(ctx, beta_t_repeated, 1, H_v * batch_size);
         }
         }
         
         
-        // Ensure kv_mem has correct dimensions for subtraction
-        // kv_mem dimensions from trace: [128, 16384, 1, 1]
-        // We need to reshape it to match v_t_final: [128, 16, 1, 1]
-        
-        // First, let's reshape kv_mem to the correct dimensions
         struct ggml_tensor * kv_mem_reshaped;
         struct ggml_tensor * kv_mem_reshaped;
         if (kv_mem->ne[0] == S_v && kv_mem->ne[1] == H_v * batch_size) {
         if (kv_mem->ne[0] == S_v && kv_mem->ne[1] == H_v * batch_size) {
-            // Perfect match
             kv_mem_reshaped = kv_mem;
             kv_mem_reshaped = kv_mem;
         } else if (kv_mem->ne[0] == S_v) {
         } else if (kv_mem->ne[0] == S_v) {
-            // We have the right first dimension, need to fix the second dimension
             kv_mem_reshaped = ggml_view_2d(ctx, kv_mem, S_v, H_v * batch_size, kv_mem->nb[1], 0);
             kv_mem_reshaped = ggml_view_2d(ctx, kv_mem, S_v, H_v * batch_size, kv_mem->nb[1], 0);
         } else {
         } else {
-            // Handle other dimension mismatches
             report_tensor_size("kv_mem_before_reshape", kv_mem);
             report_tensor_size("kv_mem_before_reshape", kv_mem);
             kv_mem_reshaped = ggml_reshape_2d(ctx, kv_mem, S_v, H_v * batch_size);
             kv_mem_reshaped = ggml_reshape_2d(ctx, kv_mem, S_v, H_v * batch_size);
         }
         }
         kv_mem_reshaped = ggml_cont(ctx, kv_mem_reshaped);
         kv_mem_reshaped = ggml_cont(ctx, kv_mem_reshaped);
         report_tensor_size("kv_mem_reshaped", kv_mem_reshaped);
         report_tensor_size("kv_mem_reshaped", kv_mem_reshaped);
         
         
-        // Now ensure kv_mem_reshaped has the same dimensions as v_t_final
         struct ggml_tensor * kv_mem_final;
         struct ggml_tensor * kv_mem_final;
         if (kv_mem_reshaped->ne[0] == v_t_final->ne[0] && kv_mem_reshaped->ne[1] == v_t_final->ne[1]) {
         if (kv_mem_reshaped->ne[0] == v_t_final->ne[0] && kv_mem_reshaped->ne[1] == v_t_final->ne[1]) {
             kv_mem_final = kv_mem_reshaped;
             kv_mem_final = kv_mem_reshaped;
         } else {
         } else {
-            // Use repeat to match dimensions if they're compatible
             kv_mem_final = ggml_repeat(ctx, kv_mem_reshaped, v_t_final);
             kv_mem_final = ggml_repeat(ctx, kv_mem_reshaped, v_t_final);
         }
         }
         report_tensor_size("kv_mem_final", kv_mem_final);
         report_tensor_size("kv_mem_final", kv_mem_final);
         
         
-        // Compute delta = (v_t - kv_mem) * beta_t
         struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, v_t_final, kv_mem_final), beta_t_final);
         struct ggml_tensor * delta = ggml_mul(ctx, ggml_sub(ctx, v_t_final, kv_mem_final), beta_t_final);
         report_tensor_size("delta", delta);
         report_tensor_size("delta", delta);
         
         
-        // Update state: state = state + outer(k_t, delta)
         struct ggml_tensor * delta_reshaped = ggml_reshape_2d(ctx, delta, S_v, H_v * batch_size);
         struct ggml_tensor * delta_reshaped = ggml_reshape_2d(ctx, delta, S_v, H_v * batch_size);
         report_tensor_size("delta_reshaped", delta_reshaped);
         report_tensor_size("delta_reshaped", delta_reshaped);
         
         
-        // Handle the outer product for all heads and batches
-        // We need to compute outer(k_t, delta) where:
-        // k_t is [S_k * H_k, batch_size] -> reshape to [S_k, H_k * batch_size]
-        // delta is [S_v, H_v * batch_size]
-        // For outer product, we want k_t @ delta^T
-        
-        // First, handle head dimension mismatch for k_t (reuse existing k_t_final variable)
         if (H_k == H_v) {
         if (H_k == H_v) {
             k_t_final = k_t_reshaped;
             k_t_final = k_t_reshaped;
         } else {
         } else {
-            // Need to repeat k along the head dimension to match H_v
             int64_t repeat_factor = H_v / H_k;
             int64_t repeat_factor = H_v / H_k;
             GGML_ASSERT(H_v % H_k == 0);
             GGML_ASSERT(H_v % H_k == 0);
             
             
-            // Reshape to separate repeat dimension: [S_k, 1, H_k, batch_size]
             k_t_final = ggml_reshape_3d(ctx, k_t_reshaped, S_k, 1, H_k * batch_size);
             k_t_final = ggml_reshape_3d(ctx, k_t_reshaped, S_k, 1, H_k * batch_size);
             report_tensor_size("k_t_final_reshape1", k_t_final);
             report_tensor_size("k_t_final_reshape1", k_t_final);
             
             
-            // Repeat along the new dimension: [S_k, repeat_factor, H_k, batch_size]
             k_t_final = ggml_repeat_4d(ctx, k_t_final, S_k, repeat_factor, H_k, batch_size);
             k_t_final = ggml_repeat_4d(ctx, k_t_final, S_k, repeat_factor, H_k, batch_size);
             report_tensor_size("k_t_final_repeat", k_t_final);
             report_tensor_size("k_t_final_repeat", k_t_final);
             
             
-            // Reshape back: [S_k, H_v * batch_size]
             k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
             k_t_final = ggml_reshape_2d(ctx, k_t_final, S_k, H_v * batch_size);
             report_tensor_size("k_t_final_reshape2", k_t_final);
             report_tensor_size("k_t_final_reshape2", k_t_final);
         }
         }
         
         
-        // Make k_t_final contiguous
         k_t_final = ggml_cont(ctx, k_t_final);
         k_t_final = ggml_cont(ctx, k_t_final);
         report_tensor_size("k_t_final_cont", k_t_final);
         report_tensor_size("k_t_final_cont", k_t_final);
         
         
-        // Handle dimension mismatch between S_k and S_v
         struct ggml_tensor * k_t_for_outer;
         struct ggml_tensor * k_t_for_outer;
         if (S_k == S_v) {
         if (S_k == S_v) {
             k_t_for_outer = k_t_final;
             k_t_for_outer = k_t_final;
         } else if (S_k < S_v) {
         } else if (S_k < S_v) {
-            // Pad k_t to match S_v
             struct ggml_tensor * padding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v - S_k, H_v * batch_size);
             struct ggml_tensor * padding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v - S_k, H_v * batch_size);
             report_tensor_size("k_t_padding", padding);
             report_tensor_size("k_t_padding", padding);
             k_t_for_outer = ggml_concat(ctx, k_t_final, padding, 0);
             k_t_for_outer = ggml_concat(ctx, k_t_final, padding, 0);
             report_tensor_size("k_t_for_outer_padded", k_t_for_outer);
             report_tensor_size("k_t_for_outer_padded", k_t_for_outer);
         } else {
         } else {
-            // Truncate k_t to match S_v
             k_t_for_outer = ggml_view_2d(ctx, k_t_final, S_v, H_v * batch_size, k_t_final->nb[1], 0);
             k_t_for_outer = ggml_view_2d(ctx, k_t_final, S_v, H_v * batch_size, k_t_final->nb[1], 0);
             report_tensor_size("k_t_for_outer_truncated", k_t_for_outer);
             report_tensor_size("k_t_for_outer_truncated", k_t_for_outer);
         }
         }
         
         
-        // Make sure k_t_for_outer is contiguous
         k_t_for_outer = ggml_cont(ctx, k_t_for_outer);
         k_t_for_outer = ggml_cont(ctx, k_t_for_outer);
         report_tensor_size("k_t_for_outer_cont", k_t_for_outer);
         report_tensor_size("k_t_for_outer_cont", k_t_for_outer);
         
         
-        // Compute outer product: k_t_for_outer @ delta_reshaped^T
-        // k_t_for_outer: [S_v, H_v * batch_size]
-        // delta_reshaped: [S_v, H_v * batch_size]
-        // For outer product, we want: k_t_for_outer @ delta_reshaped^T
-        
-        // We need to satisfy ggml_can_mul_mat requirements:
-        // t0->ne[0] == t1->ne[0] (first dimensions must be equal)
-        // t1->ne[2]%t0->ne[2] == 0 (broadcastable along 3rd dimension)
-        // t1->ne[3]%t0->ne[3] == 0 (broadcastable along 4th dimension)
-        
-        // First, reshape k_t_for_outer to [S_v, H_v * batch_size, 1, 1]
         struct ggml_tensor * k_t_reshaped_4d = ggml_reshape_4d(ctx, k_t_for_outer, S_v, H_v, 1, batch_size);
         struct ggml_tensor * k_t_reshaped_4d = ggml_reshape_4d(ctx, k_t_for_outer, S_v, H_v, 1, batch_size);
         report_tensor_size("k_t_reshaped_4d", k_t_reshaped_4d);
         report_tensor_size("k_t_reshaped_4d", k_t_reshaped_4d);
         
         
-        // Transpose delta_reshaped to get [H_v * batch_size, S_v]
         struct ggml_tensor * delta_transposed = ggml_transpose(ctx, delta_reshaped);
         struct ggml_tensor * delta_transposed = ggml_transpose(ctx, delta_reshaped);
         report_tensor_size("delta_transposed", delta_transposed);
         report_tensor_size("delta_transposed", delta_transposed);
         
         
-        // Make delta_transposed contiguous before reshaping
         delta_transposed = ggml_cont(ctx, delta_transposed);
         delta_transposed = ggml_cont(ctx, delta_transposed);
         report_tensor_size("delta_transposed_cont", delta_transposed);
         report_tensor_size("delta_transposed_cont", delta_transposed);
         
         
-        // Reshape delta_transposed to [H_v * batch_size, S_v, 1, 1]
         struct ggml_tensor * delta_reshaped_4d = ggml_reshape_4d(ctx, delta_transposed, H_v, S_v, 1, batch_size);
         struct ggml_tensor * delta_reshaped_4d = ggml_reshape_4d(ctx, delta_transposed, H_v, S_v, 1, batch_size);
         report_tensor_size("delta_reshaped_4d", delta_reshaped_4d);
         report_tensor_size("delta_reshaped_4d", delta_reshaped_4d);
         
         
-        // For outer product k @ delta^T, we need: [S_v, H_v * batch_size] @ [H_v * batch_size, S_v] = [S_v, S_v]
-        // But GGML requires the first dimensions to be equal for matrix multiplication
-        // So we need to transpose the first tensor: k_t_reshaped_4d^T @ delta_reshaped_4d
-        // [H_v * batch_size, S_v] @ [H_v * batch_size, S_v] - this won't work
-        
-        // Instead, we need to do: delta_reshaped_4d^T @ k_t_reshaped_4d^T
-        // But GGML doesn't allow transposed first argument, so we need to swap the order
-        // and transpose the result if needed
-        
-        // Let's do: delta_reshaped_4d^T @ k_t_reshaped_4d
-        // [S_v, H_v * batch_size] @ [S_v, H_v * batch_size] - this won't work either
-        
-        // The correct approach is: k_t_reshaped_4d @ delta_reshaped_4d^T
-        // But we need to make the first dimensions equal by transposing k_t_reshaped_4d
         struct ggml_tensor * k_t_transposed = ggml_transpose(ctx, k_t_reshaped_4d);
         struct ggml_tensor * k_t_transposed = ggml_transpose(ctx, k_t_reshaped_4d);
         report_tensor_size("k_t_transposed", k_t_transposed);
         report_tensor_size("k_t_transposed", k_t_transposed);
         
         
-        // Now we can do: k_t_transposed @ delta_reshaped_4d
-        // [H_v * batch_size, S_v] @ [H_v * batch_size, S_v] - still won't work
-        
-        // Let's try a different approach: use the transpose of the result
-        // We want: k @ delta^T = (delta @ k^T)^T
         struct ggml_tensor * temp_product = ggml_mul_mat(ctx, delta_reshaped_4d, k_t_transposed);
         struct ggml_tensor * temp_product = ggml_mul_mat(ctx, delta_reshaped_4d, k_t_transposed);
         report_tensor_size("temp_product", temp_product);
         report_tensor_size("temp_product", temp_product);
         
         
-        // Transpose the result to get the final outer product
         struct ggml_tensor * outer_product_raw = ggml_transpose(ctx, temp_product);
         struct ggml_tensor * outer_product_raw = ggml_transpose(ctx, temp_product);
         report_tensor_size("outer_product_raw", outer_product_raw);
         report_tensor_size("outer_product_raw", outer_product_raw);
         
         
-        // Make outer_product_raw contiguous before reshaping
         struct ggml_tensor * outer_product_cont = ggml_cont(ctx, outer_product_raw);
         struct ggml_tensor * outer_product_cont = ggml_cont(ctx, outer_product_raw);
         report_tensor_size("outer_product_cont", outer_product_cont);
         report_tensor_size("outer_product_cont", outer_product_cont);
         
         
-        // Reshape to 2D: [S_v, S_v]
         struct ggml_tensor * outer_product = ggml_reshape_2d(ctx, outer_product_cont, S_v, S_v);
         struct ggml_tensor * outer_product = ggml_reshape_2d(ctx, outer_product_cont, S_v, S_v);
         report_tensor_size("outer_product", outer_product);
         report_tensor_size("outer_product", outer_product);
         
         
-        // Now we need to reshape outer_product to match state_flat dimensions
-        // outer_product: [S_v, S_v] -> reshape to [S_v * S_v, H_v * batch_size]
         struct ggml_tensor * outer_product_reshaped;
         struct ggml_tensor * outer_product_reshaped;
         if (outer_product->ne[0] == S_v && outer_product->ne[1] == S_v) {
         if (outer_product->ne[0] == S_v && outer_product->ne[1] == S_v) {
-            // Perfect match for a single head/sequence
             outer_product_reshaped = ggml_reshape_2d(ctx, outer_product, S_v * S_v, 1);
             outer_product_reshaped = ggml_reshape_2d(ctx, outer_product, S_v * S_v, 1);
         } else {
         } else {
-            // Handle whatever dimensions we got
             outer_product_reshaped = ggml_reshape_2d(ctx, outer_product,
             outer_product_reshaped = ggml_reshape_2d(ctx, outer_product,
                                                     outer_product->ne[0] * outer_product->ne[1], 1);
                                                     outer_product->ne[0] * outer_product->ne[1], 1);
         }
         }
         report_tensor_size("outer_product_reshaped", outer_product_reshaped);
         report_tensor_size("outer_product_reshaped", outer_product_reshaped);
         
         
-        // Repeat outer_product_reshaped to match the number of heads and batches
         struct ggml_tensor * outer_product_repeated = ggml_repeat(ctx, outer_product_reshaped, state_flat);
         struct ggml_tensor * outer_product_repeated = ggml_repeat(ctx, outer_product_reshaped, state_flat);
         report_tensor_size("outer_product_repeated", outer_product_repeated);
         report_tensor_size("outer_product_repeated", outer_product_repeated);
         
         
-        // Update state
         state_flat = ggml_add(ctx, state_flat, outer_product_repeated);
         state_flat = ggml_add(ctx, state_flat, outer_product_repeated);
         report_tensor_size("state_flat_updated", state_flat);
         report_tensor_size("state_flat_updated", state_flat);
         
         
-        // Compute output = current_state @ q_t^T for all heads and batches
-        // Simplified approach: follow Python implementation more closely
-        // In Python: output = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
-        // This means: for each batch and head, multiply state by q_t and sum over the last dimension
-        
-        // First, let's work with the original q_t (already processed to match H_v)
         struct ggml_tensor * q_t_final = q_t;
         struct ggml_tensor * q_t_final = q_t;
         report_tensor_size("q_t_final", q_t_final);
         report_tensor_size("q_t_final", q_t_final);
         
         
-        // Make q_t_final contiguous for matrix operations
         q_t_final = ggml_cont(ctx, q_t_final);
         q_t_final = ggml_cont(ctx, q_t_final);
         report_tensor_size("q_t_final_cont", q_t_final);
         report_tensor_size("q_t_final_cont", q_t_final);
         
         
-        // For the output computation, we want: (state * q_t.unsqueeze(-1)).sum(dim=-2)
-        // This is equivalent to: state @ q_t^T where q_t is reshaped appropriately
-        
-        // Simple approach: reshape q_t to [S_k, H_v * batch_size] and state to [S_v * S_v, H_v * batch_size]
-        // Then compute: state^T @ q_t
-        // But we need to handle the GGML requirements
-        
-        // Make state_flat contiguous
         struct ggml_tensor * state_flat_cont = ggml_cont(ctx, state_flat);
         struct ggml_tensor * state_flat_cont = ggml_cont(ctx, state_flat);
         report_tensor_size("state_flat_cont", state_flat_cont);
         report_tensor_size("state_flat_cont", state_flat_cont);
         
         
-        // Reshape q_t to [S_k, H_v * batch_size] for matrix multiplication
         struct ggml_tensor * q_t_matrix = ggml_reshape_2d(ctx, q_t_final, S_k, H_v * batch_size);
         struct ggml_tensor * q_t_matrix = ggml_reshape_2d(ctx, q_t_final, S_k, H_v * batch_size);
         report_tensor_size("q_t_matrix", q_t_matrix);
         report_tensor_size("q_t_matrix", q_t_matrix);
         
         
-        // Now we want to compute: state_flat_cont^T @ q_t_matrix
-        // state_flat_cont: [S_v * S_v, H_v * batch_size] = [16384, 16]
-        // q_t_matrix: [S_k, H_v * batch_size] = [128, 16]
-        
-        // For GGML, we need: q_t_matrix^T @ state_flat_cont^T
-        // But GGML doesn't allow transposed first argument, so we use the property: A @ B = (B^T @ A^T)^T
-        
-        // Transpose q_t_matrix to get [H_v * batch_size, S_k] = [16, 128]
         struct ggml_tensor * q_t_matrix_transposed = ggml_transpose(ctx, q_t_matrix);
         struct ggml_tensor * q_t_matrix_transposed = ggml_transpose(ctx, q_t_matrix);
         report_tensor_size("q_t_matrix_transposed", q_t_matrix_transposed);
         report_tensor_size("q_t_matrix_transposed", q_t_matrix_transposed);
         
         
-        // Transpose state_flat_cont to get [H_v * batch_size, S_v * S_v] = [16, 16384]
         struct ggml_tensor * state_flat_transposed = ggml_transpose(ctx, state_flat_cont);
         struct ggml_tensor * state_flat_transposed = ggml_transpose(ctx, state_flat_cont);
         report_tensor_size("state_flat_transposed", state_flat_transposed);
         report_tensor_size("state_flat_transposed", state_flat_transposed);
         
         
-        // Now we can do: q_t_matrix_transposed @ state_flat_transposed
-        // [16, 128] @ [16, 16384] - this won't work because first dimensions don't match
-        
-        // Instead, let's do: state_flat_transposed^T @ q_t_matrix_transposed^T
-        // But we need to transpose both again
         struct ggml_tensor * q_t_matrix_final = ggml_transpose(ctx, q_t_matrix_transposed);
         struct ggml_tensor * q_t_matrix_final = ggml_transpose(ctx, q_t_matrix_transposed);
         report_tensor_size("q_t_matrix_final", q_t_matrix_final);
         report_tensor_size("q_t_matrix_final", q_t_matrix_final);
         
         
         struct ggml_tensor * state_flat_final = ggml_transpose(ctx, state_flat_transposed);
         struct ggml_tensor * state_flat_final = ggml_transpose(ctx, state_flat_transposed);
         report_tensor_size("state_flat_final", state_flat_final);
         report_tensor_size("state_flat_final", state_flat_final);
         
         
-        // Now we can do: q_t_matrix_final @ state_flat_final
-        // [128, 16] @ [16384, 16] - this won't work either
-        
-        // Let me try a different approach: use element-wise multiplication and sum
-        // We want: (state * q_t.unsqueeze(-1)).sum(dim=-2)
-        
-        // First, reshape q_t to broadcast with state
         struct ggml_tensor * q_t_broadcast = ggml_repeat(ctx, q_t_final, state_flat_cont);
         struct ggml_tensor * q_t_broadcast = ggml_repeat(ctx, q_t_final, state_flat_cont);
         report_tensor_size("q_t_broadcast", q_t_broadcast);
         report_tensor_size("q_t_broadcast", q_t_broadcast);
         
         
-        // Element-wise multiplication
         struct ggml_tensor * state_q_product = ggml_mul(ctx, state_flat_cont, q_t_broadcast);
         struct ggml_tensor * state_q_product = ggml_mul(ctx, state_flat_cont, q_t_broadcast);
         report_tensor_size("state_q_product", state_q_product);
         report_tensor_size("state_q_product", state_q_product);
                
                
-        // Let's reshape to separate the dimensions we want to sum over
-        
-        // Reshape state_q_product to [S_v * S_v, H_v, batch_size]
         struct ggml_tensor * state_q_3d = ggml_reshape_3d(ctx, state_q_product, S_v * S_v, H_v, batch_size);
         struct ggml_tensor * state_q_3d = ggml_reshape_3d(ctx, state_q_product, S_v * S_v, H_v, batch_size);
         report_tensor_size("state_q_3d", state_q_3d);
         report_tensor_size("state_q_3d", state_q_3d);
-        // Ensure contiguous layout so byte-strides are consistent for subsequent views/slices.
         state_q_3d = ggml_cont(ctx, state_q_3d);
         state_q_3d = ggml_cont(ctx, state_q_3d);
         report_tensor_size("state_q_3d_cont", state_q_3d);
         report_tensor_size("state_q_3d_cont", state_q_3d);
         
         
-        // Sum over the H_v dimension (axis 1)
-        // Create a proper ones vector: ggml_new_tensor_1d already creates a zero-filled tensor,
-        // so ggml_exp on it will produce ones (exp(0) = 1).
         struct ggml_tensor * ones_vector = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, H_v);
         struct ggml_tensor * ones_vector = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, H_v);
         ones_vector = ggml_exp(ctx, ones_vector);      // exp(0) = 1
         ones_vector = ggml_exp(ctx, ones_vector);      // exp(0) = 1
         report_tensor_size("ones_vector", ones_vector);
         report_tensor_size("ones_vector", ones_vector);
         
         
-        // Reshape to [H_v, 1] for matrix multiplication
         struct ggml_tensor * ones_col = ggml_reshape_2d(ctx, ones_vector, H_v, 1);
         struct ggml_tensor * ones_col = ggml_reshape_2d(ctx, ones_vector, H_v, 1);
         report_tensor_size("ones_col", ones_col);
         report_tensor_size("ones_col", ones_col);
         
         
-        // Prepare per-batch results
         struct ggml_tensor * output_parts[batch_size];
         struct ggml_tensor * output_parts[batch_size];
         for (int64_t b = 0; b < batch_size; b++) {
         for (int64_t b = 0; b < batch_size; b++) {
-            // Extract slice for this batch: [S_v * S_v, H_v]
-            // Use the contiguous state_q_3d so nb and offsets are reliable.
             struct ggml_tensor * batch_slice = ggml_view_3d(ctx, state_q_3d, S_v * S_v, H_v, 1,
             struct ggml_tensor * batch_slice = ggml_view_3d(ctx, state_q_3d, S_v * S_v, H_v, 1,
                                                            state_q_3d->nb[1], state_q_3d->nb[2], b * state_q_3d->nb[2]);
                                                            state_q_3d->nb[1], state_q_3d->nb[2], b * state_q_3d->nb[2]);
             batch_slice = ggml_cont(ctx, batch_slice);
             batch_slice = ggml_cont(ctx, batch_slice);
             report_tensor_size("batch_slice", batch_slice);
             report_tensor_size("batch_slice", batch_slice);
             
             
-            // Multiply by ones and sum across H_v:
-            // ones_col: [H_v, 1], batch_slice^T: [H_v, S_v * S_v] -> ones_col @ batch_slice^T = [1, S_v * S_v]
             struct ggml_tensor * batch_slice_t = ggml_transpose(ctx, batch_slice);
             struct ggml_tensor * batch_slice_t = ggml_transpose(ctx, batch_slice);
             report_tensor_size("batch_slice_t", batch_slice_t);
             report_tensor_size("batch_slice_t", batch_slice_t);
             struct ggml_tensor * batch_sum = ggml_mul_mat(ctx, ones_col, batch_slice_t);
             struct ggml_tensor * batch_sum = ggml_mul_mat(ctx, ones_col, batch_slice_t);
             report_tensor_size("batch_sum", batch_sum);
             report_tensor_size("batch_sum", batch_sum);
             
             
-            // Reshape [1, S_v*S_v] -> [S_v, S_v]
             struct ggml_tensor * batch_result = ggml_reshape_2d(ctx, batch_sum, S_v, S_v);
             struct ggml_tensor * batch_result = ggml_reshape_2d(ctx, batch_sum, S_v, S_v);
             report_tensor_size("batch_result", batch_result);
             report_tensor_size("batch_result", batch_result);
             output_parts[b] = batch_result;
             output_parts[b] = batch_result;
         }
         }
         
         
-        // Concatenate results from all batches into [S_v * S_v, batch_size]
         struct ggml_tensor * output_concat = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v * S_v, batch_size);
         struct ggml_tensor * output_concat = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v * S_v, batch_size);
         for (int64_t b = 0; b < batch_size; b++) {
         for (int64_t b = 0; b < batch_size; b++) {
             struct ggml_tensor * batch_output = ggml_view_2d(ctx, output_concat, S_v * S_v, 1,
             struct ggml_tensor * batch_output = ggml_view_2d(ctx, output_concat, S_v * S_v, 1,
@@ -695,12 +480,10 @@ struct ggml_tensor * ggml_delta_net_op(
             batch_output = ggml_cpy(ctx, output_parts[b], batch_output);
             batch_output = ggml_cpy(ctx, output_parts[b], batch_output);
         }
         }
         
         
-        // Reshape concatenated result to [S_v, S_v] for this token (batch_size typically 1)
         struct ggml_tensor * output_t_reshaped = ggml_reshape_2d(ctx, output_concat, S_v, S_v);
         struct ggml_tensor * output_t_reshaped = ggml_reshape_2d(ctx, output_concat, S_v, S_v);
         struct ggml_tensor * output_t = ggml_cont(ctx, output_t_reshaped);
         struct ggml_tensor * output_t = ggml_cont(ctx, output_t_reshaped);
         report_tensor_size("output_t", output_t);
         report_tensor_size("output_t", output_t);
               
               
-        // Store output for this token
         struct ggml_tensor * output_slice = ggml_view_3d(ctx, output, S_v, S_v, batch_size,
         struct ggml_tensor * output_slice = ggml_view_3d(ctx, output, S_v, S_v, batch_size,
                                                         output->nb[1], output->nb[2], t * output->nb[2]);
                                                         output->nb[1], output->nb[2], t * output->nb[2]);
         report_tensor_size("output_slice", output_slice);
         report_tensor_size("output_slice", output_slice);
@@ -712,4 +495,3 @@ struct ggml_tensor * ggml_delta_net_op(
     report_tensor_size("result_final", result);
     report_tensor_size("result_final", result);
     return result;
     return result;
 }
 }
-// ggml_rwkv_wkv7