浏览代码

More food for Mr. Chunky.

Piotr Wilkin 2 月之前
父节点
当前提交
5d0a237a1c
共有 1 个文件被更改,包括 55 次插入54 次删除
  1. 55 54
      ggml/src/ggml-cpu/ops.cpp

+ 55 - 54
ggml/src/ggml-cpu/ops.cpp

@@ -10769,10 +10769,11 @@ static void delta_compute_q_k_attn_chunk_f32(const float * q, const float * k, c
 // Helper function for matrix multiplication with state tensors for entire chunk
 static void delta_matmul_state_chunk_f32(const float * a, const float * state, float * dst,
                                         const int64_t rows_a, const int64_t cols_a, const int64_t cols_state,
-                                        const int64_t n_seqs, const int64_t H_v) {
+                                        const int64_t n_seqs, const int64_t H_v, int chunk, int num_chunks) {
     for (int64_t seq = 0; seq < n_seqs; seq++) {
         for (int64_t head = 0; head < H_v; head++) {
-            const float * a_ptr = a + seq * (rows_a * cols_a * H_v) + head * (rows_a * cols_a);
+            const float * a_ptr = chunk < 0 ? a + seq * (rows_a * cols_a * H_v) + head * (rows_a * cols_a) :
+                a + seq * (rows_a * cols_a * H_v * num_chunks) + (head * num_chunks + chunk) * (rows_a * cols_a);
             const float * state_ptr = state + seq * (cols_a * cols_state * H_v) + head * (cols_a * cols_state);
             float * dst_ptr = dst + seq * (rows_a * cols_state * H_v) + head * (rows_a * cols_state);
             delta_matmul_state_f32(a_ptr, state_ptr, dst_ptr, rows_a, cols_a, cols_state);
@@ -10817,10 +10818,10 @@ static void delta_update_recurrent_state_chunk_f32(const float * state, const fl
 
 // Helper function for element-wise tensor subtraction for entire chunk
 static void delta_tensor_subtract_chunk_f32(const float * a, const float * b, float * dst, const int64_t size,
-                                           const int64_t n_seqs, const int64_t H_v) {
+                                           const int64_t n_seqs, const int64_t H_v, int num_chunks, int chunk) {
     for (int64_t seq = 0; seq < n_seqs; seq++) {
         for (int64_t head = 0; head < H_v; head++) {
-            const float * a_ptr = a + seq * (size * H_v) + head * size;
+            const float * a_ptr = a + seq * (size * num_chunks * H_v) + (head * num_chunks + chunk) * size;
             const float * b_ptr = b + seq * (size * H_v) + head * size;
             float * dst_ptr = dst + seq * (size * H_v) + head * size;
             delta_tensor_subtract_f32(a_ptr, b_ptr, dst_ptr, size);
@@ -10889,8 +10890,8 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
 
     float * dst_data  = (float *) dst->data;
     // Following GLA pattern: output is first part, state is second part
-    float * output    = dst_data; // [S_v * H_v, n_tokens, 1, n_seqs] - only real sequence length, not padded
-    float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs);  // [S_v, S_v * H_v, 1, n_seqs]
+    float * output    = dst_data; // [S_v, H_v, n_tokens, n_seqs] - only real sequence length, not padded
+    float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs);  // [S_v, S_v, H_v, n_seqs]
 
     const int ith = params->ith;
     // const int nth = params->nth;  // nth is unused
@@ -10968,7 +10969,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
             } 
         }
     }
-    print_debug_info(k_cumdecay, chunk_size * S_v * H_v * n_seqs, "k_cumdecay", -1);
+    print_debug_info(k_cumdecay, chunk_size * S_v * H_v * num_chunks * n_seqs, "k_cumdecay", -1);
 
     // Process each chunk with all sequences and heads together
     for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
@@ -10984,27 +10985,27 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         };
 
         // Allocate per-chunk arrays containing all sequences and heads
-        float * core_attn_out = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
-        float * attn_inter    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
-        float * v_new         = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
-        float * v_prime       = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
-        float * g_diff_exp    = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
-        float * g_last        = (float *) malloc(H_v * n_seqs * sizeof(float));
+        float * pc_core_attn_out = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * pc_attn_inter    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * pc_v_new         = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * pc_v_prime       = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * pc_g_diff_exp    = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
+        float * pc_g_last        = (float *) malloc(H_v * n_seqs * sizeof(float));
 
         // Create temporary arrays for entire chunk
-        float * q_chunk_data    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
-        float * k_chunk_data    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
-        float * q_g_exp         = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
-        float * attn_v_new      = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * pc_q_chunk_data    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * pc_k_chunk_data    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * pc_q_g_exp         = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
+        float * pc_attn_v_new      = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
 
         // Fill temporary arrays with data from all sequences and heads
         for (int64_t seq = 0; seq < n_seqs; seq++) {
             for (int64_t head = 0; head < H_v; head++) {
-                float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
-                float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                float * q_ptr = pc_q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                float * k_ptr = pc_k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
                 float * g_ptr = g + (chunk_size * H_v * num_chunks) * seq + chunk_size * (head * num_chunks + chunk);
 
-                float * q_g_exp_ptr = q_g_exp + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                float * q_g_exp_ptr = pc_q_g_exp + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
 
                 // Fill q, k, decay_mask, and g data
                 for (int64_t i = 0; i < chunk_size; i++) {
@@ -11024,8 +11025,8 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
             }
         }
 
-        print_debug_info(q_chunk_data, chunk_size * S_v * H_v * n_seqs, "q_i_chunk", chunk);
-        print_debug_info(k_chunk_data, chunk_size * S_v * H_v * n_seqs, "k_i_chunk", chunk);
+        print_debug_info(pc_q_chunk_data, chunk_size * S_v * H_v * n_seqs, "q_i_chunk", chunk);
+        print_debug_info(pc_k_chunk_data, chunk_size * S_v * H_v * n_seqs, "k_i_chunk", chunk);
 
         // Step 4: Compute NEW attention matrix for this chunk: attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
         // Note: decay_mask[:, :, i] means we need to use the decay_mask for this specific chunk
@@ -11035,8 +11036,8 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         for (int64_t seq = 0; seq < n_seqs; seq++) {
             for (int64_t head = 0; head < H_v; head++) {
                 float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
-                const float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
-                const float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                const float * q_ptr = pc_q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                const float * k_ptr = pc_k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
 
                 float * k_trans = (float *) malloc(chunk_size * S_v * sizeof(float));
                 for (int i = 0; i < S_v; i++) {
@@ -11072,41 +11073,41 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
 
         // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
         // k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
-        delta_matmul_state_chunk_f32(k_cumdecay, new_state, v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
-        print_debug_info(v_prime, chunk_size * S_v * H_v * n_seqs, "v_prime_chunk", chunk);
+        delta_matmul_state_chunk_f32(k_cumdecay, new_state, pc_v_prime, chunk_size, S_v, S_v, n_seqs, H_v, chunk, num_chunks);
+        print_debug_info(pc_v_prime, chunk_size * S_v * H_v * n_seqs, "v_prime_chunk", chunk);
 
         // v_new = v_i - v_prime
-        delta_tensor_subtract_chunk_f32(value, v_prime, v_new, chunk_size * S_v, n_seqs, H_v);
-        print_debug_info(v_new, chunk_size * S_v * H_v * n_seqs, "v_new_chunk", chunk);
+        delta_tensor_subtract_chunk_f32(value, pc_v_prime, pc_v_new, chunk_size * S_v, n_seqs, H_v, num_chunks, chunk);
+        print_debug_info(pc_v_new, chunk_size * S_v * H_v * n_seqs, "v_new_chunk", chunk);
 
         // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-        delta_matmul_state_chunk_f32(q_g_exp, new_state, attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
-        print_debug_info(attn_inter, chunk_size * S_v * H_v * n_seqs, "attn_inter_chunk", chunk);
+        delta_matmul_state_chunk_f32(pc_q_g_exp, new_state, pc_attn_inter, chunk_size, S_v, S_v, n_seqs, H_v, -1, -1);
+        print_debug_info(pc_attn_inter, chunk_size * S_v * H_v * n_seqs, "attn_inter_chunk", chunk);
 
         // core_attn_out[:, :, i] = attn_inter + attn @ v_new
         // Use regular matrix multiplication for attn @ v_new
         for (int64_t seq = 0; seq < n_seqs; seq++) {
             for (int64_t head = 0; head < H_v; head++) {
                 const float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
-                const float * v_new_ptr = v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
-                float * attn_v_new_ptr = attn_v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                const float * v_new_ptr = pc_v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                float * attn_v_new_ptr = pc_attn_v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
                                 
                 // Compute attn @ v_new: [chunk_size, chunk_size] @ [chunk_size, S_v] -> [chunk_size, S_v]
                 delta_matmul_f32(attn_ptr, v_new_ptr, attn_v_new_ptr, chunk_size, S_v, chunk_size);
             }
         }
-        print_debug_info(attn_v_new, chunk_size * S_v * H_v * n_seqs, "attn_v_new_chunk", chunk);
-        delta_tensor_add_chunk_f32(attn_inter, attn_v_new, core_attn_out, chunk_size * S_v, n_seqs, H_v);
-        print_debug_info(core_attn_out, chunk_size * S_v * H_v * n_seqs, "core_attn_out_chunk", chunk);
+        print_debug_info(pc_attn_v_new, chunk_size * S_v * H_v * n_seqs, "attn_v_new_chunk", chunk);
+        delta_tensor_add_chunk_f32(pc_attn_inter, pc_attn_v_new, pc_core_attn_out, chunk_size * S_v, n_seqs, H_v);
+        print_debug_info(pc_core_attn_out, chunk_size * S_v * H_v * n_seqs, "core_attn_out_chunk", chunk);
 
         // Prepare g_last and g_diff_exp for state update
         for (int64_t seq = 0; seq < n_seqs; seq++) {
             for (int64_t head = 0; head < H_v; head++) {
-                float * g_ptr = g + seq * (chunk_size * H_v) + head * chunk_size;
+                float * g_ptr = g + seq * (chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * chunk_size;
                 float g_last_val         = g_ptr[chunk_size - 1];
-                g_last[seq * H_v + head] = expf(g_last_val);
+                pc_g_last[seq * H_v + head] = expf(g_last_val);
 
-                float * g_diff_exp_ptr = g_diff_exp + seq * (chunk_size * H_v) + head * chunk_size;
+                float * g_diff_exp_ptr = pc_g_diff_exp + seq * (chunk_size * H_v) + head * chunk_size;
                 for (int64_t i = 0; i < chunk_size; i++) {
                     float diff        = g_last_val - g_ptr[i];
                     g_diff_exp_ptr[i] = expf(diff);
@@ -11114,8 +11115,8 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
             }
         }
 
-        print_debug_info(g_last, H_v * n_seqs, "g_last_chunk", chunk);
-        print_debug_info(g_diff_exp, chunk_size * H_v * n_seqs, "g_diff_exp", chunk);
+        print_debug_info(pc_g_last, H_v * n_seqs, "g_last_chunk", chunk);
+        print_debug_info(pc_g_diff_exp, chunk_size * H_v * n_seqs, "g_diff_exp", chunk);
 
         float * k_g_diffexp = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
         for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -11123,7 +11124,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
                 for (int64_t i = 0; i < chunk_size; i++) {
                     for (int64_t j = 0; j < S_v; j++) {
                         k_g_diffexp[seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v) + i * S_v + j] = 
-                            k_chunk(seq, head, i, j) * g_diff_exp[seq * (chunk_size * H_v) + head * chunk_size + i];
+                            k_chunk(seq, head, i, j) * pc_g_diff_exp[seq * (chunk_size * H_v) + head * chunk_size + i];
                     }
                 }
             }
@@ -11165,7 +11166,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         for (int64_t seq = 0; seq < n_seqs; seq++) {
             for (int64_t head = 0; head < H_v; head++) {
                 delta_matmul_f32(k_g_diffexp_T + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head, 
-                    v_new + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
+                    pc_v_new + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
                     kgd_mul_vnew + (S_v * S_v * H_v) * seq + (S_v * S_v) * head,
                     S_v, S_v, chunk_size);
             }
@@ -11194,7 +11195,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
                 for (int i = 0; i < S_v; i++) {
                     for (int j = 0; j < S_v; j++) {
                         new_state[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] = 
-                            state_data[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] * g_last[seq * H_v + head] + 
+                            state_data[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] * pc_g_last[seq * H_v + head] + 
                             kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j];
                     }
                 }
@@ -11203,10 +11204,10 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "state_end_chunk", chunk);
 
         // Free temporary memory
-        free(q_chunk_data);
-        free(k_chunk_data);
-        free(q_g_exp);
-        free(attn_v_new);
+        free(pc_q_chunk_data);
+        free(pc_k_chunk_data);
+        free(pc_q_g_exp);
+        free(pc_attn_v_new);
         free(kgd_mul_vnew);
         free(k_g_diffexp_T);
         free(k_g_diffexp);
@@ -11214,7 +11215,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         // Store output for this chunk (all sequences and heads)
         for (int64_t seq = 0; seq < n_seqs; seq++) {
             for (int64_t head = 0; head < H_v; head++) {
-                float * core_attn_out_ptr = core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
+                float * core_attn_out_ptr = pc_core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
 
                 // Compute number of tokens for this chunk (chunk_size unless this is the last chunk)
                 int64_t n_tokens_chunk = chunk == num_chunks - 1 ? n_tokens % chunk_size : chunk_size;
@@ -11244,12 +11245,12 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         // }
         print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);
 
-        free(core_attn_out);
-        free(attn_inter);
-        free(v_new);
-        free(v_prime);
-        free(g_diff_exp);
-        free(g_last);
+        free(pc_core_attn_out);
+        free(pc_attn_inter);
+        free(pc_v_new);
+        free(pc_v_prime);
+        free(pc_g_diff_exp);
+        free(pc_g_last);
     }
 
     GGML_ASSERT(output + S_v * H_v * n_tokens * n_seqs == new_state);