Piotr Wilkin пре 3 месеци
родитељ
комит
477c1616ad
1 измењених фајлова са 11 додато и 7 уклоњено
  1. 11 7
      ggml/src/ggml-cpu/ops.cpp

+ 11 - 7
ggml/src/ggml-cpu/ops.cpp

@@ -10694,15 +10694,12 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     float * new_state = dst_data + (S_v * H_v * n_tokens);  // [S_v * H_v, S_v * n_seqs, 1, 1]
 
     const int ith = params->ith;
-    // const int nth = params->nth;  // nth is unused
-
-    // TODO: parallelize across heads and sequences
-    if (ith != 0) {
-        return;
-    }
+    const int nth = params->nth;  // nth is unused
 
     // Clear output and new state section
-    memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
+    if (ith == 0) {
+        memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
+    }
 
     // Calculate chunk size
     const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
@@ -10730,9 +10727,16 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     GGML_ASSERT(ggml_is_contiguous(src7));
     GGML_ASSERT(ggml_is_contiguous(src8));
 
+    int64_t total_params = n_seqs * H_v * num_chunks;
+    int64_t per_thread = total_params / nth;
+
     for (int64_t seq = 0; seq < n_seqs; seq++) {
         for (int64_t head = 0; head < H_v; head++) {
             for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
+                int64_t tidx = seq * (H_v * num_chunks) + head * num_chunks + chunk;
+                if (tidx < ith * per_thread || tidx >= (ith + 1) * per_thread) {
+                    continue; // not our thread;
+                }
                 float * attn_data_for_chs = attn_data + (src8->nb[3] / sizeof(float)) * seq + (src8->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
                 float * value_chunk = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
                 float * k_cumdecay = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));