Piotr Wilkin 4 месяцев назад
Родитель
Сommit
e0c5dff2a7
4 измененных файлов с 223 добавлено и 25 удалено
  1. 1 1
      examples/model-conversion/scripts/causal/run-org-model.py
  2. 11 1
      ggml/include/ggml.h
  3. 210 22
      ggml/src/ggml.c
  4. 1 1
      src/llama-model.cpp

+ 1 - 1
examples/model-conversion/scripts/causal/run-org-model.py

@@ -193,7 +193,7 @@ print(f"Input text: {repr(prompt)}")
 print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
 
 with torch.no_grad():
-    outputs = model(input_ids)
+    outputs = model(input_ids.to("cuda"))
     logits = outputs.logits
 
     # Extract logits for the last token (next token prediction)

+ 11 - 1
ggml/include/ggml.h

@@ -2300,10 +2300,20 @@ extern "C" {
             struct ggml_tensor  * conv_bias,
             struct ggml_tensor  * beta,
             struct ggml_tensor  * state,
-            int                   chunk_size,
             bool                  use_qk_l2norm,
             float                 scale);
 
+struct ggml_tensor * ggml_delta_net_op(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * beta,
+        struct ggml_tensor  * state,
+        bool                  use_qk_l2norm,
+        float                 scale);
+
     GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
             struct ggml_context * ctx,
             struct ggml_tensor  * r,

+ 210 - 22
ggml/src/ggml.c

@@ -5431,9 +5431,9 @@ struct ggml_tensor * ggml_delta_net(
         struct ggml_tensor  * conv_bias,
         struct ggml_tensor  * beta,
         struct ggml_tensor  * state,
-        int                   chunk_size,
         bool                  use_qk_l2norm,
         float                 scale) {
+    
     GGML_ASSERT(ggml_is_contiguous(k));
     GGML_ASSERT(ggml_is_contiguous(v));
     GGML_ASSERT(ggml_is_contiguous(q));
@@ -5474,10 +5474,12 @@ struct ggml_tensor * ggml_delta_net(
     
     // Apply sigmoid to beta for gating
     struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
+    
+    // Concatenate q, k, v for convolution processing
     struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 1);
     mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
 
-    u_int32_t dim = (S_v * H_v) + 2 * (H_k * S_k);
+    uint32_t dim = (S_v * H_v) + 2 * (H_k * S_k);
 
     mixed_qkv = ggml_reshape_3d(ctx, mixed_qkv, 1, dim, n_tokens);
     struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, 3, 0, 0, 0);
@@ -5566,33 +5568,219 @@ struct ggml_tensor * ggml_delta_net(
         q_broadcast = ggml_reshape_4d(ctx, q_broadcast, S_k, H_v, n_tokens, 1);
         k_broadcast = ggml_reshape_4d(ctx, k_broadcast, S_k, H_v, n_tokens, 1);
     }
+
+    struct ggml_tensor * v_reshape = ggml_reshape_4d(ctx, v_conv, S_v, H_v, n_tokens, 1);
+    struct ggml_tensor * v_broadcast = ggml_repeat_4d(ctx, v_reshape, S_v, H_v, n_tokens, n_seqs);
+    struct ggml_tensor * g_reshape = ggml_reshape_4d(ctx, g, 1, H_v, n_tokens, n_seqs);
+    q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, H_v, n_tokens, n_seqs);
+    k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, H_v, n_tokens, n_seqs);
+    struct ggml_tensor * beta_reshape = ggml_reshape_4d(ctx, beta_sigmoid, 1, H_v, n_tokens, 1);
+    struct ggml_tensor * beta_broadcast = ggml_repeat_4d(ctx, beta_reshape, 1, H_v, n_tokens, n_seqs);
+    struct ggml_tensor * state_broadcast = ggml_repeat_4d(ctx, state, S_v, S_v, H_v, n_seqs);
     
