Ver Fonte

Valgrind debugging session / multi-chunk support

Piotr Wilkin há 2 meses atrás
pai
commit
16b3f9c300
1 ficheiros alterados com 80 adições e 60 exclusões
  1. 80 60
      ggml/src/ggml-cpu/ops.cpp

+ 80 - 60
ggml/src/ggml-cpu/ops.cpp

@@ -10530,62 +10530,73 @@ static void delta_compute_k_beta_key_t_f32(const float * k_beta, const float * k
 }
 }
 
 
 // Helper function to apply triangular updates to entire chunk (all sequences and heads)
 // Helper function to apply triangular updates to entire chunk (all sequences and heads)
-static void delta_apply_triangular_updates_chunk_f32(float * attn, const int64_t chunk_size,
-                                                    const int64_t n_seqs, const int64_t H_v) {
+static void delta_apply_triangular_updates_chunk_f32(float *       attn,
+                                                     const int64_t chunk_size,
+                                                     const int64_t n_seqs,
+                                                     const int64_t H_v,
+                                                     int           num_chunks) {
     for (int64_t seq = 0; seq < n_seqs; seq++) {
     for (int64_t seq = 0; seq < n_seqs; seq++) {
-        for (int64_t head = 0; head < H_v; head++) {
-            float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
-            
-            // Apply triangular updates following the Python reference exactly:
-            // for i in range(1, chunk_size):
-            //     row = attn[..., i, :i].clone()
-            //     sub = attn[..., :i, :i].clone()
-            //     attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
-            for (int64_t i = 1; i < chunk_size; i++) {
-                // Create temporary storage for row and sub to avoid modifying during computation
-                float * row = (float *) malloc(i * sizeof(float));
-                float * sub = (float *) malloc(i * i * sizeof(float));
-                
-                // Copy row = attn[..., i, :i]
-                for (int64_t j = 0; j < i; j++) {
-                    row[j] = attn_ptr[i * chunk_size + j];
-                }
-                
-                // Copy sub = attn[..., :i, :i]
-                for (int64_t k = 0; k < i; k++) {
+        for (int i = 0; i < num_chunks; i++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + (head * num_chunks + i) * (chunk_size * chunk_size);
+
+                // Apply triangular updates following the Python reference exactly:
+                // for i in range(1, chunk_size):
+                //     row = attn[..., i, :i].clone()
+                //     sub = attn[..., :i, :i].clone()
+                //     attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+                for (int64_t i = 1; i < chunk_size; i++) {
+                    // Create temporary storage for row and sub to avoid modifying during computation
+                    float * row = (float *) malloc(i * sizeof(float));
+                    float * sub = (float *) malloc(i * i * sizeof(float));
+
+                    // Copy row = attn[..., i, :i]
                     for (int64_t j = 0; j < i; j++) {
                     for (int64_t j = 0; j < i; j++) {
-                        sub[k * i + j] = attn_ptr[k * chunk_size + j];
+                        row[j] = attn_ptr[i * chunk_size + j];
                     }
                     }
-                }
-                
-                // Compute updates for each j in :i
-                for (int64_t j = 0; j < i; j++) {
-                    // Compute (row.unsqueeze(-1) * sub).sum(-2)
-                    float sum_val = 0.0f;
+
+                    // Copy sub = attn[..., :i, :i]
                     for (int64_t k = 0; k < i; k++) {
                     for (int64_t k = 0; k < i; k++) {
-                        sum_val += row[k] * sub[k * i + j];
+                        for (int64_t j = 0; j < i; j++) {
+                            sub[k * i + j] = attn_ptr[k * chunk_size + j];
+                        }
+                    }
+
+                    // Compute updates for each j in :i
+                    for (int64_t j = 0; j < i; j++) {
+                        // Compute (row.unsqueeze(-1) * sub).sum(-2)
+                        float sum_val = 0.0f;
+                        for (int64_t k = 0; k < i; k++) {
+                            sum_val += row[k] * sub[k * i + j];
+                        }
+
+                        // Update: attn[..., i, j] = row[j] + sum_val
+                        attn_ptr[i * chunk_size + j] = row[j] + sum_val;
                     }
                     }
-                                       
-                    // Update: attn[..., i, j] = row[j] + sum_val
-                    attn_ptr[i * chunk_size + j] = row[j] + sum_val;
+
+                    free(row);
+                    free(sub);
                 }
                 }
-                
-                free(row);
-                free(sub);
             }
             }
         }
         }
     }
     }
 }
 }
 
 
 // Helper function to add identity matrix to entire chunk (all sequences and heads)
 // Helper function to add identity matrix to entire chunk (all sequences and heads)
