Răsfoiți Sursa

Fix memory corruption

Piotr Wilkin 3 luni în urmă
părinte
comite
9de7244c26
1 a modificat fișierele cu 23 adăugiri și 18 ștergeri
  1. 23 18
      ggml/src/ggml-cpu/ops.cpp

+ 23 - 18
ggml/src/ggml-cpu/ops.cpp

@@ -10865,7 +10865,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     float * dst_data  = (float *) dst->data;
     // Following GLA pattern: output is first part, state is second part
     float * output    = dst_data; // [S_v * H_v, n_tokens, 1, n_seqs] - only real sequence length, not padded
-    float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs);  // [S_v * H_v, S_v * n_seqs, 1, 1]
+    float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs);  // [S_v, S_v * H_v, 1, n_seqs]
 
     const int ith = params->ith;
     // const int nth = params->nth;  // nth is unused
@@ -10884,6 +10884,20 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
 
     float * state_data = (float *) src4->data;
 
+    // Init new state with initial state (will probably be zeroes)
+    for (int64_t seq = 0; seq < n_seqs; seq++) {
+        for (int64_t head = 0; head < H_v; head++) {
+            for (int64_t i = 0; i < S_v; i++) {
+                for (int64_t j = 0; j < S_v; j++) {
+                    new_state[seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j] = 
+                        state_data[seq * src4->nb[3] / sizeof(float) + (head * S_v + i) * src4->nb[1] / sizeof(float) + j * src4->nb[0] / sizeof(float)];
+                }
+            }
+        }
+    }
+    print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "init_state", -1);
+
+
     GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(ggml_is_contiguous(src1));
     GGML_ASSERT(ggml_is_contiguous(src2));
@@ -10896,12 +10910,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
 
     // int64_t total_params = n_seqs * H_v * num_chunks;
     // int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
-
-    // Create helper lambda for state tensor access
-    const auto state_ptr = [state_data, src4] (int64_t seq, int64_t head, int64_t i, int64_t j) {
-        return state_data + (j * src4->nb[0] / sizeof(float)) + (i * src4->nb[1] / sizeof(float)) +
-            (head * src4->nb[2] / sizeof(float)) + (seq * src4->nb[3] / sizeof(float));
-    };
     
     float * attn          = (float *) malloc(chunk_size * chunk_size * H_v * n_seqs * sizeof(float));
     float * value         = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
@@ -11048,7 +11056,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
 
         // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
         // k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
-        delta_matmul_state_chunk_f32(k_cumdecay, state_data, v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
+        delta_matmul_state_chunk_f32(k_cumdecay, new_state, v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
         print_debug_info(v_prime, chunk_size * S_v * H_v * n_seqs, "v_prime_chunk", chunk);
 
         // v_new = v_i - v_prime
@@ -11056,7 +11064,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         print_debug_info(v_new, chunk_size * S_v * H_v * n_seqs, "v_new_chunk", chunk);
 
         // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-        delta_matmul_state_chunk_f32(q_g_exp, state_data, attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
+        delta_matmul_state_chunk_f32(q_g_exp, new_state, attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
         print_debug_info(attn_inter, chunk_size * S_v * H_v * n_seqs, "attn_inter_chunk", chunk);
 
         // core_attn_out[:, :, i] = attn_inter + attn @ v_new
@@ -11203,19 +11211,16 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
             }
         }
         print_debug_info(output, S_v * H_v * n_tokens * n_seqs, "output", chunk);
-
-        // Update state tensor (all sequences and heads)
+        GGML_LOG_INFO("\nFull output tensor: \n\n");
         for (int64_t seq = 0; seq < n_seqs; seq++) {
             for (int64_t head = 0; head < H_v; head++) {
-                float * temp_state_ptr = temp_state + seq * (S_v * S_v * H_v) + head * (S_v * S_v);
-
-                for (int64_t i = 0; i < S_v; i++) {
-                    for (int64_t j = 0; j < S_v; j++) {
-                        int64_t state_idx             = seq * S_v * S_v * H_v + head * S_v * S_v + i * S_v + j;
-                        new_state[state_idx]          = temp_state_ptr[i * S_v + j];
-                        *(state_ptr(seq, head, i, j)) = temp_state_ptr[i * S_v + j];
+                GGML_LOG_INFO("\n[ ");
+                for (int64_t i = 0; i < n_tokens; i++) {
+                    for (int64_t d = 0; d < S_v; d++) {
+                        GGML_LOG_INFO("%.4f  ", output[seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d]);
                     }
                 }
+                GGML_LOG_INFO(" ]");
             }
         }
         print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);