|
|
@@ -10694,15 +10694,12 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
float * new_state = dst_data + (S_v * H_v * n_tokens); // [S_v * H_v, S_v * n_seqs, 1, 1]
|
|
|
|
|
|
const int ith = params->ith;
|
|
|
- // const int nth = params->nth; // nth is unused
|
|
|
-
|
|
|
- // TODO: parallelize across heads and sequences
|
|
|
- if (ith != 0) {
|
|
|
- return;
|
|
|
- }
|
|
|
+ const int nth = params->nth; // nth is unused
|
|
|
|
|
|
// Clear output and new state section
|
|
|
- memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
|
|
|
+ if (ith == 0) {
|
|
|
+ memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
|
|
|
+ }
|
|
|
|
|
|
// Calculate chunk size
|
|
|
const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
|
|
|
@@ -10730,9 +10727,16 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
|
|
|
GGML_ASSERT(ggml_is_contiguous(src7));
|
|
|
GGML_ASSERT(ggml_is_contiguous(src8));
|
|
|
|
|
|
+ int64_t total_params = n_seqs * H_v * num_chunks;
|
|
|
+ int64_t per_thread = total_params / nth;
|
|
|
+
|
|
|
for (int64_t seq = 0; seq < n_seqs; seq++) {
|
|
|
for (int64_t head = 0; head < H_v; head++) {
|
|
|
for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
|
|
|
+ int64_t tidx = seq * (H_v * num_chunks) + head * num_chunks + chunk;
|
|
|
+ if (tidx < ith * per_thread || tidx >= (ith + 1) * per_thread) {
|
|
|
+ continue; // not our thread;
|
|
|
+ }
|
|
|
float * attn_data_for_chs = attn_data + (src8->nb[3] / sizeof(float)) * seq + (src8->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
|
|
|
float * value_chunk = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
|
|
|
float * k_cumdecay = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
|