-static void delta_add_identity_matrix_chunk_f32(float * matrix, const int64_t chunk_size,
-                                                const int64_t n_seqs, const int64_t H_v) {
+static void delta_add_identity_matrix_chunk_f32(float *       matrix,
+                                                const int64_t chunk_size,
+                                                const int64_t n_seqs,
+                                                const int64_t H_v,
+                                                int           num_chunks) {
     for (int64_t seq = 0; seq < n_seqs; seq++) {
     for (int64_t seq = 0; seq < n_seqs; seq++) {
-        for (int64_t head = 0; head < H_v; head++) {
-            float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
-            // Add identity matrix directly
-            for (int64_t i = 0; i < chunk_size; i++) {
-                matrix_ptr[i * chunk_size + i] += 1.0f;
+        for (int i = 0; i < num_chunks; i++) {
+            for (int64_t head = 0; head < H_v; head++) {
+                float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) +
+                                     (head * num_chunks + i) * (chunk_size * chunk_size);
+                // Add identity matrix directly
+                for (int64_t i = 0; i < chunk_size; i++) {
+                    matrix_ptr[i * chunk_size + i] += 1.0f;
+                }
             }
             }
         }
         }
     }
     }
@@ -10617,15 +10628,19 @@ static void delta_compute_value_f32(const float * attn,
                                     const int64_t chunk_size,
                                     const int64_t chunk_size,
                                     const int64_t v_head_dim,
                                     const int64_t v_head_dim,
                                     const int64_t n_heads,
                                     const int64_t n_heads,
-                                    const int64_t n_seqs) {
+                                    const int64_t n_seqs,
+                                    int           num_chunks) {
     for (int64_t seq = 0; seq < n_seqs; seq++) {
     for (int64_t seq = 0; seq < n_seqs; seq++) {
-        for (int64_t head = 0; head < n_heads; head++) {
-            delta_matmul_f32(
-                attn + (chunk_size * chunk_size * n_heads) * seq + (chunk_size * chunk_size) * head, 
-                v_beta + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head, 
-                value + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head, 
-                chunk_size, v_head_dim, chunk_size);
+        for (int i = 0; i < num_chunks; i++) {
+            for (int64_t head = 0; head < n_heads; head++) {
+                delta_matmul_f32(
+                    attn + (chunk_size * chunk_size * n_heads) * seq + (chunk_size * chunk_size) * (head * num_chunks + i), 
+                    v_beta + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head, 
+                    value + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head, 
+                    chunk_size, v_head_dim, chunk_size);
+            }
         }
         }
+        
     }
     }
 }
 }
 
 
@@ -10913,11 +10928,11 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     // int64_t total_params = n_seqs * H_v * num_chunks;
     // int64_t total_params = n_seqs * H_v * num_chunks;
     // int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
     // int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
     
     
-    float * attn          = (float *) malloc(chunk_size * chunk_size * H_v * n_seqs * sizeof(float));
-    float * value         = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
-    float * k_cumdecay    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+    float * attn          = (float *) malloc(chunk_size * chunk_size * H_v * num_chunks * n_seqs * sizeof(float));
+    float * value         = (float *) malloc(chunk_size * S_v * H_v * num_chunks * n_seqs * sizeof(float));
+    float * k_cumdecay    = (float *) malloc(chunk_size * S_v * H_v * num_chunks * n_seqs * sizeof(float));
     bool *  mask          = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
     bool *  mask          = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
