|
|
@@ -10964,7 +10964,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
};
|
|
|
|
|
|
// Allocate per-chunk arrays containing all sequences and heads
|
|
|
- float * temp_state = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
|
|
|
float * core_attn_out = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
float * attn_inter = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
float * v_new = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
@@ -10972,9 +10971,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
float * g_diff_exp = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
float * g_last = (float *) malloc(H_v * n_seqs * sizeof(float));
|
|
|
|
|
|
- // Initialize temp_state with zeros for all sequences and heads (state should be empty initially)
|
|
|
- memset(temp_state, 0, S_v * S_v * H_v * n_seqs * sizeof(float));
|
|
|
-
|
|
|
// Create temporary arrays for entire chunk
|
|
|
float * q_chunk_data = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
float * k_chunk_data = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
|
|
|
@@ -11177,14 +11173,14 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
for (int i = 0; i < S_v; i++) {
|
|
|
for (int j = 0; j < S_v; j++) {
|
|
|
- temp_state[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] =
|
|
|
+ new_state[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] =
|
|
|
state_data[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] * g_last[seq * H_v + head] +
|
|
|
kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j];
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- print_debug_info(temp_state, S_v * S_v * H_v * n_seqs, "temp_state", chunk);
|
|
|
+ print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "state_end_chunk", chunk);
|
|
|
|
|
|
// Free temporary memory
|
|
|
free(q_chunk_data);
|
|
|
@@ -11225,7 +11221,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
// }
|
|
|
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);
|
|
|
|
|
|
- free(temp_state);
|
|
|
free(core_attn_out);
|
|
|
free(attn_inter);
|
|
|
free(v_new);
|