Piotr Wilkin 2 месяцев назад
Родитель
Сommit
6798b69bcc
2 измененных файлов с 30 добавлено и 23 удалено
  1. 26 23
      ggml/src/ggml-cpu/ops.cpp
  2. 4 0
      src/models/llm_build_qwen3next.cpp

+ 26 - 23
ggml/src/ggml-cpu/ops.cpp

@@ -10854,6 +10854,7 @@ static void print_debug_info(float * data, size_t size, const char * name, int64
 #endif
 }
 
+// chunked version of delta_net (for prompt processing)
 void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml_tensor * dst) {
     const struct ggml_tensor * src0 = dst->src[0];  // q (already normalized and scaled)
     const struct ggml_tensor * src1 = dst->src[1];  // k (already normalized)
@@ -10870,13 +10871,20 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     const int64_t original_n_tokens = (int64_t) dst->op_params[3];  // Get original sequence length
     const int64_t n_tokens          = original_n_tokens;            // Use the original sequence length
     const int64_t n_seqs            = src0->ne[3];                  // q tensor has n_seqs in dim 3
-
+    // Calculate chunk size
+    const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
+    const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
+    const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
 
     // Add assertions to verify tensor dimensions
-    GGML_ASSERT(src0->ne[3] == n_seqs);  // q tensor
-    GGML_ASSERT(src1->ne[3] == n_seqs);  // k tensor
-    GGML_ASSERT(src2->ne[3] == n_seqs);  // v tensor
-    GGML_ASSERT(src3->ne[3] == n_seqs);  // g tensor
+    GGML_ASSERT(src0->ne[3] == n_seqs && src0->ne[2] == num_chunks * H_v);  // q tensor
+    GGML_ASSERT(src1->ne[3] == n_seqs && src1->ne[2] == num_chunks * H_v);  // k tensor
+    GGML_ASSERT(src2->ne[3] == n_seqs && src2->ne[2] == num_chunks * H_v);  // v tensor
+    GGML_ASSERT(src3->ne[3] == n_seqs && src3->ne[2] == num_chunks * H_v);  // g tensor
+    GGML_ASSERT(src5->ne[3] == n_seqs && src5->ne[2] == num_chunks * H_v);  // decay mask tensor
+    GGML_ASSERT(src6->ne[3] == n_seqs && src6->ne[2] == num_chunks * H_v);  // v_beta tensor
+    GGML_ASSERT(src7->ne[3] == n_seqs && src7->ne[2] == num_chunks * H_v);  // k_beta tensor
+
     GGML_ASSERT(src4->ne[3] == n_seqs);  // state tensor
 
     float * dst_data  = (float *) dst->data;
@@ -10894,11 +10902,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         return;
     }
 
-    // Calculate chunk size
-    const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
-    const int64_t pad_size = (chunk_size - n_tokens % chunk_size) % chunk_size;
-    const int64_t num_chunks = (n_tokens + pad_size) / chunk_size;
-
     float * state_data = (float *) src4->data;
 
     // Init new state with initial state (will probably be zeroes)
@@ -10970,14 +10973,14 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
     // Process each chunk with all sequences and heads together
     for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
         // Create lambdas for tensor access similar to recurrent function
