소스 검색

e steps forward, pi steps back

Piotr Wilkin 3 달 전
부모
커밋
875de2bcc2
1개의 변경된 파일2개의 추가작업 그리고 7개의 파일을 삭제
  1. 2 7
      ggml/src/ggml-cpu/ops.cpp

+ 2 - 7
ggml/src/ggml-cpu/ops.cpp

@@ -10964,7 +10964,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         };
 
         // Allocate per-chunk arrays containing all sequences and heads
-        float * temp_state    = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
         float * core_attn_out = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
         float * attn_inter    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
         float * v_new         = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
@@ -10972,9 +10971,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         float * g_diff_exp    = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
         float * g_last        = (float *) malloc(H_v * n_seqs * sizeof(float));
 
-        // Initialize temp_state with zeros for all sequences and heads (state should be empty initially)
-        memset(temp_state, 0, S_v * S_v * H_v * n_seqs * sizeof(float));
-    
         // Create temporary arrays for entire chunk
         float * q_chunk_data    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
         float * k_chunk_data    = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
@@ -11177,14 +11173,14 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
             for (int64_t head = 0; head < H_v; head++) {
                 for (int i = 0; i < S_v; i++) {
                     for (int j = 0; j < S_v; j++) {
-                        temp_state[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] = 
+                        new_state[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] = 
                             state_data[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] * g_last[seq * H_v + head] + 
                             kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j];
                     }
                 }
             }
         }
-        print_debug_info(temp_state, S_v * S_v * H_v * n_seqs, "temp_state", chunk);
+        print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "state_end_chunk", chunk);
 
         // Free temporary memory
         free(q_chunk_data);
@@ -11225,7 +11221,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         // }
         print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);
 
-        free(temp_state);
         free(core_attn_out);
         free(attn_inter);
         free(v_new);