Просмотр исходного кода

Handle case with more than one token per seq with elegant loop plus completely not crazy change to max nodes ;)

Piotr Wilkin 3 месяцев назад
Родитель
Сommit
572864287e
2 измененных файлов с 68 добавлено и 46 удалено
  1. 1 1
      src/llama-context.cpp
  2. 67 45
      src/models/llm_build_qwen3next.cpp

+ 1 - 1
src/llama-context.cpp

@@ -1362,7 +1362,7 @@ void llama_context::output_reorder() {
 //
 
 uint32_t llama_context::graph_max_nodes() const {
-    return std::max<uint32_t>(1024u, 32u*model.n_tensors());
+    return std::max<uint32_t>(16384, 512u*model.n_tensors());
 }
 
 llm_graph_result * llama_context::get_gf_res_reserve() const {

+ 67 - 45
src/models/llm_build_qwen3next.cpp

@@ -412,7 +412,6 @@ struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
     const int64_t S_v = v->ne[0];
     const int64_t H_v = v->ne[1];
 
-    GGML_ASSERT(n_tokens == 1); // Recurrent version only supports sequence_length = 1
     GGML_ASSERT(v->ne[2] == n_tokens);
     GGML_ASSERT(k->ne[2] == n_tokens);
     GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
@@ -459,62 +458,85 @@ struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
     g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
     cb(g, "g_permute", il);
 
-    ggml_tensor * q_t = ggml_cont_4d(ctx, q, 1, S_k, H_k, n_seqs);
-    ggml_tensor * k_t = ggml_cont_4d(ctx, k, 1, S_k, H_k, n_seqs);
-    ggml_tensor * v_t = ggml_cont_4d(ctx, v, 1, S_v, H_k, n_seqs);
-    ggml_tensor * g_t = ggml_cont_4d(ctx, g, 1, 1, H_k, n_seqs);
-    ggml_tensor * beta_t = ggml_cont_4d(ctx, beta, 1, 1, H_k, n_seqs);
+    ggml_tensor * q_tokens = ggml_cont_4d(ctx, q, n_tokens, S_k, H_k, n_seqs);
+    ggml_tensor * k_tokens = ggml_cont_4d(ctx, k, n_tokens, S_k, H_k, n_seqs);
+    ggml_tensor * v_tokens = ggml_cont_4d(ctx, v, n_tokens, S_v, H_k, n_seqs);
+    ggml_tensor * g_tokens = ggml_cont_4d(ctx, g, n_tokens, 1, H_k, n_seqs);
+    ggml_tensor * beta_tokens = ggml_cont_4d(ctx, beta, n_tokens, 1, H_k, n_seqs);
+
     state = ggml_cont_4d(ctx, state, S_v, S_v, H_k, n_seqs);
+    ggml_tensor * g_tokens_exp = ggml_exp(ctx, g_tokens);
+
+    ggml_tensor * final_output = nullptr;
+    ggml_tensor * q_t, * k_t, * v_t, * g_t_exp, * beta_t;
+    for (int i = 0; i < n_tokens; i++) { // this part is per token
+        if (n_tokens == 1) { // don't do unnecessary reshapes / views
+            q_t = q_tokens;
+            k_t = k_tokens;
+            v_t = v_tokens;
+            g_t_exp = g_tokens_exp;
+            beta_t = beta_tokens;
+        } else {
+            q_t = ggml_view_4d(ctx, q_tokens, 1, S_k, H_k, n_seqs, q_tokens->nb[1], q_tokens->nb[2], q_tokens->nb[3], i * ggml_element_size(q_tokens));
+            k_t = ggml_view_4d(ctx, k_tokens, 1, S_k, H_k, n_seqs, k_tokens->nb[1], k_tokens->nb[2], k_tokens->nb[3], i * ggml_element_size(k_tokens));
+            v_t = ggml_view_4d(ctx, v_tokens, 1, S_v, H_k, n_seqs, v_tokens->nb[1], v_tokens->nb[2], v_tokens->nb[3], i * ggml_element_size(v_tokens));
+            g_t_exp = ggml_view_4d(ctx, g_tokens_exp, 1, 1, H_k, n_seqs, g_tokens_exp->nb[1], g_tokens_exp->nb[2], g_tokens_exp->nb[3], i * ggml_element_size(g_tokens_exp));
+            beta_t = ggml_view_4d(ctx, beta_tokens, 1, 1, H_k, n_seqs, beta_tokens->nb[1], beta_tokens->nb[2], beta_tokens->nb[3], i * ggml_element_size(beta_tokens));
+        }
 
-    // Apply exponential to gate: exp(g)
-    ggml_tensor * g_exp = ggml_exp(ctx, g_t);
-    cb(g_exp, "g_exp", il);
+        // Apply gate to state: state = state * exp(g)
+        ggml_tensor * gated_state = ggml_mul(ctx, state, g_t_exp);
+        cb(gated_state, "gated_state", il);
 
-    // Apply gate to state: state = state * exp(g)
-    ggml_tensor * gated_state = ggml_mul(ctx, state, g_exp);
-    cb(gated_state, "gated_state", il);
+        // Compute kv_memory from state and key
+        // kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
+        
+        // Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
+        // to make it compatible with k_expanded for element-wise multiplication
+        ggml_tensor * gated_state_reshaped = ggml_reshape_4d(ctx, gated_state, S_v, S_v, H_v, n_seqs);
+        cb(gated_state_reshaped, "gated_state_reshaped", il);
+        
+        ggml_tensor * state_k_product = ggml_mul(ctx, gated_state_reshaped, k_t);
+        cb(state_k_product, "state_k_product", il);
 
-    // Compute kv_memory from state and key
-    // kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
-    
-    // Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
-    // to make it compatible with k_expanded for element-wise multiplication
-    ggml_tensor * gated_state_reshaped = ggml_reshape_4d(ctx, gated_state, S_v, S_v, H_v, n_seqs);
-    cb(gated_state_reshaped, "gated_state_reshaped", il);
-    
-    ggml_tensor * state_k_product = ggml_mul(ctx, gated_state_reshaped, k_t);
-    cb(state_k_product, "state_k_product", il);
+        ggml_tensor * kv_memory = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_k_product)));
+        cb(kv_memory, "kv_memory", il);
 