-    // concat output and new_state
-    const int64_t ne[4] = { S_v * H_v, n_tokens + H_v * n_seqs, 1, 1 };
+    // Call tensor-level kernel with convolved and processed tensors
+    return ggml_delta_net_op(ctx, q_broadcast, k_broadcast, v_broadcast, g_reshape, beta_broadcast, state_broadcast, use_qk_l2norm, scale);
+}
+
+struct ggml_tensor * ggml_delta_net_op(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * beta,
+        struct ggml_tensor  * state,
+        bool                  use_qk_l2norm,
+        float                 scale) {
+    
+    // Validate dimensions
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(beta));
+    GGML_ASSERT(ggml_is_contiguous(state));
+    
+    const int64_t S_k = q->ne[0];  // head dimension for q/k
+    const int64_t H_v = q->ne[1];  // number of heads (already processed to match v)
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs = q->ne[3];
+    
+    const int64_t S_v = v->ne[0];  // head dimension for v
+    
+    // Validate dimensions (q and k should now have same head count as v)
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_v && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+    GGML_ASSERT(g->ne[0] == 1 && g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == H_v && g->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
+    
+    // Create output tensor: [S_v, H_v, n_tokens, n_seqs]
+    struct ggml_tensor * output = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S_v, H_v, n_tokens, n_seqs);
+    
+    // Create new state tensor: [S_v, S_v, H_v, n_seqs]
+    struct ggml_tensor * new_state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S_v, S_v, H_v, n_seqs);
+    
+    // Process each sequence independently
+    for (int64_t seq_idx = 0; seq_idx < n_seqs; ++seq_idx) {
+        // Extract current sequence data
+        struct ggml_tensor * q_seq = ggml_view_4d(ctx, q, S_k, H_v, n_tokens, 1,
+                                                  q->nb[1], q->nb[2], q->nb[3], 
+                                                  seq_idx * q->nb[3]);
+        struct ggml_tensor * k_seq = ggml_view_4d(ctx, k, S_k, H_v, n_tokens, 1,
+                                                  k->nb[1], k->nb[2], k->nb[3],
+                                                  seq_idx * k->nb[3]);
+        struct ggml_tensor * v_seq = ggml_view_4d(ctx, v, S_v, H_v, n_tokens, 1,
+                                                  v->nb[1], v->nb[2], v->nb[3],
+                                                  seq_idx * v->nb[3]);
+        struct ggml_tensor * g_seq = ggml_view_4d(ctx, g, 1, H_v, n_tokens, 1,
+                                                  g->nb[1], g->nb[2], g->nb[3],
+                                                  seq_idx * g->nb[3]);
+        struct ggml_tensor * beta_seq = ggml_view_4d(ctx, beta, 1, H_v, n_tokens, 1,
+                                                     beta->nb[1], beta->nb[2], beta->nb[3],
+                                                     seq_idx * beta->nb[3]);
+        struct ggml_tensor * state_seq = ggml_view_4d(ctx, state, S_v, S_v, H_v, 1,
+                                                      state->nb[1], state->nb[2], state->nb[3],
+                                                      seq_idx * state->nb[3]);
+        
+        // Process each head
+        for (int64_t head_idx = 0; head_idx < H_v; ++head_idx) {
+            // Extract current head data
+            struct ggml_tensor * q_head = ggml_view_3d(ctx, q_seq, S_k, n_tokens, 1,
+                                                       q_seq->nb[1], q_seq->nb[2],
+                                                       head_idx * q_seq->nb[2]);
+            struct ggml_tensor * k_head = ggml_view_3d(ctx, k_seq, S_k, n_tokens, 1,
+                                                       k_seq->nb[1], k_seq->nb[2],
+                                                       head_idx * k_seq->nb[2]);
+            struct ggml_tensor * v_head = ggml_view_3d(ctx, v_seq, S_v, n_tokens, 1,
+                                                       v_seq->nb[1], v_seq->nb[2],
+                                                       head_idx * v_seq->nb[2]);
+            struct ggml_tensor * g_head = ggml_view_3d(ctx, g_seq, 1, n_tokens, 1,
+                                                       g_seq->nb[1], g_seq->nb[2],
+                                                       head_idx * g_seq->nb[2]);
+            struct ggml_tensor * beta_head = ggml_view_3d(ctx, beta_seq, 1, n_tokens, 1,
+                                                          beta_seq->nb[1], beta_seq->nb[2],
+                                                          head_idx * beta_seq->nb[2]);
+            struct ggml_tensor * state_head = ggml_view_3d(ctx, state_seq, S_v, S_v, 1,
+                                                           state_seq->nb[1], state_seq->nb[2],
+                                                           head_idx * state_seq->nb[2]);
+            
+            // Transpose to get [n_tokens, S] layout for sequential processing
+            q_head = ggml_cont(ctx, ggml_permute(ctx, q_head, 1, 0, 2, 3));  // [n_tokens, S_k]
+            k_head = ggml_cont(ctx, ggml_permute(ctx, k_head, 1, 0, 2, 3));  // [n_tokens, S_k]
+            v_head = ggml_cont(ctx, ggml_permute(ctx, v_head, 1, 0, 2, 3));  // [n_tokens, S_v]
+            g_head = ggml_cont(ctx, ggml_permute(ctx, g_head, 1, 0, 2, 3));  // [n_tokens, 1]
+            beta_head = ggml_cont(ctx, ggml_permute(ctx, beta_head, 1, 0, 2, 3));  // [n_tokens, 1]
+            
+            // Process each token - apply L2 normalization and scaling per token as original
+            struct ggml_tensor * current_state = state_head;
+            struct ggml_tensor * output_head = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v, n_tokens);
+            
+            for (int64_t t = 0; t < n_tokens; ++t) {
+                // Extract current token data
+                struct ggml_tensor * q_t = ggml_view_1d(ctx, q_head, S_k, t * S_k * sizeof(float));
+                struct ggml_tensor * k_t = ggml_view_1d(ctx, k_head, S_k, t * S_k * sizeof(float));
+                struct ggml_tensor * v_t = ggml_view_1d(ctx, v_head, S_v, t * S_v * sizeof(float));
+                struct ggml_tensor * g_t = ggml_view_1d(ctx, g_head, 1, t * sizeof(float));
+                struct ggml_tensor * beta_t = ggml_view_1d(ctx, beta_head, 1, t * sizeof(float));
+                
+                // Apply L2 normalization if requested - per token as in original
+                if (use_qk_l2norm) {
+                    // Compute L2 norm for q_t and k_t
+                    struct ggml_tensor * q_norm = ggml_l2_norm(ctx, q_t, 1e-6f);
+                    struct ggml_tensor * k_norm = ggml_l2_norm(ctx, k_t, 1e-6f);
+                    q_t = q_norm;
+                    k_t = k_norm;
+                }
+                
+                // Apply scaling to query - per token as in original
+                q_t = ggml_scale(ctx, q_t, scale);
+                
+                // Apply gate decay to state: state = state * exp(g_t)
+                struct ggml_tensor * g_exp = ggml_exp(ctx, g_t);
+                // Broadcast g_exp to match state dimensions using multiplication
+                struct ggml_tensor * ones = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v, S_v);
+                ggml_exp(ctx, ones);
+                struct ggml_tensor * g_exp_broadcast = ggml_mul(ctx, ones, g_exp);
+                current_state = ggml_mul(ctx, current_state, g_exp_broadcast);
+                
+                // Compute kv_mem = state @ k_t^T
+                struct ggml_tensor * k_t_reshaped = ggml_reshape_2d(ctx, k_t, S_k, 1);
+                struct ggml_tensor * kv_mem = ggml_mul_mat(ctx, current_state, k_t_reshaped);  // [S_v, 1]
+                kv_mem = ggml_reshape_1d(ctx, kv_mem, S_v);
+                
+                // Compute delta = (v_t - kv_mem) * beta_t
+                struct ggml_tensor * v_minus_kv = ggml_sub(ctx, v_t, kv_mem);
+                // Broadcast beta_t through multiplication (GGML auto-broadcasts)
+                struct ggml_tensor * delta = ggml_mul(ctx, v_minus_kv, beta_t);
+                
+                // Update state: state = state + outer(k_t, delta)
+                struct ggml_tensor * k_t_reshaped_2 = ggml_reshape_2d(ctx, k_t, 1, S_k);
+                struct ggml_tensor * delta_reshaped = ggml_reshape_2d(ctx, delta, 1, S_v);
+                struct ggml_tensor * outer_product = ggml_mul_mat(ctx, delta_reshaped, k_t_reshaped_2);  // [S_k, S_v]
+                
+                // Handle S_k != S_v case
+                if (S_k == S_v) {
+                    current_state = ggml_add(ctx, current_state, outer_product);
+                } else {
+                    // For S_k != S_v, handle dimension mismatch
+                    if (S_k < S_v) {
+                        // Pad outer_product with zeros
+                        struct ggml_tensor * padding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, S_v - S_k, S_v);
+                        outer_product = ggml_concat(ctx, outer_product, padding, 0);
+                    } else if (S_k > S_v) {
+                        // Truncate outer_product
+                        outer_product = ggml_view_2d(ctx, outer_product, S_v, S_v, outer_product->nb[1], 0);
+                    }
+                    current_state = ggml_add(ctx, current_state, outer_product);
+                }
+                
+                // Compute output = current_state @ q_t^T
+                struct ggml_tensor * q_t_reshaped = ggml_reshape_2d(ctx, q_t, S_k, 1);
+                struct ggml_tensor * output_t = ggml_mul_mat(ctx, current_state, q_t_reshaped);  // [S_v, 1]
+                output_t = ggml_reshape_1d(ctx, output_t, S_v);
+                
+                // Store output for this token using view and copy
+                struct ggml_tensor * output_slice = ggml_view_1d(ctx, output_head, S_v, t * S_v * sizeof(float));
+                output_slice = ggml_cpy(ctx, output_t, output_slice);
+            }
+            
+            // Store processed head data using proper tensor operations
+            // Reshape and permute output head to correct layout
+            struct ggml_tensor * output_head_reshaped = ggml_reshape_3d(ctx, output_head, S_v, 1, n_tokens);
+            struct ggml_tensor * output_head_final = ggml_cont(ctx, ggml_permute(ctx, output_head_reshaped, 1, 2, 0, 3));  // [1, n_tokens, S_v]
+            output_head_final = ggml_reshape_2d(ctx, output_head_final, S_v, n_tokens);
+            
+            // Copy to final output tensor using proper tensor operations
+            struct ggml_tensor * output_slice = ggml_view_2d(ctx, output, S_v, n_tokens, 
+                                                            output->nb[1], head_idx * S_v * sizeof(float));
+            output_slice = ggml_cpy(ctx, output_head_final, output_slice);
+            
+            // Copy current state to new_state tensor
+            struct ggml_tensor * state_slice = ggml_view_2d(ctx, new_state, S_v, S_v,
+                                                           new_state->nb[1], head_idx * S_v * S_v * sizeof(float));
+            state_slice = ggml_cpy(ctx, current_state, state_slice);
+        }
+    }
+    
+    // Concatenate output and new_state into final result
+    const int64_t ne[4] = { S_v * H_v, n_tokens + S_v * n_seqs, 1, 1 };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
     