-        const auto q_chunk = [chunk, src0](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
-            return ggml_get_f32_nd(src0, i, chunk * chunk_size + token_idx, head, seq);
+        const auto q_chunk = [chunk, src0, num_chunks](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
+            return ggml_get_f32_nd(src0, i, token_idx, head * num_chunks + chunk, seq);
         };
-        const auto k_chunk = [chunk, src1](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
-            return ggml_get_f32_nd(src1, i, chunk * chunk_size + token_idx, head, seq);
+        const auto k_chunk = [chunk, src1, num_chunks](int64_t seq, int64_t head, int64_t token_idx, int64_t i) {
+            return ggml_get_f32_nd(src1, i, token_idx, head * num_chunks + chunk, seq);
         };
-        const auto g_chunk = [chunk, src3](int64_t seq, int64_t head, int64_t token_idx) {
-            return ggml_get_f32_nd(src3, chunk * chunk_size + token_idx, 0, head, seq);
+        const auto g_chunk = [chunk, src3, num_chunks](int64_t seq, int64_t head, int64_t token_idx) {
+            return ggml_get_f32_nd(src3, token_idx, 0, head * num_chunks + chunk, seq);
         };
 
         // Allocate per-chunk arrays containing all sequences and heads
@@ -10999,7 +11002,8 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
             for (int64_t head = 0; head < H_v; head++) {
                 float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
                 float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
-                float * g_ptr       = g + seq * (chunk_size * H_v) + head * chunk_size;
+                float * g_ptr = g + (chunk_size * H_v * num_chunks) * seq + chunk_size * (head * num_chunks + chunk);
+
                 float * q_g_exp_ptr = q_g_exp + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
 
                 // Fill q, k, decay_mask, and g data
@@ -11030,7 +11034,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         // Now compute attention for all sequences and heads together
         for (int64_t seq = 0; seq < n_seqs; seq++) {
             for (int64_t head = 0; head < H_v; head++) {
-                float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+                float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
                 const float * q_ptr = q_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
                 const float * k_ptr = k_chunk_data + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
 
@@ -11046,13 +11050,12 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         }
         print_debug_info(attn, chunk_size * chunk_size * H_v * n_seqs, "q_k_trans", chunk);
 
-
         for (int64_t seq = 0; seq < n_seqs; seq++) {
             for (int64_t head = 0; head < H_v; head++) {
                 for (int64_t i = 0; i < chunk_size; i++) {
                     for (int64_t j = 0; j < chunk_size; j++) {
-                        float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
-                        const float * decay_mask_ptr = (float *) src5->data + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+                        float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
+                        const float * decay_mask_ptr = (float *) src5->data + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
                         float attn_val = attn_ptr[i * chunk_size + j] * decay_mask_ptr[i * chunk_size + j];
                         // Apply simple causal attention mask (upper triangular with diagonal=1)
                         // This corresponds to: torch.triu(torch.ones(chunk_size, chunk_size), diagonal=1)
@@ -11065,7 +11068,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
             }
         }
         
-        print_debug_info(attn, chunk_size * chunk_size * H_v * n_seqs, "attn_step4_new_chunk", chunk);
+        print_debug_info(attn, chunk_size * chunk_size * H_v * num_chunks * n_seqs, "attn_step4_new_chunk", chunk);
 
         // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
         // k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
@@ -11084,7 +11087,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
         // Use regular matrix multiplication for attn @ v_new
         for (int64_t seq = 0; seq < n_seqs; seq++) {
             for (int64_t head = 0; head < H_v; head++) {
-                const float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
+                const float * attn_ptr = attn + seq * (chunk_size * chunk_size * num_chunks * H_v) + (head * num_chunks + chunk) * (chunk_size * chunk_size);
                 const float * v_new_ptr = v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
                 float * attn_v_new_ptr = attn_v_new + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
                                 

+ 4 - 0
src/models/llm_build_qwen3next.cpp

@@ -320,6 +320,10 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
     cb(k_beta, "k_beta", il);
     k = ggml_reshape_4d(ctx, k, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
     cb(k_beta, "k_reshape", il);
+    q = ggml_reshape_4d(ctx, q, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
+    cb(q, "q_reshape", il);
+    v = ggml_reshape_4d(ctx, q, S_v, GGML_DELTA_NET_CHUNK, H_v * num_chunks, n_seqs);
+    cb(v, "v_reshape", il);
     g = ggml_reshape_4d(ctx, g, GGML_DELTA_NET_CHUNK, 1, H_k * num_chunks, n_seqs);
     cb(g, "g_reshape", il);
     struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);