-    ggml_tensor * kv_memory = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_k_product)));
-    cb(kv_memory, "kv_memory", il);
+        // Compute delta = (v - kv_memory) * beta
+        ggml_tensor * v_diff = ggml_sub(ctx, v_t, kv_memory);
+        ggml_tensor * delta = ggml_mul(ctx, v_diff, beta_t);
+        cb(delta, "delta", il);
 
-    // Compute delta = (v - kv_memory) * beta
-    ggml_tensor * v_diff = ggml_sub(ctx, v_t, kv_memory);
-    ggml_tensor * delta = ggml_mul(ctx, v_diff, beta_t);
-    cb(delta, "delta", il);
+        // Update state = state + k * delta
+        // In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
+        ggml_tensor * delta_t = ggml_transpose(ctx, delta);
 
-    // Update state = state + k * delta
-    // In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
-    ggml_tensor * delta_t = ggml_transpose(ctx, delta);
+        // Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
+        ggml_tensor * delta_t_broadcast = ggml_repeat_4d(ctx, delta_t, S_v, S_v, H_v, n_seqs);
+        ggml_tensor * k_t_broadcast  = ggml_repeat_4d(ctx, k_t, S_v, S_v, H_v, n_seqs);
+        ggml_tensor * k_delta_product = ggml_mul(ctx, k_t_broadcast, delta_t_broadcast);
+        cb(k_delta_product, "k_delta", il);
 
-    // Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
-    ggml_tensor * delta_t_broadcast = ggml_repeat_4d(ctx, delta_t, S_v, S_v, H_v, n_seqs);
-    ggml_tensor * k_t_broadcast  = ggml_repeat_4d(ctx, k_t, S_v, S_v, H_v, n_seqs);
-    ggml_tensor * k_delta_product = ggml_mul(ctx, k_t_broadcast, delta_t_broadcast);
-    cb(k_delta_product, "k_delta", il);
+        state = ggml_add(ctx, gated_state_reshaped, k_delta_product);
+        cb(state, "updated_state", il);
 
-    ggml_tensor * updated_state = ggml_add(ctx, gated_state_reshaped, k_delta_product);
-    cb(updated_state, "updated_state", il);
-    
-    ggml_tensor * state_q_product = ggml_mul(ctx, updated_state, q_t);
-    cb(state_q_product, "state_q_product", il);
-    ggml_tensor * output = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_q_product)));
-    cb(output, "output", il);
+        ggml_tensor * state_q_product = ggml_mul(ctx, state, q_t);
+        cb(state_q_product, "state_q_product", il);
+        
+        ggml_tensor * output = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_q_product)));
+        cb(output, "output", il);
 
+        if (final_output == nullptr) {
+            final_output = output;
+        } else {
+            final_output = ggml_concat(ctx, final_output, output, 0);
+        }
+    }
+    
     // Concatenate output and updated_state into a single tensor
     // First, flatten both tensors to 1D
-    ggml_tensor * output_1d = ggml_cont_1d(ctx, output, ggml_nelements(output));
-    ggml_tensor * updated_state_1d = ggml_cont_1d(ctx, updated_state, ggml_nelements(updated_state));
+    ggml_tensor * output_1d = ggml_cont_1d(ctx, final_output, ggml_nelements(final_output));
+    ggml_tensor * updated_state_1d = ggml_cont_1d(ctx, state, ggml_nelements(state));
     
     // Concatenate them: [output, updated_state]
     ggml_tensor * result = ggml_concat(ctx, output_1d, updated_state_1d, 0);