-    // Set operation parameters for the delta rule computation
-    int32_t params[8] = {
-        chunk_size,
-        use_qk_l2norm ? 1 : 0,
-        0, 0,  // reserved
-        0, 0, 0  // scale and other params
-    };
-    memcpy(params + 4, &scale, sizeof(float));
-    ggml_set_op_params(result, params, sizeof(params));
+    // Copy output data using proper tensor operations
+    struct ggml_tensor * output_flat = ggml_reshape_2d(ctx, output, S_v * H_v, n_tokens);
+    struct ggml_tensor * output_result_slice = ggml_view_2d(ctx, result, S_v * H_v, n_tokens, 
+                                                           result->nb[1], 0);
+    output_result_slice = ggml_cpy(ctx, output_flat, output_result_slice);
     
-    // Use custom operation for the gated delta rule computation
-    result->op = GGML_OP_DELTA_NET;
-    result->src[0] = q_broadcast;
-    result->src[1] = k_broadcast;
-    result->src[2] = v_conv;
-    result->src[3] = g;
-    result->src[4] = beta_sigmoid;
-    result->src[5] = state;
+    // Copy new_state data using proper tensor operations
+    struct ggml_tensor * new_state_flat = ggml_reshape_2d(ctx, new_state, S_v * S_v * H_v, n_seqs);
+    struct ggml_tensor * state_result_slice = ggml_view_2d(ctx, result, S_v * S_v * H_v, n_seqs,
+                                                          result->nb[1], n_tokens * sizeof(float));
+    state_result_slice = ggml_cpy(ctx, new_state_flat, state_result_slice);
     
     return result;
 }
-
 // ggml_rwkv_wkv7
 
 struct ggml_tensor * ggml_rwkv_wkv7(

+ 1 - 1
src/llama-model.cpp

@@ -19018,6 +19018,7 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
         cur = build_lora_mm(model.output, cur);
 
         cb(cur, "result_output", -1);
+        ggml_set_output(cur);
         res->t_logits = cur;
 
         ggml_build_forward_expand(gf, cur);
@@ -19223,7 +19224,6 @@ private:
                                               conv_bias,       // conv_bias tensor (can be nullptr)
                                               beta,            // beta tensor
                                               state,           // state tensor
-                                              64,              // chunk_size (adjust as needed)
                                               true,            // use_qk_l2norm
                                               1.0f             // scale (adjust based on your model)
         );