-    float * g =             (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
+    float * g =             (float *) malloc(chunk_size * H_v * num_chunks * n_seqs * sizeof(float));
 
 
     // Create upper triangular mask for causal attention (exclude diagonal)
     // Create upper triangular mask for causal attention (exclude diagonal)
     for (int64_t i = 0; i < chunk_size; i++) {
     for (int64_t i = 0; i < chunk_size; i++) {
@@ -10934,18 +10949,20 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     // This corresponds to the reference implementation:
     // This corresponds to the reference implementation:
     // for i in range(1, chunk_size): attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
     // for i in range(1, chunk_size): attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
     // attn = attn + torch.eye(chunk_size)
     // attn = attn + torch.eye(chunk_size)
-    delta_apply_triangular_updates_chunk_f32(attn, chunk_size, n_seqs, H_v);
-    delta_add_identity_matrix_chunk_f32(attn, chunk_size, n_seqs, H_v);
+    delta_apply_triangular_updates_chunk_f32(attn, chunk_size, n_seqs, H_v, num_chunks);
+    delta_add_identity_matrix_chunk_f32(attn, chunk_size, n_seqs, H_v, num_chunks);
 
 
     // Compute value = attn @ v_beta
     // Compute value = attn @ v_beta
-    delta_compute_value_f32(attn, (const float *) src6->data, value, chunk_size, S_v, H_v, n_seqs);
+    delta_compute_value_f32(attn, (const float *) src6->data, value, chunk_size, S_v, H_v, n_seqs, num_chunks);
     for (int64_t seq = 0; seq < n_seqs; seq++) {
     for (int64_t seq = 0; seq < n_seqs; seq++) {
-        for (int64_t head = 0; head < H_v; head++) {
+        for (int i = 0; i < num_chunks; i++) {
+            for (int64_t head = 0; head < H_v; head++) {
                 delta_compute_k_cumdecay_f32(attn + (chunk_size * chunk_size * H_v) * seq + (chunk_size * chunk_size) * head, 
                 delta_compute_k_cumdecay_f32(attn + (chunk_size * chunk_size * H_v) * seq + (chunk_size * chunk_size) * head, 
                     (float *) src7->data + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
                     (float *) src7->data + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
                     g + (chunk_size * H_v) * seq + chunk_size * head,
                     g + (chunk_size * H_v) * seq + chunk_size * head,
                     k_cumdecay + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
                     k_cumdecay + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
                     chunk_size, S_v);
                     chunk_size, S_v);
+            } 
         }
         }
     }
     }
     print_debug_info(k_cumdecay, chunk_size * S_v * H_v * n_seqs, "k_cumdecay", -1);
     print_debug_info(k_cumdecay, chunk_size * S_v * H_v * n_seqs, "k_cumdecay", -1);
@@ -10996,7 +11013,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
 
 
                 // Compute q_g_exp = q * g.exp()
                 // Compute q_g_exp = q * g.exp()
                 for (int64_t i = 0; i < chunk_size; i++) {
                 for (int64_t i = 0; i < chunk_size; i++) {
-                    for (int64_t d = 0; d < S_v; d++) {
+                    for (int64_t d = 0; d <  S_v; d++) {
                         q_g_exp_ptr[i * S_v + d] = q_ptr[i * S_v + d] * expf(g_ptr[i]);
                         q_g_exp_ptr[i * S_v + d] = q_ptr[i * S_v + d] * expf(g_ptr[i]);
                     }
                     }
                 }
                 }
@@ -11196,8 +11213,11 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
             for (int64_t head = 0; head < H_v; head++) {
             for (int64_t head = 0; head < H_v; head++) {
                 float * core_attn_out_ptr = core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
                 float * core_attn_out_ptr = core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
 
 
+                // Compute number of tokens for this chunk (chunk_size unless this is the last chunk)
+                int64_t n_tokens_chunk = chunk == num_chunks - 1 ? n_tokens % chunk_size : chunk_size;
+                
                 // Store output for this chunk
                 // Store output for this chunk
-                for (int64_t i = 0; i < n_tokens; i++) {
+                for (int64_t i = 0; i < n_tokens_chunk; i++) {
                     for (int64_t d = 0; d < S_v; d++) {
                     for (int64_t d = 0; d < S_v; d++) {
                         int64_t output_idx =
                         int64_t output_idx =
                             seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d;
                             seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d;